Initial implementation of SwiftDBAI

Chat with any SQLite database using natural language. Built on
AnyLanguageModel (HuggingFace) for LLM-agnostic provider support
and GRDB for SQLite access.

Core features:
- Auto schema introspection from sqlite_master (zero config)
- NL → SQL generation via any AnyLanguageModel provider
- Three rendering modes: text summary, data table, Swift Charts
- Drop-in DataChatView (SwiftUI) and headless ChatEngine
- Operation allowlist with read-only default
- Mutation policy with per-table control
- ToolExecutionDelegate for destructive operation confirmation
- Multi-turn conversation context
- 352 tests across 24 suites, all passing

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Krishna Kumar
2026-04-04 09:30:56 -05:00
commit b1724fe7ca
55 changed files with 15506 additions and 0 deletions

View File

@@ -0,0 +1,254 @@
// BinarySizeTests.swift
// SwiftDBAI
//
// Validates that the SwiftDBAI package stays within its 2 MB binary size budget.
// This test suite uses source-level heuristics since we can't measure the actual
// compiled binary size in a unit test. The constraints ensure the package remains
// lightweight by checking:
// 1. Total source code size (proxy for compiled size)
// 2. No embedded binary assets or large resources
// 3. No unnecessary heavy dependencies
// 4. File count stays reasonable (no code bloat)
import Foundation
import Testing
@Suite("Binary Size Budget")
struct BinarySizeTests {
/// The maximum allowed total source code size in bytes.
/// At typical Swift optimized compilation ratios (2-4x), 500 KB of source
/// compiles to roughly 1-2 MB of binary. We set the source budget at 500 KB
/// to keep the compiled output well under 2 MB.
private static let maxSourceSizeBytes: Int = 500_000 // 500 KB
/// Maximum number of Swift source files allowed.
/// More files generally means more code and larger binaries.
private static let maxSourceFileCount: Int = 60
/// Maximum size for any single source file in bytes.
/// Large individual files often indicate code that should be split or
/// contains embedded data that bloats the binary.
private static let maxSingleFileSizeBytes: Int = 50_000 // 50 KB
/// Disallowed file extensions in the Sources directory that would bloat the binary.
private static let disallowedExtensions: Set<String> = [
"png", "jpg", "jpeg", "gif", "bmp", "tiff",
"mp3", "mp4", "wav", "mov",
"mlmodel", "mlmodelc", "mlpackage",
"sqlite", "db",
"zip", "tar", "gz",
"bin", "dat",
"framework", "dylib", "a"
]
// MARK: - Helper
/// Recursively finds all files in the Sources/SwiftDBAI directory.
private func findSourceFiles() throws -> [URL] {
let sourcesDir = findSourcesDirectory()
guard let sourcesDir else {
Issue.record("Could not locate Sources/SwiftDBAI directory")
return []
}
let fileManager = FileManager.default
guard let enumerator = fileManager.enumerator(
at: sourcesDir,
includingPropertiesForKeys: [.fileSizeKey, .isRegularFileKey],
options: [.skipsHiddenFiles]
) else {
Issue.record("Could not enumerate Sources/SwiftDBAI directory")
return []
}
var files: [URL] = []
for case let fileURL as URL in enumerator {
let resourceValues = try fileURL.resourceValues(forKeys: [.isRegularFileKey])
if resourceValues.isRegularFile == true {
files.append(fileURL)
}
}
return files
}
/// Locates the Sources/SwiftDBAI directory by walking up from the test bundle.
private func findSourcesDirectory() -> URL? {
// Try common locations relative to the build directory
let fileManager = FileManager.default
// In SPM test runs, we can find the package root by checking known paths
var candidateURL = URL(fileURLWithPath: #filePath)
// Walk up from Tests/SwiftDBAITests/BinarySizeTests.swift to package root
for _ in 0..<3 {
candidateURL = candidateURL.deletingLastPathComponent()
}
let sourcesDir = candidateURL.appendingPathComponent("Sources/SwiftDBAI")
if fileManager.fileExists(atPath: sourcesDir.path) {
return sourcesDir
}
// Fallback: check current working directory
let cwdSources = URL(fileURLWithPath: fileManager.currentDirectoryPath)
.appendingPathComponent("Sources/SwiftDBAI")
if fileManager.fileExists(atPath: cwdSources.path) {
return cwdSources
}
return nil
}
// MARK: - Tests
@Test("Total source code size stays under 500 KB budget")
func totalSourceCodeSizeUnderBudget() throws {
let files = try findSourceFiles()
let swiftFiles = files.filter { $0.pathExtension == "swift" }
var totalSize: Int = 0
for file in swiftFiles {
let attributes = try FileManager.default.attributesOfItem(atPath: file.path)
let fileSize = attributes[.size] as? Int ?? 0
totalSize += fileSize
}
#expect(totalSize < Self.maxSourceSizeBytes,
"""
Total Swift source size (\(totalSize) bytes) exceeds \(Self.maxSourceSizeBytes) byte budget.
At typical 2-4x compilation ratio, this would produce a binary larger than 2 MB.
Consider removing unused code or splitting into optional sub-targets.
""")
// Log the actual size for visibility
let sizeKB = Double(totalSize) / 1024.0
let budgetKB = Double(Self.maxSourceSizeBytes) / 1024.0
print("📦 SwiftDBAI source size: \(String(format: "%.1f", sizeKB)) KB / \(String(format: "%.0f", budgetKB)) KB budget (\(String(format: "%.0f", (sizeKB / budgetKB) * 100))% used)")
}
@Test("Source file count stays reasonable")
func sourceFileCountUnderLimit() throws {
let files = try findSourceFiles()
let swiftFiles = files.filter { $0.pathExtension == "swift" }
#expect(swiftFiles.count <= Self.maxSourceFileCount,
"""
Swift source file count (\(swiftFiles.count)) exceeds limit of \(Self.maxSourceFileCount).
More files generally means more code and larger binaries.
""")
print("📦 SwiftDBAI file count: \(swiftFiles.count) / \(Self.maxSourceFileCount) max")
}
@Test("No individual source file exceeds 50 KB")
func noOversizedSourceFiles() throws {
let files = try findSourceFiles()
let swiftFiles = files.filter { $0.pathExtension == "swift" }
for file in swiftFiles {
let attributes = try FileManager.default.attributesOfItem(atPath: file.path)
let fileSize = attributes[.size] as? Int ?? 0
#expect(fileSize < Self.maxSingleFileSizeBytes,
"""
File \(file.lastPathComponent) is \(fileSize) bytes, exceeding the \(Self.maxSingleFileSizeBytes) byte limit.
Large files may contain embedded data or code that should be split.
""")
}
}
@Test("No binary assets or heavy resources in Sources directory")
func noBinaryAssetsInSources() throws {
let files = try findSourceFiles()
let disallowedFiles = files.filter { file in
Self.disallowedExtensions.contains(file.pathExtension.lowercased())
}
#expect(disallowedFiles.isEmpty,
"""
Found \(disallowedFiles.count) disallowed file(s) in Sources directory:
\(disallowedFiles.map(\.lastPathComponent).joined(separator: "\n"))
These file types bloat the binary. Remove them or move to a separate resource bundle.
""")
}
@Test("Package has no resource bundles that could bloat binary")
func noResourceBundles() throws {
let files = try findSourceFiles()
let resourceFiles = files.filter { file in
let ext = file.pathExtension.lowercased()
return ["xcassets", "storyboard", "xib", "nib", "xcdatamodeld"].contains(ext)
}
#expect(resourceFiles.isEmpty,
"""
Found resource bundle files that could bloat the binary:
\(resourceFiles.map(\.lastPathComponent).joined(separator: "\n"))
SwiftDBAI should be pure code — no bundled resources.
""")
}
@Test("Only expected dependencies declared (GRDB + AnyLanguageModel)")
func minimalDependencies() throws {
// Read Package.swift to verify we only have the expected dependencies
var packageURL = URL(fileURLWithPath: #filePath)
for _ in 0..<3 {
packageURL = packageURL.deletingLastPathComponent()
}
let packageSwiftURL = packageURL.appendingPathComponent("Package.swift")
guard FileManager.default.fileExists(atPath: packageSwiftURL.path) else {
// Skip if we can't find Package.swift (CI environments etc.)
return
}
let packageContents = try String(contentsOf: packageSwiftURL, encoding: .utf8)
// Count .package() declarations (dependencies)
let packageDeclarations = packageContents.components(separatedBy: ".package(")
.count - 1 // subtract 1 because the first segment is before any .package(
#expect(packageDeclarations <= 2,
"""
Found \(packageDeclarations) package dependencies, expected at most 2 (GRDB + AnyLanguageModel).
Additional dependencies increase binary size. Evaluate if they're truly needed.
""")
// Verify the expected dependencies are present
#expect(packageContents.contains("GRDB"), "Expected GRDB dependency")
#expect(packageContents.contains("AnyLanguageModel"), "Expected AnyLanguageModel dependency")
print("📦 SwiftDBAI dependencies: \(packageDeclarations) (GRDB + AnyLanguageModel)")
}
@Test("Estimated binary size under 2 MB")
func estimatedBinarySizeUnderLimit() throws {
let files = try findSourceFiles()
let swiftFiles = files.filter { $0.pathExtension == "swift" }
var totalSize: Int = 0
for file in swiftFiles {
let attributes = try FileManager.default.attributesOfItem(atPath: file.path)
let fileSize = attributes[.size] as? Int ?? 0
totalSize += fileSize
}
// Conservative estimate: optimized Swift binary is typically 2-4x source size.
// Use 4x as worst case multiplier for safety margin.
let worstCaseMultiplier = 4.0
let estimatedBinarySize = Double(totalSize) * worstCaseMultiplier
let maxBinarySize: Double = 2.0 * 1024.0 * 1024.0 // 2 MB
#expect(estimatedBinarySize < maxBinarySize,
"""
Estimated binary size (\(String(format: "%.1f", estimatedBinarySize / 1024.0)) KB) exceeds 2 MB limit.
Source: \(totalSize) bytes × \(worstCaseMultiplier)x multiplier = \(String(format: "%.1f", estimatedBinarySize / 1024.0)) KB
Note: This is the SwiftDBAI module only — excludes GRDB and AnyLanguageModel
which are existing dependencies the developer already includes.
""")
let estimatedMB = estimatedBinarySize / (1024.0 * 1024.0)
print("📦 Estimated SwiftDBAI binary size: \(String(format: "%.2f", estimatedMB)) MB / 2.00 MB limit (worst case \(worstCaseMultiplier)x)")
}
}

View File

