Drop-in SwiftUI chat view, headless ChatEngine, LLM-agnostic via AnyLanguageModel. Read-only by default with configurable allowlists. Robust SQL parser with 63 tests. Includes demo app with GitHub stars dataset.
364 lines
15 KiB
Swift
364 lines
15 KiB
Swift
// MultiTurnContextTests.swift
|
|
// SwiftDBAI Tests
|
|
//
|
|
// Tests verifying multi-turn conversation context — follow-up queries
|
|
// correctly reference the prior query's table, columns, and results.
|
|
|
|
import AnyLanguageModel
|
|
import Foundation
|
|
import GRDB
|
|
import Testing
|
|
|
|
@testable import SwiftDBAI
|
|
|
|
@Suite("Multi-Turn Context Tests")
|
|
struct MultiTurnContextTests {
|
|
|
|
// MARK: - Test Database Setup
|
|
|
|
/// Creates an in-memory database with users (including age) and orders.
|
|
private func makeTestDatabase() throws -> DatabaseQueue {
|
|
let db = try DatabaseQueue(path: ":memory:")
|
|
try db.write { db in
|
|
try db.execute(sql: """
|
|
CREATE TABLE users (
|
|
id INTEGER PRIMARY KEY,
|
|
name TEXT NOT NULL,
|
|
age INTEGER NOT NULL,
|
|
email TEXT NOT NULL,
|
|
city TEXT NOT NULL
|
|
)
|
|
""")
|
|
try db.execute(sql: """
|
|
INSERT INTO users (name, age, email, city) VALUES
|
|
('Alice', 25, 'alice@example.com', 'New York'),
|
|
('Bob', 35, 'bob@example.com', 'San Francisco'),
|
|
('Charlie', 42, 'charlie@example.com', 'New York'),
|
|
('Diana', 28, 'diana@example.com', 'Chicago'),
|
|
('Eve', 55, 'eve@example.com', 'San Francisco')
|
|
""")
|
|
try db.execute(sql: """
|
|
CREATE TABLE orders (
|
|
id INTEGER PRIMARY KEY,
|
|
user_id INTEGER NOT NULL,
|
|
amount REAL NOT NULL,
|
|
status TEXT NOT NULL,
|
|
created_at TEXT NOT NULL,
|
|
FOREIGN KEY (user_id) REFERENCES users(id)
|
|
)
|
|
""")
|
|
try db.execute(sql: """
|
|
INSERT INTO orders (user_id, amount, status, created_at) VALUES
|
|
(1, 99.99, 'completed', '2024-01-15'),
|
|
(1, 49.50, 'pending', '2024-02-20'),
|
|
(2, 150.00, 'completed', '2024-01-10'),
|
|
(3, 200.00, 'completed', '2024-03-01'),
|
|
(4, 75.00, 'cancelled', '2024-02-05')
|
|
""")
|
|
}
|
|
return db
|
|
}
|
|
|
|
// MARK: - Multi-Turn Context Tests
|
|
|
|
@Test("Follow-up 'filter those by age > 30' references prior 'show all users' context")
|
|
func followUpFilterReferencesUsersTable() async throws {
|
|
let db = try makeTestDatabase()
|
|
|
|
// Turn 1: "show all users" → SELECT * FROM users (returns 5 rows, LLM summary needed)
|
|
// Turn 2: "filter those by age > 30" → should reference users table from context
|
|
let mock = PromptCapturingMockModel(responses: [
|
|
"SELECT * FROM users",
|
|
"Here are all 5 users in the database.",
|
|
"SELECT * FROM users WHERE age > 30",
|
|
"Found 3 users over 30: Bob (35), Charlie (42), and Eve (55)."
|
|
])
|
|
|
|
let engine = ChatEngine(database: db, model: mock)
|
|
|
|
// First turn: show all users
|
|
let response1 = try await engine.send("show all users")
|
|
#expect(response1.sql == "SELECT * FROM users")
|
|
#expect(response1.queryResult?.rowCount == 5)
|
|
|
|
// Second turn: follow-up with implicit reference
|
|
let response2 = try await engine.send("filter those by age > 30")
|
|
#expect(response2.sql == "SELECT * FROM users WHERE age > 30")
|
|
#expect(response2.queryResult?.rowCount == 3)
|
|
|
|
// Verify the follow-up prompt includes conversation history
|
|
let prompts = mock.capturedPrompts
|
|
// Find the prompt for the second SQL generation (skip summary prompts)
|
|
let followUpSQLPrompt = prompts.first { prompt in
|
|
prompt.contains("filter those by age > 30") && prompt.contains("CONVERSATION HISTORY")
|
|
}
|
|
#expect(followUpSQLPrompt != nil, "Follow-up prompt should contain CONVERSATION HISTORY")
|
|
|
|
// The conversation history should include the prior query and its SQL
|
|
if let prompt = followUpSQLPrompt {
|
|
#expect(prompt.contains("show all users"), "History should contain prior user message")
|
|
#expect(prompt.contains("SELECT * FROM users"), "History should contain prior SQL")
|
|
#expect(prompt.contains("filter those by age > 30"), "Prompt should contain current question")
|
|
}
|
|
}
|
|
|
|
@Test("Follow-up correctly inherits table context across multiple turns")
|
|
func multipleFollowUpsInheritContext() async throws {
|
|
let db = try makeTestDatabase()
|
|
|
|
// 3-turn conversation narrowing down results
|
|
let mock = PromptCapturingMockModel(responses: [
|
|
"SELECT * FROM users",
|
|
"Here are all 5 users.",
|
|
"SELECT * FROM users WHERE city = 'New York'",
|
|
"Found 2 users in New York: Alice and Charlie.",
|
|
"SELECT * FROM users WHERE city = 'New York' AND age > 30",
|
|
"Charlie (42) is the only New York user over 30."
|
|
])
|
|
|
|
let engine = ChatEngine(database: db, model: mock)
|
|
|
|
// Turn 1
|
|
_ = try await engine.send("show all users")
|
|
|
|
// Turn 2 — narrows by city
|
|
let response2 = try await engine.send("only those in New York")
|
|
#expect(response2.sql == "SELECT * FROM users WHERE city = 'New York'")
|
|
#expect(response2.queryResult?.rowCount == 2)
|
|
|
|
// Turn 3 — further narrows by age
|
|
let response3 = try await engine.send("now filter by age over 30")
|
|
#expect(response3.sql == "SELECT * FROM users WHERE city = 'New York' AND age > 30")
|
|
#expect(response3.queryResult?.rowCount == 1)
|
|
|
|
// Verify third turn's prompt includes the full conversation history
|
|
let prompts = mock.capturedPrompts
|
|
let thirdTurnPrompt = prompts.last { prompt in
|
|
prompt.contains("now filter by age over 30") && prompt.contains("CONVERSATION HISTORY")
|
|
}
|
|
#expect(thirdTurnPrompt != nil)
|
|
|
|
if let prompt = thirdTurnPrompt {
|
|
// Should include both prior user messages
|
|
#expect(prompt.contains("show all users"))
|
|
#expect(prompt.contains("only those in New York"))
|
|
// Should include prior SQL
|
|
#expect(prompt.contains("SELECT * FROM users"))
|
|
#expect(prompt.contains("SELECT * FROM users WHERE city = 'New York'"))
|
|
}
|
|
}
|
|
|
|
@Test("Follow-up switching tables preserves cross-table context")
|
|
func followUpSwitchesTableWithContext() async throws {
|
|
let db = try makeTestDatabase()
|
|
|
|
// Turn 1: query users, Turn 2: ask about their orders
|
|
let mock = PromptCapturingMockModel(responses: [
|
|
"SELECT name, age FROM users WHERE age > 30",
|
|
"Found 3 users over 30.",
|
|
"SELECT o.id, u.name, o.amount, o.status FROM orders o JOIN users u ON o.user_id = u.id WHERE u.age > 30",
|
|
"Bob has a $150 completed order, Charlie has a $200 completed order."
|
|
])
|
|
|
|
let engine = ChatEngine(database: db, model: mock)
|
|
|
|
// Turn 1: users over 30
|
|
let response1 = try await engine.send("show users over 30")
|
|
#expect(response1.queryResult?.rowCount == 3)
|
|
|
|
// Turn 2: their orders — references the previous result context
|
|
let response2 = try await engine.send("show their orders")
|
|
#expect(response2.sql?.contains("JOIN") == true)
|
|
|
|
// Verify the follow-up prompt contains the users context
|
|
let prompts = mock.capturedPrompts
|
|
let orderPrompt = prompts.first { prompt in
|
|
prompt.contains("show their orders") && prompt.contains("CONVERSATION HISTORY")
|
|
}
|
|
#expect(orderPrompt != nil)
|
|
|
|
if let prompt = orderPrompt {
|
|
#expect(prompt.contains("show users over 30"), "Should contain prior user message")
|
|
#expect(prompt.contains("age > 30"), "Should contain prior SQL context for table reference")
|
|
}
|
|
}
|
|
|
|
@Test("Conversation history includes SQL from prior turns for context")
|
|
func historyIncludesSQLFromPriorTurns() async throws {
|
|
let db = try makeTestDatabase()
|
|
|
|
// Both queries are aggregates → no LLM summarization needed
|
|
let mock = PromptCapturingMockModel(responses: [
|
|
"SELECT COUNT(*) FROM users",
|
|
"SELECT COUNT(*) FROM users WHERE age > 30",
|
|
])
|
|
|
|
let engine = ChatEngine(database: db, model: mock)
|
|
|
|
// Turn 1
|
|
let r1 = try await engine.send("how many users are there?")
|
|
#expect(r1.sql == "SELECT COUNT(*) FROM users")
|
|
|
|
// Turn 2 — references "those" implicitly
|
|
let r2 = try await engine.send("how many of those are over 30?")
|
|
#expect(r2.sql == "SELECT COUNT(*) FROM users WHERE age > 30")
|
|
|
|
// Verify engine history has all 4 messages (2 user + 2 assistant)
|
|
let messages = engine.messages
|
|
#expect(messages.count == 4)
|
|
#expect(messages[0].role == .user)
|
|
#expect(messages[0].content == "how many users are there?")
|
|
#expect(messages[1].role == .assistant)
|
|
#expect(messages[1].sql == "SELECT COUNT(*) FROM users")
|
|
#expect(messages[2].role == .user)
|
|
#expect(messages[2].content == "how many of those are over 30?")
|
|
#expect(messages[3].role == .assistant)
|
|
#expect(messages[3].sql == "SELECT COUNT(*) FROM users WHERE age > 30")
|
|
|
|
// The second prompt should reference the first query SQL
|
|
let prompts = mock.capturedPrompts
|
|
#expect(prompts.count >= 2)
|
|
let secondPrompt = prompts[1]
|
|
#expect(secondPrompt.contains("CONVERSATION HISTORY"))
|
|
#expect(secondPrompt.contains("SELECT COUNT(*) FROM users"))
|
|
#expect(secondPrompt.contains("how many users are there?"))
|
|
}
|
|
|
|
@Test("Follow-up after aggregate uses prior table context")
|
|
func followUpAfterAggregateUsesTableContext() async throws {
|
|
let db = try makeTestDatabase()
|
|
|
|
// Turn 1: aggregate (no LLM summary needed)
|
|
// Turn 2: follow-up referencing "those"
|
|
let mock = PromptCapturingMockModel(responses: [
|
|
"SELECT AVG(age) FROM users",
|
|
"SELECT name, age FROM users WHERE age > 35",
|
|
"Charlie (42) and Eve (55) are older than average."
|
|
])
|
|
|
|
let engine = ChatEngine(database: db, model: mock)
|
|
|
|
// Turn 1: average age → aggregate, template summary
|
|
let r1 = try await engine.send("what is the average age of users?")
|
|
#expect(r1.sql == "SELECT AVG(age) FROM users")
|
|
|
|
// Turn 2: "who is above that?" — needs the avg context
|
|
let r2 = try await engine.send("who is above average?")
|
|
#expect(r2.queryResult?.rowCount == 2)
|
|
|
|
// Verify context passed
|
|
let prompts = mock.capturedPrompts
|
|
let followUp = prompts.first { prompt in
|
|
prompt.contains("who is above average?") && prompt.contains("CONVERSATION HISTORY")
|
|
}
|
|
#expect(followUp != nil)
|
|
if let prompt = followUp {
|
|
#expect(prompt.contains("AVG(age)"), "Should include prior aggregate SQL for context")
|
|
#expect(prompt.contains("users"), "Should include table reference from prior turn")
|
|
}
|
|
}
|
|
|
|
@Test("Context window limits how much history is visible in follow-ups")
|
|
func contextWindowLimitsHistoryInFollowUps() async throws {
|
|
let db = try makeTestDatabase()
|
|
|
|
// 3 turns, but context window of 2 messages
|
|
let mock = PromptCapturingMockModel(responses: [
|
|
"SELECT COUNT(*) FROM users",
|
|
"SELECT COUNT(*) FROM orders",
|
|
"SELECT COUNT(*) FROM users WHERE age > 30",
|
|
])
|
|
|
|
let config = ChatEngineConfiguration(
|
|
queryTimeout: nil,
|
|
contextWindowSize: 2
|
|
)
|
|
|
|
let engine = ChatEngine(
|
|
database: db,
|
|
model: mock,
|
|
configuration: config
|
|
)
|
|
|
|
_ = try await engine.send("how many users?")
|
|
_ = try await engine.send("how many orders?")
|
|
_ = try await engine.send("how many users over 30?")
|
|
|
|
// The third prompt should only have the last 2 messages from turn 2
|
|
let prompts = mock.capturedPrompts
|
|
#expect(prompts.count >= 3)
|
|
|
|
let thirdPrompt = prompts[2]
|
|
#expect(thirdPrompt.contains("CONVERSATION HISTORY"))
|
|
// Turn 2 context should be present
|
|
#expect(thirdPrompt.contains("how many orders?"))
|
|
#expect(thirdPrompt.contains("SELECT COUNT(*) FROM orders"))
|
|
// Turn 1 context should be trimmed (window=2 means last 2 messages)
|
|
#expect(!thirdPrompt.contains("how many users?\n"), "First turn should be trimmed from context window")
|
|
}
|
|
|
|
@Test("clearHistory resets context so follow-ups have no prior history")
|
|
func clearHistoryResetsFollowUpContext() async throws {
|
|
let db = try makeTestDatabase()
|
|
|
|
let mock = PromptCapturingMockModel(responses: [
|
|
"SELECT * FROM users",
|
|
"Here are the 5 users.",
|
|
"SELECT COUNT(*) FROM users",
|
|
])
|
|
|
|
let engine = ChatEngine(database: db, model: mock)
|
|
|
|
// Turn 1
|
|
_ = try await engine.send("show all users")
|
|
#expect(engine.messages.count == 2)
|
|
|
|
// Clear history
|
|
engine.clearHistory()
|
|
#expect(engine.messages.isEmpty)
|
|
|
|
// Turn 2 after clear — should NOT have conversation history
|
|
_ = try await engine.send("count all users")
|
|
|
|
let prompts = mock.capturedPrompts
|
|
let lastPrompt = prompts.last!
|
|
// After clearing, the prompt should NOT contain conversation history
|
|
#expect(!lastPrompt.contains("CONVERSATION HISTORY"),
|
|
"After clearHistory(), follow-up should not have prior context")
|
|
#expect(!lastPrompt.contains("show all users"),
|
|
"After clearHistory(), prior messages should be gone")
|
|
}
|
|
|
|
@Test("Multi-turn with result data in context enables informed follow-ups")
|
|
func resultDataInContextEnablesInformedFollowUps() async throws {
|
|
let db = try makeTestDatabase()
|
|
|
|
// Turn 1: list users → multi-row result, LLM summarizes
|
|
// Turn 2: "sort those by age" → references same table
|
|
let mock = PromptCapturingMockModel(responses: [
|
|
"SELECT name, age, city FROM users",
|
|
"Found 5 users: Alice (25, NY), Bob (35, SF), Charlie (42, NY), Diana (28, Chicago), Eve (55, SF).",
|
|
"SELECT name, age, city FROM users ORDER BY age DESC",
|
|
"Users sorted by age: Eve (55), Charlie (42), Bob (35), Diana (28), Alice (25)."
|
|
])
|
|
|
|
let engine = ChatEngine(database: db, model: mock)
|
|
|
|
let r1 = try await engine.send("list all users with their age and city")
|
|
#expect(r1.queryResult?.rowCount == 5)
|
|
#expect(r1.queryResult?.columns.contains("age") == true)
|
|
#expect(r1.queryResult?.columns.contains("city") == true)
|
|
|
|
let r2 = try await engine.send("sort those by age descending")
|
|
#expect(r2.sql == "SELECT name, age, city FROM users ORDER BY age DESC")
|
|
|
|
// Verify the assistant message in history includes the SQL
|
|
let messages = engine.messages
|
|
#expect(messages.count == 4)
|
|
// First assistant message should have the SQL recorded
|
|
#expect(messages[1].sql == "SELECT name, age, city FROM users")
|
|
// Second assistant should have the sorted SQL
|
|
#expect(messages[3].sql == "SELECT name, age, city FROM users ORDER BY age DESC")
|
|
}
|
|
}
|