Initial implementation of SwiftDBAI
Chat with any SQLite database using natural language. Built on AnyLanguageModel (HuggingFace) for LLM-agnostic provider support and GRDB for SQLite access. Core features: - Auto schema introspection from sqlite_master (zero config) - NL → SQL generation via any AnyLanguageModel provider - Three rendering modes: text summary, data table, Swift Charts - Drop-in DataChatView (SwiftUI) and headless ChatEngine - Operation allowlist with read-only default - Mutation policy with per-table control - ToolExecutionDelegate for destructive operation confirmation - Multi-turn conversation context - 352 tests across 24 suites, all passing Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
363
Tests/SwiftDBAITests/MultiTurnContextTests.swift
Normal file
363
Tests/SwiftDBAITests/MultiTurnContextTests.swift
Normal file
@@ -0,0 +1,363 @@
|
||||
// MultiTurnContextTests.swift
|
||||
// SwiftDBAI Tests
|
||||
//
|
||||
// Tests verifying multi-turn conversation context — follow-up queries
|
||||
// correctly reference the prior query's table, columns, and results.
|
||||
|
||||
import AnyLanguageModel
|
||||
import Foundation
|
||||
import GRDB
|
||||
import Testing
|
||||
|
||||
@testable import SwiftDBAI
|
||||
|
||||
@Suite("Multi-Turn Context Tests")
|
||||
struct MultiTurnContextTests {
|
||||
|
||||
// MARK: - Test Database Setup
|
||||
|
||||
/// Creates an in-memory database with users (including age) and orders.
|
||||
private func makeTestDatabase() throws -> DatabaseQueue {
|
||||
let db = try DatabaseQueue(path: ":memory:")
|
||||
try db.write { db in
|
||||
try db.execute(sql: """
|
||||
CREATE TABLE users (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
age INTEGER NOT NULL,
|
||||
email TEXT NOT NULL,
|
||||
city TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
try db.execute(sql: """
|
||||
INSERT INTO users (name, age, email, city) VALUES
|
||||
('Alice', 25, 'alice@example.com', 'New York'),
|
||||
('Bob', 35, 'bob@example.com', 'San Francisco'),
|
||||
('Charlie', 42, 'charlie@example.com', 'New York'),
|
||||
('Diana', 28, 'diana@example.com', 'Chicago'),
|
||||
('Eve', 55, 'eve@example.com', 'San Francisco')
|
||||
""")
|
||||
try db.execute(sql: """
|
||||
CREATE TABLE orders (
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
amount REAL NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
)
|
||||
""")
|
||||
try db.execute(sql: """
|
||||
INSERT INTO orders (user_id, amount, status, created_at) VALUES
|
||||
(1, 99.99, 'completed', '2024-01-15'),
|
||||
(1, 49.50, 'pending', '2024-02-20'),
|
||||
(2, 150.00, 'completed', '2024-01-10'),
|
||||
(3, 200.00, 'completed', '2024-03-01'),
|
||||
(4, 75.00, 'cancelled', '2024-02-05')
|
||||
""")
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// MARK: - Multi-Turn Context Tests
|
||||
|
||||
@Test("Follow-up 'filter those by age > 30' references prior 'show all users' context")
|
||||
func followUpFilterReferencesUsersTable() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
|
||||
// Turn 1: "show all users" → SELECT * FROM users (returns 5 rows, LLM summary needed)
|
||||
// Turn 2: "filter those by age > 30" → should reference users table from context
|
||||
let mock = PromptCapturingMockModel(responses: [
|
||||
"SELECT * FROM users",
|
||||
"Here are all 5 users in the database.",
|
||||
"SELECT * FROM users WHERE age > 30",
|
||||
"Found 3 users over 30: Bob (35), Charlie (42), and Eve (55)."
|
||||
])
|
||||
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
|
||||
// First turn: show all users
|
||||
let response1 = try await engine.send("show all users")
|
||||
#expect(response1.sql == "SELECT * FROM users")
|
||||
#expect(response1.queryResult?.rowCount == 5)
|
||||
|
||||
// Second turn: follow-up with implicit reference
|
||||
let response2 = try await engine.send("filter those by age > 30")
|
||||
#expect(response2.sql == "SELECT * FROM users WHERE age > 30")
|
||||
#expect(response2.queryResult?.rowCount == 3)
|
||||
|
||||
// Verify the follow-up prompt includes conversation history
|
||||
let prompts = mock.capturedPrompts
|
||||
// Find the prompt for the second SQL generation (skip summary prompts)
|
||||
let followUpSQLPrompt = prompts.first { prompt in
|
||||
prompt.contains("filter those by age > 30") && prompt.contains("CONVERSATION HISTORY")
|
||||
}
|
||||
#expect(followUpSQLPrompt != nil, "Follow-up prompt should contain CONVERSATION HISTORY")
|
||||
|
||||
// The conversation history should include the prior query and its SQL
|
||||
if let prompt = followUpSQLPrompt {
|
||||
#expect(prompt.contains("show all users"), "History should contain prior user message")
|
||||
#expect(prompt.contains("SELECT * FROM users"), "History should contain prior SQL")
|
||||
#expect(prompt.contains("filter those by age > 30"), "Prompt should contain current question")
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Follow-up correctly inherits table context across multiple turns")
|
||||
func multipleFollowUpsInheritContext() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
|
||||
// 3-turn conversation narrowing down results
|
||||
let mock = PromptCapturingMockModel(responses: [
|
||||
"SELECT * FROM users",
|
||||
"Here are all 5 users.",
|
||||
"SELECT * FROM users WHERE city = 'New York'",
|
||||
"Found 2 users in New York: Alice and Charlie.",
|
||||
"SELECT * FROM users WHERE city = 'New York' AND age > 30",
|
||||
"Charlie (42) is the only New York user over 30."
|
||||
])
|
||||
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
|
||||
// Turn 1
|
||||
_ = try await engine.send("show all users")
|
||||
|
||||
// Turn 2 — narrows by city
|
||||
let response2 = try await engine.send("only those in New York")
|
||||
#expect(response2.sql == "SELECT * FROM users WHERE city = 'New York'")
|
||||
#expect(response2.queryResult?.rowCount == 2)
|
||||
|
||||
// Turn 3 — further narrows by age
|
||||
let response3 = try await engine.send("now filter by age over 30")
|
||||
#expect(response3.sql == "SELECT * FROM users WHERE city = 'New York' AND age > 30")
|
||||
#expect(response3.queryResult?.rowCount == 1)
|
||||
|
||||
// Verify third turn's prompt includes the full conversation history
|
||||
let prompts = mock.capturedPrompts
|
||||
let thirdTurnPrompt = prompts.last { prompt in
|
||||
prompt.contains("now filter by age over 30") && prompt.contains("CONVERSATION HISTORY")
|
||||
}
|
||||
#expect(thirdTurnPrompt != nil)
|
||||
|
||||
if let prompt = thirdTurnPrompt {
|
||||
// Should include both prior user messages
|
||||
#expect(prompt.contains("show all users"))
|
||||
#expect(prompt.contains("only those in New York"))
|
||||
// Should include prior SQL
|
||||
#expect(prompt.contains("SELECT * FROM users"))
|
||||
#expect(prompt.contains("SELECT * FROM users WHERE city = 'New York'"))
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Follow-up switching tables preserves cross-table context")
|
||||
func followUpSwitchesTableWithContext() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
|
||||
// Turn 1: query users, Turn 2: ask about their orders
|
||||
let mock = PromptCapturingMockModel(responses: [
|
||||
"SELECT name, age FROM users WHERE age > 30",
|
||||
"Found 3 users over 30.",
|
||||
"SELECT o.id, u.name, o.amount, o.status FROM orders o JOIN users u ON o.user_id = u.id WHERE u.age > 30",
|
||||
"Bob has a $150 completed order, Charlie has a $200 completed order."
|
||||
])
|
||||
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
|
||||
// Turn 1: users over 30
|
||||
let response1 = try await engine.send("show users over 30")
|
||||
#expect(response1.queryResult?.rowCount == 3)
|
||||
|
||||
// Turn 2: their orders — references the previous result context
|
||||
let response2 = try await engine.send("show their orders")
|
||||
#expect(response2.sql?.contains("JOIN") == true)
|
||||
|
||||
// Verify the follow-up prompt contains the users context
|
||||
let prompts = mock.capturedPrompts
|
||||
let orderPrompt = prompts.first { prompt in
|
||||
prompt.contains("show their orders") && prompt.contains("CONVERSATION HISTORY")
|
||||
}
|
||||
#expect(orderPrompt != nil)
|
||||
|
||||
if let prompt = orderPrompt {
|
||||
#expect(prompt.contains("show users over 30"), "Should contain prior user message")
|
||||
#expect(prompt.contains("age > 30"), "Should contain prior SQL context for table reference")
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Conversation history includes SQL from prior turns for context")
|
||||
func historyIncludesSQLFromPriorTurns() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
|
||||
// Both queries are aggregates → no LLM summarization needed
|
||||
let mock = PromptCapturingMockModel(responses: [
|
||||
"SELECT COUNT(*) FROM users",
|
||||
"SELECT COUNT(*) FROM users WHERE age > 30",
|
||||
])
|
||||
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
|
||||
// Turn 1
|
||||
let r1 = try await engine.send("how many users are there?")
|
||||
#expect(r1.sql == "SELECT COUNT(*) FROM users")
|
||||
|
||||
// Turn 2 — references "those" implicitly
|
||||
let r2 = try await engine.send("how many of those are over 30?")
|
||||
#expect(r2.sql == "SELECT COUNT(*) FROM users WHERE age > 30")
|
||||
|
||||
// Verify engine history has all 4 messages (2 user + 2 assistant)
|
||||
let messages = engine.messages
|
||||
#expect(messages.count == 4)
|
||||
#expect(messages[0].role == .user)
|
||||
#expect(messages[0].content == "how many users are there?")
|
||||
#expect(messages[1].role == .assistant)
|
||||
#expect(messages[1].sql == "SELECT COUNT(*) FROM users")
|
||||
#expect(messages[2].role == .user)
|
||||
#expect(messages[2].content == "how many of those are over 30?")
|
||||
#expect(messages[3].role == .assistant)
|
||||
#expect(messages[3].sql == "SELECT COUNT(*) FROM users WHERE age > 30")
|
||||
|
||||
// The second prompt should reference the first query SQL
|
||||
let prompts = mock.capturedPrompts
|
||||
#expect(prompts.count >= 2)
|
||||
let secondPrompt = prompts[1]
|
||||
#expect(secondPrompt.contains("CONVERSATION HISTORY"))
|
||||
#expect(secondPrompt.contains("SELECT COUNT(*) FROM users"))
|
||||
#expect(secondPrompt.contains("how many users are there?"))
|
||||
}
|
||||
|
||||
@Test("Follow-up after aggregate uses prior table context")
|
||||
func followUpAfterAggregateUsesTableContext() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
|
||||
// Turn 1: aggregate (no LLM summary needed)
|
||||
// Turn 2: follow-up referencing "those"
|
||||
let mock = PromptCapturingMockModel(responses: [
|
||||
"SELECT AVG(age) FROM users",
|
||||
"SELECT name, age FROM users WHERE age > 35",
|
||||
"Charlie (42) and Eve (55) are older than average."
|
||||
])
|
||||
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
|
||||
// Turn 1: average age → aggregate, template summary
|
||||
let r1 = try await engine.send("what is the average age of users?")
|
||||
#expect(r1.sql == "SELECT AVG(age) FROM users")
|
||||
|
||||
// Turn 2: "who is above that?" — needs the avg context
|
||||
let r2 = try await engine.send("who is above average?")
|
||||
#expect(r2.queryResult?.rowCount == 2)
|
||||
|
||||
// Verify context passed
|
||||
let prompts = mock.capturedPrompts
|
||||
let followUp = prompts.first { prompt in
|
||||
prompt.contains("who is above average?") && prompt.contains("CONVERSATION HISTORY")
|
||||
}
|
||||
#expect(followUp != nil)
|
||||
if let prompt = followUp {
|
||||
#expect(prompt.contains("AVG(age)"), "Should include prior aggregate SQL for context")
|
||||
#expect(prompt.contains("users"), "Should include table reference from prior turn")
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Context window limits how much history is visible in follow-ups")
|
||||
func contextWindowLimitsHistoryInFollowUps() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
|
||||
// 3 turns, but context window of 2 messages
|
||||
let mock = PromptCapturingMockModel(responses: [
|
||||
"SELECT COUNT(*) FROM users",
|
||||
"SELECT COUNT(*) FROM orders",
|
||||
"SELECT COUNT(*) FROM users WHERE age > 30",
|
||||
])
|
||||
|
||||
let config = ChatEngineConfiguration(
|
||||
queryTimeout: nil,
|
||||
contextWindowSize: 2
|
||||
)
|
||||
|
||||
let engine = ChatEngine(
|
||||
database: db,
|
||||
model: mock,
|
||||
configuration: config
|
||||
)
|
||||
|
||||
_ = try await engine.send("how many users?")
|
||||
_ = try await engine.send("how many orders?")
|
||||
_ = try await engine.send("how many users over 30?")
|
||||
|
||||
// The third prompt should only have the last 2 messages from turn 2
|
||||
let prompts = mock.capturedPrompts
|
||||
#expect(prompts.count >= 3)
|
||||
|
||||
let thirdPrompt = prompts[2]
|
||||
#expect(thirdPrompt.contains("CONVERSATION HISTORY"))
|
||||
// Turn 2 context should be present
|
||||
#expect(thirdPrompt.contains("how many orders?"))
|
||||
#expect(thirdPrompt.contains("SELECT COUNT(*) FROM orders"))
|
||||
// Turn 1 context should be trimmed (window=2 means last 2 messages)
|
||||
#expect(!thirdPrompt.contains("how many users?\n"), "First turn should be trimmed from context window")
|
||||
}
|
||||
|
||||
@Test("clearHistory resets context so follow-ups have no prior history")
|
||||
func clearHistoryResetsFollowUpContext() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
|
||||
let mock = PromptCapturingMockModel(responses: [
|
||||
"SELECT * FROM users",
|
||||
"Here are the 5 users.",
|
||||
"SELECT COUNT(*) FROM users",
|
||||
])
|
||||
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
|
||||
// Turn 1
|
||||
_ = try await engine.send("show all users")
|
||||
#expect(engine.messages.count == 2)
|
||||
|
||||
// Clear history
|
||||
engine.clearHistory()
|
||||
#expect(engine.messages.isEmpty)
|
||||
|
||||
// Turn 2 after clear — should NOT have conversation history
|
||||
_ = try await engine.send("count all users")
|
||||
|
||||
let prompts = mock.capturedPrompts
|
||||
let lastPrompt = prompts.last!
|
||||
// After clearing, the prompt should NOT contain conversation history
|
||||
#expect(!lastPrompt.contains("CONVERSATION HISTORY"),
|
||||
"After clearHistory(), follow-up should not have prior context")
|
||||
#expect(!lastPrompt.contains("show all users"),
|
||||
"After clearHistory(), prior messages should be gone")
|
||||
}
|
||||
|
||||
@Test("Multi-turn with result data in context enables informed follow-ups")
|
||||
func resultDataInContextEnablesInformedFollowUps() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
|
||||
// Turn 1: list users → multi-row result, LLM summarizes
|
||||
// Turn 2: "sort those by age" → references same table
|
||||
let mock = PromptCapturingMockModel(responses: [
|
||||
"SELECT name, age, city FROM users",
|
||||
"Found 5 users: Alice (25, NY), Bob (35, SF), Charlie (42, NY), Diana (28, Chicago), Eve (55, SF).",
|
||||
"SELECT name, age, city FROM users ORDER BY age DESC",
|
||||
"Users sorted by age: Eve (55), Charlie (42), Bob (35), Diana (28), Alice (25)."
|
||||
])
|
||||
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
|
||||
let r1 = try await engine.send("list all users with their age and city")
|
||||
#expect(r1.queryResult?.rowCount == 5)
|
||||
#expect(r1.queryResult?.columns.contains("age") == true)
|
||||
#expect(r1.queryResult?.columns.contains("city") == true)
|
||||
|
||||
let r2 = try await engine.send("sort those by age descending")
|
||||
#expect(r2.sql == "SELECT name, age, city FROM users ORDER BY age DESC")
|
||||
|
||||
// Verify the assistant message in history includes the SQL
|
||||
let messages = engine.messages
|
||||
#expect(messages.count == 4)
|
||||
// First assistant message should have the SQL recorded
|
||||
#expect(messages[1].sql == "SELECT name, age, city FROM users")
|
||||
// Second assistant should have the sorted SQL
|
||||
#expect(messages[3].sql == "SELECT name, age, city FROM users ORDER BY age DESC")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user