@@ -0,0 +1,293 @@
// ChartDataDetectorTests.swift
// SwiftDBAITests
import Testing
@testable import SwiftDBAI
@Suite("ChartDataDetector")
struct ChartDataDetectorTests {
let detector = ChartDataDetector()
// MARK: - Helpers
private func makeQueryResult(
columns: [String],
rows: [[QueryResult.Value]],
sql: String = "SELECT *"
) -> QueryResult {
let rowDicts = rows.map { values in
Dictionary(uniqueKeysWithValues: zip(columns, values))
}
return QueryResult(
columns: columns,
rows: rowDicts,
sql: sql,
executionTime: 0.01
)
}
private func makeTable(
columns: [String],
rows: [[QueryResult.Value]],
sql: String = "SELECT *"
) -> DataTable {
DataTable(makeQueryResult(columns: columns, rows: rows, sql: sql))
}
// MARK: - Basic Eligibility
@Test("Returns nil for single-column results")
func singleColumn() {
let table = makeTable(
columns: ["count"],
rows: [[.integer(42)]]
)
#expect(detector.detect(table) == nil)
}
@Test("Returns nil for empty results")
func emptyResults() {
let table = makeTable(columns: ["name", "value"], rows: [])
#expect(detector.detect(table) == nil)
}
@Test("Returns nil for single row")
func singleRow() {
let table = makeTable(
columns: ["name", "count"],
rows: [[.text("A"), .integer(10)]]
)
#expect(detector.detect(table) == nil)
}
@Test("Returns nil for too many rows")
func tooManyRows() {
let rows = (0..<101).map { i in
[QueryResult.Value.text("cat\(i)"), .integer(Int64(i))]
}
let table = makeTable(columns: ["name", "count"], rows: rows)
#expect(detector.detect(table) == nil)
}
// MARK: - Bar Chart Detection
@Test("Recommends bar chart for categorical text + numeric")
func barChartCategorical() {
let table = makeTable(
columns: ["department", "headcount"],
rows: [
[.text("Engineering"), .integer(45)],
[.text("Marketing"), .integer(20)],
[.text("Sales"), .integer(30)],
[.text("HR"), .integer(10)],
]
)
let rec = detector.detect(table)
#expect(rec != nil)
#expect(rec?.chartType == .bar)
#expect(rec?.categoryColumn == "department")
#expect(rec?.valueColumn == "headcount")
#expect(rec?.confidence ?? 0 > 0.5)
}
// MARK: - Pie Chart Detection
@Test("Recommends pie chart for small positive proportions")
func pieChartSmallCategories() {
let table = makeTable(
columns: ["status", "count"],
rows: [
[.text("Active"), .integer(50)],
[.text("Inactive"), .integer(30)],
[.text("Pending"), .integer(20)],
]
)
let rec = detector.detect(table)
#expect(rec != nil)
#expect(rec?.chartType == .pie)
#expect(rec?.categoryColumn == "status")
#expect(rec?.valueColumn == "count")
}
@Test("Does not recommend pie with negative values")
func pieRejectsNegative() {
let table = makeTable(
columns: ["category", "change"],
rows: [
[.text("A"), .integer(50)],
[.text("B"), .integer(-10)],
[.text("C"), .integer(20)],
]
)
let rec = detector.detect(table)
#expect(rec != nil)
// Should NOT be pie since there's a negative value
#expect(rec?.chartType != .pie)
}
@Test("Does not recommend pie with too many slices")
func pieRejectsTooManySlices() {
let rows = (0..<10).map { i in
[QueryResult.Value.text("cat\(i)"), .integer(Int64(i + 1))]
}
let table = makeTable(columns: ["category", "value"], rows: rows)
let rec = detector.detect(table)
#expect(rec != nil)
#expect(rec?.chartType != .pie)
}
// MARK: - Line Chart Detection
@Test("Recommends line chart for time-series column names")
func lineChartTimeSeries() {
let table = makeTable(
columns: ["year", "revenue"],
rows: [
[.text("2020"), .real(1_000_000)],
[.text("2021"), .real(1_200_000)],
[.text("2022"), .real(1_500_000)],
[.text("2023"), .real(1_800_000)],
[.text("2024"), .real(2_100_000)],
]
)
let rec = detector.detect(table)
#expect(rec != nil)
#expect(rec?.chartType == .line)
#expect(rec?.categoryColumn == "year")
#expect(rec?.valueColumn == "revenue")
}
@Test("Recommends line chart for date-formatted text values")
func lineChartDateValues() {
let table = makeTable(
columns: ["period", "sales"],
rows: [
[.text("2024-01"), .integer(100)],
[.text("2024-02"), .integer(120)],
[.text("2024-03"), .integer(90)],
[.text("2024-04"), .integer(150)],
]
)
let rec = detector.detect(table)
#expect(rec != nil)
#expect(rec?.chartType == .line)
}
@Test("Recommends line chart for sequential numeric x-axis")
func lineChartSequential() {
let table = makeTable(
columns: ["step", "value"],
rows: [
[.integer(1), .real(2.5)],
[.integer(2), .real(3.1)],
[.integer(3), .real(4.0)],
[.integer(4), .real(3.8)],
]
)
let rec = detector.detect(table)
#expect(rec != nil)
#expect(rec?.chartType == .line)
}
// MARK: - All Recommendations
@Test("Returns multiple recommendations sorted by confidence")
func allRecommendations() {
let table = makeTable(
columns: ["category", "amount"],
rows: [
[.text("A"), .integer(30)],
[.text("B"), .integer(50)],
[.text("C"), .integer(20)],
]
)
let recs = detector.allRecommendations(for: table)
#expect(!recs.isEmpty)
// Should be sorted by confidence descending
for i in 1..<recs.count {
#expect(recs[i - 1].confidence >= recs[i].confidence)
}
}
// MARK: - Two Numeric Columns Fallback
@Test("Uses first numeric as category when no text column exists")
func numericOnlyColumns() {
let table = makeTable(
columns: ["x", "y"],
rows: [
[.integer(1), .integer(10)],
[.integer(2), .integer(20)],
[.integer(3), .integer(30)],
]
)
let rec = detector.detect(table)
#expect(rec != nil)
#expect(rec?.categoryColumn == "x")
#expect(rec?.valueColumn == "y")
}
// MARK: - Confidence & Reason
@Test("Confidence is between 0 and 1")
func confidenceBounds() {
let table = makeTable(
columns: ["name", "score"],
rows: [
[.text("A"), .integer(10)],
[.text("B"), .integer(20)],
]
)
let rec = detector.detect(table)
#expect(rec != nil)
#expect(rec!.confidence >= 0.0)
#expect(rec!.confidence <= 1.0)
}
@Test("Reason is non-empty")
func reasonPresent() {
let table = makeTable(
columns: ["name", "score"],
rows: [
[.text("A"), .integer(10)],
[.text("B"), .integer(20)],
]
)
let rec = detector.detect(table)
#expect(rec != nil)
#expect(!rec!.reason.isEmpty)
}
// MARK: - Custom Configuration
@Test("Respects custom minimumRows")
func customMinRows() {
let strict = ChartDataDetector(minimumRows: 5)
let table = makeTable(
columns: ["name", "value"],
rows: [
[.text("A"), .integer(1)],
[.text("B"), .integer(2)],
[.text("C"), .integer(3)],
]
)
#expect(strict.detect(table) == nil)
}
@Test("Respects custom maxPieSlices")
func customMaxPieSlices() {
let narrow = ChartDataDetector(maxPieSlices: 2)
let table = makeTable(
columns: ["status", "count"],
rows: [
[.text("A"), .integer(50)],
[.text("B"), .integer(30)],
[.text("C"), .integer(20)],
]
)
let rec = narrow.detect(table)
// With maxPieSlices=2, 3 rows should not get pie
#expect(rec?.chartType != .pie)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,164 @@
// ChatViewTests.swift
// SwiftDBAITests
//
// Tests for ChatView, ChatViewModel, and MessageBubbleView integration
// with ScrollableDataTableView.
import Testing
import Foundation
@testable import SwiftDBAI
@Suite("SchemaReadiness Tests")
struct SchemaReadinessTests {
@Test("SchemaReadiness isReady returns true only for ready state")
func isReadyProperty() {
#expect(SchemaReadiness.idle.isReady == false)
#expect(SchemaReadiness.loading.isReady == false)
#expect(SchemaReadiness.ready(tableCount: 3).isReady == true)
#expect(SchemaReadiness.failed("error").isReady == false)
}
}
@Suite("ChatViewModel Tests")
struct ChatViewModelTests {
@Test("Messages with query results produce DataTable-compatible data")
func messageWithQueryResultHasTableData() {
// A ChatMessage with a queryResult should have the data needed
// for ScrollableDataTableView rendering
let result = QueryResult(
columns: ["id", "name", "score"],
rows: [
["id": .integer(1), "name": .text("Alice"), "score": .real(95.5)],
["id": .integer(2), "name": .text("Bob"), "score": .real(87.3)],
],
sql: "SELECT id, name, score FROM users",
executionTime: 0.01
)
let message = ChatMessage(
role: .assistant,
content: "Found 2 users.",
queryResult: result,
sql: "SELECT id, name, score FROM users"
)
// Verify queryResult is present and can be converted to DataTable
#expect(message.queryResult != nil)
#expect(message.queryResult!.columns.count == 3)
#expect(message.queryResult!.rows.count == 2)
// Verify DataTable conversion works (this is what MessageBubbleView does)
let dataTable = DataTable(message.queryResult!)
#expect(dataTable.columnCount == 3)
#expect(dataTable.rowCount == 2)
#expect(dataTable.columns[0].name == "id")
#expect(dataTable.columns[1].name == "name")
#expect(dataTable.columns[2].name == "score")
}
@Test("Messages without query results do not trigger table rendering")
func messageWithoutQueryResult() {
let message = ChatMessage(
role: .assistant,
content: "Hello! How can I help?",
queryResult: nil,
sql: nil
)
#expect(message.queryResult == nil)
}
@Test("Empty query results do not trigger table rendering")
func emptyQueryResult() {
let result = QueryResult(
columns: [],
rows: [],
sql: "SELECT * FROM empty_table",
executionTime: 0.001
)
let message = ChatMessage(
role: .assistant,
content: "No results found.",
queryResult: result,
sql: "SELECT * FROM empty_table"
)
// Even though queryResult exists, it has no columns/rows
// MessageBubbleView checks both conditions before showing the table
#expect(message.queryResult != nil)
#expect(message.queryResult!.columns.isEmpty)
#expect(message.queryResult!.rows.isEmpty)
}
@Test("Mutation results do not trigger table rendering")
func mutationQueryResult() {
let result = QueryResult(
columns: [],
rows: [],
sql: "INSERT INTO users (name) VALUES ('Charlie')",
executionTime: 0.005,
rowsAffected: 1
)
let message = ChatMessage(
role: .assistant,
content: "Successfully inserted 1 row.",
queryResult: result,
sql: "INSERT INTO users (name) VALUES ('Charlie')"
)
// Mutation results have empty columns no table shown
#expect(message.queryResult!.columns.isEmpty)
}
@Test("Error messages never have query results")
func errorMessageHasNoQueryResult() {
let message = ChatMessage(
role: .error,
content: "SELECT operations are not allowed."
)
#expect(message.queryResult == nil)
#expect(message.role == .error)
}
@Test("DataTable preserves column order from QueryResult")
func dataTableColumnOrder() {
let result = QueryResult(
columns: ["date", "revenue", "category"],
rows: [
["date": .text("2024-01-01"), "revenue": .real(1500.0), "category": .text("Electronics")],
],
sql: "SELECT date, revenue, category FROM sales",
executionTime: 0.02
)
let dataTable = DataTable(result)
#expect(dataTable.columnNames == ["date", "revenue", "category"])
}
@Test("Large result sets are renderable as DataTable")
func largeResultSet() {
var rows: [[String: QueryResult.Value]] = []
for i in 0..<500 {
rows.append([
"id": .integer(Int64(i)),
"value": .real(Double(i) * 1.5),
])
}
let result = QueryResult(
columns: ["id", "value"],
rows: rows,
sql: "SELECT id, value FROM big_table",
executionTime: 0.15
)
let dataTable = DataTable(result)
#expect(dataTable.rowCount == 500)
#expect(dataTable.columnCount == 2)
}
}

View File

@@ -0,0 +1,136 @@
// DataChatViewUsageTests.swift
// SwiftDBAITests
//
// Proves DataChatView works with minimal setup under 10 lines of code.
// A developer only needs a GRDB connection and a LanguageModel to get a
// full chat-with-database SwiftUI view.
import Testing
import Foundation
import GRDB
@testable import SwiftDBAI
// MARK: - Minimal Setup: DataChatView in Under 10 Lines
/// This test suite proves the "zero_config_reads" principle:
/// A developer with an existing SQLite database can create a fully functional
/// chat UI by providing only a GRDB connection and a language model instance.
/// No schema files, no annotations, no manual configuration required.
@Suite("DataChatView Minimal Setup")
struct DataChatViewMinimalSetupTests {
//
// USAGE EXAMPLE DataChatView in 6 lines of real code
//
// import SwiftDBAI
// import GRDB
//
// let db = try DatabaseQueue(path: "mydata.sqlite")
// let model = OllamaLanguageModel(model: "llama3")
//
// var body: some View {
// DataChatView(database: db, model: model)
// }
//
/// Creates a temporary in-memory database with sample data for tests.
private static func makeSampleDatabase() throws -> DatabaseQueue {
let db = try DatabaseQueue()
try db.write { db in
try db.execute(sql: """
CREATE TABLE products (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
price REAL NOT NULL,
category TEXT
);
INSERT INTO products (name, price, category) VALUES ('Widget', 9.99, 'Hardware');
INSERT INTO products (name, price, category) VALUES ('Gadget', 24.99, 'Electronics');
INSERT INTO products (name, price, category) VALUES ('Doohickey', 4.99, 'Hardware');
""")
}
return db
}
@Test("DataChatView initializes from database + model in 2 lines")
@MainActor
func dataChatViewMinimalInit() throws {
// LINE 1: Create (or receive) a GRDB connection
let db = try Self.makeSampleDatabase()
// LINE 2: Create the view that's it!
let _ = DataChatView(database: db, model: MockLanguageModel())
// The view is ready. No schema files, no annotations, no extra config.
}
@Test("DataChatView path-based init works in 1 line given a path and model")
@MainActor
func dataChatViewPathInit() throws {
// Create a temp database file
let tempDir = FileManager.default.temporaryDirectory
let dbPath = tempDir.appendingPathComponent("test_\(UUID().uuidString).sqlite").path
let db = try DatabaseQueue(path: dbPath)
try db.write { db in
try db.execute(sql: "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
}
// ONE LINE to get a full chat UI:
let _ = DataChatView(databasePath: dbPath, model: MockLanguageModel())
// Cleanup
try? FileManager.default.removeItem(atPath: dbPath)
}
@Test("ChatEngine headless usage works in 3 lines")
func chatEngineMinimalUsage() async throws {
// LINE 1: Database
let db = try Self.makeSampleDatabase()
// LINE 2: Engine
let engine = ChatEngine(database: db, model: MockLanguageModel(responseText: "SELECT COUNT(*) AS total FROM products"))
// LINE 3: Schema preparation verifies auto-introspection works
let schema = try await engine.prepareSchema()
// The engine auto-discovered the schema no manual config needed
#expect(schema.tableNames.contains("products"))
#expect(schema.tableNames.count == 1)
}
@Test("ChatViewModel works with zero configuration beyond db + model")
@MainActor
func chatViewModelMinimalUsage() async throws {
let db = try Self.makeSampleDatabase()
let engine = ChatEngine(database: db, model: MockLanguageModel())
let viewModel = ChatViewModel(engine: engine)
// Prepare triggers auto-schema-introspection
await viewModel.prepare()
#expect(viewModel.schemaReadiness.isReady)
#expect(viewModel.messages.isEmpty) // Clean slate, ready to chat
}
@Test("Default configuration is read-only (safe by default)")
@MainActor
func defaultIsReadOnly() throws {
let db = try Self.makeSampleDatabase()
// No allowlist specified defaults to .readOnly
let _ = DataChatView(database: db, model: MockLanguageModel())
// This compiles and works. SELECT-only is the safe default.
// Developer must explicitly opt in to writes:
// DataChatView(database: db, model: model, allowlist: .standard)
}
@Test("Full DataChatView with all options still under 10 lines")
@MainActor
func dataChatViewFullConfig() throws {
let db = try Self.makeSampleDatabase() // 1
let model = MockLanguageModel() // 2
let _ = DataChatView( // 3-8
database: db,
model: model,
allowlist: .readOnly,
additionalContext: "Product catalog for an e-commerce store",
maxSummaryRows: 100
)
// Even with ALL options specified, it's under 10 lines of setup.
}
}

View File

@@ -0,0 +1,285 @@
// DataTableTests.swift
// SwiftDBAITests
import Foundation
import Testing
@testable import SwiftDBAI
@Suite("DataTable")
struct DataTableTests {
// MARK: - Helpers
private func makeQueryResult(
columns: [String],
rows: [[String: QueryResult.Value]],
sql: String = "SELECT * FROM test",
executionTime: TimeInterval = 0.01
) -> QueryResult {
QueryResult(
columns: columns,
rows: rows,
sql: sql,
executionTime: executionTime
)
}
// MARK: - Basic Construction
@Test("Converts QueryResult columns and rows correctly")
func basicConversion() {
let result = makeQueryResult(
columns: ["id", "name", "score"],
rows: [
["id": .integer(1), "name": .text("Alice"), "score": .real(95.5)],
["id": .integer(2), "name": .text("Bob"), "score": .real(87.0)],
]
)
let table = DataTable(result)
#expect(table.columnCount == 3)
#expect(table.rowCount == 2)
#expect(table.columnNames == ["id", "name", "score"])
#expect(table.sql == "SELECT * FROM test")
#expect(table.executionTime == 0.01)
}
@Test("Empty result produces empty table")
func emptyResult() {
let result = makeQueryResult(columns: ["id", "name"], rows: [])
let table = DataTable(result)
#expect(table.isEmpty)
#expect(table.rowCount == 0)
#expect(table.columnCount == 2)
#expect(table.columnNames == ["id", "name"])
}
// MARK: - Subscript Access
@Test("Subscript by row and column index")
func subscriptByIndex() {
let result = makeQueryResult(
columns: ["a", "b"],
rows: [
["a": .integer(10), "b": .text("hello")],
["a": .integer(20), "b": .text("world")],
]
)
let table = DataTable(result)
#expect(table[row: 0, column: 0] == .integer(10))
#expect(table[row: 0, column: 1] == .text("hello"))
#expect(table[row: 1, column: 0] == .integer(20))
#expect(table[row: 1, column: 1] == .text("world"))
}
@Test("Subscript by row index and column name")
func subscriptByName() {
let result = makeQueryResult(
columns: ["x", "y"],
rows: [["x": .real(1.5), "y": .real(2.5)]]
)
let table = DataTable(result)
#expect(table[row: 0, column: "x"] == .real(1.5))
#expect(table[row: 0, column: "y"] == .real(2.5))
#expect(table[row: 0, column: "z"] == .null) // non-existent column
}
// MARK: - Column Data Extraction
@Test("Extract column values by index")
func columnValuesByIndex() {
let result = makeQueryResult(
columns: ["val"],
rows: [
["val": .integer(1)],
["val": .integer(2)],
["val": .integer(3)],
]
)
let table = DataTable(result)
let values = table.columnValues(at: 0)
#expect(values == [.integer(1), .integer(2), .integer(3)])
}
@Test("Extract column values by name")
func columnValuesByName() {
let result = makeQueryResult(
columns: ["name"],
rows: [
["name": .text("A")],
["name": .text("B")],
]
)
let table = DataTable(result)
#expect(table.columnValues(named: "name") == [.text("A"), .text("B")])
#expect(table.columnValues(named: "missing").isEmpty)
}
@Test("numericValues extracts doubles from numeric column")
func numericValues() {
let result = makeQueryResult(
columns: ["score"],
rows: [
["score": .integer(10)],
["score": .real(20.5)],
["score": .null],
["score": .text("not a number")],
]
)
let table = DataTable(result)
let nums = table.numericValues(forColumn: "score")
#expect(nums.count == 2)
#expect(nums[0] == 10.0)
#expect(nums[1] == 20.5)
}
@Test("stringValues extracts non-null strings")
func stringValues() {
let result = makeQueryResult(
columns: ["label"],
rows: [
["label": .text("foo")],
["label": .null],
["label": .text("bar")],
]
)
let table = DataTable(result)
let strs = table.stringValues(forColumn: "label")
#expect(strs == ["foo", "bar"])
}
// MARK: - Type Inference
@Test("Infers integer type for all-integer column")
func inferInteger() {
let result = makeQueryResult(
columns: ["id"],
rows: [["id": .integer(1)], ["id": .integer(2)]]
)
let table = DataTable(result)
#expect(table.columns[0].inferredType == .integer)
}
@Test("Infers real type for all-real column")
func inferReal() {
let result = makeQueryResult(
columns: ["price"],
rows: [["price": .real(1.99)], ["price": .real(2.50)]]
)
let table = DataTable(result)
#expect(table.columns[0].inferredType == .real)
}
@Test("Infers text type for all-text column")
func inferText() {
let result = makeQueryResult(
columns: ["name"],
rows: [["name": .text("A")], ["name": .text("B")]]
)
let table = DataTable(result)
#expect(table.columns[0].inferredType == .text)
}
@Test("Promotes integer + real to real")
func inferNumericPromotion() {
let result = makeQueryResult(
columns: ["val"],
rows: [["val": .integer(1)], ["val": .real(2.5)]]
)
let table = DataTable(result)
#expect(table.columns[0].inferredType == .real)
}
@Test("Mixed types result in .mixed")
func inferMixed() {
let result = makeQueryResult(
columns: ["data"],
rows: [["data": .integer(1)], ["data": .text("hello")]]
)
let table = DataTable(result)
#expect(table.columns[0].inferredType == .mixed)
}
@Test("All-null column infers .null")
func inferNull() {
let result = makeQueryResult(
columns: ["empty"],
rows: [["empty": .null], ["empty": .null]]
)
let table = DataTable(result)
#expect(table.columns[0].inferredType == .null)
}
@Test("Null values are ignored during type inference")
func inferIgnoresNulls() {
let result = makeQueryResult(
columns: ["val"],
rows: [["val": .integer(1)], ["val": .null], ["val": .integer(3)]]
)
let table = DataTable(result)
#expect(table.columns[0].inferredType == .integer)
}
// MARK: - Missing Values
@Test("Missing dictionary keys become .null")
func missingKeysBecomNull() {
let result = makeQueryResult(
columns: ["a", "b"],
rows: [["a": .integer(1)]] // "b" is missing
)
let table = DataTable(result)
#expect(table[row: 0, column: 0] == .integer(1))
#expect(table[row: 0, column: 1] == .null)
}
// MARK: - Row Identity
@Test("Rows have sequential IDs")
func rowIdentity() {
let result = makeQueryResult(
columns: ["x"],
rows: [["x": .integer(1)], ["x": .integer(2)], ["x": .integer(3)]]
)
let table = DataTable(result)
#expect(table.rows[0].id == 0)
#expect(table.rows[1].id == 1)
#expect(table.rows[2].id == 2)
}
// MARK: - Column Identity
@Test("Columns are Identifiable by name")
func columnIdentity() {
let result = makeQueryResult(
columns: ["alpha", "beta"],
rows: [["alpha": .integer(1), "beta": .integer(2)]]
)
let table = DataTable(result)
#expect(table.columns[0].id == "alpha")
#expect(table.columns[1].id == "beta")
#expect(table.columns[0].index == 0)
#expect(table.columns[1].index == 1)
}
}

View File

@@ -0,0 +1,745 @@
// DestructiveOperationTests.swift
// SwiftDBAITests
//
// Tests verifying that destructive operations are blocked without confirmation
// and allowed when the delegate approves.
import AnyLanguageModel
import Foundation
import GRDB
import Testing
@testable import SwiftDBAI
// MARK: - Test Delegates
/// A delegate that always rejects destructive operations and tracks calls.
private final class RejectingTrackingDelegate: SwiftDBAI.ToolExecutionDelegate, @unchecked Sendable {
private let lock = NSLock()
private var _confirmCalls: [DestructiveOperationContext] = []
private var _willExecuteCalls: [(sql: String, classification: DestructiveClassification)] = []
private var _didExecuteCalls: [(sql: String, success: Bool)] = []
var confirmCalls: [DestructiveOperationContext] {
lock.withLock { _confirmCalls }
}
var willExecuteCalls: [(sql: String, classification: DestructiveClassification)] {
lock.withLock { _willExecuteCalls }
}
var didExecuteCalls: [(sql: String, success: Bool)] {
lock.withLock { _didExecuteCalls }
}
func confirmDestructiveOperation(_ context: DestructiveOperationContext) async -> Bool {
lock.withLock { _confirmCalls.append(context) }
return false
}
func willExecuteSQL(_ sql: String, classification: DestructiveClassification) async {
lock.withLock { _willExecuteCalls.append((sql: sql, classification: classification)) }
}
func didExecuteSQL(_ sql: String, success: Bool) async {
lock.withLock { _didExecuteCalls.append((sql: sql, success: success)) }
}
}
/// A delegate that always approves destructive operations and tracks calls.
private final class ApprovingTrackingDelegate: SwiftDBAI.ToolExecutionDelegate, @unchecked Sendable {
private let lock = NSLock()
private var _confirmCalls: [DestructiveOperationContext] = []
private var _willExecuteCalls: [(sql: String, classification: DestructiveClassification)] = []
private var _didExecuteCalls: [(sql: String, success: Bool)] = []
var confirmCalls: [DestructiveOperationContext] {
lock.withLock { _confirmCalls }
}
var willExecuteCalls: [(sql: String, classification: DestructiveClassification)] {
lock.withLock { _willExecuteCalls }
}
var didExecuteCalls: [(sql: String, success: Bool)] {
lock.withLock { _didExecuteCalls }
}
func confirmDestructiveOperation(_ context: DestructiveOperationContext) async -> Bool {
lock.withLock { _confirmCalls.append(context) }
return true
}
func willExecuteSQL(_ sql: String, classification: DestructiveClassification) async {
lock.withLock { _willExecuteCalls.append((sql: sql, classification: classification)) }
}
func didExecuteSQL(_ sql: String, success: Bool) async {
lock.withLock { _didExecuteCalls.append((sql: sql, success: success)) }
}
}
// MARK: - Helpers
/// Creates an in-memory database with test data for destructive operation tests.
/// Users 1 and 2 have orders; user 3 has no orders (safe to delete).
private func makeTestDatabase() throws -> DatabaseQueue {
let db = try DatabaseQueue(path: ":memory:")
try db.write { db in
// Disable FK enforcement for test flexibility, then re-enable
try db.execute(sql: "PRAGMA foreign_keys = OFF")
try db.execute(sql: """
CREATE TABLE users (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
email TEXT NOT NULL
)
""")
try db.execute(sql: """
INSERT INTO users (name, email) VALUES
('Alice', 'alice@example.com'),
('Bob', 'bob@example.com'),
('Charlie', 'charlie@example.com')
""")
try db.execute(sql: """
CREATE TABLE orders (
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
amount REAL NOT NULL
)
""")
try db.execute(sql: """
INSERT INTO orders (user_id, amount) VALUES
(1, 99.99),
(2, 150.00),
(3, 25.50)
""")
}
return db
}
/// A sequential mock model for tests. Returns responses in order.
private struct TestSequentialModel: LanguageModel {
typealias UnavailableReason = Never
let responses: [String]
private let callCounter = CallCounter()
private final class CallCounter: @unchecked Sendable {
var count = 0
let lock = NSLock()
func next() -> Int {
lock.lock()
defer { lock.unlock() }
let c = count
count += 1
return c
}
}
init(responses: [String]) {
self.responses = responses
}
func respond<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
includeSchemaInPrompt: Bool,
options: GenerationOptions
) async throws -> LanguageModelSession.Response<Content> where Content: Generable {
let idx = callCounter.next()
let text = idx < responses.count ? responses[idx] : "fallback response"
let rawContent = GeneratedContent(kind: .string(text))
let content = try Content(rawContent)
return LanguageModelSession.Response(
content: content,
rawContent: rawContent,
transcriptEntries: [][...]
)
}
func streamResponse<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
includeSchemaInPrompt: Bool,
options: GenerationOptions
) -> sending LanguageModelSession.ResponseStream<Content> where Content: Generable {
let idx = callCounter.next()
let text = idx < responses.count ? responses[idx] : "fallback response"
let rawContent = GeneratedContent(kind: .string(text))
let content = try! Content(rawContent)
return LanguageModelSession.ResponseStream(content: content, rawContent: rawContent)
}
}
// MARK: - Tests: Destructive Operations Blocked Without Confirmation
@Suite("Destructive Operations - Blocked Without Confirmation")
struct DestructiveOperationsBlockedTests {
@Test("DELETE is blocked when no delegate is provided")
func deleteBlockedWithoutDelegate() async throws {
let db = try makeTestDatabase()
let model = TestSequentialModel(responses: [
"DELETE FROM users WHERE id = 1"
])
// Unrestricted allowlist permits DELETE, but no delegate to confirm
let engine = ChatEngine(
database: db,
model: model,
allowlist: .unrestricted
)
do {
_ = try await engine.send("Delete user 1")
Issue.record("Expected confirmationRequired error but send succeeded")
} catch let error as SwiftDBAIError {
guard case .confirmationRequired(let sql, let operation) = error else {
Issue.record("Expected confirmationRequired, got: \(error)")
return
}
#expect(sql.uppercased().contains("DELETE"))
#expect(operation == "delete")
}
// Verify the user was NOT deleted (data remains intact)
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 1")
}
#expect(count == 1, "User should NOT have been deleted")
}
@Test("DELETE is blocked when delegate rejects")
func deleteBlockedWhenDelegateRejects() async throws {
let db = try makeTestDatabase()
let delegate = RejectingTrackingDelegate()
let model = TestSequentialModel(responses: [
"DELETE FROM users WHERE id = 2"
])
let engine = ChatEngine(
database: db,
model: model,
allowlist: .unrestricted,
delegate: delegate
)
do {
_ = try await engine.send("Delete user 2")
Issue.record("Expected confirmationRequired error but send succeeded")
} catch let error as SwiftDBAIError {
guard case .confirmationRequired(let sql, let operation) = error else {
Issue.record("Expected confirmationRequired, got: \(error)")
return
}
#expect(sql.uppercased().contains("DELETE"))
#expect(operation == "delete")
}
// Verify delegate was consulted
#expect(delegate.confirmCalls.count == 1)
#expect(delegate.confirmCalls[0].statementKind == .delete)
#expect(delegate.confirmCalls[0].sql.uppercased().contains("DELETE"))
// Verify the data was NOT modified
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 2")
}
#expect(count == 1, "User should NOT have been deleted")
// Verify no SQL was actually executed (no willExecute/didExecute calls)
#expect(delegate.willExecuteCalls.isEmpty, "No SQL should have been executed")
#expect(delegate.didExecuteCalls.isEmpty, "No SQL should have been executed")
}
@Test("DELETE is blocked with MutationPolicy and no delegate")
func deleteBlockedWithMutationPolicyNoDelegate() async throws {
let db = try makeTestDatabase()
let model = TestSequentialModel(responses: [
"DELETE FROM users WHERE id = 3"
])
let policy = MutationPolicy(
allowedOperations: [.insert, .update, .delete],
requiresDestructiveConfirmation: true
)
let engine = ChatEngine(
database: db,
model: model,
mutationPolicy: policy
)
do {
_ = try await engine.send("Delete user 3")
Issue.record("Expected confirmationRequired error but send succeeded")
} catch let error as SwiftDBAIError {
guard case .confirmationRequired(let sql, let operation) = error else {
Issue.record("Expected confirmationRequired, got: \(error)")
return
}
#expect(sql.uppercased().contains("DELETE"))
#expect(operation == "delete")
}
// Data intact
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 3")
}
#expect(count == 1, "User should NOT have been deleted")
}
@Test("DELETE is blocked with MutationPolicy and rejecting delegate")
func deleteBlockedWithMutationPolicyRejectingDelegate() async throws {
let db = try makeTestDatabase()
let delegate = RejectingTrackingDelegate()
let model = TestSequentialModel(responses: [
"DELETE FROM orders WHERE user_id = 1"
])
let policy = MutationPolicy(
allowedOperations: [.insert, .update, .delete],
requiresDestructiveConfirmation: true
)
let engine = ChatEngine(
database: db,
model: model,
mutationPolicy: policy,
delegate: delegate
)
do {
_ = try await engine.send("Delete all orders for user 1")
Issue.record("Expected confirmationRequired error but send succeeded")
} catch let error as SwiftDBAIError {
guard case .confirmationRequired = error else {
Issue.record("Expected confirmationRequired, got: \(error)")
return
}
}
// Delegate was consulted and rejected
#expect(delegate.confirmCalls.count == 1)
#expect(delegate.confirmCalls[0].statementKind == .delete)
// Orders remain
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM orders WHERE user_id = 1")
}
#expect(count == 1, "Orders should NOT have been deleted")
}
@Test("Default delegate implementation rejects destructive operations")
func defaultDelegateRejectsDestructive() async {
struct DefaultDelegate: SwiftDBAI.ToolExecutionDelegate {}
let delegate = DefaultDelegate()
let context = DestructiveOperationContext(
sql: "DELETE FROM users WHERE id = 1",
statementKind: .delete,
classification: .destructive(.delete),
description: "Delete from users"
)
let approved = await delegate.confirmDestructiveOperation(context)
#expect(approved == false, "Default delegate should reject destructive operations")
}
@Test("DELETE not in readOnly allowlist is rejected before delegate is consulted")
func deleteNotInAllowlistRejectedEarly() async throws {
let db = try makeTestDatabase()
let delegate = ApprovingTrackingDelegate()
let model = TestSequentialModel(responses: [
"DELETE FROM users WHERE id = 1"
])
// Read-only allowlist does NOT include DELETE
let engine = ChatEngine(
database: db,
model: model,
allowlist: .readOnly,
delegate: delegate
)
do {
_ = try await engine.send("Delete user 1")
Issue.record("Expected operationNotAllowed error")
} catch let error as SwiftDBAIError {
guard case .operationNotAllowed(let operation) = error else {
Issue.record("Expected operationNotAllowed, got: \(error)")
return
}
#expect(operation == "delete")
}
// Delegate should NOT have been consulted the allowlist rejects before delegation
#expect(delegate.confirmCalls.isEmpty, "Delegate should not be consulted when op is not in allowlist")
}
}
// MARK: - Tests: Destructive Operations Allowed When Delegate Approves
@Suite("Destructive Operations - Allowed When Delegate Approves")
struct DestructiveOperationsAllowedTests {
@Test("DELETE succeeds when delegate approves")
func deleteSucceedsWithApprovingDelegate() async throws {
let db = try makeTestDatabase()
let delegate = ApprovingTrackingDelegate()
let model = TestSequentialModel(responses: [
"DELETE FROM users WHERE id = 1",
"Successfully deleted 1 user."
])
let engine = ChatEngine(
database: db,
model: model,
allowlist: .unrestricted,
delegate: delegate
)
let response = try await engine.send("Delete user 1")
// Delegate was consulted and approved
#expect(delegate.confirmCalls.count == 1)
#expect(delegate.confirmCalls[0].statementKind == .delete)
#expect(delegate.confirmCalls[0].sql.uppercased().contains("DELETE"))
#expect(delegate.confirmCalls[0].targetTable == "users")
// SQL was executed
#expect(delegate.willExecuteCalls.count == 1)
#expect(delegate.didExecuteCalls.count == 1)
#expect(delegate.didExecuteCalls[0].success == true)
// Verify the data was actually deleted
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 1")
}
#expect(count == 0, "User should have been deleted")
// Response should contain meaningful content
#expect(response.sql?.uppercased().contains("DELETE") == true)
#expect(response.queryResult != nil)
}
@Test("DELETE with MutationPolicy succeeds when delegate approves")
func deleteWithPolicySucceedsWhenApproved() async throws {
let db = try makeTestDatabase()
let delegate = ApprovingTrackingDelegate()
let model = TestSequentialModel(responses: [
"DELETE FROM orders WHERE user_id = 2",
"Deleted 1 order."
])
let policy = MutationPolicy(
allowedOperations: [.insert, .update, .delete],
requiresDestructiveConfirmation: true
)
let engine = ChatEngine(
database: db,
model: model,
mutationPolicy: policy,
delegate: delegate
)
let response = try await engine.send("Delete all orders for user 2")
// Delegate approved
#expect(delegate.confirmCalls.count == 1)
// Data was actually deleted
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM orders WHERE user_id = 2")
}
#expect(count == 0, "Orders should have been deleted")
#expect(response.sql?.uppercased().contains("DELETE") == true)
}
@Test("AutoApproveDelegate allows DELETE without user interaction")
func autoApproveDelegateAllowsDelete() async throws {
let db = try makeTestDatabase()
let delegate = AutoApproveDelegate()
let model = TestSequentialModel(responses: [
"DELETE FROM users WHERE id = 3",
"Deleted 1 user."
])
let engine = ChatEngine(
database: db,
model: model,
allowlist: .unrestricted,
delegate: delegate
)
let response = try await engine.send("Delete user 3")
// Should succeed without error
#expect(response.sql?.uppercased().contains("DELETE") == true)
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 3")
}
#expect(count == 0, "User should have been deleted")
}
@Test("sendConfirmed bypasses delegate and executes directly")
func sendConfirmedBypassesDelegate() async throws {
let db = try makeTestDatabase()
let delegate = RejectingTrackingDelegate()
let model = TestSequentialModel(responses: [
"Deleted 1 user."
])
let engine = ChatEngine(
database: db,
model: model,
allowlist: .unrestricted,
delegate: delegate
)
// sendConfirmed should execute directly without consulting the delegate for confirmation
let response = try await engine.sendConfirmed(
"Delete user 1",
confirmedSQL: "DELETE FROM users WHERE id = 1"
)
// Delegate was NOT asked to confirm (sendConfirmed skips confirmation)
#expect(delegate.confirmCalls.isEmpty)
// But willExecute/didExecute hooks were still called
#expect(delegate.willExecuteCalls.count == 1)
#expect(delegate.didExecuteCalls.count == 1)
#expect(delegate.didExecuteCalls[0].success == true)
// Data was deleted
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 1")
}
#expect(count == 0)
#expect(response.summary.contains("deleted") || response.summary.contains("Deleted") || response.summary.contains("1"))
}
}
// MARK: - Tests: Delegate Context Correctness
@Suite("Destructive Operations - Delegate Context")
struct DestructiveOperationContextTests {
@Test("Delegate receives correct context for DELETE on specific table")
func delegateReceivesCorrectContext() async throws {
let db = try makeTestDatabase()
let delegate = RejectingTrackingDelegate()
let model = TestSequentialModel(responses: [
"DELETE FROM orders WHERE amount < 50"
])
let engine = ChatEngine(
database: db,
model: model,
allowlist: .unrestricted,
delegate: delegate
)
do {
_ = try await engine.send("Delete cheap orders")
Issue.record("Expected confirmationRequired error")
} catch is SwiftDBAIError {
// Expected
}
#expect(delegate.confirmCalls.count == 1)
let ctx = delegate.confirmCalls[0]
#expect(ctx.statementKind == .delete)
#expect(ctx.classification == .destructive(.delete))
#expect(ctx.classification.requiresConfirmation == true)
#expect(ctx.sql.uppercased().contains("DELETE FROM ORDERS"))
#expect(ctx.targetTable == "orders")
#expect(!ctx.description.isEmpty)
}
@Test("Non-destructive operations do not consult delegate")
func selectDoesNotConsultDelegate() async throws {
let db = try makeTestDatabase()
let delegate = ApprovingTrackingDelegate()
let model = TestSequentialModel(responses: [
"SELECT COUNT(*) FROM users",
"There are 3 users."
])
let engine = ChatEngine(
database: db,
model: model,
allowlist: .unrestricted,
delegate: delegate
)
_ = try await engine.send("How many users?")
// Delegate should NOT have been asked to confirm (SELECT is not destructive)
#expect(delegate.confirmCalls.isEmpty)
// But willExecute/didExecute should still be called (observation hooks)
#expect(delegate.willExecuteCalls.count == 1)
#expect(delegate.didExecuteCalls.count == 1)
}
@Test("INSERT does not require confirmation even with delegate")
func insertDoesNotRequireConfirmation() async throws {
let db = try makeTestDatabase()
let delegate = RejectingTrackingDelegate()
let model = TestSequentialModel(responses: [
"INSERT INTO users (name, email) VALUES ('Dave', 'dave@example.com')",
"Inserted 1 row."
])
let engine = ChatEngine(
database: db,
model: model,
allowlist: .standard,
delegate: delegate
)
let response = try await engine.send("Add user Dave")
// No confirmation needed for INSERT
#expect(delegate.confirmCalls.isEmpty)
#expect(response.sql?.uppercased().contains("INSERT") == true)
// Verify the insert happened
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE name = 'Dave'")
}
#expect(count == 1)
}
@Test("UPDATE does not require confirmation even with delegate")
func updateDoesNotRequireConfirmation() async throws {
let db = try makeTestDatabase()
let delegate = RejectingTrackingDelegate()
let model = TestSequentialModel(responses: [
"UPDATE users SET email = 'alice-new@example.com' WHERE id = 1",
"Updated 1 row."
])
let engine = ChatEngine(
database: db,
model: model,
allowlist: .standard,
delegate: delegate
)
let response = try await engine.send("Update Alice's email")
// No confirmation needed for UPDATE
#expect(delegate.confirmCalls.isEmpty)
#expect(response.sql?.uppercased().contains("UPDATE") == true)
}
}
// MARK: - Tests: MutationPolicy Confirmation Flag
@Suite("Destructive Operations - MutationPolicy Confirmation Control")
struct MutationPolicyConfirmationTests {
@Test("DELETE skips confirmation when requiresDestructiveConfirmation is false")
func deleteSkipsConfirmationWhenDisabled() async throws {
let db = try makeTestDatabase()
let delegate = RejectingTrackingDelegate()
let model = TestSequentialModel(responses: [
"DELETE FROM users WHERE id = 1",
"Deleted 1 user."
])
let policy = MutationPolicy(
allowedOperations: [.insert, .update, .delete],
requiresDestructiveConfirmation: false // Explicitly disabled
)
let engine = ChatEngine(
database: db,
model: model,
mutationPolicy: policy,
delegate: delegate
)
// Should succeed without confirmation since the policy disables it
let response = try await engine.send("Delete user 1")
// Delegate should NOT have been consulted for confirmation
#expect(delegate.confirmCalls.isEmpty)
// But the SQL should have executed
#expect(response.sql?.uppercased().contains("DELETE") == true)
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 1")
}
#expect(count == 0, "User should have been deleted without confirmation")
}
@Test("MutationPolicy.requiresConfirmation only triggers for DELETE")
func requiresConfirmationOnlyForDelete() {
let policy = MutationPolicy(
allowedOperations: [.insert, .update, .delete],
requiresDestructiveConfirmation: true
)
#expect(policy.requiresConfirmation(for: .delete) == true)
#expect(policy.requiresConfirmation(for: .select) == false)
#expect(policy.requiresConfirmation(for: .insert) == false)
#expect(policy.requiresConfirmation(for: .update) == false)
}
@Test("MutationPolicy.readOnly never requires confirmation (no delete allowed)")
func readOnlyNeverRequiresConfirmation() {
let policy = MutationPolicy.readOnly
#expect(policy.requiresConfirmation(for: .select) == false)
#expect(policy.requiresConfirmation(for: .delete) == true) // Would require confirmation IF allowed
#expect(policy.isOperationAllowed(.delete) == false) // But it's not allowed at all
}
@Test("Table-restricted DELETE is blocked for disallowed tables")
func tableRestrictedDeleteBlocked() async throws {
let db = try makeTestDatabase()
let model = TestSequentialModel(responses: [
"DELETE FROM users WHERE id = 1"
])
let policy = MutationPolicy(
allowedOperations: [.insert, .update, .delete],
allowedTables: ["orders"], // Only orders, NOT users
requiresDestructiveConfirmation: true
)
let engine = ChatEngine(
database: db,
model: model,
mutationPolicy: policy
)
do {
_ = try await engine.send("Delete user 1")
Issue.record("Expected tableNotAllowedForMutation error")
} catch let error as SwiftDBAIError {
guard case .tableNotAllowedForMutation(let tableName, let operation) = error else {
Issue.record("Expected tableNotAllowedForMutation, got: \(error)")
return
}
#expect(tableName == "users")
#expect(operation == "delete")
}
// User was not deleted
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 1")
}
#expect(count == 1)
}
}

View File

@@ -0,0 +1,49 @@
// MockLanguageModel.swift
// SwiftDBAI Tests
//
// A mock LanguageModel for unit tests that returns canned responses.
import AnyLanguageModel
import Foundation
/// A mock language model that returns a configurable canned response.
///
/// Used in tests to avoid hitting a real LLM provider.
struct MockLanguageModel: LanguageModel {
typealias UnavailableReason = Never
/// The text the mock will return from `respond(...)`.
let responseText: String
init(responseText: String = "Mock summary response.") {
self.responseText = responseText
}
func respond<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
includeSchemaInPrompt: Bool,
options: GenerationOptions
) async throws -> LanguageModelSession.Response<Content> where Content: Generable {
let rawContent = GeneratedContent(kind: .string(responseText))
let content = try Content(rawContent)
return LanguageModelSession.Response(
content: content,
rawContent: rawContent,
transcriptEntries: [][...]
)
}
func streamResponse<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
includeSchemaInPrompt: Bool,
options: GenerationOptions
) -> sending LanguageModelSession.ResponseStream<Content> where Content: Generable {
let rawContent = GeneratedContent(kind: .string(responseText))
let content = try! Content(rawContent)
return LanguageModelSession.ResponseStream(content: content, rawContent: rawContent)
}
}

View File

@@ -0,0 +1,337 @@
// LocalProviderConfigurationTests.swift
// SwiftDBAI Tests
//
// Tests for local/self-hosted provider configurations (Ollama, llama.cpp):
// factory methods, endpoint discovery, connection handling, and model creation.
import AnyLanguageModel
import Foundation
import GRDB
@testable import SwiftDBAI
import Testing
@Suite("Local Provider Configuration")
struct LocalProviderConfigurationTests {
// MARK: - Ollama Configuration
@Test("Ollama configuration stores provider and model")
func ollamaBasicConfiguration() {
let config = ProviderConfiguration.ollama(model: "llama3.2")
#expect(config.provider == .ollama)
#expect(config.model == "llama3.2")
#expect(config.baseURL == OllamaLanguageModel.defaultBaseURL)
}
@Test("Ollama configuration produces OllamaLanguageModel")
func ollamaMakeModel() {
let config = ProviderConfiguration.ollama(model: "qwen2.5")
let model = config.makeModel()
#expect(model is OllamaLanguageModel)
}
@Test("Ollama with custom base URL for remote instance")
func ollamaCustomBaseURL() {
let remoteURL = URL(string: "http://192.168.1.100:11434")!
let config = ProviderConfiguration.ollama(
model: "mistral",
baseURL: remoteURL
)
#expect(config.baseURL == remoteURL)
#expect(config.provider == .ollama)
let model = config.makeModel()
#expect(model is OllamaLanguageModel)
}
@Test("Ollama does not require an API key")
func ollamaNoAPIKey() {
let config = ProviderConfiguration.ollama(model: "llama3.2")
// Ollama doesn't need an API key, so the key is empty
#expect(config.apiKey == "")
// hasValidAPIKey returns false because key is empty, but that's expected
// for local providers they don't need authentication
#expect(!config.hasValidAPIKey)
}
@Test("Ollama model is available without API key")
func ollamaModelAvailable() {
let config = ProviderConfiguration.ollama(model: "llama3.2")
let model = config.makeModel()
#expect(model.isAvailable)
}
// MARK: - llama.cpp Configuration
@Test("llama.cpp configuration stores provider and model")
func llamaCppBasicConfiguration() {
let config = ProviderConfiguration.llamaCpp(model: "my-model")
#expect(config.provider == .llamaCpp)
#expect(config.model == "my-model")
#expect(config.baseURL == LocalProviderDiscovery.defaultLlamaCppURL)
}
@Test("llama.cpp uses 'default' model name by default")
func llamaCppDefaultModel() {
let config = ProviderConfiguration.llamaCpp()
#expect(config.model == "default")
}
@Test("llama.cpp configuration produces OpenAILanguageModel (compatible API)")
func llamaCppMakeModel() {
let config = ProviderConfiguration.llamaCpp(model: "my-gguf")
let model = config.makeModel()
// llama.cpp uses OpenAI-compatible API
#expect(model is OpenAILanguageModel)
}
@Test("llama.cpp with custom base URL")
func llamaCppCustomBaseURL() {
let customURL = URL(string: "http://localhost:9090")!
let config = ProviderConfiguration.llamaCpp(
model: "custom-model",
baseURL: customURL
)
#expect(config.baseURL == customURL)
let model = config.makeModel()
#expect(model is OpenAILanguageModel)
}
@Test("llama.cpp with API key authentication")
func llamaCppWithAPIKey() {
let config = ProviderConfiguration.llamaCpp(
model: "secured-model",
apiKey: "my-secret-key"
)
#expect(config.apiKey == "my-secret-key")
#expect(config.hasValidAPIKey)
}
@Test("llama.cpp without API key")
func llamaCppNoAPIKey() {
let config = ProviderConfiguration.llamaCpp(model: "open-model")
#expect(config.apiKey == "")
}
// MARK: - Provider Enum
@Test("Provider enum includes ollama and llamaCpp cases")
func providerEnumHasLocalCases() {
let cases = ProviderConfiguration.Provider.allCases
#expect(cases.contains(.ollama))
#expect(cases.contains(.llamaCpp))
// Total: openAI, anthropic, gemini, openAICompatible, ollama, llamaCpp
#expect(cases.count == 6)
}
// MARK: - fromEnvironment
@Test("fromEnvironment creates Ollama configuration")
func fromEnvironmentOllama() {
let config = ProviderConfiguration.fromEnvironment(
provider: .ollama,
environmentVariable: "NONEXISTENT_OLLAMA_KEY",
model: "llama3.2"
)
#expect(config.provider == .ollama)
#expect(config.model == "llama3.2")
}
@Test("fromEnvironment creates llama.cpp configuration")
func fromEnvironmentLlamaCpp() {
let config = ProviderConfiguration.fromEnvironment(
provider: .llamaCpp,
environmentVariable: "NONEXISTENT_LLAMACPP_KEY",
model: "default"
)
#expect(config.provider == .llamaCpp)
#expect(config.model == "default")
}
// MARK: - ChatEngine Convenience Init with Local Providers
@Test("ChatEngine can be created with Ollama provider")
func chatEngineWithOllama() throws {
let dbQueue = try GRDB.DatabaseQueue()
let config = ProviderConfiguration.ollama(model: "llama3.2")
let engine = ChatEngine(database: dbQueue, provider: config)
#expect(engine.tableCount == nil) // schema not yet introspected
}
@Test("ChatEngine can be created with llama.cpp provider")
func chatEngineWithLlamaCpp() throws {
let dbQueue = try GRDB.DatabaseQueue()
let config = ProviderConfiguration.llamaCpp()
let engine = ChatEngine(database: dbQueue, provider: config)
#expect(engine.tableCount == nil)
}
// MARK: - LocalProviderType
@Test("LocalProviderType has expected raw values")
func localProviderTypeRawValues() {
#expect(LocalProviderType.ollama.rawValue == "ollama")
#expect(LocalProviderType.llamaCpp.rawValue == "llama.cpp")
}
@Test("LocalProviderType CaseIterable includes both cases")
func localProviderTypeCases() {
let cases = LocalProviderType.allCases
#expect(cases.count == 2)
#expect(cases.contains(.ollama))
#expect(cases.contains(.llamaCpp))
}
// MARK: - LocalProviderEndpoint
@Test("LocalProviderEndpoint description includes status and model count")
func endpointDescription() {
let endpoint = LocalProviderEndpoint(
baseURL: URL(string: "http://localhost:11434")!,
providerType: .ollama,
isReachable: true,
availableModels: ["llama3.2", "qwen2.5"]
)
#expect(endpoint.description.contains("ollama"))
#expect(endpoint.description.contains("reachable"))
#expect(endpoint.description.contains("2 models"))
}
@Test("LocalProviderEndpoint shows unreachable when not connected")
func endpointUnreachableDescription() {
let endpoint = LocalProviderEndpoint(
baseURL: URL(string: "http://localhost:8080")!,
providerType: .llamaCpp,
isReachable: false,
availableModels: []
)
#expect(endpoint.description.contains("unreachable"))
#expect(endpoint.description.contains("0 models"))
}
@Test("LocalProviderEndpoint equality works correctly")
func endpointEquality() {
let a = LocalProviderEndpoint(
baseURL: URL(string: "http://localhost:11434")!,
providerType: .ollama,
isReachable: true,
availableModels: ["llama3.2"]
)
let b = LocalProviderEndpoint(
baseURL: URL(string: "http://localhost:11434")!,
providerType: .ollama,
isReachable: true,
availableModels: ["llama3.2"]
)
let c = LocalProviderEndpoint(
baseURL: URL(string: "http://localhost:11434")!,
providerType: .ollama,
isReachable: false,
availableModels: []
)
#expect(a == b)
#expect(a != c)
}
// MARK: - Discovery (No Local Server Running)
@Test("Discovery returns unreachable when no server is running")
func discoveryUnreachableEndpoint() async {
// Use a port that's almost certainly not running anything
let endpoint = await LocalProviderDiscovery.discover(
providerType: .ollama,
host: "127.0.0.1",
port: 59999,
timeout: 1
)
#expect(!endpoint.isReachable)
#expect(endpoint.availableModels.isEmpty)
#expect(endpoint.providerType == .ollama)
}
@Test("isOllamaRunning returns false for unreachable endpoint")
func ollamaNotRunning() async {
let unreachableURL = URL(string: "http://127.0.0.1:59998")!
let running = await LocalProviderDiscovery.isOllamaRunning(
at: unreachableURL,
timeout: 1
)
#expect(!running)
}
@Test("isLlamaCppRunning returns false for unreachable endpoint")
func llamaCppNotRunning() async {
let unreachableURL = URL(string: "http://127.0.0.1:59997")!
let running = await LocalProviderDiscovery.isLlamaCppRunning(
at: unreachableURL,
timeout: 1
)
#expect(!running)
}
@Test("listOllamaModels returns empty for unreachable endpoint")
func ollamaModelsUnreachable() async {
let unreachableURL = URL(string: "http://127.0.0.1:59996")!
let models = await LocalProviderDiscovery.listOllamaModels(
at: unreachableURL,
timeout: 1
)
#expect(models.isEmpty)
}
@Test("listLlamaCppModels returns empty for unreachable endpoint")
func llamaCppModelsUnreachable() async {
let unreachableURL = URL(string: "http://127.0.0.1:59995")!
let models = await LocalProviderDiscovery.listLlamaCppModels(
at: unreachableURL,
timeout: 1
)
#expect(models.isEmpty)
}
@Test("discoverAll returns endpoints for both provider types")
func discoverAllReturnsAllProviders() async {
// Use very short timeout since we likely don't have servers running
let endpoints = await LocalProviderDiscovery.discoverAll(timeout: 0.5)
// Should return exactly 2 endpoints (one per well-known provider)
#expect(endpoints.count == 2)
let types = Set(endpoints.map(\.providerType))
#expect(types.contains(.ollama))
#expect(types.contains(.llamaCpp))
}
// MARK: - Default URLs
@Test("Default Ollama URL is correct")
func defaultOllamaURL() {
#expect(LocalProviderDiscovery.defaultOllamaURL.absoluteString == "http://localhost:11434")
}
@Test("Default llama.cpp URL is correct")
func defaultLlamaCppURL() {
#expect(LocalProviderDiscovery.defaultLlamaCppURL.absoluteString == "http://localhost:8080")
}
}

View File

@@ -0,0 +1,363 @@
// MultiTurnContextTests.swift
// SwiftDBAI Tests
//
// Tests verifying multi-turn conversation context follow-up queries
// correctly reference the prior query's table, columns, and results.
import AnyLanguageModel
import Foundation
import GRDB
import Testing
@testable import SwiftDBAI
@Suite("Multi-Turn Context Tests")
struct MultiTurnContextTests {
// MARK: - Test Database Setup
/// Creates an in-memory database with users (including age) and orders.
private func makeTestDatabase() throws -> DatabaseQueue {
let db = try DatabaseQueue(path: ":memory:")
try db.write { db in
try db.execute(sql: """
CREATE TABLE users (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
age INTEGER NOT NULL,
email TEXT NOT NULL,
city TEXT NOT NULL
)
""")
try db.execute(sql: """
INSERT INTO users (name, age, email, city) VALUES
('Alice', 25, 'alice@example.com', 'New York'),
('Bob', 35, 'bob@example.com', 'San Francisco'),
('Charlie', 42, 'charlie@example.com', 'New York'),
('Diana', 28, 'diana@example.com', 'Chicago'),
('Eve', 55, 'eve@example.com', 'San Francisco')
""")
try db.execute(sql: """
CREATE TABLE orders (
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
amount REAL NOT NULL,
status TEXT NOT NULL,
created_at TEXT NOT NULL,
FOREIGN KEY (user_id) REFERENCES users(id)
)
""")
try db.execute(sql: """
INSERT INTO orders (user_id, amount, status, created_at) VALUES
(1, 99.99, 'completed', '2024-01-15'),
(1, 49.50, 'pending', '2024-02-20'),
(2, 150.00, 'completed', '2024-01-10'),
(3, 200.00, 'completed', '2024-03-01'),
(4, 75.00, 'cancelled', '2024-02-05')
""")
}
return db
}
// MARK: - Multi-Turn Context Tests
@Test("Follow-up 'filter those by age > 30' references prior 'show all users' context")
func followUpFilterReferencesUsersTable() async throws {
let db = try makeTestDatabase()
// Turn 1: "show all users" SELECT * FROM users (returns 5 rows, LLM summary needed)
// Turn 2: "filter those by age > 30" should reference users table from context
let mock = PromptCapturingMockModel(responses: [
"SELECT * FROM users",
"Here are all 5 users in the database.",
"SELECT * FROM users WHERE age > 30",
"Found 3 users over 30: Bob (35), Charlie (42), and Eve (55)."
])
let engine = ChatEngine(database: db, model: mock)
// First turn: show all users
let response1 = try await engine.send("show all users")
#expect(response1.sql == "SELECT * FROM users")
#expect(response1.queryResult?.rowCount == 5)
// Second turn: follow-up with implicit reference
let response2 = try await engine.send("filter those by age > 30")
#expect(response2.sql == "SELECT * FROM users WHERE age > 30")
#expect(response2.queryResult?.rowCount == 3)
// Verify the follow-up prompt includes conversation history
let prompts = mock.capturedPrompts
// Find the prompt for the second SQL generation (skip summary prompts)
let followUpSQLPrompt = prompts.first { prompt in
prompt.contains("filter those by age > 30") && prompt.contains("CONVERSATION HISTORY")
}
#expect(followUpSQLPrompt != nil, "Follow-up prompt should contain CONVERSATION HISTORY")
// The conversation history should include the prior query and its SQL
if let prompt = followUpSQLPrompt {
#expect(prompt.contains("show all users"), "History should contain prior user message")
#expect(prompt.contains("SELECT * FROM users"), "History should contain prior SQL")
#expect(prompt.contains("filter those by age > 30"), "Prompt should contain current question")
}
}
@Test("Follow-up correctly inherits table context across multiple turns")
func multipleFollowUpsInheritContext() async throws {
let db = try makeTestDatabase()
// 3-turn conversation narrowing down results
let mock = PromptCapturingMockModel(responses: [
"SELECT * FROM users",
"Here are all 5 users.",
"SELECT * FROM users WHERE city = 'New York'",
"Found 2 users in New York: Alice and Charlie.",
"SELECT * FROM users WHERE city = 'New York' AND age > 30",
"Charlie (42) is the only New York user over 30."
])
let engine = ChatEngine(database: db, model: mock)
// Turn 1
_ = try await engine.send("show all users")
// Turn 2 narrows by city
let response2 = try await engine.send("only those in New York")
#expect(response2.sql == "SELECT * FROM users WHERE city = 'New York'")
#expect(response2.queryResult?.rowCount == 2)
// Turn 3 further narrows by age
let response3 = try await engine.send("now filter by age over 30")
#expect(response3.sql == "SELECT * FROM users WHERE city = 'New York' AND age > 30")
#expect(response3.queryResult?.rowCount == 1)
// Verify third turn's prompt includes the full conversation history
let prompts = mock.capturedPrompts
let thirdTurnPrompt = prompts.last { prompt in
prompt.contains("now filter by age over 30") && prompt.contains("CONVERSATION HISTORY")
}
#expect(thirdTurnPrompt != nil)
if let prompt = thirdTurnPrompt {
// Should include both prior user messages
#expect(prompt.contains("show all users"))
#expect(prompt.contains("only those in New York"))
// Should include prior SQL
#expect(prompt.contains("SELECT * FROM users"))
#expect(prompt.contains("SELECT * FROM users WHERE city = 'New York'"))
}
}
@Test("Follow-up switching tables preserves cross-table context")
func followUpSwitchesTableWithContext() async throws {
let db = try makeTestDatabase()
// Turn 1: query users, Turn 2: ask about their orders
let mock = PromptCapturingMockModel(responses: [
"SELECT name, age FROM users WHERE age > 30",
"Found 3 users over 30.",
"SELECT o.id, u.name, o.amount, o.status FROM orders o JOIN users u ON o.user_id = u.id WHERE u.age > 30",
"Bob has a $150 completed order, Charlie has a $200 completed order."
])
let engine = ChatEngine(database: db, model: mock)
// Turn 1: users over 30
let response1 = try await engine.send("show users over 30")
#expect(response1.queryResult?.rowCount == 3)
// Turn 2: their orders references the previous result context
let response2 = try await engine.send("show their orders")
#expect(response2.sql?.contains("JOIN") == true)
// Verify the follow-up prompt contains the users context
let prompts = mock.capturedPrompts
let orderPrompt = prompts.first { prompt in
prompt.contains("show their orders") && prompt.contains("CONVERSATION HISTORY")
}
#expect(orderPrompt != nil)
if let prompt = orderPrompt {
#expect(prompt.contains("show users over 30"), "Should contain prior user message")
#expect(prompt.contains("age > 30"), "Should contain prior SQL context for table reference")
}
}
@Test("Conversation history includes SQL from prior turns for context")
func historyIncludesSQLFromPriorTurns() async throws {
let db = try makeTestDatabase()
// Both queries are aggregates no LLM summarization needed
let mock = PromptCapturingMockModel(responses: [
"SELECT COUNT(*) FROM users",
"SELECT COUNT(*) FROM users WHERE age > 30",
])
let engine = ChatEngine(database: db, model: mock)
// Turn 1
let r1 = try await engine.send("how many users are there?")
#expect(r1.sql == "SELECT COUNT(*) FROM users")
// Turn 2 references "those" implicitly
let r2 = try await engine.send("how many of those are over 30?")
#expect(r2.sql == "SELECT COUNT(*) FROM users WHERE age > 30")
// Verify engine history has all 4 messages (2 user + 2 assistant)
let messages = engine.messages
#expect(messages.count == 4)
#expect(messages[0].role == .user)
#expect(messages[0].content == "how many users are there?")
#expect(messages[1].role == .assistant)
#expect(messages[1].sql == "SELECT COUNT(*) FROM users")
#expect(messages[2].role == .user)
#expect(messages[2].content == "how many of those are over 30?")
#expect(messages[3].role == .assistant)
#expect(messages[3].sql == "SELECT COUNT(*) FROM users WHERE age > 30")
// The second prompt should reference the first query SQL
let prompts = mock.capturedPrompts
#expect(prompts.count >= 2)
let secondPrompt = prompts[1]
#expect(secondPrompt.contains("CONVERSATION HISTORY"))
#expect(secondPrompt.contains("SELECT COUNT(*) FROM users"))
#expect(secondPrompt.contains("how many users are there?"))
}
@Test("Follow-up after aggregate uses prior table context")
func followUpAfterAggregateUsesTableContext() async throws {
let db = try makeTestDatabase()
// Turn 1: aggregate (no LLM summary needed)
// Turn 2: follow-up referencing "those"
let mock = PromptCapturingMockModel(responses: [
"SELECT AVG(age) FROM users",
"SELECT name, age FROM users WHERE age > 35",
"Charlie (42) and Eve (55) are older than average."
])
let engine = ChatEngine(database: db, model: mock)
// Turn 1: average age aggregate, template summary
let r1 = try await engine.send("what is the average age of users?")
#expect(r1.sql == "SELECT AVG(age) FROM users")
// Turn 2: "who is above that?" needs the avg context
let r2 = try await engine.send("who is above average?")
#expect(r2.queryResult?.rowCount == 2)
// Verify context passed
let prompts = mock.capturedPrompts
let followUp = prompts.first { prompt in
prompt.contains("who is above average?") && prompt.contains("CONVERSATION HISTORY")
}
#expect(followUp != nil)
if let prompt = followUp {
#expect(prompt.contains("AVG(age)"), "Should include prior aggregate SQL for context")
#expect(prompt.contains("users"), "Should include table reference from prior turn")
}
}
@Test("Context window limits how much history is visible in follow-ups")
func contextWindowLimitsHistoryInFollowUps() async throws {
let db = try makeTestDatabase()
// 3 turns, but context window of 2 messages
let mock = PromptCapturingMockModel(responses: [
"SELECT COUNT(*) FROM users",
"SELECT COUNT(*) FROM orders",
"SELECT COUNT(*) FROM users WHERE age > 30",
])
let config = ChatEngineConfiguration(
queryTimeout: nil,
contextWindowSize: 2
)
let engine = ChatEngine(
database: db,
model: mock,
configuration: config
)
_ = try await engine.send("how many users?")
_ = try await engine.send("how many orders?")
_ = try await engine.send("how many users over 30?")
// The third prompt should only have the last 2 messages from turn 2
let prompts = mock.capturedPrompts
#expect(prompts.count >= 3)
let thirdPrompt = prompts[2]
#expect(thirdPrompt.contains("CONVERSATION HISTORY"))
// Turn 2 context should be present
#expect(thirdPrompt.contains("how many orders?"))
#expect(thirdPrompt.contains("SELECT COUNT(*) FROM orders"))
// Turn 1 context should be trimmed (window=2 means last 2 messages)
#expect(!thirdPrompt.contains("how many users?\n"), "First turn should be trimmed from context window")
}
@Test("clearHistory resets context so follow-ups have no prior history")
func clearHistoryResetsFollowUpContext() async throws {
let db = try makeTestDatabase()
let mock = PromptCapturingMockModel(responses: [
"SELECT * FROM users",
"Here are the 5 users.",
"SELECT COUNT(*) FROM users",
])
let engine = ChatEngine(database: db, model: mock)
// Turn 1
_ = try await engine.send("show all users")
#expect(engine.messages.count == 2)
// Clear history
engine.clearHistory()
#expect(engine.messages.isEmpty)
// Turn 2 after clear should NOT have conversation history
_ = try await engine.send("count all users")
let prompts = mock.capturedPrompts
let lastPrompt = prompts.last!
// After clearing, the prompt should NOT contain conversation history
#expect(!lastPrompt.contains("CONVERSATION HISTORY"),
"After clearHistory(), follow-up should not have prior context")
#expect(!lastPrompt.contains("show all users"),
"After clearHistory(), prior messages should be gone")
}
@Test("Multi-turn with result data in context enables informed follow-ups")
func resultDataInContextEnablesInformedFollowUps() async throws {
let db = try makeTestDatabase()
// Turn 1: list users multi-row result, LLM summarizes
// Turn 2: "sort those by age" references same table
let mock = PromptCapturingMockModel(responses: [
"SELECT name, age, city FROM users",
"Found 5 users: Alice (25, NY), Bob (35, SF), Charlie (42, NY), Diana (28, Chicago), Eve (55, SF).",
"SELECT name, age, city FROM users ORDER BY age DESC",
"Users sorted by age: Eve (55), Charlie (42), Bob (35), Diana (28), Alice (25)."
])
let engine = ChatEngine(database: db, model: mock)
let r1 = try await engine.send("list all users with their age and city")
#expect(r1.queryResult?.rowCount == 5)
#expect(r1.queryResult?.columns.contains("age") == true)
#expect(r1.queryResult?.columns.contains("city") == true)
let r2 = try await engine.send("sort those by age descending")
#expect(r2.sql == "SELECT name, age, city FROM users ORDER BY age DESC")
// Verify the assistant message in history includes the SQL
let messages = engine.messages
#expect(messages.count == 4)
// First assistant message should have the SQL recorded
#expect(messages[1].sql == "SELECT name, age, city FROM users")
// Second assistant should have the sorted SQL
#expect(messages[3].sql == "SELECT name, age, city FROM users ORDER BY age DESC")
}
}

View File

@@ -0,0 +1,508 @@
// OnDeviceProviderConfigurationTests.swift
// SwiftDBAI Tests
//
// Tests for on-device provider configurations (CoreML, MLX) including
// configuration validation, inference pipeline setup, and system readiness.
import AnyLanguageModel
import Foundation
@testable import SwiftDBAI
import Testing
@Suite("OnDeviceProviderConfiguration")
struct OnDeviceProviderConfigurationTests {
// MARK: - OnDeviceProviderType
@Test("OnDeviceProviderType has CoreML and MLX cases")
func providerTypeCases() {
let cases = OnDeviceProviderType.allCases
#expect(cases.count == 2)
#expect(cases.contains(.coreML))
#expect(cases.contains(.mlx))
}
@Test("OnDeviceProviderType raw values are descriptive")
func providerTypeRawValues() {
#expect(OnDeviceProviderType.coreML.rawValue == "coreML")
#expect(OnDeviceProviderType.mlx.rawValue == "mlx")
}
// MARK: - CoreML Configuration
@Test("CoreML configuration stores all properties")
func coreMLBasicConfiguration() {
let url = URL(fileURLWithPath: "/tmp/TestModel.mlmodelc")
let config = CoreMLProviderConfiguration(
modelURL: url,
computeUnits: .cpuAndGPU,
maxResponseTokens: 1024,
useSampling: true,
temperature: 0.3
)
#expect(config.modelURL == url)
#expect(config.computeUnits == .cpuAndGPU)
#expect(config.maxResponseTokens == 1024)
#expect(config.useSampling == true)
#expect(config.temperature == 0.3)
}
@Test("CoreML configuration uses sensible defaults")
func coreMLDefaultConfiguration() {
let url = URL(fileURLWithPath: "/tmp/TestModel.mlmodelc")
let config = CoreMLProviderConfiguration(modelURL: url)
#expect(config.computeUnits == .all)
#expect(config.maxResponseTokens == 2048)
#expect(config.useSampling == false)
#expect(config.temperature == 0.1)
}
@Test("CoreML validation fails for non-mlmodelc extension")
func coreMLValidateWrongExtension() {
let url = URL(fileURLWithPath: "/tmp/TestModel.onnx")
let config = CoreMLProviderConfiguration(modelURL: url)
#expect(throws: OnDeviceProviderError.self) {
try config.validate()
}
}
@Test("CoreML validation fails for missing model file")
func coreMLValidateMissingFile() {
let url = URL(fileURLWithPath: "/nonexistent/path/Model.mlmodelc")
let config = CoreMLProviderConfiguration(modelURL: url)
#expect(throws: OnDeviceProviderError.self) {
try config.validate()
}
}
@Test("CoreML configuration is Equatable")
func coreMLEquatable() {
let url = URL(fileURLWithPath: "/tmp/TestModel.mlmodelc")
let a = CoreMLProviderConfiguration(modelURL: url, computeUnits: .all)
let b = CoreMLProviderConfiguration(modelURL: url, computeUnits: .all)
let c = CoreMLProviderConfiguration(modelURL: url, computeUnits: .cpuOnly)
#expect(a == b)
#expect(a != c)
}
// MARK: - ComputeUnitPreference
@Test("ComputeUnitPreference has all expected cases")
func computeUnitCases() {
let cases = ComputeUnitPreference.allCases
#expect(cases.count == 4)
#expect(cases.contains(.all))
#expect(cases.contains(.cpuOnly))
#expect(cases.contains(.cpuAndGPU))
#expect(cases.contains(.cpuAndNeuralEngine))
}
// MARK: - MLX Configuration
@Test("MLX configuration stores all properties")
func mlxBasicConfiguration() {
let dir = URL(fileURLWithPath: "/tmp/models/my-model")
let config = MLXProviderConfiguration(
modelId: "mlx-community/Test-Model-4bit",
localDirectory: dir,
gpuMemory: .minimal,
maxResponseTokens: 512,
temperature: 0.2,
topP: 0.9,
repetitionPenalty: 1.2
)
#expect(config.modelId == "mlx-community/Test-Model-4bit")
#expect(config.localDirectory == dir)
#expect(config.gpuMemory == .minimal)
#expect(config.maxResponseTokens == 512)
#expect(config.temperature == 0.2)
#expect(config.topP == 0.9)
#expect(config.repetitionPenalty == 1.2)
}
@Test("MLX configuration uses sensible defaults")
func mlxDefaultConfiguration() {
let config = MLXProviderConfiguration(modelId: "test-model")
#expect(config.localDirectory == nil)
#expect(config.gpuMemory == .automatic)
#expect(config.maxResponseTokens == 2048)
#expect(config.temperature == 0.1)
#expect(config.topP == 0.95)
#expect(config.repetitionPenalty == 1.1)
}
@Test("MLX validation fails for empty model ID")
func mlxValidateEmptyModelId() {
let config = MLXProviderConfiguration(modelId: "")
#expect(throws: OnDeviceProviderError.self) {
try config.validate()
}
}
@Test("MLX validation fails for nonexistent local directory")
func mlxValidateMissingDirectory() {
let config = MLXProviderConfiguration(
modelId: "test-model",
localDirectory: URL(fileURLWithPath: "/nonexistent/directory")
)
#expect(throws: OnDeviceProviderError.self) {
try config.validate()
}
}
@Test("MLX validation fails for negative temperature")
func mlxValidateNegativeTemperature() {
let config = MLXProviderConfiguration(
modelId: "test-model",
temperature: -0.5
)
#expect(throws: OnDeviceProviderError.self) {
try config.validate()
}
}
@Test("MLX validation fails for topP out of range")
func mlxValidateInvalidTopP() {
let configZero = MLXProviderConfiguration(
modelId: "test-model",
topP: 0.0
)
#expect(throws: OnDeviceProviderError.self) {
try configZero.validate()
}
let configOver = MLXProviderConfiguration(
modelId: "test-model",
topP: 1.5
)
#expect(throws: OnDeviceProviderError.self) {
try configOver.validate()
}
}
@Test("MLX validation fails for zero repetition penalty")
func mlxValidateInvalidRepetitionPenalty() {
let config = MLXProviderConfiguration(
modelId: "test-model",
repetitionPenalty: 0.0
)
#expect(throws: OnDeviceProviderError.self) {
try config.validate()
}
}
@Test("MLX validation succeeds for valid configuration")
func mlxValidateSuccess() throws {
let config = MLXProviderConfiguration(modelId: "test-model")
// Should not throw (no local directory set, model ID is non-empty)
try config.validate()
}
@Test("MLX configuration is Equatable")
func mlxEquatable() {
let a = MLXProviderConfiguration(modelId: "model-a")
let b = MLXProviderConfiguration(modelId: "model-a")
let c = MLXProviderConfiguration(modelId: "model-b")
#expect(a == b)
#expect(a != c)
}
// MARK: - Well-Known MLX Models
@Test("Llama 3.2 3B preset has correct model ID")
func llama3_2_3BPreset() {
let config = MLXProviderConfiguration.llama3_2_3B()
#expect(config.modelId == "mlx-community/Llama-3.2-3B-Instruct-4bit")
#expect(config.temperature == 0.1)
#expect(config.maxResponseTokens == 2048)
}
@Test("Qwen 2.5 Coder 3B preset has correct model ID")
func qwen2_5_coder3BPreset() {
let config = MLXProviderConfiguration.qwen2_5_coder_3B()
#expect(config.modelId == "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit")
#expect(config.temperature == 0.05)
}
@Test("Phi 3.5 Mini preset has correct model ID")
func phi3_5_miniPreset() {
let config = MLXProviderConfiguration.phi3_5_mini()
#expect(config.modelId == "mlx-community/Phi-3.5-mini-instruct-4bit")
#expect(config.temperature == 0.1)
}
@Test("Well-known models accept custom GPU memory config")
func wellKnownModelsCustomGPU() {
let config = MLXProviderConfiguration.llama3_2_3B(
gpuMemory: .minimal
)
#expect(config.gpuMemory == .minimal)
}
// MARK: - GPU Memory Configuration
@Test("Automatic GPU memory config scales with RAM")
func automaticGPUMemory() {
let config = MLXGPUMemoryConfig.automatic
#expect(config.activeCacheLimit > 0)
#expect(config.idleCacheLimit == 50_000_000)
#expect(config.clearCacheOnEviction == true)
}
@Test("Minimal GPU memory config is conservative")
func minimalGPUMemory() {
let config = MLXGPUMemoryConfig.minimal
#expect(config.activeCacheLimit == 64_000_000)
#expect(config.idleCacheLimit == 16_000_000)
#expect(config.clearCacheOnEviction == true)
}
@Test("Unconstrained GPU memory config uses max values")
func unconstrainedGPUMemory() {
let config = MLXGPUMemoryConfig.unconstrained
#expect(config.activeCacheLimit == Int.max)
#expect(config.idleCacheLimit == Int.max)
#expect(config.clearCacheOnEviction == false)
}
@Test("GPU memory config is Equatable")
func gpuMemoryEquatable() {
#expect(MLXGPUMemoryConfig.minimal == MLXGPUMemoryConfig.minimal)
#expect(MLXGPUMemoryConfig.minimal != MLXGPUMemoryConfig.unconstrained)
}
// MARK: - On-Device Provider Errors
@Test("OnDeviceProviderError has descriptive messages")
func errorDescriptions() {
let errors: [OnDeviceProviderError] = [
.modelNotFound(URL(fileURLWithPath: "/tmp/model")),
.invalidModelFormat(expected: ".mlmodelc", actual: ".onnx"),
.emptyModelId,
.invalidParameter(name: "temperature", value: "-1", reason: "Must be non-negative"),
.providerUnavailable(.mlx, reason: "MLX build flag not enabled"),
.modelLoadFailed(reason: "Out of memory"),
.inferenceFailed(reason: "Token limit exceeded"),
]
for error in errors {
#expect(error.errorDescription != nil)
#expect(!error.errorDescription!.isEmpty)
}
}
@Test("OnDeviceProviderError is Equatable")
func errorEquatable() {
let a = OnDeviceProviderError.emptyModelId
let b = OnDeviceProviderError.emptyModelId
let c = OnDeviceProviderError.modelLoadFailed(reason: "test")
#expect(a == b)
#expect(a != c)
}
// MARK: - Inference Pipeline
@Test("MLX inference pipeline initializes with correct type")
func mlxPipelineInit() {
let config = MLXProviderConfiguration.llama3_2_3B()
let pipeline = OnDeviceInferencePipeline(mlxConfiguration: config)
#expect(pipeline.providerType == .mlx)
#expect(pipeline.mlxConfiguration != nil)
#expect(pipeline.coreMLConfiguration == nil)
#expect(pipeline.status == .notLoaded)
}
@Test("CoreML inference pipeline initializes with correct type")
func coreMLPipelineInit() {
let url = URL(fileURLWithPath: "/tmp/TestModel.mlmodelc")
let config = CoreMLProviderConfiguration(modelURL: url)
let pipeline = OnDeviceInferencePipeline(coreMLConfiguration: config)
#expect(pipeline.providerType == .coreML)
#expect(pipeline.coreMLConfiguration != nil)
#expect(pipeline.mlxConfiguration == nil)
#expect(pipeline.status == .notLoaded)
}
@Test("Pipeline validates MLX configuration")
func pipelineValidatesMLX() throws {
let validConfig = MLXProviderConfiguration(modelId: "test-model")
let pipeline = OnDeviceInferencePipeline(mlxConfiguration: validConfig)
try pipeline.validateConfiguration()
let invalidConfig = MLXProviderConfiguration(modelId: "")
let invalidPipeline = OnDeviceInferencePipeline(mlxConfiguration: invalidConfig)
#expect(throws: OnDeviceProviderError.self) {
try invalidPipeline.validateConfiguration()
}
}
@Test("Pipeline validates CoreML configuration")
func pipelineValidatesCoreML() {
let url = URL(fileURLWithPath: "/tmp/TestModel.onnx")
let config = CoreMLProviderConfiguration(modelURL: url)
let pipeline = OnDeviceInferencePipeline(coreMLConfiguration: config)
#expect(throws: OnDeviceProviderError.self) {
try pipeline.validateConfiguration()
}
}
@Test("Pipeline provides SQL generation hints for MLX")
func mlxSQLHints() {
let config = MLXProviderConfiguration(
modelId: "test-model",
maxResponseTokens: 512,
temperature: 0.2
)
let pipeline = OnDeviceInferencePipeline(mlxConfiguration: config)
let hints = pipeline.recommendedSQLGenerationHints
#expect(hints.maxTokens == 512)
#expect(hints.temperature == 0.2)
#expect(hints.useSampling == true)
#expect(hints.systemPromptSuffix.contains("MLX"))
}
@Test("Pipeline provides SQL generation hints for CoreML")
func coreMLSQLHints() {
let url = URL(fileURLWithPath: "/tmp/TestModel.mlmodelc")
let config = CoreMLProviderConfiguration(
modelURL: url,
maxResponseTokens: 1024,
useSampling: false,
temperature: 0.05
)
let pipeline = OnDeviceInferencePipeline(coreMLConfiguration: config)
let hints = pipeline.recommendedSQLGenerationHints
#expect(hints.maxTokens == 1024)
#expect(hints.temperature == 0.05)
#expect(hints.useSampling == false)
#expect(hints.systemPromptSuffix.contains("SQL"))
}
// MARK: - System Readiness
@Test("System capability check returns valid data")
func systemCapability() {
let capability = OnDeviceModelReadiness.checkSystemCapability()
#expect(capability.totalRAM > 0)
// On any modern test machine, we should have at least some RAM
#expect(capability.totalRAM > 1024 * 1024 * 1024) // > 1GB
// On Apple silicon Macs, this should be true
#if arch(arm64)
#expect(capability.hasNeuralEngine == true)
#endif
}
@Test("Suggested MLX model returns a valid configuration")
func suggestedMLXModel() {
let config = OnDeviceModelReadiness.suggestedMLXModel()
#expect(!config.modelId.isEmpty)
#expect(config.temperature >= 0)
#expect(config.maxResponseTokens > 0)
}
@Test("Recommended model size enum has correct raw values")
func recommendedModelSizeRawValues() {
#expect(OnDeviceModelReadiness.RecommendedModelSize.small.rawValue == "small")
#expect(OnDeviceModelReadiness.RecommendedModelSize.medium.rawValue == "medium")
#expect(OnDeviceModelReadiness.RecommendedModelSize.large.rawValue == "large")
}
// MARK: - ProviderConfiguration Integration
@Test("onDeviceMLX creates a ProviderConfiguration")
func onDeviceMLXProviderConfig() {
let mlxConfig = MLXProviderConfiguration.llama3_2_3B()
let providerConfig = ProviderConfiguration.onDeviceMLX(mlxConfig)
#expect(providerConfig.model == mlxConfig.modelId)
#expect(!providerConfig.hasValidAPIKey) // No API key needed for on-device
}
@Test("onDeviceCoreML creates a ProviderConfiguration")
func onDeviceCoreMLProviderConfig() {
let url = URL(fileURLWithPath: "/tmp/SQLModel.mlmodelc")
let coreMLConfig = CoreMLProviderConfiguration(modelURL: url)
let providerConfig = ProviderConfiguration.onDeviceCoreML(coreMLConfig)
#expect(providerConfig.model == "SQLModel.mlmodelc")
#expect(!providerConfig.hasValidAPIKey)
}
// MARK: - Pipeline Status
@Test("Pipeline status transitions")
func pipelineStatusTransitions() {
let config = MLXProviderConfiguration(modelId: "test-model")
let pipeline = OnDeviceInferencePipeline(mlxConfiguration: config)
#expect(pipeline.status == .notLoaded)
pipeline.setStatus(.loading)
#expect(pipeline.status == .loading)
pipeline.setStatus(.ready)
#expect(pipeline.status == .ready)
pipeline.setStatus(.failed("Out of memory"))
#expect(pipeline.status == .failed("Out of memory"))
}
@Test("Pipeline Status is Equatable")
func pipelineStatusEquatable() {
#expect(OnDeviceInferencePipeline.Status.notLoaded == .notLoaded)
#expect(OnDeviceInferencePipeline.Status.loading == .loading)
#expect(OnDeviceInferencePipeline.Status.ready == .ready)
#expect(OnDeviceInferencePipeline.Status.failed("a") == .failed("a"))
#expect(OnDeviceInferencePipeline.Status.failed("a") != .failed("b"))
#expect(OnDeviceInferencePipeline.Status.notLoaded != .ready)
}
// MARK: - SQL Generation Hints
@Test("SQL generation hints are Equatable")
func sqlHintsEquatable() {
let a = OnDeviceSQLGenerationHints(
maxTokens: 512,
temperature: 0.1,
systemPromptSuffix: "test",
useSampling: true
)
let b = OnDeviceSQLGenerationHints(
maxTokens: 512,
temperature: 0.1,
systemPromptSuffix: "test",
useSampling: true
)
let c = OnDeviceSQLGenerationHints(
maxTokens: 1024,
temperature: 0.1,
systemPromptSuffix: "test",
useSampling: true
)
#expect(a == b)
#expect(a != c)
}
}

View File

@@ -0,0 +1,254 @@
// PromptBuilderTests.swift
// SwiftDBAI
import Testing
@testable import SwiftDBAI
@Suite("PromptBuilder")
struct PromptBuilderTests {
// MARK: - Helpers
/// Creates a sample schema for testing.
private func makeSampleSchema() -> DatabaseSchema {
let usersTable = TableSchema(
name: "users",
columns: [
ColumnSchema(cid: 0, name: "id", type: "INTEGER", isNotNull: true, defaultValue: nil, isPrimaryKey: true),
ColumnSchema(cid: 1, name: "name", type: "TEXT", isNotNull: true, defaultValue: nil, isPrimaryKey: false),
ColumnSchema(cid: 2, name: "email", type: "TEXT", isNotNull: false, defaultValue: nil, isPrimaryKey: false),
ColumnSchema(cid: 3, name: "created_at", type: "TEXT", isNotNull: false, defaultValue: "CURRENT_TIMESTAMP", isPrimaryKey: false),
],
primaryKey: ["id"],
foreignKeys: [],
indexes: [
IndexSchema(name: "idx_users_email", isUnique: true, columns: ["email"])
]
)
let ordersTable = TableSchema(
name: "orders",
columns: [
ColumnSchema(cid: 0, name: "id", type: "INTEGER", isNotNull: true, defaultValue: nil, isPrimaryKey: true),
ColumnSchema(cid: 1, name: "user_id", type: "INTEGER", isNotNull: true, defaultValue: nil, isPrimaryKey: false),
ColumnSchema(cid: 2, name: "total", type: "REAL", isNotNull: true, defaultValue: nil, isPrimaryKey: false),
ColumnSchema(cid: 3, name: "status", type: "TEXT", isNotNull: true, defaultValue: "'pending'", isPrimaryKey: false),
],
primaryKey: ["id"],
foreignKeys: [
ForeignKeySchema(fromColumn: "user_id", toTable: "users", toColumn: "id", onUpdate: "NO ACTION", onDelete: "CASCADE")
],
indexes: []
)
return DatabaseSchema(
tables: ["users": usersTable, "orders": ordersTable],
tableNames: ["users", "orders"]
)
}
private func makeEmptySchema() -> DatabaseSchema {
DatabaseSchema(tables: [:], tableNames: [])
}
// MARK: - System Instructions Tests
@Test("System instructions contain role section")
func systemInstructionsContainRole() {
let builder = PromptBuilder(schema: makeSampleSchema())
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("ROLE"))
#expect(instructions.contains("SQL assistant"))
#expect(instructions.contains("SQLite database"))
}
@Test("System instructions contain schema")
func systemInstructionsContainSchema() {
let builder = PromptBuilder(schema: makeSampleSchema())
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("DATABASE SCHEMA"))
#expect(instructions.contains("TABLE users"))
#expect(instructions.contains("TABLE orders"))
#expect(instructions.contains("name TEXT"))
#expect(instructions.contains("email TEXT"))
}
@Test("System instructions contain foreign keys from schema")
func systemInstructionsContainForeignKeys() {
let builder = PromptBuilder(schema: makeSampleSchema())
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("FOREIGN KEY"))
#expect(instructions.contains("REFERENCES users(id)"))
}
@Test("System instructions contain SQL generation rules")
func systemInstructionsContainRules() {
let builder = PromptBuilder(schema: makeSampleSchema())
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("SQL GENERATION RULES"))
#expect(instructions.contains("Use ONLY the tables and columns"))
#expect(instructions.contains("Never generate DDL"))
}
@Test("System instructions contain output format section")
func systemInstructionsContainOutputFormat() {
let builder = PromptBuilder(schema: makeSampleSchema())
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("OUTPUT FORMAT"))
}
@Test("Default allowlist is read-only")
func defaultAllowlistIsReadOnly() {
let builder = PromptBuilder(schema: makeSampleSchema())
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("ONLY generate SELECT queries"))
#expect(instructions.contains("No data modifications"))
}
@Test("Standard allowlist shows correct operations")
func standardAllowlistInstructions() {
let builder = PromptBuilder(schema: makeSampleSchema(), allowlist: .standard)
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("INSERT"))
#expect(instructions.contains("SELECT"))
#expect(instructions.contains("UPDATE"))
}
@Test("Unrestricted allowlist warns about DELETE")
func unrestrictedAllowlistWarnsAboutDelete() {
let builder = PromptBuilder(schema: makeSampleSchema(), allowlist: .unrestricted)
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("DELETE"))
#expect(instructions.contains("destructive"))
#expect(instructions.contains("confirmation"))
}
@Test("Additional context is appended")
func additionalContextAppended() {
let builder = PromptBuilder(
schema: makeSampleSchema(),
additionalContext: "All dates are stored in ISO 8601 format."
)
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("ADDITIONAL CONTEXT"))
#expect(instructions.contains("ISO 8601"))
}
@Test("No additional context section when nil")
func noAdditionalContextWhenNil() {
let builder = PromptBuilder(schema: makeSampleSchema())
let instructions = builder.buildSystemInstructions()
#expect(!instructions.contains("ADDITIONAL CONTEXT"))
}
@Test("No additional context section when empty string")
func noAdditionalContextWhenEmpty() {
let builder = PromptBuilder(schema: makeSampleSchema(), additionalContext: "")
let instructions = builder.buildSystemInstructions()
#expect(!instructions.contains("ADDITIONAL CONTEXT"))
}
@Test("Empty schema produces valid instructions")
func emptySchemaProducesValidInstructions() {
let builder = PromptBuilder(schema: makeEmptySchema())
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("ROLE"))
#expect(instructions.contains("SQL GENERATION RULES"))
// Schema section should still be present, just empty
#expect(instructions.contains("DATABASE SCHEMA"))
}
// MARK: - User Prompt Tests
@Test("User prompt passes through question directly")
func userPromptPassesThrough() {
let builder = PromptBuilder(schema: makeSampleSchema())
let prompt = builder.buildUserPrompt("How many users signed up this week?")
#expect(prompt == "How many users signed up this week?")
}
// MARK: - Follow-up Prompt Tests
@Test("Follow-up prompt includes previous context")
func followUpPromptIncludesPreviousContext() {
let builder = PromptBuilder(schema: makeSampleSchema())
let prompt = builder.buildFollowUpPrompt(
"Now sort them by name",
previousSQL: "SELECT * FROM users WHERE created_at > date('now', '-7 days')",
previousResultSummary: "Found 42 users who signed up this week"
)
#expect(prompt.contains("Previous query:"))
#expect(prompt.contains("SELECT * FROM users"))
#expect(prompt.contains("Previous result:"))
#expect(prompt.contains("42 users"))
#expect(prompt.contains("Follow-up question:"))
#expect(prompt.contains("sort them by name"))
}
// MARK: - Schema Description Quality
@Test("Schema includes column types and constraints")
func schemaIncludesColumnDetails() {
let builder = PromptBuilder(schema: makeSampleSchema())
let instructions = builder.buildSystemInstructions()
// Should include type info
#expect(instructions.contains("INTEGER"))
#expect(instructions.contains("TEXT"))
#expect(instructions.contains("REAL"))
// Should include constraints
#expect(instructions.contains("NOT NULL"))
#expect(instructions.contains("PRIMARY KEY"))
}
@Test("Schema includes index information")
func schemaIncludesIndexes() {
let builder = PromptBuilder(schema: makeSampleSchema())
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("INDEX"))
#expect(instructions.contains("idx_users_email"))
}
// MARK: - Sendable Conformance
@Test("PromptBuilder is Sendable")
func promptBuilderIsSendable() async {
let builder = PromptBuilder(schema: makeSampleSchema())
// Verify it can be sent across concurrency boundaries
let instructions = await Task.detached {
builder.buildSystemInstructions()
}.value
#expect(instructions.contains("ROLE"))
}
// MARK: - Custom Allowlist
@Test("Custom allowlist with select and delete only")
func customAllowlist() {
let allowlist = OperationAllowlist([.select, .delete])
let builder = PromptBuilder(schema: makeSampleSchema(), allowlist: allowlist)
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("DELETE"))
#expect(instructions.contains("SELECT"))
#expect(instructions.contains("destructive"))
}
}

View File

@@ -0,0 +1,325 @@
// ProviderConfigurationTests.swift
// SwiftDBAI Tests
//
// Tests for ProviderConfiguration verifying all cloud provider configurations
// produce valid LanguageModel instances with correct settings.
import AnyLanguageModel
import Foundation
@testable import SwiftDBAI
import Testing
@Suite("ProviderConfiguration")
struct ProviderConfigurationTests {
// MARK: - OpenAI Configuration
@Test("OpenAI configuration stores provider and model")
func openAIBasicConfiguration() {
let config = ProviderConfiguration.openAI(
apiKey: "sk-test-key-123",
model: "gpt-4o"
)
#expect(config.provider == .openAI)
#expect(config.model == "gpt-4o")
#expect(config.apiKey == "sk-test-key-123")
#expect(config.hasValidAPIKey)
}
@Test("OpenAI configuration produces a valid LanguageModel")
func openAIMakeModel() {
let config = ProviderConfiguration.openAI(
apiKey: "sk-test-key",
model: "gpt-4o-mini"
)
let model = config.makeModel()
#expect(model is OpenAILanguageModel)
}
@Test("OpenAI with custom base URL for compatible services")
func openAICustomBaseURL() {
let customURL = URL(string: "https://my-proxy.example.com/v1/")!
let config = ProviderConfiguration.openAI(
apiKey: "sk-proxy-key",
model: "gpt-4o",
baseURL: customURL
)
#expect(config.baseURL == customURL)
let model = config.makeModel()
#expect(model is OpenAILanguageModel)
}
@Test("OpenAI with Responses API variant")
func openAIResponsesVariant() {
let config = ProviderConfiguration.openAI(
apiKey: "sk-test",
model: "gpt-4o",
variant: .responses
)
#expect(config.openAIVariant == .responses)
let model = config.makeModel()
#expect(model is OpenAILanguageModel)
}
@Test("OpenAI with dynamic key provider captures key by reference")
func openAIDynamicKeyProvider() {
nonisolated(unsafe) var currentKey = "sk-initial"
let config = ProviderConfiguration.openAI(
apiKeyProvider: { currentKey },
model: "gpt-4o"
)
#expect(config.apiKey == "sk-initial")
currentKey = "sk-rotated"
#expect(config.apiKey == "sk-rotated")
}
// MARK: - Anthropic Configuration
@Test("Anthropic configuration stores provider and model")
func anthropicBasicConfiguration() {
let config = ProviderConfiguration.anthropic(
apiKey: "sk-ant-test-key",
model: "claude-sonnet-4-20250514"
)
#expect(config.provider == .anthropic)
#expect(config.model == "claude-sonnet-4-20250514")
#expect(config.apiKey == "sk-ant-test-key")
#expect(config.hasValidAPIKey)
}
@Test("Anthropic configuration produces a valid LanguageModel")
func anthropicMakeModel() {
let config = ProviderConfiguration.anthropic(
apiKey: "sk-ant-test",
model: "claude-sonnet-4-20250514"
)
let model = config.makeModel()
#expect(model is AnthropicLanguageModel)
}
@Test("Anthropic with API version and betas")
func anthropicWithVersionAndBetas() {
let config = ProviderConfiguration.anthropic(
apiKey: "sk-ant-test",
model: "claude-sonnet-4-20250514",
apiVersion: "2024-01-01",
betas: ["computer-use"]
)
#expect(config.apiVersion == "2024-01-01")
#expect(config.betas == ["computer-use"])
let model = config.makeModel()
#expect(model is AnthropicLanguageModel)
}
@Test("Anthropic with dynamic key provider captures key by reference")
func anthropicDynamicKeyProvider() {
nonisolated(unsafe) var currentKey = "sk-ant-initial"
let config = ProviderConfiguration.anthropic(
apiKeyProvider: { currentKey },
model: "claude-sonnet-4-20250514"
)
#expect(config.apiKey == "sk-ant-initial")
currentKey = "sk-ant-rotated"
#expect(config.apiKey == "sk-ant-rotated")
}
// MARK: - Gemini Configuration
@Test("Gemini configuration stores provider and model")
func geminiBasicConfiguration() {
let config = ProviderConfiguration.gemini(
apiKey: "AIzaSyTest123",
model: "gemini-2.0-flash"
)
#expect(config.provider == .gemini)
#expect(config.model == "gemini-2.0-flash")
#expect(config.apiKey == "AIzaSyTest123")
#expect(config.hasValidAPIKey)
}
@Test("Gemini configuration produces a valid LanguageModel")
func geminiMakeModel() {
let config = ProviderConfiguration.gemini(
apiKey: "AIzaSyTest",
model: "gemini-2.0-flash"
)
let model = config.makeModel()
#expect(model is GeminiLanguageModel)
}
@Test("Gemini with custom API version")
func geminiCustomVersion() {
let config = ProviderConfiguration.gemini(
apiKey: "AIzaSyTest",
model: "gemini-2.0-flash",
apiVersion: "v1"
)
#expect(config.apiVersion == "v1")
let model = config.makeModel()
#expect(model is GeminiLanguageModel)
}
@Test("Gemini with dynamic key provider captures key by reference")
func geminiDynamicKeyProvider() {
nonisolated(unsafe) var currentKey = "AIza-initial"
let config = ProviderConfiguration.gemini(
apiKeyProvider: { currentKey },
model: "gemini-2.0-flash"
)
#expect(config.apiKey == "AIza-initial")
currentKey = "AIza-rotated"
#expect(config.apiKey == "AIza-rotated")
}
// MARK: - OpenAI-Compatible Configuration
@Test("OpenAI-compatible configuration with custom base URL")
func openAICompatibleConfiguration() {
let baseURL = URL(string: "https://api.together.xyz/v1/")!
let config = ProviderConfiguration.openAICompatible(
apiKey: "together-key",
model: "meta-llama/Llama-3.1-70B",
baseURL: baseURL
)
#expect(config.provider == .openAICompatible)
#expect(config.model == "meta-llama/Llama-3.1-70B")
#expect(config.baseURL == baseURL)
let model = config.makeModel()
#expect(model is OpenAILanguageModel)
}
@Test("OpenAI-compatible with dynamic key provider")
func openAICompatibleDynamicKey() {
let baseURL = URL(string: "http://localhost:1234/v1/")!
nonisolated(unsafe) var currentKey = "local-key"
let config = ProviderConfiguration.openAICompatible(
apiKeyProvider: { currentKey },
model: "local-model",
baseURL: baseURL
)
#expect(config.apiKey == "local-key")
currentKey = "new-local-key"
#expect(config.apiKey == "new-local-key")
}
// MARK: - API Key Validation
@Test("Empty API key reports invalid")
func emptyAPIKeyInvalid() {
let config = ProviderConfiguration.openAI(
apiKey: "",
model: "gpt-4o"
)
#expect(!config.hasValidAPIKey)
}
@Test("Whitespace-only API key reports invalid")
func whitespaceAPIKeyInvalid() {
let config = ProviderConfiguration.openAI(
apiKey: " \n\t ",
model: "gpt-4o"
)
#expect(!config.hasValidAPIKey)
}
@Test("Non-empty API key reports valid")
func nonEmptyAPIKeyValid() {
let config = ProviderConfiguration.openAI(
apiKey: "x",
model: "gpt-4o"
)
#expect(config.hasValidAPIKey)
}
// MARK: - Environment Variable Configuration
@Test("fromEnvironment creates configuration for each provider")
func fromEnvironmentCreatesConfig() {
let openAI = ProviderConfiguration.fromEnvironment(
provider: .openAI,
environmentVariable: "SWIFTDAI_TEST_OPENAI_KEY",
model: "gpt-4o"
)
#expect(openAI.provider == .openAI)
#expect(openAI.model == "gpt-4o")
let anthropic = ProviderConfiguration.fromEnvironment(
provider: .anthropic,
environmentVariable: "SWIFTDAI_TEST_ANTHROPIC_KEY",
model: "claude-sonnet-4-20250514"
)
#expect(anthropic.provider == .anthropic)
let gemini = ProviderConfiguration.fromEnvironment(
provider: .gemini,
environmentVariable: "SWIFTDAI_TEST_GEMINI_KEY",
model: "gemini-2.0-flash"
)
#expect(gemini.provider == .gemini)
}
@Test("fromEnvironment returns empty key when variable not set")
func fromEnvironmentMissingVariable() {
let config = ProviderConfiguration.fromEnvironment(
provider: .openAI,
environmentVariable: "NONEXISTENT_KEY_VAR_SWIFTDBAI_TEST",
model: "gpt-4o"
)
#expect(!config.hasValidAPIKey)
#expect(config.apiKey == "")
}
// MARK: - Provider Enum
@Test("Provider enum has all expected cases")
func providerCases() {
let cases = ProviderConfiguration.Provider.allCases
#expect(cases.count == 6)
#expect(cases.contains(.openAI))
#expect(cases.contains(.anthropic))
#expect(cases.contains(.gemini))
#expect(cases.contains(.openAICompatible))
#expect(cases.contains(.ollama))
#expect(cases.contains(.llamaCpp))
}
// MARK: - Cross-Provider Model Creation
@Test("All providers produce available models")
func allProvidersCreateAvailableModels() {
let configs: [ProviderConfiguration] = [
.openAI(apiKey: "test", model: "gpt-4o"),
.anthropic(apiKey: "test", model: "claude-sonnet-4-20250514"),
.gemini(apiKey: "test", model: "gemini-2.0-flash"),
.openAICompatible(
apiKey: "test",
model: "local",
baseURL: URL(string: "http://localhost:8080/v1/")!
),
]
for config in configs {
let model = config.makeModel()
#expect(model.isAvailable, "Model for \(config.provider) should be available")
}
}
}

View File

@@ -0,0 +1,397 @@
// SQLQueryParserTests.swift
// SwiftDBAITests
import Testing
@testable import SwiftDBAI
@Suite("SQLQueryParser")
struct SQLQueryParserTests {
let readOnlyParser = SQLQueryParser(allowlist: .readOnly)
let standardParser = SQLQueryParser(allowlist: .standard)
let unrestrictedParser = SQLQueryParser(allowlist: .unrestricted)
// MARK: - Extraction from code blocks
@Test("Extracts SQL from markdown sql code block")
func extractFromSQLCodeBlock() throws {
let text = """
Here's the query to find the top users:
```sql
SELECT name, COUNT(*) as count FROM users GROUP BY name ORDER BY count DESC
```
This will give you the results.
"""
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT name, COUNT(*) as count FROM users GROUP BY name ORDER BY count DESC")
#expect(result.operation == .select)
#expect(result.requiresConfirmation == false)
}
@Test("Extracts SQL from generic code block")
func extractFromGenericCodeBlock() throws {
let text = """
Here you go:
```
SELECT * FROM products WHERE price > 100
```
"""
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT * FROM products WHERE price > 100")
}
@Test("Extracts SQL from labeled text")
func extractFromLabel() throws {
let text = """
I can help with that.
SQL: SELECT id, name FROM categories WHERE active = 1
That should work.
"""
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT id, name FROM categories WHERE active = 1")
}
@Test("Extracts direct SQL from plain text")
func extractDirectSQL() throws {
let text = "SELECT COUNT(*) FROM orders WHERE status = 'shipped'"
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT COUNT(*) FROM orders WHERE status = 'shipped'")
}
@Test("Handles SQL with trailing semicolons")
func trailingSemicolon() throws {
let text = "```sql\nSELECT * FROM users;\n```"
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT * FROM users")
}
@Test("Handles multiline SQL in code block")
func multilineSQL() throws {
let text = """
```sql
SELECT u.name, COUNT(o.id) as order_count
FROM users u
JOIN orders o ON u.id = o.user_id
GROUP BY u.name
ORDER BY order_count DESC
LIMIT 10
```
"""
let result = try readOnlyParser.parse(text)
#expect(result.sql.contains("SELECT u.name"))
#expect(result.sql.contains("LIMIT 10"))
}
@Test("Handles WITH (CTE) queries as SELECT")
func cteQuery() throws {
let text = """
```sql
WITH top_users AS (
SELECT user_id, COUNT(*) as cnt FROM orders GROUP BY user_id
)
SELECT * FROM top_users WHERE cnt > 5
```
"""
let result = try readOnlyParser.parse(text)
#expect(result.operation == .select)
}
// MARK: - No SQL found
@Test("Throws noSQLFound for text without SQL")
func noSQLFound() throws {
let text = "I'm sorry, I can't help with that request."
#expect(throws: SQLParsingError.noSQLFound) {
try readOnlyParser.parse(text)
}
}
@Test("Throws noSQLFound for empty input")
func emptyInput() throws {
#expect(throws: SQLParsingError.noSQLFound) {
try readOnlyParser.parse("")
}
}
// MARK: - Operation detection
@Test("Detects INSERT operation")
func detectInsert() throws {
let text = "```sql\nINSERT INTO users (name) VALUES ('Alice')\n```"
let result = try standardParser.parse(text)
#expect(result.operation == .insert)
}
@Test("Detects UPDATE operation")
func detectUpdate() throws {
let text = "```sql\nUPDATE users SET name = 'Bob' WHERE id = 1\n```"
let result = try standardParser.parse(text)
#expect(result.operation == .update)
}
@Test("Detects DELETE operation and requires confirmation")
func detectDeleteRequiresConfirmation() throws {
let text = "```sql\nDELETE FROM users WHERE id = 99\n```"
let result = try unrestrictedParser.parse(text)
#expect(result.operation == .delete)
#expect(result.requiresConfirmation == true)
}
// MARK: - Allowlist enforcement
@Test("Rejects INSERT on read-only allowlist")
func rejectInsertOnReadOnly() throws {
let text = "```sql\nINSERT INTO users (name) VALUES ('Mallory')\n```"
#expect(throws: SQLParsingError.operationNotAllowed(.insert)) {
try readOnlyParser.parse(text)
}
}
@Test("Rejects UPDATE on read-only allowlist")
func rejectUpdateOnReadOnly() {
let text = "```sql\nUPDATE users SET name = 'Eve' WHERE id = 1\n```"
#expect(throws: SQLParsingError.operationNotAllowed(.update)) {
try readOnlyParser.parse(text)
}
}
@Test("Rejects DELETE on standard allowlist")
func rejectDeleteOnStandard() {
let text = "```sql\nDELETE FROM users WHERE id = 1\n```"
#expect(throws: SQLParsingError.operationNotAllowed(.delete)) {
try standardParser.parse(text)
}
}
// MARK: - Dangerous operations
@Test("Rejects DROP TABLE")
func rejectDrop() {
let text = "```sql\nDROP TABLE users\n```"
#expect(throws: SQLParsingError.dangerousOperation("DROP")) {
try unrestrictedParser.parse(text)
}
}
@Test("Rejects ALTER TABLE")
func rejectAlter() {
let text = "```sql\nALTER TABLE users ADD COLUMN age INTEGER\n```"
#expect(throws: SQLParsingError.dangerousOperation("ALTER")) {
try unrestrictedParser.parse(text)
}
}
@Test("Rejects PRAGMA")
func rejectPragma() {
let text = "```sql\nPRAGMA table_info(users)\n```"
#expect(throws: SQLParsingError.dangerousOperation("PRAGMA")) {
try unrestrictedParser.parse(text)
}
}
@Test("Does not match dangerous keywords inside identifiers")
func noFalsePositiveOnSubstring() throws {
// "DROPDOWN" contains "DROP" as substring but is not the keyword
let text = "SELECT dropdown_value FROM settings"
let result = try readOnlyParser.parse(text)
#expect(result.sql.contains("dropdown_value"))
}
// MARK: - Multiple statements
@Test("Rejects multiple statements separated by semicolons")
func rejectMultipleStatements() {
let text = "```sql\nSELECT * FROM users; SELECT * FROM orders\n```"
#expect(throws: SQLParsingError.multipleStatements) {
try readOnlyParser.parse(text)
}
}
@Test("Allows semicolons inside string literals")
func allowSemicolonInString() throws {
let text = "SELECT * FROM users WHERE bio = 'hello; world'"
let result = try readOnlyParser.parse(text)
#expect(result.sql.contains("hello; world"))
}
// MARK: - ParsedSQL equality
@Test("ParsedSQL equality works")
func parsedSQLEquality() {
let a = ParsedSQL(sql: "SELECT 1", operation: .select)
let b = ParsedSQL(sql: "SELECT 1", operation: .select)
#expect(a == b)
}
// MARK: - Error descriptions
@Test("Error descriptions are meaningful")
func errorDescriptions() {
#expect(SQLParsingError.noSQLFound.description.contains("No SQL"))
#expect(SQLParsingError.operationNotAllowed(.insert).description.contains("INSERT"))
#expect(SQLParsingError.dangerousOperation("DROP").description.contains("DROP"))
#expect(SQLParsingError.multipleStatements.description.contains("single"))
}
// MARK: - MutationPolicy integration
@Test("MutationPolicy allows INSERT on permitted table")
func mutationPolicyAllowsInsertOnPermittedTable() throws {
let policy = MutationPolicy(
allowedOperations: [.insert, .update],
allowedTables: ["orders", "order_items"]
)
let parser = SQLQueryParser(mutationPolicy: policy)
let text = "```sql\nINSERT INTO orders (product, qty) VALUES ('Widget', 3)\n```"
let result = try parser.parse(text)
#expect(result.operation == .insert)
#expect(result.requiresConfirmation == false)
}
@Test("MutationPolicy rejects INSERT on non-permitted table")
func mutationPolicyRejectsInsertOnForbiddenTable() {
let policy = MutationPolicy(
allowedOperations: [.insert, .update],
allowedTables: ["orders"]
)
let parser = SQLQueryParser(mutationPolicy: policy)
let text = "```sql\nINSERT INTO users (name) VALUES ('Alice')\n```"
#expect(throws: SQLParsingError.tableNotAllowed(table: "users", operation: .insert)) {
try parser.parse(text)
}
}
@Test("MutationPolicy rejects UPDATE on non-permitted table")
func mutationPolicyRejectsUpdateOnForbiddenTable() {
let policy = MutationPolicy(
allowedOperations: [.insert, .update],
allowedTables: ["orders"]
)
let parser = SQLQueryParser(mutationPolicy: policy)
let text = "```sql\nUPDATE users SET name = 'Bob' WHERE id = 1\n```"
#expect(throws: SQLParsingError.tableNotAllowed(table: "users", operation: .update)) {
try parser.parse(text)
}
}
@Test("MutationPolicy rejects DELETE on non-permitted table")
func mutationPolicyRejectsDeleteOnForbiddenTable() {
let policy = MutationPolicy(
allowedOperations: [.insert, .update, .delete],
allowedTables: ["temp_data"]
)
let parser = SQLQueryParser(mutationPolicy: policy)
let text = "```sql\nDELETE FROM users WHERE id = 99\n```"
#expect(throws: SQLParsingError.tableNotAllowed(table: "users", operation: .delete)) {
try parser.parse(text)
}
}
@Test("MutationPolicy allows mutation on any table when allowedTables is nil")
func mutationPolicyAllowsAllTablesWhenNil() throws {
let policy = MutationPolicy(allowedOperations: [.insert, .update])
let parser = SQLQueryParser(mutationPolicy: policy)
let text = "```sql\nINSERT INTO any_table (col) VALUES ('val')\n```"
let result = try parser.parse(text)
#expect(result.operation == .insert)
}
@Test("MutationPolicy SELECT is never restricted by table allowlist")
func mutationPolicySelectIgnoresTableRestrictions() throws {
let policy = MutationPolicy(
allowedOperations: [.insert],
allowedTables: ["orders"]
)
let parser = SQLQueryParser(mutationPolicy: policy)
// SELECT from a table NOT in allowedTables should still work
let text = "```sql\nSELECT * FROM users\n```"
let result = try parser.parse(text)
#expect(result.operation == .select)
#expect(result.requiresConfirmation == false)
}
@Test("MutationPolicy DELETE requires confirmation by default")
func mutationPolicyDeleteRequiresConfirmation() throws {
let policy = MutationPolicy(allowedOperations: [.delete])
let parser = SQLQueryParser(mutationPolicy: policy)
let text = "```sql\nDELETE FROM users WHERE id = 1\n```"
let result = try parser.parse(text)
#expect(result.operation == .delete)
#expect(result.requiresConfirmation == true)
}
@Test("MutationPolicy DELETE skips confirmation when configured")
func mutationPolicyDeleteNoConfirmation() throws {
let policy = MutationPolicy(
allowedOperations: [.delete],
requiresDestructiveConfirmation: false
)
let parser = SQLQueryParser(mutationPolicy: policy)
let text = "```sql\nDELETE FROM users WHERE id = 1\n```"
let result = try parser.parse(text)
#expect(result.operation == .delete)
#expect(result.requiresConfirmation == false)
}
@Test("MutationPolicy readOnly preset rejects all mutations")
func mutationPolicyReadOnlyRejectsAll() {
let parser = SQLQueryParser(mutationPolicy: .readOnly)
#expect(throws: SQLParsingError.operationNotAllowed(.insert)) {
try parser.parse("INSERT INTO t (a) VALUES (1)")
}
#expect(throws: SQLParsingError.operationNotAllowed(.update)) {
try parser.parse("UPDATE t SET a = 1")
}
#expect(throws: SQLParsingError.operationNotAllowed(.delete)) {
try parser.parse("DELETE FROM t WHERE id = 1")
}
}
@Test("MutationPolicy table matching is case-insensitive")
func mutationPolicyTableCaseInsensitive() throws {
let policy = MutationPolicy(
allowedOperations: [.insert],
allowedTables: ["Orders"]
)
let parser = SQLQueryParser(mutationPolicy: policy)
let text = "INSERT INTO orders (product) VALUES ('Widget')"
let result = try parser.parse(text)
#expect(result.operation == .insert)
}
@Test("MutationPolicy handles quoted table names")
func mutationPolicyQuotedTableNames() throws {
let policy = MutationPolicy(
allowedOperations: [.insert, .update],
allowedTables: ["order_items"]
)
let parser = SQLQueryParser(mutationPolicy: policy)
// Backtick-quoted
let backtick = "INSERT INTO `order_items` (qty) VALUES (5)"
let r1 = try parser.parse(backtick)
#expect(r1.operation == .insert)
// Double-quote-quoted
let doubleQuote = "UPDATE \"order_items\" SET qty = 10 WHERE id = 1"
let r2 = try parser.parse(doubleQuote)
#expect(r2.operation == .update)
}
@Test("Error description for tableNotAllowed is meaningful")
func tableNotAllowedDescription() {
let error = SQLParsingError.tableNotAllowed(table: "secret", operation: .delete)
#expect(error.description.contains("secret"))
#expect(error.description.contains("DELETE"))
}
@Test("Error description for confirmationRequired is meaningful")
func confirmationRequiredDescription() {
let error = SQLParsingError.confirmationRequired(sql: "DELETE FROM x", operation: .delete)
#expect(error.description.contains("DELETE"))
#expect(error.description.contains("confirmation"))
}
}

View File

@@ -0,0 +1,234 @@
// SchemaIntrospectorTests.swift
// SwiftDBAI
import Testing
import GRDB
@testable import SwiftDBAI
@Suite("SchemaIntrospector")
struct SchemaIntrospectorTests {
// MARK: - Helper
/// Creates an in-memory database with a sample schema for testing.
private func makeTestDatabase() throws -> DatabaseQueue {
let db = try DatabaseQueue(configuration: {
var config = Configuration()
config.foreignKeysEnabled = true
return config
}())
try db.write { db in
try db.execute(sql: """
CREATE TABLE authors (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
email TEXT UNIQUE
);
""")
try db.execute(sql: """
CREATE TABLE books (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL,
author_id INTEGER NOT NULL REFERENCES authors(id) ON DELETE CASCADE,
published_date TEXT,
price REAL DEFAULT 9.99
);
""")
try db.execute(sql: """
CREATE INDEX idx_books_author ON books(author_id);
""")
try db.execute(sql: """
CREATE INDEX idx_books_title ON books(title);
""")
try db.execute(sql: """
CREATE TABLE reviews (
id INTEGER PRIMARY KEY,
book_id INTEGER NOT NULL REFERENCES books(id),
rating INTEGER NOT NULL,
comment TEXT
);
""")
}
return db
}
// MARK: - Tests
@Test("Discovers all user tables")
func discoversAllTables() async throws {
let db = try makeTestDatabase()
let schema = try await SchemaIntrospector.introspect(database: db)
#expect(schema.tableNames.count == 3)
#expect(schema.tableNames.contains("authors"))
#expect(schema.tableNames.contains("books"))
#expect(schema.tableNames.contains("reviews"))
}
@Test("Excludes sqlite_ internal tables")
func excludesInternalTables() async throws {
let db = try makeTestDatabase()
let schema = try await SchemaIntrospector.introspect(database: db)
for name in schema.tableNames {
#expect(!name.hasPrefix("sqlite_"))
}
}
@Test("Introspects column names and types")
func introspectsColumns() async throws {
let db = try makeTestDatabase()
let schema = try await SchemaIntrospector.introspect(database: db)
let books = try #require(schema.tables["books"])
#expect(books.columns.count == 5)
let titleCol = try #require(books.columns.first { $0.name == "title" })
#expect(titleCol.type == "TEXT")
#expect(titleCol.isNotNull == true)
#expect(titleCol.isPrimaryKey == false)
let priceCol = try #require(books.columns.first { $0.name == "price" })
#expect(priceCol.type == "REAL")
#expect(priceCol.defaultValue == "9.99")
}
@Test("Detects primary keys")
func detectsPrimaryKeys() async throws {
let db = try makeTestDatabase()
let schema = try await SchemaIntrospector.introspect(database: db)
let authors = try #require(schema.tables["authors"])
#expect(authors.primaryKey == ["id"])
let idCol = try #require(authors.columns.first { $0.name == "id" })
#expect(idCol.isPrimaryKey == true)
}
@Test("Detects foreign keys")
func detectsForeignKeys() async throws {
let db = try makeTestDatabase()
let schema = try await SchemaIntrospector.introspect(database: db)
let books = try #require(schema.tables["books"])
#expect(books.foreignKeys.count == 1)
let fk = books.foreignKeys[0]
#expect(fk.fromColumn == "author_id")
#expect(fk.toTable == "authors")
#expect(fk.toColumn == "id")
#expect(fk.onDelete == "CASCADE")
}
@Test("Detects indexes")
func detectsIndexes() async throws {
let db = try makeTestDatabase()
let schema = try await SchemaIntrospector.introspect(database: db)
let books = try #require(schema.tables["books"])
let indexNames = books.indexes.map(\.name)
#expect(indexNames.contains("idx_books_author"))
#expect(indexNames.contains("idx_books_title"))
}
@Test("Detects NOT NULL constraints")
func detectsNotNull() async throws {
let db = try makeTestDatabase()
let schema = try await SchemaIntrospector.introspect(database: db)
let reviews = try #require(schema.tables["reviews"])
let ratingCol = try #require(reviews.columns.first { $0.name == "rating" })
#expect(ratingCol.isNotNull == true)
let commentCol = try #require(reviews.columns.first { $0.name == "comment" })
#expect(commentCol.isNotNull == false)
}
@Test("Generates LLM-friendly schema description")
func generatesSchemaDescription() async throws {
let db = try makeTestDatabase()
let schema = try await SchemaIntrospector.introspect(database: db)
let description = schema.schemaDescription
#expect(description.contains("TABLE authors"))
#expect(description.contains("TABLE books"))
#expect(description.contains("FOREIGN KEY"))
#expect(description.contains("REFERENCES authors(id)"))
#expect(description.contains("INDEX idx_books_author"))
}
@Test("Handles empty database")
func handlesEmptyDatabase() async throws {
let db = try DatabaseQueue()
let schema = try await SchemaIntrospector.introspect(database: db)
#expect(schema.tables.isEmpty)
#expect(schema.tableNames.isEmpty)
#expect(schema.schemaDescription.isEmpty)
}
@Test("Handles composite primary keys")
func handlesCompositePrimaryKey() async throws {
let db = try DatabaseQueue()
try await db.write { db in
try db.execute(sql: """
CREATE TABLE book_tags (
book_id INTEGER NOT NULL,
tag_id INTEGER NOT NULL,
PRIMARY KEY (book_id, tag_id)
);
""")
}
let schema = try await SchemaIntrospector.introspect(database: db)
let bookTags = try #require(schema.tables["book_tags"])
#expect(bookTags.primaryKey.count == 2)
#expect(bookTags.primaryKey.contains("book_id"))
#expect(bookTags.primaryKey.contains("tag_id"))
}
@Test("Handles tables with no explicit types (SQLite dynamic typing)")
func handlesDynamicTyping() async throws {
let db = try DatabaseQueue()
try await db.write { db in
try db.execute(sql: """
CREATE TABLE flexible (
id INTEGER PRIMARY KEY,
data,
info BLOB
);
""")
}
let schema = try await SchemaIntrospector.introspect(database: db)
let flexible = try #require(schema.tables["flexible"])
let dataCol = try #require(flexible.columns.first { $0.name == "data" })
#expect(dataCol.type == "") // No declared type
let infoCol = try #require(flexible.columns.first { $0.name == "info" })
#expect(infoCol.type == "BLOB")
}
@Test("Synchronous introspection works within database access")
func synchronousIntrospection() async throws {
let db = try DatabaseQueue()
try await db.write { db in
try db.execute(sql: "CREATE TABLE test (id INTEGER PRIMARY KEY, val TEXT);")
}
let schema = try await db.read { db in
try SchemaIntrospector.introspect(db: db)
}
#expect(schema.tableNames == ["test"])
let table = try #require(schema.tables["test"])
#expect(table.columns.count == 2)
}
}

View File

@@ -0,0 +1,133 @@
// ScrollableDataTableViewTests.swift
// SwiftDBAITests
//
// Tests for the ScrollableDataTableView component.
import Foundation
import Testing
@testable import SwiftDBAI
@Suite("ScrollableDataTableView")
@MainActor
struct ScrollableDataTableViewTests {
// MARK: - Test Helpers
private func makeDataTable(
columnNames: [String] = ["id", "name", "score"],
inferredTypes: [DataTable.InferredType] = [.integer, .text, .real],
rowCount: Int = 5
) -> DataTable {
let columns = columnNames.enumerated().map { idx, name in
DataTable.Column(name: name, index: idx, inferredType: inferredTypes[idx])
}
let rows = (0..<rowCount).map { i in
DataTable.Row(
id: i,
values: [
.integer(Int64(i + 1)),
.text("Item \(i + 1)"),
.real(Double(i) * 10.5),
],
columnNames: columnNames
)
}
return DataTable(columns: columns, rows: rows, sql: "SELECT * FROM test", executionTime: 0.015)
}
private func makeEmptyDataTable() -> DataTable {
DataTable(columns: [], rows: [], sql: "", executionTime: 0)
}
// MARK: - Initialization Tests
@Test("Initializes with default parameters")
func initWithDefaults() {
let table = makeDataTable()
let view = ScrollableDataTableView(dataTable: table)
#expect(view.minimumColumnWidth == 80)
#expect(view.maximumColumnWidth == 250)
#expect(view.showAlternatingRows == true)
#expect(view.showFooter == true)
}
@Test("Initializes with custom parameters")
func initWithCustomParams() {
let table = makeDataTable()
let view = ScrollableDataTableView(
dataTable: table,
minimumColumnWidth: 100,
maximumColumnWidth: 300,
showAlternatingRows: false,
showFooter: false
)
#expect(view.minimumColumnWidth == 100)
#expect(view.maximumColumnWidth == 300)
#expect(view.showAlternatingRows == false)
#expect(view.showFooter == false)
}
@Test("Handles empty data table")
func handlesEmptyTable() {
let table = makeEmptyDataTable()
let view = ScrollableDataTableView(dataTable: table)
#expect(view.dataTable.isEmpty)
}
@Test("Handles single row table")
func handlesSingleRow() {
let table = makeDataTable(rowCount: 1)
let view = ScrollableDataTableView(dataTable: table)
#expect(view.dataTable.rowCount == 1)
#expect(view.dataTable.columnCount == 3)
}
@Test("Handles single column table")
func handlesSingleColumn() {
let columns = [DataTable.Column(name: "count", index: 0, inferredType: .integer)]
let rows = [
DataTable.Row(id: 0, values: [.integer(42)], columnNames: ["count"])
]
let table = DataTable(columns: columns, rows: rows, sql: "SELECT count(*) FROM t", executionTime: 0.001)
let view = ScrollableDataTableView(dataTable: table)
#expect(view.dataTable.columnCount == 1)
#expect(view.dataTable.rowCount == 1)
}
@Test("Handles large number of rows")
func handlesLargeRowCount() {
let table = makeDataTable(rowCount: 1000)
let view = ScrollableDataTableView(dataTable: table)
#expect(view.dataTable.rowCount == 1000)
}
@Test("Handles null values in cells")
func handlesNullValues() {
let columns = [
DataTable.Column(name: "name", index: 0, inferredType: .text),
DataTable.Column(name: "value", index: 1, inferredType: .null),
]
let rows = [
DataTable.Row(id: 0, values: [.text("test"), .null], columnNames: ["name", "value"])
]
let table = DataTable(columns: columns, rows: rows)
let view = ScrollableDataTableView(dataTable: table)
#expect(view.dataTable.rows[0][1] == .null)
}
@Test("Handles blob values in cells")
func handlesBlobValues() {
let columns = [
DataTable.Column(name: "data", index: 0, inferredType: .blob),
]
let blobData = Data([0x00, 0xFF, 0xAB])
let rows = [
DataTable.Row(id: 0, values: [.blob(blobData)], columnNames: ["data"])
]
let table = DataTable(columns: columns, rows: rows)
let view = ScrollableDataTableView(dataTable: table)
#expect(view.dataTable.rows[0][0] == QueryResult.Value.blob(blobData))
}
}

View File

@@ -0,0 +1,301 @@
// TextSummaryRendererTests.swift
// SwiftDBAI
import AnyLanguageModel
import Testing
import Foundation
@testable import SwiftDBAI
@Suite("TextSummaryRenderer")
struct TextSummaryRendererTests {
// MARK: - QueryResult.Value Tests
@Test("Value description renders correctly")
func valueDescriptions() {
#expect(QueryResult.Value.text("hello").description == "hello")
#expect(QueryResult.Value.integer(42).description == "42")
#expect(QueryResult.Value.real(3.14).description == "3.14")
#expect(QueryResult.Value.null.description == "NULL")
#expect(QueryResult.Value.blob(Data([0x01, 0x02])).description == "<2 bytes>")
}
@Test("Value doubleValue extracts numeric values")
func valueDoubleValues() {
#expect(QueryResult.Value.integer(42).doubleValue == 42.0)
#expect(QueryResult.Value.real(3.14).doubleValue == 3.14)
#expect(QueryResult.Value.text("100").doubleValue == 100.0)
#expect(QueryResult.Value.text("not a number").doubleValue == nil)
#expect(QueryResult.Value.null.doubleValue == nil)
#expect(QueryResult.Value.blob(Data()).doubleValue == nil)
}
@Test("Value isNull works correctly")
func valueIsNull() {
#expect(QueryResult.Value.null.isNull == true)
#expect(QueryResult.Value.text("").isNull == false)
#expect(QueryResult.Value.integer(0).isNull == false)
}
// MARK: - QueryResult Tests
@Test("Empty result has correct properties")
func emptyResult() {
let result = QueryResult(
columns: ["id", "name"],
rows: [],
sql: "SELECT id, name FROM users",
executionTime: 0.01
)
#expect(result.rowCount == 0)
#expect(result.isAggregate == false)
#expect(result.tabularDescription == "(empty result set)")
}
@Test("Single aggregate result is detected")
func aggregateDetection() {
let result = QueryResult(
columns: ["COUNT(*)"],
rows: [["COUNT(*)": .integer(42)]],
sql: "SELECT COUNT(*) FROM users",
executionTime: 0.01
)
#expect(result.isAggregate == true)
}
@Test("Multi-row result is not aggregate")
func nonAggregateDetection() {
let result = QueryResult(
columns: ["name"],
rows: [
["name": .text("Alice")],
["name": .text("Bob")],
],
sql: "SELECT name FROM users",
executionTime: 0.01
)
#expect(result.isAggregate == false)
}
@Test("Tabular description formats correctly")
func tabularDescription() {
let result = QueryResult(
columns: ["id", "name"],
rows: [
["id": .integer(1), "name": .text("Alice")],
["id": .integer(2), "name": .text("Bob")],
],
sql: "SELECT id, name FROM users",
executionTime: 0.01
)
let desc = result.tabularDescription
#expect(desc.contains("id | name"))
#expect(desc.contains("1 | Alice"))
#expect(desc.contains("2 | Bob"))
}
@Test("values(forColumn:) extracts column values")
func valuesForColumn() {
let result = QueryResult(
columns: ["name"],
rows: [
["name": .text("Alice")],
["name": .text("Bob")],
],
sql: "SELECT name FROM users",
executionTime: 0.01
)
let values = result.values(forColumn: "name")
#expect(values.count == 2)
#expect(values[0] == .text("Alice"))
}
// MARK: - Local Summary Tests (no LLM required)
@Test("Local summary for empty result")
func localSummaryEmpty() {
let result = makeResult(columns: ["id"], rows: [])
let renderer = makeMockRenderer()
let summary = renderer.localSummary(result: result, userQuestion: "Any users?")
#expect(summary == "No results found for your query.")
}
@Test("Local summary for single aggregate")
func localSummarySingleAggregate() {
let result = makeResult(
columns: ["COUNT(*)"],
rows: [["COUNT(*)": .integer(42)]]
)
let renderer = makeMockRenderer()
let summary = renderer.localSummary(result: result, userQuestion: "How many?")
#expect(summary.contains("42"))
}
@Test("Local summary for multiple aggregates")
func localSummaryMultipleAggregates() {
let result = makeResult(
columns: ["COUNT(*)", "AVG(price)"],
rows: [["COUNT(*)": .integer(10), "AVG(price)": .real(25.5)]]
)
let renderer = makeMockRenderer()
let summary = renderer.localSummary(result: result, userQuestion: "Stats?")
#expect(summary.contains("count"))
#expect(summary.contains("average price"))
}
@Test("Local summary for single record")
func localSummarySingleRecord() {
let result = makeResult(
columns: ["name", "email"],
rows: [["name": .text("Alice"), "email": .text("alice@example.com")]]
)
let renderer = makeMockRenderer()
let summary = renderer.localSummary(result: result, userQuestion: "Who?")
#expect(summary.contains("1 result"))
#expect(summary.contains("Alice"))
}
@Test("Local summary for multiple records with name column")
func localSummaryMultipleWithNames() {
let result = makeResult(
columns: ["name", "age"],
rows: [
["name": .text("Alice"), "age": .integer(30)],
["name": .text("Bob"), "age": .integer(25)],
["name": .text("Charlie"), "age": .integer(35)],
["name": .text("Diana"), "age": .integer(28)],
]
)
let renderer = makeMockRenderer()
let summary = renderer.localSummary(result: result, userQuestion: "List users")
#expect(summary.contains("4 results"))
#expect(summary.contains("Alice"))
#expect(summary.contains("1 more"))
}
@Test("Local summary for mutation result")
func localSummaryMutation() {
let result = QueryResult(
columns: [],
rows: [],
sql: "INSERT INTO users (name) VALUES ('Test')",
executionTime: 0.01,
rowsAffected: 1
)
let renderer = makeMockRenderer()
let summary = renderer.localSummary(result: result, userQuestion: "Add user")
#expect(summary == "Successfully inserted 1 row.")
}
@Test("Local summary for delete mutation")
func localSummaryDelete() {
let result = QueryResult(
columns: [],
rows: [],
sql: "DELETE FROM users WHERE id = 5",
executionTime: 0.01,
rowsAffected: 3
)
let renderer = makeMockRenderer()
let summary = renderer.localSummary(result: result, userQuestion: "Delete old users")
#expect(summary == "Successfully deleted 3 rows.")
}
@Test("Local summary for update mutation")
func localSummaryUpdate() {
let result = QueryResult(
columns: [],
rows: [],
sql: "UPDATE users SET active = 0 WHERE id = 1",
executionTime: 0.01,
rowsAffected: 1
)
let renderer = makeMockRenderer()
let summary = renderer.localSummary(result: result, userQuestion: "Deactivate user")
#expect(summary == "Successfully updated 1 row.")
}
// MARK: - LLM-based Summary Tests (using MockLanguageModel)
@Test("Summarize with LLM returns mock response for multi-row results")
func summarizeWithLLM() async throws {
let result = makeResult(
columns: ["name", "age"],
rows: [
["name": .text("Alice"), "age": .integer(30)],
["name": .text("Bob"), "age": .integer(25)],
]
)
let mockModel = MockLanguageModel(responseText: "There are 2 users: Alice (30) and Bob (25).")
let renderer = TextSummaryRenderer(model: mockModel)
let summary = try await renderer.summarize(result: result, userQuestion: "List all users")
#expect(summary == "There are 2 users: Alice (30) and Bob (25).")
}
@Test("Summarize returns empty result message without calling LLM")
func summarizeEmptyResult() async throws {
let result = makeResult(columns: ["id"], rows: [])
let renderer = makeMockRenderer()
let summary = try await renderer.summarize(result: result, userQuestion: "Find users")
#expect(summary == "No results found for your query.")
}
@Test("Summarize returns direct aggregate without calling LLM")
func summarizeAggregate() async throws {
let result = makeResult(
columns: ["COUNT(*)"],
rows: [["COUNT(*)": .integer(42)]]
)
let renderer = makeMockRenderer()
let summary = try await renderer.summarize(result: result, userQuestion: "How many?")
#expect(summary.contains("42"))
}
@Test("Summarize mutation returns template without calling LLM")
func summarizeMutation() async throws {
let result = QueryResult(
columns: [],
rows: [],
sql: "UPDATE users SET name = 'Test' WHERE id = 1",
executionTime: 0.01,
rowsAffected: 1
)
let renderer = makeMockRenderer()
let summary = try await renderer.summarize(result: result, userQuestion: "Update user")
#expect(summary == "Successfully updated 1 row.")
}
@Test("Summarize passes context to LLM prompt")
func summarizeWithContext() async throws {
let result = makeResult(
columns: ["total"],
rows: [
["total": .real(100.0)],
["total": .real(200.0)],
]
)
let mockModel = MockLanguageModel(responseText: "The totals are 100 and 200.")
let renderer = TextSummaryRenderer(model: mockModel)
let summary = try await renderer.summarize(
result: result,
userQuestion: "Show totals",
context: "Amounts are in USD"
)
#expect(summary == "The totals are 100 and 200.")
}
// MARK: - Helpers
private func makeResult(
columns: [String],
rows: [[String: QueryResult.Value]],
sql: String = "SELECT * FROM test"
) -> QueryResult {
QueryResult(columns: columns, rows: rows, sql: sql, executionTime: 0.01)
}
/// Creates a renderer with a mock model (for localSummary tests that don't hit the LLM).
private func makeMockRenderer() -> TextSummaryRenderer {
TextSummaryRenderer(model: MockLanguageModel())
}
}

View File

@@ -0,0 +1,246 @@
// ToolExecutionDelegateTests.swift
// SwiftDBAITests
import Foundation
import Testing
@testable import SwiftDBAI
@Suite("DestructiveClassification")
struct DestructiveClassificationTests {
// MARK: - Safe statements
@Test("SELECT is classified as safe")
func selectIsSafe() {
let result = classifySQL("SELECT * FROM users")
#expect(result == .safe)
#expect(!result.requiresConfirmation)
#expect(!result.isMutating)
}
@Test("WITH (CTE) is classified as safe")
func withIsSafe() {
let result = classifySQL("WITH cte AS (SELECT 1) SELECT * FROM cte")
#expect(result == .safe)
}
// MARK: - Mutation statements
@Test("INSERT is classified as mutation")
func insertIsMutation() {
let result = classifySQL("INSERT INTO users (name) VALUES ('Alice')")
#expect(result == .mutation(.insert))
#expect(!result.requiresConfirmation)
#expect(result.isMutating)
}
@Test("UPDATE is classified as mutation")
func updateIsMutation() {
let result = classifySQL("UPDATE users SET name = 'Bob' WHERE id = 1")
#expect(result == .mutation(.update))
#expect(!result.requiresConfirmation)
#expect(result.isMutating)
}
// MARK: - Destructive statements
@Test("DELETE is classified as destructive")
func deleteIsDestructive() {
let result = classifySQL("DELETE FROM users WHERE id = 1")
#expect(result == .destructive(.delete))
#expect(result.requiresConfirmation)
#expect(result.isMutating)
}
@Test("DROP is classified as destructive")
func dropIsDestructive() {
let result = classifySQL("DROP TABLE users")
#expect(result == .destructive(.drop))
#expect(result.requiresConfirmation)
}
@Test("ALTER is classified as destructive")
func alterIsDestructive() {
let result = classifySQL("ALTER TABLE users ADD COLUMN age INTEGER")
#expect(result == .destructive(.alter))
#expect(result.requiresConfirmation)
}
@Test("TRUNCATE is classified as destructive")
func truncateIsDestructive() {
let result = classifySQL("TRUNCATE TABLE users")
#expect(result == .destructive(.truncate))
#expect(result.requiresConfirmation)
}
// MARK: - Case insensitivity
@Test("Classification is case-insensitive")
func caseInsensitive() {
#expect(classifySQL("delete from users") == .destructive(.delete))
#expect(classifySQL("Drop Table foo") == .destructive(.drop))
#expect(classifySQL("select 1") == .safe)
#expect(classifySQL("INSERT into t values (1)") == .mutation(.insert))
}
// MARK: - Leading whitespace
@Test("Classification ignores leading whitespace")
func leadingWhitespace() {
#expect(classifySQL(" \n DELETE FROM users") == .destructive(.delete))
#expect(classifySQL("\t SELECT 1") == .safe)
}
// MARK: - SQLStatementKind
@Test("Destructive kinds are correct")
func destructiveKinds() {
#expect(SQLStatementKind.delete.isDestructive)
#expect(SQLStatementKind.drop.isDestructive)
#expect(SQLStatementKind.alter.isDestructive)
#expect(SQLStatementKind.truncate.isDestructive)
#expect(!SQLStatementKind.select.isDestructive)
#expect(!SQLStatementKind.insert.isDestructive)
#expect(!SQLStatementKind.update.isDestructive)
}
@Test("Mutation kinds are correct")
func mutationKinds() {
#expect(SQLStatementKind.insert.isMutation)
#expect(SQLStatementKind.update.isMutation)
#expect(!SQLStatementKind.select.isMutation)
#expect(!SQLStatementKind.delete.isMutation)
}
}
@Suite("ToolExecutionDelegate")
struct ToolExecutionDelegateProtocolTests {
@Test("AutoApproveDelegate approves all operations")
func autoApprove() async {
let delegate = AutoApproveDelegate()
let context = DestructiveOperationContext(
sql: "DELETE FROM users",
statementKind: .delete,
classification: .destructive(.delete),
description: "Delete all rows from users"
)
let result = await delegate.confirmDestructiveOperation(context)
#expect(result == true)
}
@Test("RejectAllDelegate rejects all operations")
func rejectAll() async {
let delegate = RejectAllDelegate()
let context = DestructiveOperationContext(
sql: "DROP TABLE users",
statementKind: .drop,
classification: .destructive(.drop),
description: "Drop the users table"
)
let result = await delegate.confirmDestructiveOperation(context)
#expect(result == false)
}
@Test("Default delegate implementation rejects destructive operations")
func defaultRejects() async {
struct EmptyDelegate: ToolExecutionDelegate {}
let delegate = EmptyDelegate()
let context = DestructiveOperationContext(
sql: "DELETE FROM users",
statementKind: .delete,
classification: .destructive(.delete),
description: "Delete rows"
)
let result = await delegate.confirmDestructiveOperation(context)
#expect(result == false)
}
}
// MARK: - Tracking Delegate for Integration Tests
/// A delegate that records all calls for verification in tests.
private final class TrackingDelegate: ToolExecutionDelegate, @unchecked Sendable {
private let lock = NSLock()
private var _confirmCalls: [DestructiveOperationContext] = []
private var _willExecuteCalls: [(sql: String, classification: DestructiveClassification)] = []
private var _didExecuteCalls: [(sql: String, success: Bool)] = []
private var _confirmResult: Bool
var confirmCalls: [DestructiveOperationContext] {
lock.withLock { _confirmCalls }
}
var willExecuteCalls: [(sql: String, classification: DestructiveClassification)] {
lock.withLock { _willExecuteCalls }
}
var didExecuteCalls: [(sql: String, success: Bool)] {
lock.withLock { _didExecuteCalls }
}
init(confirmResult: Bool) {
self._confirmResult = confirmResult
}
func confirmDestructiveOperation(_ context: DestructiveOperationContext) async -> Bool {
lock.withLock { _confirmCalls.append(context) }
return _confirmResult
}
func willExecuteSQL(_ sql: String, classification: DestructiveClassification) async {
lock.withLock { _willExecuteCalls.append((sql: sql, classification: classification)) }
}
func didExecuteSQL(_ sql: String, success: Bool) async {
lock.withLock { _didExecuteCalls.append((sql: sql, success: success)) }
}
}
@Suite("ToolExecutionDelegate - ChatEngine Integration")
struct DelegateIntegrationTests {
@Test("DestructiveOperationContext captures target table")
func contextCapturesTable() {
let context = DestructiveOperationContext(
sql: "DELETE FROM users WHERE id = 1",
statementKind: .delete,
classification: .destructive(.delete),
description: "Delete from users",
targetTable: "users"
)
#expect(context.targetTable == "users")
#expect(context.statementKind == .delete)
#expect(context.classification.requiresConfirmation)
}
@Test("classifySQL returns destructive for DELETE")
func classifySQLDestructive() {
let result = classifySQL("DELETE FROM orders WHERE id = 5")
#expect(result == .destructive(.delete))
#expect(result.requiresConfirmation)
}
@Test("classifySQL returns safe for SELECT")
func classifySQLSafe() {
let result = classifySQL("SELECT * FROM users")
#expect(result == .safe)
#expect(!result.requiresConfirmation)
}
@Test("classifySQL returns mutation for INSERT")
func classifySQLMutation() {
let result = classifySQL("INSERT INTO users (name) VALUES ('test')")
#expect(result == .mutation(.insert))
#expect(!result.requiresConfirmation)
}
@Test("DestructiveClassification.isMutating is true for mutations and destructive")
func isMutatingCovers() {
#expect(DestructiveClassification.mutation(.insert).isMutating)
#expect(DestructiveClassification.mutation(.update).isMutating)
#expect(DestructiveClassification.destructive(.delete).isMutating)
#expect(!DestructiveClassification.safe.isMutating)
}
}

View File

@@ -0,0 +1,617 @@
// UnifiedProviderTestHarness.swift
// SwiftDBAI Tests
//
// A unified test harness that validates all seven provider types
// conform to the AnyLanguageModel protocol and produce consistent
// ChatEngine-compatible output. Covers: OpenAI, Anthropic, Gemini,
// OpenAI-Compatible, Ollama, llama.cpp, and on-device (MLX/CoreML).
import AnyLanguageModel
import Foundation
import GRDB
import Testing
@testable import SwiftDBAI
// MARK: - Provider-Simulating Mock Models
/// A mock that records which LanguageModel protocol methods were called,
/// the arguments passed, and returns configurable responses.
/// Used to validate that every provider path through ChatEngine
/// exercises the same protocol surface.
final class ProviderConformanceMock: LanguageModel, @unchecked Sendable {
typealias UnavailableReason = Never
/// Track calls to verify protocol conformance exercised fully.
struct CallRecord: Sendable {
let method: String
let promptDescription: String
let timestamp: Date
}
private let lock = NSLock()
private var _calls: [CallRecord] = []
private let _responses: [String]
private var _callIndex = 0
/// Label for diagnostics.
let providerName: String
var calls: [CallRecord] {
lock.lock()
defer { lock.unlock() }
return _calls
}
init(providerName: String, responses: [String]) {
self.providerName = providerName
self._responses = responses
}
private func nextResponse() -> String {
lock.lock()
defer { lock.unlock() }
let idx = _callIndex
_callIndex += 1
return idx < _responses.count ? _responses[idx] : "fallback response"
}
private func recordCall(method: String, prompt: String) {
lock.lock()
_calls.append(CallRecord(method: method, promptDescription: prompt, timestamp: Date()))
lock.unlock()
}
func respond<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
includeSchemaInPrompt: Bool,
options: GenerationOptions
) async throws -> LanguageModelSession.Response<Content> where Content: Generable {
recordCall(method: "respond", prompt: prompt.description)
let text = nextResponse()
let rawContent = GeneratedContent(kind: .string(text))
let content = try Content(rawContent)
return LanguageModelSession.Response(
content: content,
rawContent: rawContent,
transcriptEntries: [][...]
)
}
func streamResponse<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
includeSchemaInPrompt: Bool,
options: GenerationOptions
) -> sending LanguageModelSession.ResponseStream<Content> where Content: Generable {
recordCall(method: "streamResponse", prompt: prompt.description)
let text = nextResponse()
let rawContent = GeneratedContent(kind: .string(text))
let content = try! Content(rawContent)
return LanguageModelSession.ResponseStream(content: content, rawContent: rawContent)
}
}
// MARK: - Test Database Helper
/// Creates a minimal in-memory database for provider integration tests.
private func makeProviderTestDatabase() throws -> DatabaseQueue {
let db = try DatabaseQueue(path: ":memory:")
try db.write { db in
try db.execute(sql: """
CREATE TABLE products (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
price REAL NOT NULL,
category TEXT NOT NULL
)
""")
try db.execute(sql: """
INSERT INTO products (name, price, category) VALUES
('Widget', 9.99, 'tools'),
('Gadget', 24.99, 'electronics'),
('Doohickey', 4.50, 'tools')
""")
}
return db
}
// MARK: - Unified Provider Test Harness
@Suite("Unified Provider Test Harness")
struct UnifiedProviderTestHarness {
// MARK: - Provider Configuration Enumeration
/// All seven provider types that SwiftDBAI supports.
enum TestedProvider: String, CaseIterable {
case openAI
case anthropic
case gemini
case openAICompatible
case ollama
case llamaCpp
case onDevice
}
/// Creates a ProviderConformanceMock simulating each provider type.
private func makeMock(for provider: TestedProvider, responses: [String]) -> ProviderConformanceMock {
ProviderConformanceMock(providerName: provider.rawValue, responses: responses)
}
// MARK: - 1. Protocol Conformance All Providers Are LanguageModel
@Test("All provider types produce instances conforming to LanguageModel protocol")
func allProvidersConformToLanguageModel() {
// Cloud providers via ProviderConfiguration.makeModel()
let openAI = ProviderConfiguration.openAI(apiKey: "test-key", model: "gpt-4o").makeModel()
let anthropic = ProviderConfiguration.anthropic(apiKey: "test-key", model: "claude-sonnet-4-20250514").makeModel()
let gemini = ProviderConfiguration.gemini(apiKey: "test-key", model: "gemini-2.0-flash").makeModel()
let openAICompatible = ProviderConfiguration.openAICompatible(
apiKey: "test-key",
model: "local-model",
baseURL: URL(string: "http://localhost:8080/v1/")!
).makeModel()
let ollama = ProviderConfiguration.ollama(model: "llama3.2").makeModel()
let llamaCpp = ProviderConfiguration.llamaCpp(model: "default").makeModel()
// On-device MLX (wraps as openAICompatible internally)
let onDeviceMLX = ProviderConfiguration.onDeviceMLX(
MLXProviderConfiguration(modelId: "test-model")
).makeModel()
// Verify all are LanguageModel
let models: [(String, any LanguageModel)] = [
("OpenAI", openAI),
("Anthropic", anthropic),
("Gemini", gemini),
("OpenAI-Compatible", openAICompatible),
("Ollama", ollama),
("llama.cpp", llamaCpp),
("On-Device MLX", onDeviceMLX),
]
for (name, model) in models {
// Protocol conformance is compile-time, but we verify isAvailable works
#expect(model.isAvailable, "\(name) model should report as available")
}
}
@Test("All provider configurations produce correct concrete model types")
func providerConfigurationsProduceCorrectTypes() {
let openAI = ProviderConfiguration.openAI(apiKey: "k", model: "m").makeModel()
#expect(openAI is OpenAILanguageModel, "OpenAI config should produce OpenAILanguageModel")
let anthropic = ProviderConfiguration.anthropic(apiKey: "k", model: "m").makeModel()
#expect(anthropic is AnthropicLanguageModel, "Anthropic config should produce AnthropicLanguageModel")
let gemini = ProviderConfiguration.gemini(apiKey: "k", model: "m").makeModel()
#expect(gemini is GeminiLanguageModel, "Gemini config should produce GeminiLanguageModel")
let openAICompat = ProviderConfiguration.openAICompatible(
apiKey: "k", model: "m", baseURL: URL(string: "http://localhost:1234")!
).makeModel()
#expect(openAICompat is OpenAILanguageModel, "OpenAI-Compatible config should produce OpenAILanguageModel")
let ollama = ProviderConfiguration.ollama(model: "m").makeModel()
#expect(ollama is OllamaLanguageModel, "Ollama config should produce OllamaLanguageModel")
let llamaCpp = ProviderConfiguration.llamaCpp(model: "m").makeModel()
#expect(llamaCpp is OpenAILanguageModel, "llama.cpp config should produce OpenAILanguageModel (OpenAI-compatible)")
// On-device uses OpenAILanguageModel internally as a wrapper
let onDevice = ProviderConfiguration.onDeviceMLX(
MLXProviderConfiguration(modelId: "test")
).makeModel()
#expect(onDevice is OpenAILanguageModel, "On-device MLX config should produce OpenAILanguageModel wrapper")
}
// MARK: - 2. Consistent ChatEngine-Compatible Output
@Test("Every provider mock produces valid ChatEngine responses for SELECT queries",
arguments: TestedProvider.allCases)
func providerProducesValidChatEngineResponse(provider: TestedProvider) async throws {
let db = try makeProviderTestDatabase()
let mock = makeMock(for: provider, responses: [
"SELECT COUNT(*) FROM products", // SQL generation
"There are 3 products in the database.", // Summary (fallback)
])
let engine = ChatEngine(database: db, model: mock)
let response = try await engine.send("How many products are there?")
// All providers must produce:
// 1. Non-empty summary
#expect(!response.summary.isEmpty, "\(provider.rawValue): summary must not be empty")
// 2. Valid SQL that was executed
#expect(response.sql == "SELECT COUNT(*) FROM products",
"\(provider.rawValue): SQL must match generated query")
// 3. A QueryResult with data
#expect(response.queryResult != nil, "\(provider.rawValue): queryResult must exist")
#expect(response.queryResult?.rowCount == 1, "\(provider.rawValue): should have 1 row for COUNT")
}
@Test("Every provider mock produces valid ChatEngine responses for multi-row SELECT",
arguments: TestedProvider.allCases)
func providerProducesMultiRowResponse(provider: TestedProvider) async throws {
let db = try makeProviderTestDatabase()
let mock = makeMock(for: provider, responses: [
"SELECT name, price FROM products ORDER BY price DESC",
"Here are the products sorted by price.",
])
let engine = ChatEngine(database: db, model: mock)
let response = try await engine.send("List products by price")
#expect(response.queryResult != nil, "\(provider.rawValue): queryResult must exist")
#expect(response.queryResult?.rowCount == 3, "\(provider.rawValue): should return all 3 products")
#expect(response.queryResult?.columns.contains("name") == true,
"\(provider.rawValue): columns must include 'name'")
#expect(response.queryResult?.columns.contains("price") == true,
"\(provider.rawValue): columns must include 'price'")
}
// MARK: - 3. Consistent LanguageModelSession Integration
@Test("Every provider mock works through LanguageModelSession.respond(to:)",
arguments: TestedProvider.allCases)
func providerWorksWithSession(provider: TestedProvider) async throws {
let mock = makeMock(for: provider, responses: [
"SELECT 1 AS test",
])
let session = LanguageModelSession(
model: mock,
instructions: "You are a SQL assistant."
)
let response = try await session.respond(to: "Generate a test query")
// Verify the response content is the expected string
#expect(response.content == "SELECT 1 AS test",
"\(provider.rawValue): session response should match mock output")
// Verify the mock received the call
#expect(mock.calls.count == 1, "\(provider.rawValue): should have exactly 1 call")
#expect(mock.calls.first?.method == "respond",
"\(provider.rawValue): should call respond method")
}
@Test("Every provider mock works through LanguageModelSession.streamResponse(to:)",
arguments: TestedProvider.allCases)
func providerWorksWithStreamSession(provider: TestedProvider) async throws {
let mock = makeMock(for: provider, responses: [
"SELECT 42 AS answer",
])
let session = LanguageModelSession(
model: mock,
instructions: "You are a SQL assistant."
)
let stream = session.streamResponse(to: "Give me a number")
let collected = try await stream.collect()
#expect(collected.content == "SELECT 42 AS answer",
"\(provider.rawValue): stream collected response should match mock output")
#expect(mock.calls.count == 1, "\(provider.rawValue): should have exactly 1 call")
#expect(mock.calls.first?.method == "streamResponse",
"\(provider.rawValue): should call streamResponse method")
}
// MARK: - 4. Schema Introspection Works Identically Across Providers
@Test("Schema introspection returns same schema regardless of provider",
arguments: TestedProvider.allCases)
func schemaIntrospectionIsProviderAgnostic(provider: TestedProvider) async throws {
let db = try makeProviderTestDatabase()
let mock = makeMock(for: provider, responses: ["SELECT 1"])
let engine = ChatEngine(database: db, model: mock)
let schema = try await engine.prepareSchema()
#expect(schema.tableNames.contains("products"),
"\(provider.rawValue): schema must include 'products' table")
#expect(schema.tableNames.count == 1,
"\(provider.rawValue): should have exactly 1 table")
let table = schema.tables["products"]
#expect(table != nil, "\(provider.rawValue): must find products table")
#expect(table?.columns.count == 4,
"\(provider.rawValue): products table must have 4 columns")
}
// MARK: - 5. Error Handling Consistency
@Test("All providers handle empty schema consistently",
arguments: TestedProvider.allCases)
func emptySchemaHandledConsistently(provider: TestedProvider) async throws {
let db = try DatabaseQueue(path: ":memory:")
let mock = makeMock(for: provider, responses: ["SELECT 1"])
let engine = ChatEngine(database: db, model: mock)
do {
_ = try await engine.send("Show me data")
Issue.record("\(provider.rawValue): should throw for empty schema")
} catch let error as SwiftDBAIError {
#expect(error == .emptySchema,
"\(provider.rawValue): must throw .emptySchema for database with no tables")
}
}
@Test("All providers reject disallowed SQL operations consistently",
arguments: TestedProvider.allCases)
func disallowedSQLRejectedConsistently(provider: TestedProvider) async throws {
let db = try makeProviderTestDatabase()
let mock = makeMock(for: provider, responses: [
"DELETE FROM products WHERE id = 1",
])
// Default allowlist is readOnly (SELECT only)
let engine = ChatEngine(database: db, model: mock)
do {
_ = try await engine.send("Delete the first product")
Issue.record("\(provider.rawValue): should reject DELETE when allowlist is readOnly")
} catch {
// All providers must trigger the same error path for disallowed operations
#expect(error is SwiftDBAIError,
"\(provider.rawValue): error must be SwiftDBAIError")
}
}
// MARK: - 6. Conversation History Consistency
@Test("Conversation history works identically for all providers",
arguments: TestedProvider.allCases)
func conversationHistoryConsistent(provider: TestedProvider) async throws {
let db = try makeProviderTestDatabase()
// ChatEngine calls LLM for SQL generation, then TextSummaryRenderer
// may call LLM for summarization. For aggregate queries (COUNT, AVG),
// TextSummaryRenderer uses a template and skips the LLM call.
// So the mock sequence is: SQL1, SQL2 (each followed by template summary).
let mock = makeMock(for: provider, responses: [
"SELECT COUNT(*) FROM products",
"SELECT AVG(price) FROM products",
])
let engine = ChatEngine(database: db, model: mock)
_ = try await engine.send("How many products?")
_ = try await engine.send("What is the average price?")
let messages = engine.messages
#expect(messages.count == 4,
"\(provider.rawValue): should have 4 messages (2 user + 2 assistant)")
#expect(messages[0].role == .user, "\(provider.rawValue): first message should be user")
#expect(messages[1].role == .assistant, "\(provider.rawValue): second message should be assistant")
#expect(messages[2].role == .user, "\(provider.rawValue): third message should be user")
#expect(messages[3].role == .assistant, "\(provider.rawValue): fourth message should be assistant")
// Both assistant messages must have SQL
#expect(messages[1].sql != nil, "\(provider.rawValue): first response must have SQL")
#expect(messages[3].sql != nil, "\(provider.rawValue): second response must have SQL")
}
// MARK: - 7. ProviderConfiguration Roundtrip
@Test("All cloud provider configurations roundtrip through makeModel()")
func allCloudProvidersRoundtrip() {
let configs: [(String, ProviderConfiguration)] = [
("OpenAI", .openAI(apiKey: "sk-test", model: "gpt-4o")),
("OpenAI Responses", .openAI(apiKey: "sk-test", model: "gpt-4o", variant: .responses)),
("Anthropic", .anthropic(apiKey: "sk-ant-test", model: "claude-sonnet-4-20250514")),
("Anthropic+version", .anthropic(apiKey: "sk-ant-test", model: "claude-sonnet-4-20250514", apiVersion: "2024-01-01")),
("Anthropic+betas", .anthropic(apiKey: "sk-ant-test", model: "claude-sonnet-4-20250514", betas: ["computer-use"])),
("Gemini", .gemini(apiKey: "AIza-test", model: "gemini-2.0-flash")),
("Gemini+version", .gemini(apiKey: "AIza-test", model: "gemini-2.0-flash", apiVersion: "v1")),
("OpenAI-Compatible", .openAICompatible(
apiKey: "key", model: "model", baseURL: URL(string: "http://localhost:1234")!
)),
("Ollama", .ollama(model: "llama3.2")),
("Ollama+custom URL", .ollama(model: "qwen2.5", baseURL: URL(string: "http://192.168.1.100:11434")!)),
("llama.cpp", .llamaCpp(model: "default")),
("llama.cpp+custom", .llamaCpp(model: "my-model", baseURL: URL(string: "http://localhost:9090")!)),
]
for (name, config) in configs {
let model = config.makeModel()
#expect(model.isAvailable, "\(name): model must be available after makeModel()")
}
}
@Test("On-device provider configurations produce valid models")
func onDeviceProvidersRoundtrip() {
let mlxConfigs: [MLXProviderConfiguration] = [
.llama3_2_3B(),
.qwen2_5_coder_3B(),
.phi3_5_mini(),
MLXProviderConfiguration(modelId: "custom-model", temperature: 0.2),
]
for mlxConfig in mlxConfigs {
let providerConfig = ProviderConfiguration.onDeviceMLX(mlxConfig)
let model = providerConfig.makeModel()
#expect(model.isAvailable, "MLX model '\(mlxConfig.modelId)' must be available")
}
}
// MARK: - 8. Write Operation Allowlist Consistency
@Test("Write operations require explicit opt-in for all providers",
arguments: TestedProvider.allCases)
func writeOperationsRequireOptIn(provider: TestedProvider) async throws {
let db = try makeProviderTestDatabase()
// Mock returns an INSERT statement
let mock = makeMock(for: provider, responses: [
"INSERT INTO products (name, price, category) VALUES ('New', 1.00, 'misc')",
])
// readOnly allowlist (default)
let readOnlyEngine = ChatEngine(database: db, model: mock)
do {
_ = try await readOnlyEngine.send("Add a new product")
Issue.record("\(provider.rawValue): INSERT should be rejected with readOnly allowlist")
} catch {
#expect(error is SwiftDBAIError,
"\(provider.rawValue): must throw SwiftDBAIError for disallowed INSERT")
}
}
@Test("Allowed write operations work for all providers",
arguments: TestedProvider.allCases)
func allowedWriteOperationsWork(provider: TestedProvider) async throws {
let db = try makeProviderTestDatabase()
let mock = makeMock(for: provider, responses: [
"INSERT INTO products (name, price, category) VALUES ('NewItem', 1.00, 'misc')",
"Successfully added 1 product.",
])
let engine = ChatEngine(
database: db,
model: mock,
allowlist: .standard
)
let response = try await engine.send("Add a product called NewItem")
#expect(response.sql?.uppercased().hasPrefix("INSERT") == true,
"\(provider.rawValue): SQL should be an INSERT")
}
// MARK: - 9. Response Format Consistency
@Test("ChatResponse structure is identical regardless of provider",
arguments: TestedProvider.allCases)
func responseStructureConsistent(provider: TestedProvider) async throws {
let db = try makeProviderTestDatabase()
let mock = makeMock(for: provider, responses: [
"SELECT name, price, category FROM products",
"Found 3 products across 2 categories.",
])
let engine = ChatEngine(database: db, model: mock)
let response = try await engine.send("Show all products")
// ChatResponse must always have these properties populated
#expect(response.summary.count > 0,
"\(provider.rawValue): summary must be non-empty")
#expect(response.sql != nil,
"\(provider.rawValue): sql must be present")
#expect(response.queryResult != nil,
"\(provider.rawValue): queryResult must be present")
// QueryResult structure must match the query
let qr = response.queryResult!
#expect(qr.columns == ["name", "price", "category"],
"\(provider.rawValue): columns must match SELECT clause")
#expect(qr.rowCount == 3,
"\(provider.rawValue): must return all rows")
#expect(qr.sql == "SELECT name, price, category FROM products",
"\(provider.rawValue): QueryResult.sql must match executed SQL")
#expect(qr.executionTime >= 0,
"\(provider.rawValue): execution time must be non-negative")
}
// MARK: - 10. Provider Enum Completeness
@Test("TestedProvider covers all ProviderConfiguration.Provider cases plus on-device")
func testedProviderCoversAllCases() {
// ProviderConfiguration.Provider has 6 cases
let configProviderCount = ProviderConfiguration.Provider.allCases.count
#expect(configProviderCount == 6, "ProviderConfiguration.Provider should have 6 cases")
// TestedProvider adds on-device for 7 total
#expect(TestedProvider.allCases.count == 7, "TestedProvider should cover all 7 provider types")
// Verify 1:1 mapping for the config providers
let configNames = Set(ProviderConfiguration.Provider.allCases.map(\.rawValue))
for tested in TestedProvider.allCases where tested != .onDevice {
#expect(configNames.contains(tested.rawValue),
"\(tested.rawValue) must map to a ProviderConfiguration.Provider case")
}
}
// MARK: - 11. ChatEngine Convenience Init Consistency
@Test("ChatEngine convenience init with ProviderConfiguration works for all cloud providers")
func chatEngineConvenienceInitWorks() throws {
let db = try makeProviderTestDatabase()
let configs: [ProviderConfiguration] = [
.openAI(apiKey: "test", model: "gpt-4o"),
.anthropic(apiKey: "test", model: "claude-sonnet-4-20250514"),
.gemini(apiKey: "test", model: "gemini-2.0-flash"),
.openAICompatible(apiKey: "test", model: "m", baseURL: URL(string: "http://localhost:1234")!),
.ollama(model: "llama3.2"),
.llamaCpp(model: "default"),
]
for config in configs {
// This should not throw it only creates the engine, doesn't call the LLM
let engine = ChatEngine(database: db, provider: config)
#expect(engine.tableCount == nil, "tableCount should be nil before first query")
}
}
// MARK: - 12. Availability Reporting
@Test("All real provider models report available by default")
func allModelsReportAvailable() {
let models: [(String, any LanguageModel)] = [
("OpenAI", OpenAILanguageModel(apiKey: "k", model: "m")),
("Anthropic", AnthropicLanguageModel(apiKey: "k", model: "m")),
("Gemini", GeminiLanguageModel(apiKey: "k", model: "m")),
("Ollama", OllamaLanguageModel(model: "m")),
]
for (name, model) in models {
#expect(model.isAvailable, "\(name) should be available by default")
}
}
// MARK: - 13. On-Device Pipeline Status
@Test("On-device inference pipeline starts in notLoaded state")
func onDevicePipelineInitialState() {
let mlxPipeline = OnDeviceInferencePipeline(
mlxConfiguration: .llama3_2_3B()
)
#expect(mlxPipeline.status == .notLoaded)
#expect(mlxPipeline.providerType == .mlx)
let coreMLPipeline = OnDeviceInferencePipeline(
coreMLConfiguration: CoreMLProviderConfiguration(
modelURL: URL(fileURLWithPath: "/tmp/test.mlmodelc")
)
)
#expect(coreMLPipeline.status == .notLoaded)
#expect(coreMLPipeline.providerType == .coreML)
}
@Test("On-device SQL generation hints are populated for both provider types")
func onDeviceSQLHints() {
let mlxPipeline = OnDeviceInferencePipeline(mlxConfiguration: .llama3_2_3B())
let mlxHints = mlxPipeline.recommendedSQLGenerationHints
#expect(mlxHints.maxTokens > 0)
#expect(mlxHints.temperature >= 0)
#expect(!mlxHints.systemPromptSuffix.isEmpty)
let coreMLPipeline = OnDeviceInferencePipeline(
coreMLConfiguration: CoreMLProviderConfiguration(
modelURL: URL(fileURLWithPath: "/tmp/test.mlmodelc")
)
)
let coreMLHints = coreMLPipeline.recommendedSQLGenerationHints
#expect(coreMLHints.maxTokens > 0)
#expect(coreMLHints.temperature >= 0)
#expect(!coreMLHints.systemPromptSuffix.isEmpty)
}
}