SwiftDBAI: natural language queries for any SQLite database
Drop-in SwiftUI chat view, headless ChatEngine, LLM-agnostic via AnyLanguageModel. Read-only by default with configurable allowlists. Robust SQL parser with 63 tests. Includes demo app with GitHub stars dataset.
This commit is contained in:
254
Tests/SwiftDBAITests/BinarySizeTests.swift
Normal file
254
Tests/SwiftDBAITests/BinarySizeTests.swift
Normal file
@@ -0,0 +1,254 @@
|
||||
// BinarySizeTests.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Validates that the SwiftDBAI package stays within its 2 MB binary size budget.
|
||||
// This test suite uses source-level heuristics since we can't measure the actual
|
||||
// compiled binary size in a unit test. The constraints ensure the package remains
|
||||
// lightweight by checking:
|
||||
// 1. Total source code size (proxy for compiled size)
|
||||
// 2. No embedded binary assets or large resources
|
||||
// 3. No unnecessary heavy dependencies
|
||||
// 4. File count stays reasonable (no code bloat)
|
||||
|
||||
import Foundation
|
||||
import Testing
|
||||
|
||||
@Suite("Binary Size Budget")
|
||||
struct BinarySizeTests {
|
||||
|
||||
/// The maximum allowed total source code size in bytes.
|
||||
/// At typical Swift optimized compilation ratios (2-4x), 500 KB of source
|
||||
/// compiles to roughly 1-2 MB of binary. We set the source budget at 500 KB
|
||||
/// to keep the compiled output well under 2 MB.
|
||||
private static let maxSourceSizeBytes: Int = 500_000 // 500 KB
|
||||
|
||||
/// Maximum number of Swift source files allowed.
|
||||
/// More files generally means more code and larger binaries.
|
||||
private static let maxSourceFileCount: Int = 60
|
||||
|
||||
/// Maximum size for any single source file in bytes.
|
||||
/// Large individual files often indicate code that should be split or
|
||||
/// contains embedded data that bloats the binary.
|
||||
private static let maxSingleFileSizeBytes: Int = 50_000 // 50 KB
|
||||
|
||||
/// Disallowed file extensions in the Sources directory that would bloat the binary.
|
||||
private static let disallowedExtensions: Set<String> = [
|
||||
"png", "jpg", "jpeg", "gif", "bmp", "tiff",
|
||||
"mp3", "mp4", "wav", "mov",
|
||||
"mlmodel", "mlmodelc", "mlpackage",
|
||||
"sqlite", "db",
|
||||
"zip", "tar", "gz",
|
||||
"bin", "dat",
|
||||
"framework", "dylib", "a"
|
||||
]
|
||||
|
||||
// MARK: - Helper
|
||||
|
||||
/// Recursively finds all files in the Sources/SwiftDBAI directory.
|
||||
private func findSourceFiles() throws -> [URL] {
|
||||
let sourcesDir = findSourcesDirectory()
|
||||
guard let sourcesDir else {
|
||||
Issue.record("Could not locate Sources/SwiftDBAI directory")
|
||||
return []
|
||||
}
|
||||
|
||||
let fileManager = FileManager.default
|
||||
guard let enumerator = fileManager.enumerator(
|
||||
at: sourcesDir,
|
||||
includingPropertiesForKeys: [.fileSizeKey, .isRegularFileKey],
|
||||
options: [.skipsHiddenFiles]
|
||||
) else {
|
||||
Issue.record("Could not enumerate Sources/SwiftDBAI directory")
|
||||
return []
|
||||
}
|
||||
|
||||
var files: [URL] = []
|
||||
for case let fileURL as URL in enumerator {
|
||||
let resourceValues = try fileURL.resourceValues(forKeys: [.isRegularFileKey])
|
||||
if resourceValues.isRegularFile == true {
|
||||
files.append(fileURL)
|
||||
}
|
||||
}
|
||||
return files
|
||||
}
|
||||
|
||||
/// Locates the Sources/SwiftDBAI directory by walking up from the test bundle.
|
||||
private func findSourcesDirectory() -> URL? {
|
||||
// Try common locations relative to the build directory
|
||||
let fileManager = FileManager.default
|
||||
|
||||
// In SPM test runs, we can find the package root by checking known paths
|
||||
var candidateURL = URL(fileURLWithPath: #filePath)
|
||||
// Walk up from Tests/SwiftDBAITests/BinarySizeTests.swift to package root
|
||||
for _ in 0..<3 {
|
||||
candidateURL = candidateURL.deletingLastPathComponent()
|
||||
}
|
||||
let sourcesDir = candidateURL.appendingPathComponent("Sources/SwiftDBAI")
|
||||
if fileManager.fileExists(atPath: sourcesDir.path) {
|
||||
return sourcesDir
|
||||
}
|
||||
|
||||
// Fallback: check current working directory
|
||||
let cwdSources = URL(fileURLWithPath: fileManager.currentDirectoryPath)
|
||||
.appendingPathComponent("Sources/SwiftDBAI")
|
||||
if fileManager.fileExists(atPath: cwdSources.path) {
|
||||
return cwdSources
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MARK: - Tests
|
||||
|
||||
@Test("Total source code size stays under 500 KB budget")
|
||||
func totalSourceCodeSizeUnderBudget() throws {
|
||||
let files = try findSourceFiles()
|
||||
let swiftFiles = files.filter { $0.pathExtension == "swift" }
|
||||
|
||||
var totalSize: Int = 0
|
||||
for file in swiftFiles {
|
||||
let attributes = try FileManager.default.attributesOfItem(atPath: file.path)
|
||||
let fileSize = attributes[.size] as? Int ?? 0
|
||||
totalSize += fileSize
|
||||
}
|
||||
|
||||
#expect(totalSize < Self.maxSourceSizeBytes,
|
||||
"""
|
||||
Total Swift source size (\(totalSize) bytes) exceeds \(Self.maxSourceSizeBytes) byte budget.
|
||||
At typical 2-4x compilation ratio, this would produce a binary larger than 2 MB.
|
||||
Consider removing unused code or splitting into optional sub-targets.
|
||||
""")
|
||||
|
||||
// Log the actual size for visibility
|
||||
let sizeKB = Double(totalSize) / 1024.0
|
||||
let budgetKB = Double(Self.maxSourceSizeBytes) / 1024.0
|
||||
print("📦 SwiftDBAI source size: \(String(format: "%.1f", sizeKB)) KB / \(String(format: "%.0f", budgetKB)) KB budget (\(String(format: "%.0f", (sizeKB / budgetKB) * 100))% used)")
|
||||
}
|
||||
|
||||
@Test("Source file count stays reasonable")
|
||||
func sourceFileCountUnderLimit() throws {
|
||||
let files = try findSourceFiles()
|
||||
let swiftFiles = files.filter { $0.pathExtension == "swift" }
|
||||
|
||||
#expect(swiftFiles.count <= Self.maxSourceFileCount,
|
||||
"""
|
||||
Swift source file count (\(swiftFiles.count)) exceeds limit of \(Self.maxSourceFileCount).
|
||||
More files generally means more code and larger binaries.
|
||||
""")
|
||||
|
||||
print("📦 SwiftDBAI file count: \(swiftFiles.count) / \(Self.maxSourceFileCount) max")
|
||||
}
|
||||
|
||||
@Test("No individual source file exceeds 50 KB")
|
||||
func noOversizedSourceFiles() throws {
|
||||
let files = try findSourceFiles()
|
||||
let swiftFiles = files.filter { $0.pathExtension == "swift" }
|
||||
|
||||
for file in swiftFiles {
|
||||
let attributes = try FileManager.default.attributesOfItem(atPath: file.path)
|
||||
let fileSize = attributes[.size] as? Int ?? 0
|
||||
|
||||
#expect(fileSize < Self.maxSingleFileSizeBytes,
|
||||
"""
|
||||
File \(file.lastPathComponent) is \(fileSize) bytes, exceeding the \(Self.maxSingleFileSizeBytes) byte limit.
|
||||
Large files may contain embedded data or code that should be split.
|
||||
""")
|
||||
}
|
||||
}
|
||||
|
||||
@Test("No binary assets or heavy resources in Sources directory")
|
||||
func noBinaryAssetsInSources() throws {
|
||||
let files = try findSourceFiles()
|
||||
|
||||
let disallowedFiles = files.filter { file in
|
||||
Self.disallowedExtensions.contains(file.pathExtension.lowercased())
|
||||
}
|
||||
|
||||
#expect(disallowedFiles.isEmpty,
|
||||
"""
|
||||
Found \(disallowedFiles.count) disallowed file(s) in Sources directory:
|
||||
\(disallowedFiles.map(\.lastPathComponent).joined(separator: "\n"))
|
||||
These file types bloat the binary. Remove them or move to a separate resource bundle.
|
||||
""")
|
||||
}
|
||||
|
||||
@Test("Package has no resource bundles that could bloat binary")
|
||||
func noResourceBundles() throws {
|
||||
let files = try findSourceFiles()
|
||||
|
||||
let resourceFiles = files.filter { file in
|
||||
let ext = file.pathExtension.lowercased()
|
||||
return ["xcassets", "storyboard", "xib", "nib", "xcdatamodeld"].contains(ext)
|
||||
}
|
||||
|
||||
#expect(resourceFiles.isEmpty,
|
||||
"""
|
||||
Found resource bundle files that could bloat the binary:
|
||||
\(resourceFiles.map(\.lastPathComponent).joined(separator: "\n"))
|
||||
SwiftDBAI should be pure code — no bundled resources.
|
||||
""")
|
||||
}
|
||||
|
||||
@Test("Only expected dependencies declared (GRDB + AnyLanguageModel)")
|
||||
func minimalDependencies() throws {
|
||||
// Read Package.swift to verify we only have the expected dependencies
|
||||
var packageURL = URL(fileURLWithPath: #filePath)
|
||||
for _ in 0..<3 {
|
||||
packageURL = packageURL.deletingLastPathComponent()
|
||||
}
|
||||
let packageSwiftURL = packageURL.appendingPathComponent("Package.swift")
|
||||
|
||||
guard FileManager.default.fileExists(atPath: packageSwiftURL.path) else {
|
||||
// Skip if we can't find Package.swift (CI environments etc.)
|
||||
return
|
||||
}
|
||||
|
||||
let packageContents = try String(contentsOf: packageSwiftURL, encoding: .utf8)
|
||||
|
||||
// Count .package() declarations (dependencies)
|
||||
let packageDeclarations = packageContents.components(separatedBy: ".package(")
|
||||
.count - 1 // subtract 1 because the first segment is before any .package(
|
||||
|
||||
#expect(packageDeclarations <= 3,
|
||||
"""
|
||||
Found \(packageDeclarations) package dependencies, expected at most 4 (GRDB + AnyLanguageModel + ViewInspector for tests).
|
||||
Additional dependencies increase binary size. Evaluate if they're truly needed.
|
||||
""")
|
||||
|
||||
// Verify the expected dependencies are present
|
||||
#expect(packageContents.contains("GRDB"), "Expected GRDB dependency")
|
||||
#expect(packageContents.contains("AnyLanguageModel"), "Expected AnyLanguageModel dependency")
|
||||
|
||||
print("📦 SwiftDBAI dependencies: \(packageDeclarations) (GRDB + AnyLanguageModel)")
|
||||
}
|
||||
|
||||
@Test("Estimated binary size under 2 MB")
|
||||
func estimatedBinarySizeUnderLimit() throws {
|
||||
let files = try findSourceFiles()
|
||||
let swiftFiles = files.filter { $0.pathExtension == "swift" }
|
||||
|
||||
var totalSize: Int = 0
|
||||
for file in swiftFiles {
|
||||
let attributes = try FileManager.default.attributesOfItem(atPath: file.path)
|
||||
let fileSize = attributes[.size] as? Int ?? 0
|
||||
totalSize += fileSize
|
||||
}
|
||||
|
||||
// Conservative estimate: optimized Swift binary is typically 2-4x source size.
|
||||
// Use 4x as worst case multiplier for safety margin.
|
||||
let worstCaseMultiplier = 4.0
|
||||
let estimatedBinarySize = Double(totalSize) * worstCaseMultiplier
|
||||
let maxBinarySize: Double = 2.0 * 1024.0 * 1024.0 // 2 MB
|
||||
|
||||
#expect(estimatedBinarySize < maxBinarySize,
|
||||
"""
|
||||
Estimated binary size (\(String(format: "%.1f", estimatedBinarySize / 1024.0)) KB) exceeds 2 MB limit.
|
||||
Source: \(totalSize) bytes × \(worstCaseMultiplier)x multiplier = \(String(format: "%.1f", estimatedBinarySize / 1024.0)) KB
|
||||
Note: This is the SwiftDBAI module only — excludes GRDB and AnyLanguageModel
|
||||
which are existing dependencies the developer already includes.
|
||||
""")
|
||||
|
||||
let estimatedMB = estimatedBinarySize / (1024.0 * 1024.0)
|
||||
print("📦 Estimated SwiftDBAI binary size: \(String(format: "%.2f", estimatedMB)) MB / 2.00 MB limit (worst case \(worstCaseMultiplier)x)")
|
||||
}
|
||||
}
|
||||
293
Tests/SwiftDBAITests/ChartDataDetectorTests.swift
Normal file
293
Tests/SwiftDBAITests/ChartDataDetectorTests.swift
Normal file
@@ -0,0 +1,293 @@
|
||||
// ChartDataDetectorTests.swift
|
||||
// SwiftDBAITests
|
||||
|
||||
import Testing
|
||||
@testable import SwiftDBAI
|
||||
|
||||
@Suite("ChartDataDetector")
|
||||
struct ChartDataDetectorTests {
|
||||
|
||||
let detector = ChartDataDetector()
|
||||
|
||||
// MARK: - Helpers
|
||||
|
||||
private func makeQueryResult(
|
||||
columns: [String],
|
||||
rows: [[QueryResult.Value]],
|
||||
sql: String = "SELECT *"
|
||||
) -> QueryResult {
|
||||
let rowDicts = rows.map { values in
|
||||
Dictionary(uniqueKeysWithValues: zip(columns, values))
|
||||
}
|
||||
return QueryResult(
|
||||
columns: columns,
|
||||
rows: rowDicts,
|
||||
sql: sql,
|
||||
executionTime: 0.01
|
||||
)
|
||||
}
|
||||
|
||||
private func makeTable(
|
||||
columns: [String],
|
||||
rows: [[QueryResult.Value]],
|
||||
sql: String = "SELECT *"
|
||||
) -> DataTable {
|
||||
DataTable(makeQueryResult(columns: columns, rows: rows, sql: sql))
|
||||
}
|
||||
|
||||
// MARK: - Basic Eligibility
|
||||
|
||||
@Test("Returns nil for single-column results")
|
||||
func singleColumn() {
|
||||
let table = makeTable(
|
||||
columns: ["count"],
|
||||
rows: [[.integer(42)]]
|
||||
)
|
||||
#expect(detector.detect(table) == nil)
|
||||
}
|
||||
|
||||
@Test("Returns nil for empty results")
|
||||
func emptyResults() {
|
||||
let table = makeTable(columns: ["name", "value"], rows: [])
|
||||
#expect(detector.detect(table) == nil)
|
||||
}
|
||||
|
||||
@Test("Returns nil for single row")
|
||||
func singleRow() {
|
||||
let table = makeTable(
|
||||
columns: ["name", "count"],
|
||||
rows: [[.text("A"), .integer(10)]]
|
||||
)
|
||||
#expect(detector.detect(table) == nil)
|
||||
}
|
||||
|
||||
@Test("Returns nil for too many rows")
|
||||
func tooManyRows() {
|
||||
let rows = (0..<101).map { i in
|
||||
[QueryResult.Value.text("cat\(i)"), .integer(Int64(i))]
|
||||
}
|
||||
let table = makeTable(columns: ["name", "count"], rows: rows)
|
||||
#expect(detector.detect(table) == nil)
|
||||
}
|
||||
|
||||
// MARK: - Bar Chart Detection
|
||||
|
||||
@Test("Recommends bar chart for categorical text + numeric")
|
||||
func barChartCategorical() {
|
||||
let table = makeTable(
|
||||
columns: ["department", "headcount"],
|
||||
rows: [
|
||||
[.text("Engineering"), .integer(45)],
|
||||
[.text("Marketing"), .integer(20)],
|
||||
[.text("Sales"), .integer(30)],
|
||||
[.text("HR"), .integer(10)],
|
||||
]
|
||||
)
|
||||
let rec = detector.detect(table)
|
||||
#expect(rec != nil)
|
||||
#expect(rec?.chartType == .bar)
|
||||
#expect(rec?.categoryColumn == "department")
|
||||
#expect(rec?.valueColumn == "headcount")
|
||||
#expect(rec?.confidence ?? 0 > 0.5)
|
||||
}
|
||||
|
||||
// MARK: - Pie Chart Detection
|
||||
|
||||
@Test("Recommends pie chart for small positive proportions")
|
||||
func pieChartSmallCategories() {
|
||||
let table = makeTable(
|
||||
columns: ["status", "count"],
|
||||
rows: [
|
||||
[.text("Active"), .integer(50)],
|
||||
[.text("Inactive"), .integer(30)],
|
||||
[.text("Pending"), .integer(20)],
|
||||
]
|
||||
)
|
||||
let rec = detector.detect(table)
|
||||
#expect(rec != nil)
|
||||
#expect(rec?.chartType == .pie)
|
||||
#expect(rec?.categoryColumn == "status")
|
||||
#expect(rec?.valueColumn == "count")
|
||||
}
|
||||
|
||||
@Test("Does not recommend pie with negative values")
|
||||
func pieRejectsNegative() {
|
||||
let table = makeTable(
|
||||
columns: ["category", "change"],
|
||||
rows: [
|
||||
[.text("A"), .integer(50)],
|
||||
[.text("B"), .integer(-10)],
|
||||
[.text("C"), .integer(20)],
|
||||
]
|
||||
)
|
||||
let rec = detector.detect(table)
|
||||
#expect(rec != nil)
|
||||
// Should NOT be pie since there's a negative value
|
||||
#expect(rec?.chartType != .pie)
|
||||
}
|
||||
|
||||
@Test("Does not recommend pie with too many slices")
|
||||
func pieRejectsTooManySlices() {
|
||||
let rows = (0..<10).map { i in
|
||||
[QueryResult.Value.text("cat\(i)"), .integer(Int64(i + 1))]
|
||||
}
|
||||
let table = makeTable(columns: ["category", "value"], rows: rows)
|
||||
let rec = detector.detect(table)
|
||||
#expect(rec != nil)
|
||||
#expect(rec?.chartType != .pie)
|
||||
}
|
||||
|
||||
// MARK: - Line Chart Detection
|
||||
|
||||
@Test("Recommends line chart for time-series column names")
|
||||
func lineChartTimeSeries() {
|
||||
let table = makeTable(
|
||||
columns: ["year", "revenue"],
|
||||
rows: [
|
||||
[.text("2020"), .real(1_000_000)],
|
||||
[.text("2021"), .real(1_200_000)],
|
||||
[.text("2022"), .real(1_500_000)],
|
||||
[.text("2023"), .real(1_800_000)],
|
||||
[.text("2024"), .real(2_100_000)],
|
||||
]
|
||||
)
|
||||
let rec = detector.detect(table)
|
||||
#expect(rec != nil)
|
||||
#expect(rec?.chartType == .line)
|
||||
#expect(rec?.categoryColumn == "year")
|
||||
#expect(rec?.valueColumn == "revenue")
|
||||
}
|
||||
|
||||
@Test("Recommends line chart for date-formatted text values")
|
||||
func lineChartDateValues() {
|
||||
let table = makeTable(
|
||||
columns: ["period", "sales"],
|
||||
rows: [
|
||||
[.text("2024-01"), .integer(100)],
|
||||
[.text("2024-02"), .integer(120)],
|
||||
[.text("2024-03"), .integer(90)],
|
||||
[.text("2024-04"), .integer(150)],
|
||||
]
|
||||
)
|
||||
let rec = detector.detect(table)
|
||||
#expect(rec != nil)
|
||||
#expect(rec?.chartType == .line)
|
||||
}
|
||||
|
||||
@Test("Recommends line chart for sequential numeric x-axis")
|
||||
func lineChartSequential() {
|
||||
let table = makeTable(
|
||||
columns: ["step", "value"],
|
||||
rows: [
|
||||
[.integer(1), .real(2.5)],
|
||||
[.integer(2), .real(3.1)],
|
||||
[.integer(3), .real(4.0)],
|
||||
[.integer(4), .real(3.8)],
|
||||
]
|
||||
)
|
||||
let rec = detector.detect(table)
|
||||
#expect(rec != nil)
|
||||
#expect(rec?.chartType == .line)
|
||||
}
|
||||
|
||||
// MARK: - All Recommendations
|
||||
|
||||
@Test("Returns multiple recommendations sorted by confidence")
|
||||
func allRecommendations() {
|
||||
let table = makeTable(
|
||||
columns: ["category", "amount"],
|
||||
rows: [
|
||||
[.text("A"), .integer(30)],
|
||||
[.text("B"), .integer(50)],
|
||||
[.text("C"), .integer(20)],
|
||||
]
|
||||
)
|
||||
let recs = detector.allRecommendations(for: table)
|
||||
#expect(!recs.isEmpty)
|
||||
// Should be sorted by confidence descending
|
||||
for i in 1..<recs.count {
|
||||
#expect(recs[i - 1].confidence >= recs[i].confidence)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Two Numeric Columns Fallback
|
||||
|
||||
@Test("Uses first numeric as category when no text column exists")
|
||||
func numericOnlyColumns() {
|
||||
let table = makeTable(
|
||||
columns: ["x", "y"],
|
||||
rows: [
|
||||
[.integer(1), .integer(10)],
|
||||
[.integer(2), .integer(20)],
|
||||
[.integer(3), .integer(30)],
|
||||
]
|
||||
)
|
||||
let rec = detector.detect(table)
|
||||
#expect(rec != nil)
|
||||
#expect(rec?.categoryColumn == "x")
|
||||
#expect(rec?.valueColumn == "y")
|
||||
}
|
||||
|
||||
// MARK: - Confidence & Reason
|
||||
|
||||
@Test("Confidence is between 0 and 1")
|
||||
func confidenceBounds() {
|
||||
let table = makeTable(
|
||||
columns: ["name", "score"],
|
||||
rows: [
|
||||
[.text("A"), .integer(10)],
|
||||
[.text("B"), .integer(20)],
|
||||
]
|
||||
)
|
||||
let rec = detector.detect(table)
|
||||
#expect(rec != nil)
|
||||
#expect(rec!.confidence >= 0.0)
|
||||
#expect(rec!.confidence <= 1.0)
|
||||
}
|
||||
|
||||
@Test("Reason is non-empty")
|
||||
func reasonPresent() {
|
||||
let table = makeTable(
|
||||
columns: ["name", "score"],
|
||||
rows: [
|
||||
[.text("A"), .integer(10)],
|
||||
[.text("B"), .integer(20)],
|
||||
]
|
||||
)
|
||||
let rec = detector.detect(table)
|
||||
#expect(rec != nil)
|
||||
#expect(!rec!.reason.isEmpty)
|
||||
}
|
||||
|
||||
// MARK: - Custom Configuration
|
||||
|
||||
@Test("Respects custom minimumRows")
|
||||
func customMinRows() {
|
||||
let strict = ChartDataDetector(minimumRows: 5)
|
||||
let table = makeTable(
|
||||
columns: ["name", "value"],
|
||||
rows: [
|
||||
[.text("A"), .integer(1)],
|
||||
[.text("B"), .integer(2)],
|
||||
[.text("C"), .integer(3)],
|
||||
]
|
||||
)
|
||||
#expect(strict.detect(table) == nil)
|
||||
}
|
||||
|
||||
@Test("Respects custom maxPieSlices")
|
||||
func customMaxPieSlices() {
|
||||
let narrow = ChartDataDetector(maxPieSlices: 2)
|
||||
let table = makeTable(
|
||||
columns: ["status", "count"],
|
||||
rows: [
|
||||
[.text("A"), .integer(50)],
|
||||
[.text("B"), .integer(30)],
|
||||
[.text("C"), .integer(20)],
|
||||
]
|
||||
)
|
||||
let rec = narrow.detect(table)
|
||||
// With maxPieSlices=2, 3 rows should not get pie
|
||||
#expect(rec?.chartType != .pie)
|
||||
}
|
||||
}
|
||||
1091
Tests/SwiftDBAITests/ChatEngineTests.swift
Normal file
1091
Tests/SwiftDBAITests/ChatEngineTests.swift
Normal file
File diff suppressed because it is too large
Load Diff
170
Tests/SwiftDBAITests/ChatViewConfigurationTests.swift
Normal file
170
Tests/SwiftDBAITests/ChatViewConfigurationTests.swift
Normal file
@@ -0,0 +1,170 @@
|
||||
// ChatViewConfigurationTests.swift
|
||||
// SwiftDBAITests
|
||||
//
|
||||
// Tests for ChatViewConfiguration defaults, presets, and environment propagation.
|
||||
|
||||
import Testing
|
||||
import SwiftUI
|
||||
@testable import SwiftDBAI
|
||||
|
||||
@Suite("ChatViewConfiguration Tests")
|
||||
struct ChatViewConfigurationTests {
|
||||
|
||||
// MARK: - Default Values
|
||||
|
||||
@Test("Default configuration has expected color values")
|
||||
func defaultColors() {
|
||||
let config = ChatViewConfiguration.default
|
||||
#expect(config.userBubbleColor == .accentColor)
|
||||
#expect(config.userTextColor == .white)
|
||||
#expect(config.assistantTextColor == .primary)
|
||||
#expect(config.backgroundColor == .clear)
|
||||
#expect(config.inputBarBackgroundColor == .clear)
|
||||
#expect(config.accentColor == .accentColor)
|
||||
#expect(config.errorColor == .red)
|
||||
}
|
||||
|
||||
@Test("Default configuration has expected typography values")
|
||||
func defaultTypography() {
|
||||
let config = ChatViewConfiguration.default
|
||||
#expect(config.messageFont == .body)
|
||||
#expect(config.summaryFont == .body)
|
||||
#expect(config.sqlFont == .system(.caption, design: .monospaced))
|
||||
#expect(config.inputFont == .body)
|
||||
}
|
||||
|
||||
@Test("Default configuration has expected layout values")
|
||||
func defaultLayout() {
|
||||
let config = ChatViewConfiguration.default
|
||||
#expect(config.messagePadding == 14)
|
||||
#expect(config.bubbleCornerRadius == 16)
|
||||
#expect(config.showTimestamps == false)
|
||||
#expect(config.showSQLDisclosure == true)
|
||||
#expect(config.inputPlaceholder == "Ask about your data\u{2026}")
|
||||
#expect(config.emptyStateTitle == "Ask a question about your data")
|
||||
#expect(config.emptyStateSubtitle == "Try something like \"How many records are in the database?\"")
|
||||
#expect(config.emptyStateIcon == "bubble.left.and.text.bubble.right")
|
||||
}
|
||||
|
||||
// MARK: - Compact Preset
|
||||
|
||||
@Test("Compact preset has smaller fonts and tighter padding")
|
||||
func compactPreset() {
|
||||
let config = ChatViewConfiguration.compact
|
||||
#expect(config.messageFont == .footnote)
|
||||
#expect(config.summaryFont == .footnote)
|
||||
#expect(config.sqlFont == .system(.caption2, design: .monospaced))
|
||||
#expect(config.inputFont == .footnote)
|
||||
#expect(config.messagePadding == 8)
|
||||
#expect(config.bubbleCornerRadius == 10)
|
||||
#expect(config.showTimestamps == false)
|
||||
#expect(config.showSQLDisclosure == false)
|
||||
}
|
||||
|
||||
// MARK: - Dark Preset
|
||||
|
||||
@Test("Dark preset has dark-themed colors")
|
||||
func darkPreset() {
|
||||
let config = ChatViewConfiguration.dark
|
||||
#expect(config.userBubbleColor == Color(white: 0.25))
|
||||
#expect(config.userTextColor == .white)
|
||||
#expect(config.assistantBubbleColor == Color(white: 0.15))
|
||||
#expect(config.assistantTextColor == Color(white: 0.9))
|
||||
#expect(config.backgroundColor == .black)
|
||||
#expect(config.inputBarBackgroundColor == Color(white: 0.1))
|
||||
#expect(config.accentColor == .blue)
|
||||
#expect(config.errorColor == Color(red: 1.0, green: 0.4, blue: 0.4))
|
||||
}
|
||||
|
||||
// MARK: - Mutability
|
||||
|
||||
@Test("Configuration properties can be mutated individually")
|
||||
func mutateProperties() {
|
||||
var config = ChatViewConfiguration.default
|
||||
config.userBubbleColor = .purple
|
||||
config.inputPlaceholder = "Ask about your recipes..."
|
||||
config.bubbleCornerRadius = 20
|
||||
config.showTimestamps = true
|
||||
|
||||
#expect(config.userBubbleColor == .purple)
|
||||
#expect(config.inputPlaceholder == "Ask about your recipes...")
|
||||
#expect(config.bubbleCornerRadius == 20)
|
||||
#expect(config.showTimestamps == true)
|
||||
// Other properties remain at defaults
|
||||
#expect(config.userTextColor == .white)
|
||||
#expect(config.messageFont == .body)
|
||||
}
|
||||
|
||||
// MARK: - All Public Properties Accessible
|
||||
|
||||
@Test("All public properties are readable and writable")
|
||||
func allPropertiesAccessible() {
|
||||
var config = ChatViewConfiguration.default
|
||||
|
||||
// Colors
|
||||
_ = config.userBubbleColor
|
||||
_ = config.userTextColor
|
||||
_ = config.assistantBubbleColor
|
||||
_ = config.assistantTextColor
|
||||
_ = config.backgroundColor
|
||||
_ = config.inputBarBackgroundColor
|
||||
_ = config.accentColor
|
||||
_ = config.errorColor
|
||||
|
||||
// Typography
|
||||
_ = config.messageFont
|
||||
_ = config.summaryFont
|
||||
_ = config.sqlFont
|
||||
_ = config.inputFont
|
||||
|
||||
// Layout
|
||||
_ = config.messagePadding
|
||||
_ = config.bubbleCornerRadius
|
||||
_ = config.showTimestamps
|
||||
_ = config.showSQLDisclosure
|
||||
_ = config.inputPlaceholder
|
||||
_ = config.emptyStateTitle
|
||||
_ = config.emptyStateSubtitle
|
||||
_ = config.emptyStateIcon
|
||||
|
||||
// Verify write access compiles (set and read back)
|
||||
config.userBubbleColor = .green
|
||||
#expect(config.userBubbleColor == .green)
|
||||
|
||||
config.emptyStateIcon = "star"
|
||||
#expect(config.emptyStateIcon == "star")
|
||||
}
|
||||
|
||||
// MARK: - Presets Are Static
|
||||
|
||||
@Test("Static presets are available as expected")
|
||||
func staticPresets() {
|
||||
let _ = ChatViewConfiguration.default
|
||||
let _ = ChatViewConfiguration.compact
|
||||
let _ = ChatViewConfiguration.dark
|
||||
}
|
||||
|
||||
// MARK: - Sendable Conformance
|
||||
|
||||
@Test("Configuration is Sendable")
|
||||
func sendableConformance() async {
|
||||
let config = ChatViewConfiguration.default
|
||||
// Verify Sendable by passing across isolation boundary
|
||||
let result: ChatViewConfiguration = await Task.detached {
|
||||
return config
|
||||
}.value
|
||||
#expect(result.bubbleCornerRadius == config.bubbleCornerRadius)
|
||||
}
|
||||
|
||||
// MARK: - Environment Propagation
|
||||
|
||||
@Test("Environment key default value matches ChatViewConfiguration.default")
|
||||
func environmentKeyDefault() {
|
||||
let defaultConfig = ChatViewConfiguration.default
|
||||
let envDefault = ChatViewConfigurationKey.defaultValue
|
||||
#expect(defaultConfig.bubbleCornerRadius == envDefault.bubbleCornerRadius)
|
||||
#expect(defaultConfig.messagePadding == envDefault.messagePadding)
|
||||
#expect(defaultConfig.showSQLDisclosure == envDefault.showSQLDisclosure)
|
||||
#expect(defaultConfig.inputPlaceholder == envDefault.inputPlaceholder)
|
||||
}
|
||||
}
|
||||
164
Tests/SwiftDBAITests/ChatViewTests.swift
Normal file
164
Tests/SwiftDBAITests/ChatViewTests.swift
Normal file
@@ -0,0 +1,164 @@
|
||||
// ChatViewTests.swift
|
||||
// SwiftDBAITests
|
||||
//
|
||||
// Tests for ChatView, ChatViewModel, and MessageBubbleView integration
|
||||
// with ScrollableDataTableView.
|
||||
|
||||
import Testing
|
||||
import Foundation
|
||||
@testable import SwiftDBAI
|
||||
|
||||
@Suite("SchemaReadiness Tests")
|
||||
struct SchemaReadinessTests {
|
||||
|
||||
@Test("SchemaReadiness isReady returns true only for ready state")
|
||||
func isReadyProperty() {
|
||||
#expect(SchemaReadiness.idle.isReady == false)
|
||||
#expect(SchemaReadiness.loading.isReady == false)
|
||||
#expect(SchemaReadiness.ready(tableCount: 3).isReady == true)
|
||||
#expect(SchemaReadiness.failed("error").isReady == false)
|
||||
}
|
||||
}
|
||||
|
||||
@Suite("ChatViewModel Tests")
|
||||
struct ChatViewModelTests {
|
||||
|
||||
@Test("Messages with query results produce DataTable-compatible data")
|
||||
func messageWithQueryResultHasTableData() {
|
||||
// A ChatMessage with a queryResult should have the data needed
|
||||
// for ScrollableDataTableView rendering
|
||||
let result = QueryResult(
|
||||
columns: ["id", "name", "score"],
|
||||
rows: [
|
||||
["id": .integer(1), "name": .text("Alice"), "score": .real(95.5)],
|
||||
["id": .integer(2), "name": .text("Bob"), "score": .real(87.3)],
|
||||
],
|
||||
sql: "SELECT id, name, score FROM users",
|
||||
executionTime: 0.01
|
||||
)
|
||||
|
||||
let message = ChatMessage(
|
||||
role: .assistant,
|
||||
content: "Found 2 users.",
|
||||
queryResult: result,
|
||||
sql: "SELECT id, name, score FROM users"
|
||||
)
|
||||
|
||||
// Verify queryResult is present and can be converted to DataTable
|
||||
#expect(message.queryResult != nil)
|
||||
#expect(message.queryResult!.columns.count == 3)
|
||||
#expect(message.queryResult!.rows.count == 2)
|
||||
|
||||
// Verify DataTable conversion works (this is what MessageBubbleView does)
|
||||
let dataTable = DataTable(message.queryResult!)
|
||||
#expect(dataTable.columnCount == 3)
|
||||
#expect(dataTable.rowCount == 2)
|
||||
#expect(dataTable.columns[0].name == "id")
|
||||
#expect(dataTable.columns[1].name == "name")
|
||||
#expect(dataTable.columns[2].name == "score")
|
||||
}
|
||||
|
||||
@Test("Messages without query results do not trigger table rendering")
|
||||
func messageWithoutQueryResult() {
|
||||
let message = ChatMessage(
|
||||
role: .assistant,
|
||||
content: "Hello! How can I help?",
|
||||
queryResult: nil,
|
||||
sql: nil
|
||||
)
|
||||
|
||||
#expect(message.queryResult == nil)
|
||||
}
|
||||
|
||||
@Test("Empty query results do not trigger table rendering")
|
||||
func emptyQueryResult() {
|
||||
let result = QueryResult(
|
||||
columns: [],
|
||||
rows: [],
|
||||
sql: "SELECT * FROM empty_table",
|
||||
executionTime: 0.001
|
||||
)
|
||||
|
||||
let message = ChatMessage(
|
||||
role: .assistant,
|
||||
content: "No results found.",
|
||||
queryResult: result,
|
||||
sql: "SELECT * FROM empty_table"
|
||||
)
|
||||
|
||||
// Even though queryResult exists, it has no columns/rows
|
||||
// MessageBubbleView checks both conditions before showing the table
|
||||
#expect(message.queryResult != nil)
|
||||
#expect(message.queryResult!.columns.isEmpty)
|
||||
#expect(message.queryResult!.rows.isEmpty)
|
||||
}
|
||||
|
||||
@Test("Mutation results do not trigger table rendering")
|
||||
func mutationQueryResult() {
|
||||
let result = QueryResult(
|
||||
columns: [],
|
||||
rows: [],
|
||||
sql: "INSERT INTO users (name) VALUES ('Charlie')",
|
||||
executionTime: 0.005,
|
||||
rowsAffected: 1
|
||||
)
|
||||
|
||||
let message = ChatMessage(
|
||||
role: .assistant,
|
||||
content: "Successfully inserted 1 row.",
|
||||
queryResult: result,
|
||||
sql: "INSERT INTO users (name) VALUES ('Charlie')"
|
||||
)
|
||||
|
||||
// Mutation results have empty columns — no table shown
|
||||
#expect(message.queryResult!.columns.isEmpty)
|
||||
}
|
||||
|
||||
@Test("Error messages never have query results")
|
||||
func errorMessageHasNoQueryResult() {
|
||||
let message = ChatMessage(
|
||||
role: .error,
|
||||
content: "SELECT operations are not allowed."
|
||||
)
|
||||
|
||||
#expect(message.queryResult == nil)
|
||||
#expect(message.role == .error)
|
||||
}
|
||||
|
||||
@Test("DataTable preserves column order from QueryResult")
|
||||
func dataTableColumnOrder() {
|
||||
let result = QueryResult(
|
||||
columns: ["date", "revenue", "category"],
|
||||
rows: [
|
||||
["date": .text("2024-01-01"), "revenue": .real(1500.0), "category": .text("Electronics")],
|
||||
],
|
||||
sql: "SELECT date, revenue, category FROM sales",
|
||||
executionTime: 0.02
|
||||
)
|
||||
|
||||
let dataTable = DataTable(result)
|
||||
#expect(dataTable.columnNames == ["date", "revenue", "category"])
|
||||
}
|
||||
|
||||
@Test("Large result sets are renderable as DataTable")
|
||||
func largeResultSet() {
|
||||
var rows: [[String: QueryResult.Value]] = []
|
||||
for i in 0..<500 {
|
||||
rows.append([
|
||||
"id": .integer(Int64(i)),
|
||||
"value": .real(Double(i) * 1.5),
|
||||
])
|
||||
}
|
||||
|
||||
let result = QueryResult(
|
||||
columns: ["id", "value"],
|
||||
rows: rows,
|
||||
sql: "SELECT id, value FROM big_table",
|
||||
executionTime: 0.15
|
||||
)
|
||||
|
||||
let dataTable = DataTable(result)
|
||||
#expect(dataTable.rowCount == 500)
|
||||
#expect(dataTable.columnCount == 2)
|
||||
}
|
||||
}
|
||||
136
Tests/SwiftDBAITests/DataChatViewUsageTests.swift
Normal file
136
Tests/SwiftDBAITests/DataChatViewUsageTests.swift
Normal file
@@ -0,0 +1,136 @@
|
||||
// DataChatViewUsageTests.swift
|
||||
// SwiftDBAITests
|
||||
//
|
||||
// Proves DataChatView works with minimal setup — under 10 lines of code.
|
||||
// A developer only needs a GRDB connection and a LanguageModel to get a
|
||||
// full chat-with-database SwiftUI view.
|
||||
|
||||
import Testing
|
||||
import Foundation
|
||||
import GRDB
|
||||
@testable import SwiftDBAI
|
||||
|
||||
// MARK: - Minimal Setup: DataChatView in Under 10 Lines
|
||||
|
||||
/// This test suite proves the "zero_config_reads" principle:
|
||||
/// A developer with an existing SQLite database can create a fully functional
|
||||
/// chat UI by providing only a GRDB connection and a language model instance.
|
||||
/// No schema files, no annotations, no manual configuration required.
|
||||
@Suite("DataChatView Minimal Setup")
|
||||
struct DataChatViewMinimalSetupTests {
|
||||
|
||||
// ┌──────────────────────────────────────────────────────────┐
|
||||
// │ USAGE EXAMPLE — DataChatView in 6 lines of real code │
|
||||
// │ │
|
||||
// │ import SwiftDBAI │
|
||||
// │ import GRDB │
|
||||
// │ │
|
||||
// │ let db = try DatabaseQueue(path: "mydata.sqlite") │
|
||||
// │ let model = OllamaLanguageModel(model: "llama3") │
|
||||
// │ │
|
||||
// │ var body: some View { │
|
||||
// │ DataChatView(database: db, model: model) │
|
||||
// │ } │
|
||||
// └──────────────────────────────────────────────────────────┘
|
||||
|
||||
/// Creates a temporary in-memory database with sample data for tests.
|
||||
private static func makeSampleDatabase() throws -> DatabaseQueue {
|
||||
let db = try DatabaseQueue()
|
||||
try db.write { db in
|
||||
try db.execute(sql: """
|
||||
CREATE TABLE products (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
price REAL NOT NULL,
|
||||
category TEXT
|
||||
);
|
||||
INSERT INTO products (name, price, category) VALUES ('Widget', 9.99, 'Hardware');
|
||||
INSERT INTO products (name, price, category) VALUES ('Gadget', 24.99, 'Electronics');
|
||||
INSERT INTO products (name, price, category) VALUES ('Doohickey', 4.99, 'Hardware');
|
||||
""")
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
@Test("DataChatView initializes from database + model in 2 lines")
|
||||
@MainActor
|
||||
func dataChatViewMinimalInit() throws {
|
||||
// LINE 1: Create (or receive) a GRDB connection
|
||||
let db = try Self.makeSampleDatabase()
|
||||
// LINE 2: Create the view — that's it!
|
||||
let _ = DataChatView(database: db, model: MockLanguageModel())
|
||||
// The view is ready. No schema files, no annotations, no extra config.
|
||||
}
|
||||
|
||||
@Test("DataChatView path-based init works in 1 line given a path and model")
|
||||
@MainActor
|
||||
func dataChatViewPathInit() throws {
|
||||
// Create a temp database file
|
||||
let tempDir = FileManager.default.temporaryDirectory
|
||||
let dbPath = tempDir.appendingPathComponent("test_\(UUID().uuidString).sqlite").path
|
||||
let db = try DatabaseQueue(path: dbPath)
|
||||
try db.write { db in
|
||||
try db.execute(sql: "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
|
||||
}
|
||||
|
||||
// ONE LINE to get a full chat UI:
|
||||
let _ = DataChatView(databasePath: dbPath, model: MockLanguageModel())
|
||||
|
||||
// Cleanup
|
||||
try? FileManager.default.removeItem(atPath: dbPath)
|
||||
}
|
||||
|
||||
@Test("ChatEngine headless usage works in 3 lines")
|
||||
func chatEngineMinimalUsage() async throws {
|
||||
// LINE 1: Database
|
||||
let db = try Self.makeSampleDatabase()
|
||||
// LINE 2: Engine
|
||||
let engine = ChatEngine(database: db, model: MockLanguageModel(responseText: "SELECT COUNT(*) AS total FROM products"))
|
||||
// LINE 3: Schema preparation verifies auto-introspection works
|
||||
let schema = try await engine.prepareSchema()
|
||||
|
||||
// The engine auto-discovered the schema — no manual config needed
|
||||
#expect(schema.tableNames.contains("products"))
|
||||
#expect(schema.tableNames.count == 1)
|
||||
}
|
||||
|
||||
@Test("ChatViewModel works with zero configuration beyond db + model")
|
||||
@MainActor
|
||||
func chatViewModelMinimalUsage() async throws {
|
||||
let db = try Self.makeSampleDatabase()
|
||||
let engine = ChatEngine(database: db, model: MockLanguageModel())
|
||||
let viewModel = ChatViewModel(engine: engine)
|
||||
|
||||
// Prepare triggers auto-schema-introspection
|
||||
await viewModel.prepare()
|
||||
|
||||
#expect(viewModel.schemaReadiness.isReady)
|
||||
#expect(viewModel.messages.isEmpty) // Clean slate, ready to chat
|
||||
}
|
||||
|
||||
@Test("Default configuration is read-only (safe by default)")
|
||||
@MainActor
|
||||
func defaultIsReadOnly() throws {
|
||||
let db = try Self.makeSampleDatabase()
|
||||
// No allowlist specified — defaults to .readOnly
|
||||
let _ = DataChatView(database: db, model: MockLanguageModel())
|
||||
// This compiles and works. SELECT-only is the safe default.
|
||||
// Developer must explicitly opt in to writes:
|
||||
// DataChatView(database: db, model: model, allowlist: .standard)
|
||||
}
|
||||
|
||||
@Test("Full DataChatView with all options still under 10 lines")
|
||||
@MainActor
|
||||
func dataChatViewFullConfig() throws {
|
||||
let db = try Self.makeSampleDatabase() // 1
|
||||
let model = MockLanguageModel() // 2
|
||||
let _ = DataChatView( // 3-8
|
||||
database: db,
|
||||
model: model,
|
||||
allowlist: .readOnly,
|
||||
additionalContext: "Product catalog for an e-commerce store",
|
||||
maxSummaryRows: 100
|
||||
)
|
||||
// Even with ALL options specified, it's under 10 lines of setup.
|
||||
}
|
||||
}
|
||||
285
Tests/SwiftDBAITests/DataTableTests.swift
Normal file
285
Tests/SwiftDBAITests/DataTableTests.swift
Normal file
@@ -0,0 +1,285 @@
|
||||
// DataTableTests.swift
|
||||
// SwiftDBAITests
|
||||
|
||||
import Foundation
|
||||
import Testing
|
||||
@testable import SwiftDBAI
|
||||
|
||||
@Suite("DataTable")
|
||||
struct DataTableTests {
|
||||
|
||||
// MARK: - Helpers
|
||||
|
||||
private func makeQueryResult(
|
||||
columns: [String],
|
||||
rows: [[String: QueryResult.Value]],
|
||||
sql: String = "SELECT * FROM test",
|
||||
executionTime: TimeInterval = 0.01
|
||||
) -> QueryResult {
|
||||
QueryResult(
|
||||
columns: columns,
|
||||
rows: rows,
|
||||
sql: sql,
|
||||
executionTime: executionTime
|
||||
)
|
||||
}
|
||||
|
||||
// MARK: - Basic Construction
|
||||
|
||||
@Test("Converts QueryResult columns and rows correctly")
|
||||
func basicConversion() {
|
||||
let result = makeQueryResult(
|
||||
columns: ["id", "name", "score"],
|
||||
rows: [
|
||||
["id": .integer(1), "name": .text("Alice"), "score": .real(95.5)],
|
||||
["id": .integer(2), "name": .text("Bob"), "score": .real(87.0)],
|
||||
]
|
||||
)
|
||||
|
||||
let table = DataTable(result)
|
||||
|
||||
#expect(table.columnCount == 3)
|
||||
#expect(table.rowCount == 2)
|
||||
#expect(table.columnNames == ["id", "name", "score"])
|
||||
#expect(table.sql == "SELECT * FROM test")
|
||||
#expect(table.executionTime == 0.01)
|
||||
}
|
||||
|
||||
@Test("Empty result produces empty table")
|
||||
func emptyResult() {
|
||||
let result = makeQueryResult(columns: ["id", "name"], rows: [])
|
||||
|
||||
let table = DataTable(result)
|
||||
|
||||
#expect(table.isEmpty)
|
||||
#expect(table.rowCount == 0)
|
||||
#expect(table.columnCount == 2)
|
||||
#expect(table.columnNames == ["id", "name"])
|
||||
}
|
||||
|
||||
// MARK: - Subscript Access
|
||||
|
||||
@Test("Subscript by row and column index")
|
||||
func subscriptByIndex() {
|
||||
let result = makeQueryResult(
|
||||
columns: ["a", "b"],
|
||||
rows: [
|
||||
["a": .integer(10), "b": .text("hello")],
|
||||
["a": .integer(20), "b": .text("world")],
|
||||
]
|
||||
)
|
||||
|
||||
let table = DataTable(result)
|
||||
|
||||
#expect(table[row: 0, column: 0] == .integer(10))
|
||||
#expect(table[row: 0, column: 1] == .text("hello"))
|
||||
#expect(table[row: 1, column: 0] == .integer(20))
|
||||
#expect(table[row: 1, column: 1] == .text("world"))
|
||||
}
|
||||
|
||||
@Test("Subscript by row index and column name")
|
||||
func subscriptByName() {
|
||||
let result = makeQueryResult(
|
||||
columns: ["x", "y"],
|
||||
rows: [["x": .real(1.5), "y": .real(2.5)]]
|
||||
)
|
||||
|
||||
let table = DataTable(result)
|
||||
|
||||
#expect(table[row: 0, column: "x"] == .real(1.5))
|
||||
#expect(table[row: 0, column: "y"] == .real(2.5))
|
||||
#expect(table[row: 0, column: "z"] == .null) // non-existent column
|
||||
}
|
||||
|
||||
// MARK: - Column Data Extraction
|
||||
|
||||
@Test("Extract column values by index")
|
||||
func columnValuesByIndex() {
|
||||
let result = makeQueryResult(
|
||||
columns: ["val"],
|
||||
rows: [
|
||||
["val": .integer(1)],
|
||||
["val": .integer(2)],
|
||||
["val": .integer(3)],
|
||||
]
|
||||
)
|
||||
|
||||
let table = DataTable(result)
|
||||
let values = table.columnValues(at: 0)
|
||||
|
||||
#expect(values == [.integer(1), .integer(2), .integer(3)])
|
||||
}
|
||||
|
||||
@Test("Extract column values by name")
|
||||
func columnValuesByName() {
|
||||
let result = makeQueryResult(
|
||||
columns: ["name"],
|
||||
rows: [
|
||||
["name": .text("A")],
|
||||
["name": .text("B")],
|
||||
]
|
||||
)
|
||||
|
||||
let table = DataTable(result)
|
||||
|
||||
#expect(table.columnValues(named: "name") == [.text("A"), .text("B")])
|
||||
#expect(table.columnValues(named: "missing").isEmpty)
|
||||
}
|
||||
|
||||
@Test("numericValues extracts doubles from numeric column")
|
||||
func numericValues() {
|
||||
let result = makeQueryResult(
|
||||
columns: ["score"],
|
||||
rows: [
|
||||
["score": .integer(10)],
|
||||
["score": .real(20.5)],
|
||||
["score": .null],
|
||||
["score": .text("not a number")],
|
||||
]
|
||||
)
|
||||
|
||||
let table = DataTable(result)
|
||||
let nums = table.numericValues(forColumn: "score")
|
||||
|
||||
#expect(nums.count == 2)
|
||||
#expect(nums[0] == 10.0)
|
||||
#expect(nums[1] == 20.5)
|
||||
}
|
||||
|
||||
@Test("stringValues extracts non-null strings")
|
||||
func stringValues() {
|
||||
let result = makeQueryResult(
|
||||
columns: ["label"],
|
||||
rows: [
|
||||
["label": .text("foo")],
|
||||
["label": .null],
|
||||
["label": .text("bar")],
|
||||
]
|
||||
)
|
||||
|
||||
let table = DataTable(result)
|
||||
let strs = table.stringValues(forColumn: "label")
|
||||
|
||||
#expect(strs == ["foo", "bar"])
|
||||
}
|
||||
|
||||
// MARK: - Type Inference
|
||||
|
||||
@Test("Infers integer type for all-integer column")
|
||||
func inferInteger() {
|
||||
let result = makeQueryResult(
|
||||
columns: ["id"],
|
||||
rows: [["id": .integer(1)], ["id": .integer(2)]]
|
||||
)
|
||||
let table = DataTable(result)
|
||||
#expect(table.columns[0].inferredType == .integer)
|
||||
}
|
||||
|
||||
@Test("Infers real type for all-real column")
|
||||
func inferReal() {
|
||||
let result = makeQueryResult(
|
||||
columns: ["price"],
|
||||
rows: [["price": .real(1.99)], ["price": .real(2.50)]]
|
||||
)
|
||||
let table = DataTable(result)
|
||||
#expect(table.columns[0].inferredType == .real)
|
||||
}
|
||||
|
||||
@Test("Infers text type for all-text column")
|
||||
func inferText() {
|
||||
let result = makeQueryResult(
|
||||
columns: ["name"],
|
||||
rows: [["name": .text("A")], ["name": .text("B")]]
|
||||
)
|
||||
let table = DataTable(result)
|
||||
#expect(table.columns[0].inferredType == .text)
|
||||
}
|
||||
|
||||
@Test("Promotes integer + real to real")
|
||||
func inferNumericPromotion() {
|
||||
let result = makeQueryResult(
|
||||
columns: ["val"],
|
||||
rows: [["val": .integer(1)], ["val": .real(2.5)]]
|
||||
)
|
||||
let table = DataTable(result)
|
||||
#expect(table.columns[0].inferredType == .real)
|
||||
}
|
||||
|
||||
@Test("Mixed types result in .mixed")
|
||||
func inferMixed() {
|
||||
let result = makeQueryResult(
|
||||
columns: ["data"],
|
||||
rows: [["data": .integer(1)], ["data": .text("hello")]]
|
||||
)
|
||||
let table = DataTable(result)
|
||||
#expect(table.columns[0].inferredType == .mixed)
|
||||
}
|
||||
|
||||
@Test("All-null column infers .null")
|
||||
func inferNull() {
|
||||
let result = makeQueryResult(
|
||||
columns: ["empty"],
|
||||
rows: [["empty": .null], ["empty": .null]]
|
||||
)
|
||||
let table = DataTable(result)
|
||||
#expect(table.columns[0].inferredType == .null)
|
||||
}
|
||||
|
||||
@Test("Null values are ignored during type inference")
|
||||
func inferIgnoresNulls() {
|
||||
let result = makeQueryResult(
|
||||
columns: ["val"],
|
||||
rows: [["val": .integer(1)], ["val": .null], ["val": .integer(3)]]
|
||||
)
|
||||
let table = DataTable(result)
|
||||
#expect(table.columns[0].inferredType == .integer)
|
||||
}
|
||||
|
||||
// MARK: - Missing Values
|
||||
|
||||
@Test("Missing dictionary keys become .null")
|
||||
func missingKeysBecomNull() {
|
||||
let result = makeQueryResult(
|
||||
columns: ["a", "b"],
|
||||
rows: [["a": .integer(1)]] // "b" is missing
|
||||
)
|
||||
|
||||
let table = DataTable(result)
|
||||
|
||||
#expect(table[row: 0, column: 0] == .integer(1))
|
||||
#expect(table[row: 0, column: 1] == .null)
|
||||
}
|
||||
|
||||
// MARK: - Row Identity
|
||||
|
||||
@Test("Rows have sequential IDs")
|
||||
func rowIdentity() {
|
||||
let result = makeQueryResult(
|
||||
columns: ["x"],
|
||||
rows: [["x": .integer(1)], ["x": .integer(2)], ["x": .integer(3)]]
|
||||
)
|
||||
|
||||
let table = DataTable(result)
|
||||
|
||||
#expect(table.rows[0].id == 0)
|
||||
#expect(table.rows[1].id == 1)
|
||||
#expect(table.rows[2].id == 2)
|
||||
}
|
||||
|
||||
// MARK: - Column Identity
|
||||
|
||||
@Test("Columns are Identifiable by name")
|
||||
func columnIdentity() {
|
||||
let result = makeQueryResult(
|
||||
columns: ["alpha", "beta"],
|
||||
rows: [["alpha": .integer(1), "beta": .integer(2)]]
|
||||
)
|
||||
|
||||
let table = DataTable(result)
|
||||
|
||||
#expect(table.columns[0].id == "alpha")
|
||||
#expect(table.columns[1].id == "beta")
|
||||
#expect(table.columns[0].index == 0)
|
||||
#expect(table.columns[1].index == 1)
|
||||
}
|
||||
}
|
||||
317
Tests/SwiftDBAITests/DatabaseToolTests.swift
Normal file
317
Tests/SwiftDBAITests/DatabaseToolTests.swift
Normal file
@@ -0,0 +1,317 @@
|
||||
// DatabaseToolTests.swift
|
||||
// SwiftDBAI
|
||||
|
||||
import Testing
|
||||
import Foundation
|
||||
import GRDB
|
||||
@testable import SwiftDBAI
|
||||
|
||||
@Suite("DatabaseTool")
|
||||
struct DatabaseToolTests {
|
||||
|
||||
// MARK: - Helper
|
||||
|
||||
/// Creates an in-memory database with sample data for testing.
|
||||
private func makeTestDatabase() throws -> DatabaseQueue {
|
||||
let db = try DatabaseQueue(configuration: {
|
||||
var config = Configuration()
|
||||
config.foreignKeysEnabled = true
|
||||
return config
|
||||
}())
|
||||
|
||||
try db.write { db in
|
||||
try db.execute(sql: """
|
||||
CREATE TABLE users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
email TEXT UNIQUE
|
||||
);
|
||||
""")
|
||||
|
||||
try db.execute(sql: """
|
||||
CREATE TABLE posts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||
title TEXT NOT NULL,
|
||||
body TEXT,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
""")
|
||||
|
||||
try db.execute(sql: """
|
||||
CREATE INDEX idx_posts_user ON posts(user_id);
|
||||
""")
|
||||
|
||||
// Insert sample data
|
||||
try db.execute(sql: "INSERT INTO users (name, email) VALUES ('Alice', 'alice@example.com')")
|
||||
try db.execute(sql: "INSERT INTO users (name, email) VALUES ('Bob', 'bob@example.com')")
|
||||
try db.execute(sql: "INSERT INTO posts (user_id, title, body) VALUES (1, 'Hello World', 'First post')")
|
||||
try db.execute(sql: "INSERT INTO posts (user_id, title, body) VALUES (1, 'Second Post', 'More content')")
|
||||
try db.execute(sql: "INSERT INTO posts (user_id, title, body) VALUES (2, 'Bob Post', 'Bob writes')")
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// MARK: - Creation
|
||||
|
||||
@Test("Creates tool from database connection")
|
||||
func testCreationFromDatabase() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let tool = try await DatabaseTool(database: db)
|
||||
|
||||
#expect(tool.name == "execute_sql")
|
||||
#expect(!tool.description.isEmpty)
|
||||
}
|
||||
|
||||
@Test("Creates tool from database path")
|
||||
func testCreationFromPath() async throws {
|
||||
let tmpDir = FileManager.default.temporaryDirectory
|
||||
let dbPath = tmpDir.appendingPathComponent("test_\(UUID().uuidString).sqlite").path
|
||||
defer { try? FileManager.default.removeItem(atPath: dbPath) }
|
||||
|
||||
// Create a database at the path
|
||||
let dbQueue = try DatabaseQueue(path: dbPath)
|
||||
try await dbQueue.write { db in
|
||||
try db.execute(sql: "CREATE TABLE test (id INTEGER PRIMARY KEY)")
|
||||
}
|
||||
|
||||
let tool = try await DatabaseTool(databasePath: dbPath)
|
||||
#expect(tool.name == "execute_sql")
|
||||
}
|
||||
|
||||
// MARK: - Schema Context
|
||||
|
||||
@Test("Schema context contains table info")
|
||||
func testSchemaContext() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let tool = try await DatabaseTool(database: db)
|
||||
|
||||
let context = tool.schemaContext
|
||||
#expect(context.contains("users"))
|
||||
#expect(context.contains("posts"))
|
||||
#expect(context.contains("name"))
|
||||
#expect(context.contains("email"))
|
||||
}
|
||||
|
||||
@Test("System prompt snippet contains schema")
|
||||
func testSystemPromptSnippet() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let tool = try await DatabaseTool(database: db)
|
||||
|
||||
let snippet = tool.systemPromptSnippet
|
||||
#expect(snippet.contains("users"))
|
||||
#expect(snippet.contains("posts"))
|
||||
#expect(snippet.contains("execute_sql"))
|
||||
#expect(snippet.contains("SELECT"))
|
||||
}
|
||||
|
||||
// MARK: - SQL Execution
|
||||
|
||||
@Test("Executes valid SELECT query")
|
||||
func testExecuteSelect() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let tool = try await DatabaseTool(database: db)
|
||||
|
||||
let result = try tool.execute(sql: "SELECT name, email FROM users ORDER BY name")
|
||||
|
||||
#expect(result.rowCount == 2)
|
||||
#expect(result.columns == ["name", "email"])
|
||||
#expect(result.rows[0]["name"] == "Alice")
|
||||
#expect(result.rows[1]["name"] == "Bob")
|
||||
}
|
||||
|
||||
@Test("Executes query with JOIN")
|
||||
func testExecuteJoin() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let tool = try await DatabaseTool(database: db)
|
||||
|
||||
let result = try tool.execute(sql: """
|
||||
SELECT u.name, COUNT(p.id) as post_count
|
||||
FROM users u
|
||||
JOIN posts p ON p.user_id = u.id
|
||||
GROUP BY u.name
|
||||
ORDER BY u.name
|
||||
""")
|
||||
|
||||
#expect(result.rowCount == 2)
|
||||
#expect(result.rows[0]["name"] == "Alice")
|
||||
#expect(result.rows[0]["post_count"] == "2")
|
||||
}
|
||||
|
||||
@Test("Rejects INSERT with read-only allowlist")
|
||||
func testRejectInsert() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let tool = try await DatabaseTool(database: db, allowlist: .readOnly)
|
||||
|
||||
#expect(throws: SQLParsingError.self) {
|
||||
try tool.execute(sql: "INSERT INTO users (name, email) VALUES ('Eve', 'eve@example.com')")
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Rejects DELETE with read-only allowlist")
|
||||
func testRejectDelete() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let tool = try await DatabaseTool(database: db, allowlist: .readOnly)
|
||||
|
||||
#expect(throws: SQLParsingError.self) {
|
||||
try tool.execute(sql: "DELETE FROM users WHERE id = 1")
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Rejects DROP as dangerous operation")
|
||||
func testRejectDrop() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let tool = try await DatabaseTool(database: db, allowlist: .unrestricted)
|
||||
|
||||
#expect(throws: SQLParsingError.self) {
|
||||
try tool.execute(sql: "DROP TABLE users")
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Executes raw query returning QueryResult")
|
||||
func testExecuteRaw() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let tool = try await DatabaseTool(database: db)
|
||||
|
||||
let result = try tool.executeRaw(sql: "SELECT COUNT(*) as cnt FROM users")
|
||||
|
||||
#expect(result.rowCount == 1)
|
||||
#expect(result.columns == ["cnt"])
|
||||
if case .integer(let count) = result.rows[0]["cnt"] {
|
||||
#expect(count == 2)
|
||||
} else {
|
||||
Issue.record("Expected integer value")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - ToolResult Formatting
|
||||
|
||||
@Test("ToolResult JSON serialization")
|
||||
func testToolResultJSON() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let tool = try await DatabaseTool(database: db)
|
||||
|
||||
let result = try tool.execute(sql: "SELECT name FROM users ORDER BY name LIMIT 1")
|
||||
|
||||
let json = result.jsonString
|
||||
#expect(json.contains("\"columns\""))
|
||||
#expect(json.contains("\"rows\""))
|
||||
#expect(json.contains("\"row_count\""))
|
||||
#expect(json.contains("Alice"))
|
||||
|
||||
// Verify it is valid JSON
|
||||
let data = json.data(using: .utf8)!
|
||||
let parsed = try JSONSerialization.jsonObject(with: data) as! [String: Any]
|
||||
#expect(parsed["row_count"] as? Int == 1)
|
||||
}
|
||||
|
||||
@Test("ToolResult markdown table formatting")
|
||||
func testToolResultMarkdown() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let tool = try await DatabaseTool(database: db)
|
||||
|
||||
let result = try tool.execute(sql: "SELECT name, email FROM users ORDER BY name")
|
||||
|
||||
let md = result.markdownTable
|
||||
#expect(md.contains("| name | email |"))
|
||||
#expect(md.contains("| --- | --- |"))
|
||||
#expect(md.contains("| Alice | alice@example.com |"))
|
||||
#expect(md.contains("| Bob | bob@example.com |"))
|
||||
}
|
||||
|
||||
@Test("ToolResult markdown table with empty result")
|
||||
func testToolResultMarkdownEmpty() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let tool = try await DatabaseTool(database: db)
|
||||
|
||||
let result = try tool.execute(sql: "SELECT name FROM users WHERE name = 'Nobody'")
|
||||
|
||||
#expect(result.markdownTable == "_No results._")
|
||||
}
|
||||
|
||||
@Test("ToolResult text summary")
|
||||
func testToolResultTextSummary() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let tool = try await DatabaseTool(database: db)
|
||||
|
||||
let result = try tool.execute(sql: "SELECT name FROM users")
|
||||
|
||||
let summary = result.textSummary
|
||||
#expect(summary.contains("2 rows"))
|
||||
#expect(summary.contains("name"))
|
||||
}
|
||||
|
||||
@Test("ToolResult text summary with empty result")
|
||||
func testToolResultTextSummaryEmpty() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let tool = try await DatabaseTool(database: db)
|
||||
|
||||
let result = try tool.execute(sql: "SELECT name FROM users WHERE 1 = 0")
|
||||
|
||||
#expect(result.textSummary.contains("no results"))
|
||||
}
|
||||
|
||||
// MARK: - Parameters Schema
|
||||
|
||||
@Test("Parameters schema has correct structure")
|
||||
func testParametersSchema() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let tool = try await DatabaseTool(database: db)
|
||||
|
||||
let schema = tool.parametersSchema
|
||||
#expect(schema["type"] as? String == "object")
|
||||
|
||||
let properties = schema["properties"] as? [String: Any]
|
||||
#expect(properties != nil)
|
||||
|
||||
let sqlProp = properties?["sql"] as? [String: Any]
|
||||
#expect(sqlProp?["type"] as? String == "string")
|
||||
|
||||
let required = schema["required"] as? [String]
|
||||
#expect(required == ["sql"])
|
||||
}
|
||||
|
||||
// MARK: - OpenAI Function Definition
|
||||
|
||||
@Test("OpenAI function definition has correct format")
|
||||
func testOpenAIFunctionDefinition() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let tool = try await DatabaseTool(database: db)
|
||||
|
||||
let def = tool.openAIFunctionDefinition
|
||||
#expect(def["type"] as? String == "function")
|
||||
|
||||
let function = def["function"] as? [String: Any]
|
||||
#expect(function?["name"] as? String == "execute_sql")
|
||||
#expect(function?["description"] as? String != nil)
|
||||
#expect(function?["parameters"] as? [String: Any] != nil)
|
||||
|
||||
// Verify it can be serialized to JSON
|
||||
let data = try JSONSerialization.data(withJSONObject: def)
|
||||
#expect(data.count > 0)
|
||||
}
|
||||
|
||||
// MARK: - ToolResult Codable
|
||||
|
||||
@Test("ToolResult is Codable")
|
||||
func testToolResultCodable() throws {
|
||||
let result = ToolResult(
|
||||
columns: ["name", "age"],
|
||||
rows: [["name": "Alice", "age": "30"]],
|
||||
rowCount: 1,
|
||||
executionTime: 0.005,
|
||||
sql: "SELECT name, age FROM users"
|
||||
)
|
||||
|
||||
let encoder = JSONEncoder()
|
||||
let data = try encoder.encode(result)
|
||||
let decoder = JSONDecoder()
|
||||
let decoded = try decoder.decode(ToolResult.self, from: data)
|
||||
|
||||
#expect(decoded.columns == result.columns)
|
||||
#expect(decoded.rows == result.rows)
|
||||
#expect(decoded.rowCount == result.rowCount)
|
||||
#expect(decoded.sql == result.sql)
|
||||
}
|
||||
}
|
||||
745
Tests/SwiftDBAITests/DestructiveOperationTests.swift
Normal file
745
Tests/SwiftDBAITests/DestructiveOperationTests.swift
Normal file
@@ -0,0 +1,745 @@
|
||||
// DestructiveOperationTests.swift
|
||||
// SwiftDBAITests
|
||||
//
|
||||
// Tests verifying that destructive operations are blocked without confirmation
|
||||
// and allowed when the delegate approves.
|
||||
|
||||
import AnyLanguageModel
|
||||
import Foundation
|
||||
import GRDB
|
||||
import Testing
|
||||
|
||||
@testable import SwiftDBAI
|
||||
|
||||
// MARK: - Test Delegates
|
||||
|
||||
/// A delegate that always rejects destructive operations and tracks calls.
|
||||
private final class RejectingTrackingDelegate: SwiftDBAI.ToolExecutionDelegate, @unchecked Sendable {
|
||||
private let lock = NSLock()
|
||||
private var _confirmCalls: [DestructiveOperationContext] = []
|
||||
private var _willExecuteCalls: [(sql: String, classification: DestructiveClassification)] = []
|
||||
private var _didExecuteCalls: [(sql: String, success: Bool)] = []
|
||||
|
||||
var confirmCalls: [DestructiveOperationContext] {
|
||||
lock.withLock { _confirmCalls }
|
||||
}
|
||||
|
||||
var willExecuteCalls: [(sql: String, classification: DestructiveClassification)] {
|
||||
lock.withLock { _willExecuteCalls }
|
||||
}
|
||||
|
||||
var didExecuteCalls: [(sql: String, success: Bool)] {
|
||||
lock.withLock { _didExecuteCalls }
|
||||
}
|
||||
|
||||
func confirmDestructiveOperation(_ context: DestructiveOperationContext) async -> Bool {
|
||||
lock.withLock { _confirmCalls.append(context) }
|
||||
return false
|
||||
}
|
||||
|
||||
func willExecuteSQL(_ sql: String, classification: DestructiveClassification) async {
|
||||
lock.withLock { _willExecuteCalls.append((sql: sql, classification: classification)) }
|
||||
}
|
||||
|
||||
func didExecuteSQL(_ sql: String, success: Bool) async {
|
||||
lock.withLock { _didExecuteCalls.append((sql: sql, success: success)) }
|
||||
}
|
||||
}
|
||||
|
||||
/// A delegate that always approves destructive operations and tracks calls.
|
||||
private final class ApprovingTrackingDelegate: SwiftDBAI.ToolExecutionDelegate, @unchecked Sendable {
|
||||
private let lock = NSLock()
|
||||
private var _confirmCalls: [DestructiveOperationContext] = []
|
||||
private var _willExecuteCalls: [(sql: String, classification: DestructiveClassification)] = []
|
||||
private var _didExecuteCalls: [(sql: String, success: Bool)] = []
|
||||
|
||||
var confirmCalls: [DestructiveOperationContext] {
|
||||
lock.withLock { _confirmCalls }
|
||||
}
|
||||
|
||||
var willExecuteCalls: [(sql: String, classification: DestructiveClassification)] {
|
||||
lock.withLock { _willExecuteCalls }
|
||||
}
|
||||
|
||||
var didExecuteCalls: [(sql: String, success: Bool)] {
|
||||
lock.withLock { _didExecuteCalls }
|
||||
}
|
||||
|
||||
func confirmDestructiveOperation(_ context: DestructiveOperationContext) async -> Bool {
|
||||
lock.withLock { _confirmCalls.append(context) }
|
||||
return true
|
||||
}
|
||||
|
||||
func willExecuteSQL(_ sql: String, classification: DestructiveClassification) async {
|
||||
lock.withLock { _willExecuteCalls.append((sql: sql, classification: classification)) }
|
||||
}
|
||||
|
||||
func didExecuteSQL(_ sql: String, success: Bool) async {
|
||||
lock.withLock { _didExecuteCalls.append((sql: sql, success: success)) }
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Helpers
|
||||
|
||||
/// Creates an in-memory database with test data for destructive operation tests.
|
||||
/// Users 1 and 2 have orders; user 3 has no orders (safe to delete).
|
||||
private func makeTestDatabase() throws -> DatabaseQueue {
|
||||
let db = try DatabaseQueue(path: ":memory:")
|
||||
try db.write { db in
|
||||
// Disable FK enforcement for test flexibility, then re-enable
|
||||
try db.execute(sql: "PRAGMA foreign_keys = OFF")
|
||||
try db.execute(sql: """
|
||||
CREATE TABLE users (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
email TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
try db.execute(sql: """
|
||||
INSERT INTO users (name, email) VALUES
|
||||
('Alice', 'alice@example.com'),
|
||||
('Bob', 'bob@example.com'),
|
||||
('Charlie', 'charlie@example.com')
|
||||
""")
|
||||
try db.execute(sql: """
|
||||
CREATE TABLE orders (
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
amount REAL NOT NULL
|
||||
)
|
||||
""")
|
||||
try db.execute(sql: """
|
||||
INSERT INTO orders (user_id, amount) VALUES
|
||||
(1, 99.99),
|
||||
(2, 150.00),
|
||||
(3, 25.50)
|
||||
""")
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
/// A sequential mock model for tests. Returns responses in order.
|
||||
private struct TestSequentialModel: LanguageModel {
|
||||
typealias UnavailableReason = Never
|
||||
|
||||
let responses: [String]
|
||||
private let callCounter = CallCounter()
|
||||
|
||||
private final class CallCounter: @unchecked Sendable {
|
||||
var count = 0
|
||||
let lock = NSLock()
|
||||
|
||||
func next() -> Int {
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
let c = count
|
||||
count += 1
|
||||
return c
|
||||
}
|
||||
}
|
||||
|
||||
init(responses: [String]) {
|
||||
self.responses = responses
|
||||
}
|
||||
|
||||
func respond<Content>(
|
||||
within session: LanguageModelSession,
|
||||
to prompt: Prompt,
|
||||
generating type: Content.Type,
|
||||
includeSchemaInPrompt: Bool,
|
||||
options: GenerationOptions
|
||||
) async throws -> LanguageModelSession.Response<Content> where Content: Generable {
|
||||
let idx = callCounter.next()
|
||||
let text = idx < responses.count ? responses[idx] : "fallback response"
|
||||
let rawContent = GeneratedContent(kind: .string(text))
|
||||
let content = try Content(rawContent)
|
||||
return LanguageModelSession.Response(
|
||||
content: content,
|
||||
rawContent: rawContent,
|
||||
transcriptEntries: [][...]
|
||||
)
|
||||
}
|
||||
|
||||
func streamResponse<Content>(
|
||||
within session: LanguageModelSession,
|
||||
to prompt: Prompt,
|
||||
generating type: Content.Type,
|
||||
includeSchemaInPrompt: Bool,
|
||||
options: GenerationOptions
|
||||
) -> sending LanguageModelSession.ResponseStream<Content> where Content: Generable {
|
||||
let idx = callCounter.next()
|
||||
let text = idx < responses.count ? responses[idx] : "fallback response"
|
||||
let rawContent = GeneratedContent(kind: .string(text))
|
||||
let content = try! Content(rawContent)
|
||||
return LanguageModelSession.ResponseStream(content: content, rawContent: rawContent)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Tests: Destructive Operations Blocked Without Confirmation
|
||||
|
||||
@Suite("Destructive Operations - Blocked Without Confirmation")
|
||||
struct DestructiveOperationsBlockedTests {
|
||||
|
||||
@Test("DELETE is blocked when no delegate is provided")
|
||||
func deleteBlockedWithoutDelegate() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let model = TestSequentialModel(responses: [
|
||||
"DELETE FROM users WHERE id = 1"
|
||||
])
|
||||
|
||||
// Unrestricted allowlist permits DELETE, but no delegate to confirm
|
||||
let engine = ChatEngine(
|
||||
database: db,
|
||||
model: model,
|
||||
allowlist: .unrestricted
|
||||
)
|
||||
|
||||
do {
|
||||
_ = try await engine.send("Delete user 1")
|
||||
Issue.record("Expected confirmationRequired error but send succeeded")
|
||||
} catch let error as SwiftDBAIError {
|
||||
guard case .confirmationRequired(let sql, let operation) = error else {
|
||||
Issue.record("Expected confirmationRequired, got: \(error)")
|
||||
return
|
||||
}
|
||||
#expect(sql.uppercased().contains("DELETE"))
|
||||
#expect(operation == "delete")
|
||||
}
|
||||
|
||||
// Verify the user was NOT deleted (data remains intact)
|
||||
let count = try await db.read { db in
|
||||
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 1")
|
||||
}
|
||||
#expect(count == 1, "User should NOT have been deleted")
|
||||
}
|
||||
|
||||
@Test("DELETE is blocked when delegate rejects")
|
||||
func deleteBlockedWhenDelegateRejects() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let delegate = RejectingTrackingDelegate()
|
||||
let model = TestSequentialModel(responses: [
|
||||
"DELETE FROM users WHERE id = 2"
|
||||
])
|
||||
|
||||
let engine = ChatEngine(
|
||||
database: db,
|
||||
model: model,
|
||||
allowlist: .unrestricted,
|
||||
delegate: delegate
|
||||
)
|
||||
|
||||
do {
|
||||
_ = try await engine.send("Delete user 2")
|
||||
Issue.record("Expected confirmationRequired error but send succeeded")
|
||||
} catch let error as SwiftDBAIError {
|
||||
guard case .confirmationRequired(let sql, let operation) = error else {
|
||||
Issue.record("Expected confirmationRequired, got: \(error)")
|
||||
return
|
||||
}
|
||||
#expect(sql.uppercased().contains("DELETE"))
|
||||
#expect(operation == "delete")
|
||||
}
|
||||
|
||||
// Verify delegate was consulted
|
||||
#expect(delegate.confirmCalls.count == 1)
|
||||
#expect(delegate.confirmCalls[0].statementKind == .delete)
|
||||
#expect(delegate.confirmCalls[0].sql.uppercased().contains("DELETE"))
|
||||
|
||||
// Verify the data was NOT modified
|
||||
let count = try await db.read { db in
|
||||
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 2")
|
||||
}
|
||||
#expect(count == 1, "User should NOT have been deleted")
|
||||
|
||||
// Verify no SQL was actually executed (no willExecute/didExecute calls)
|
||||
#expect(delegate.willExecuteCalls.isEmpty, "No SQL should have been executed")
|
||||
#expect(delegate.didExecuteCalls.isEmpty, "No SQL should have been executed")
|
||||
}
|
||||
|
||||
@Test("DELETE is blocked with MutationPolicy and no delegate")
|
||||
func deleteBlockedWithMutationPolicyNoDelegate() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let model = TestSequentialModel(responses: [
|
||||
"DELETE FROM users WHERE id = 3"
|
||||
])
|
||||
|
||||
let policy = MutationPolicy(
|
||||
allowedOperations: [.insert, .update, .delete],
|
||||
requiresDestructiveConfirmation: true
|
||||
)
|
||||
|
||||
let engine = ChatEngine(
|
||||
database: db,
|
||||
model: model,
|
||||
mutationPolicy: policy
|
||||
)
|
||||
|
||||
do {
|
||||
_ = try await engine.send("Delete user 3")
|
||||
Issue.record("Expected confirmationRequired error but send succeeded")
|
||||
} catch let error as SwiftDBAIError {
|
||||
guard case .confirmationRequired(let sql, let operation) = error else {
|
||||
Issue.record("Expected confirmationRequired, got: \(error)")
|
||||
return
|
||||
}
|
||||
#expect(sql.uppercased().contains("DELETE"))
|
||||
#expect(operation == "delete")
|
||||
}
|
||||
|
||||
// Data intact
|
||||
let count = try await db.read { db in
|
||||
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 3")
|
||||
}
|
||||
#expect(count == 1, "User should NOT have been deleted")
|
||||
}
|
||||
|
||||
@Test("DELETE is blocked with MutationPolicy and rejecting delegate")
|
||||
func deleteBlockedWithMutationPolicyRejectingDelegate() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let delegate = RejectingTrackingDelegate()
|
||||
let model = TestSequentialModel(responses: [
|
||||
"DELETE FROM orders WHERE user_id = 1"
|
||||
])
|
||||
|
||||
let policy = MutationPolicy(
|
||||
allowedOperations: [.insert, .update, .delete],
|
||||
requiresDestructiveConfirmation: true
|
||||
)
|
||||
|
||||
let engine = ChatEngine(
|
||||
database: db,
|
||||
model: model,
|
||||
mutationPolicy: policy,
|
||||
delegate: delegate
|
||||
)
|
||||
|
||||
do {
|
||||
_ = try await engine.send("Delete all orders for user 1")
|
||||
Issue.record("Expected confirmationRequired error but send succeeded")
|
||||
} catch let error as SwiftDBAIError {
|
||||
guard case .confirmationRequired = error else {
|
||||
Issue.record("Expected confirmationRequired, got: \(error)")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Delegate was consulted and rejected
|
||||
#expect(delegate.confirmCalls.count == 1)
|
||||
#expect(delegate.confirmCalls[0].statementKind == .delete)
|
||||
|
||||
// Orders remain
|
||||
let count = try await db.read { db in
|
||||
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM orders WHERE user_id = 1")
|
||||
}
|
||||
#expect(count == 1, "Orders should NOT have been deleted")
|
||||
}
|
||||
|
||||
@Test("Default delegate implementation rejects destructive operations")
|
||||
func defaultDelegateRejectsDestructive() async {
|
||||
struct DefaultDelegate: SwiftDBAI.ToolExecutionDelegate {}
|
||||
let delegate = DefaultDelegate()
|
||||
|
||||
let context = DestructiveOperationContext(
|
||||
sql: "DELETE FROM users WHERE id = 1",
|
||||
statementKind: .delete,
|
||||
classification: .destructive(.delete),
|
||||
description: "Delete from users"
|
||||
)
|
||||
|
||||
let approved = await delegate.confirmDestructiveOperation(context)
|
||||
#expect(approved == false, "Default delegate should reject destructive operations")
|
||||
}
|
||||
|
||||
@Test("DELETE not in readOnly allowlist is rejected before delegate is consulted")
|
||||
func deleteNotInAllowlistRejectedEarly() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let delegate = ApprovingTrackingDelegate()
|
||||
let model = TestSequentialModel(responses: [
|
||||
"DELETE FROM users WHERE id = 1"
|
||||
])
|
||||
|
||||
// Read-only allowlist does NOT include DELETE
|
||||
let engine = ChatEngine(
|
||||
database: db,
|
||||
model: model,
|
||||
allowlist: .readOnly,
|
||||
delegate: delegate
|
||||
)
|
||||
|
||||
do {
|
||||
_ = try await engine.send("Delete user 1")
|
||||
Issue.record("Expected operationNotAllowed error")
|
||||
} catch let error as SwiftDBAIError {
|
||||
guard case .operationNotAllowed(let operation) = error else {
|
||||
Issue.record("Expected operationNotAllowed, got: \(error)")
|
||||
return
|
||||
}
|
||||
#expect(operation == "delete")
|
||||
}
|
||||
|
||||
// Delegate should NOT have been consulted — the allowlist rejects before delegation
|
||||
#expect(delegate.confirmCalls.isEmpty, "Delegate should not be consulted when op is not in allowlist")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Tests: Destructive Operations Allowed When Delegate Approves
|
||||
|
||||
@Suite("Destructive Operations - Allowed When Delegate Approves")
|
||||
struct DestructiveOperationsAllowedTests {
|
||||
|
||||
@Test("DELETE succeeds when delegate approves")
|
||||
func deleteSucceedsWithApprovingDelegate() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let delegate = ApprovingTrackingDelegate()
|
||||
let model = TestSequentialModel(responses: [
|
||||
"DELETE FROM users WHERE id = 1",
|
||||
"Successfully deleted 1 user."
|
||||
])
|
||||
|
||||
let engine = ChatEngine(
|
||||
database: db,
|
||||
model: model,
|
||||
allowlist: .unrestricted,
|
||||
delegate: delegate
|
||||
)
|
||||
|
||||
let response = try await engine.send("Delete user 1")
|
||||
|
||||
// Delegate was consulted and approved
|
||||
#expect(delegate.confirmCalls.count == 1)
|
||||
#expect(delegate.confirmCalls[0].statementKind == .delete)
|
||||
#expect(delegate.confirmCalls[0].sql.uppercased().contains("DELETE"))
|
||||
#expect(delegate.confirmCalls[0].targetTable == "users")
|
||||
|
||||
// SQL was executed
|
||||
#expect(delegate.willExecuteCalls.count == 1)
|
||||
#expect(delegate.didExecuteCalls.count == 1)
|
||||
#expect(delegate.didExecuteCalls[0].success == true)
|
||||
|
||||
// Verify the data was actually deleted
|
||||
let count = try await db.read { db in
|
||||
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 1")
|
||||
}
|
||||
#expect(count == 0, "User should have been deleted")
|
||||
|
||||
// Response should contain meaningful content
|
||||
#expect(response.sql?.uppercased().contains("DELETE") == true)
|
||||
#expect(response.queryResult != nil)
|
||||
}
|
||||
|
||||
@Test("DELETE with MutationPolicy succeeds when delegate approves")
|
||||
func deleteWithPolicySucceedsWhenApproved() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let delegate = ApprovingTrackingDelegate()
|
||||
let model = TestSequentialModel(responses: [
|
||||
"DELETE FROM orders WHERE user_id = 2",
|
||||
"Deleted 1 order."
|
||||
])
|
||||
|
||||
let policy = MutationPolicy(
|
||||
allowedOperations: [.insert, .update, .delete],
|
||||
requiresDestructiveConfirmation: true
|
||||
)
|
||||
|
||||
let engine = ChatEngine(
|
||||
database: db,
|
||||
model: model,
|
||||
mutationPolicy: policy,
|
||||
delegate: delegate
|
||||
)
|
||||
|
||||
let response = try await engine.send("Delete all orders for user 2")
|
||||
|
||||
// Delegate approved
|
||||
#expect(delegate.confirmCalls.count == 1)
|
||||
|
||||
// Data was actually deleted
|
||||
let count = try await db.read { db in
|
||||
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM orders WHERE user_id = 2")
|
||||
}
|
||||
#expect(count == 0, "Orders should have been deleted")
|
||||
#expect(response.sql?.uppercased().contains("DELETE") == true)
|
||||
}
|
||||
|
||||
@Test("AutoApproveDelegate allows DELETE without user interaction")
|
||||
func autoApproveDelegateAllowsDelete() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let delegate = AutoApproveDelegate()
|
||||
let model = TestSequentialModel(responses: [
|
||||
"DELETE FROM users WHERE id = 3",
|
||||
"Deleted 1 user."
|
||||
])
|
||||
|
||||
let engine = ChatEngine(
|
||||
database: db,
|
||||
model: model,
|
||||
allowlist: .unrestricted,
|
||||
delegate: delegate
|
||||
)
|
||||
|
||||
let response = try await engine.send("Delete user 3")
|
||||
|
||||
// Should succeed without error
|
||||
#expect(response.sql?.uppercased().contains("DELETE") == true)
|
||||
|
||||
let count = try await db.read { db in
|
||||
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 3")
|
||||
}
|
||||
#expect(count == 0, "User should have been deleted")
|
||||
}
|
||||
|
||||
@Test("sendConfirmed bypasses delegate and executes directly")
|
||||
func sendConfirmedBypassesDelegate() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let delegate = RejectingTrackingDelegate()
|
||||
let model = TestSequentialModel(responses: [
|
||||
"Deleted 1 user."
|
||||
])
|
||||
|
||||
let engine = ChatEngine(
|
||||
database: db,
|
||||
model: model,
|
||||
allowlist: .unrestricted,
|
||||
delegate: delegate
|
||||
)
|
||||
|
||||
// sendConfirmed should execute directly without consulting the delegate for confirmation
|
||||
let response = try await engine.sendConfirmed(
|
||||
"Delete user 1",
|
||||
confirmedSQL: "DELETE FROM users WHERE id = 1"
|
||||
)
|
||||
|
||||
// Delegate was NOT asked to confirm (sendConfirmed skips confirmation)
|
||||
#expect(delegate.confirmCalls.isEmpty)
|
||||
|
||||
// But willExecute/didExecute hooks were still called
|
||||
#expect(delegate.willExecuteCalls.count == 1)
|
||||
#expect(delegate.didExecuteCalls.count == 1)
|
||||
#expect(delegate.didExecuteCalls[0].success == true)
|
||||
|
||||
// Data was deleted
|
||||
let count = try await db.read { db in
|
||||
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 1")
|
||||
}
|
||||
#expect(count == 0)
|
||||
#expect(response.summary.contains("deleted") || response.summary.contains("Deleted") || response.summary.contains("1"))
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Tests: Delegate Context Correctness
|
||||
|
||||
@Suite("Destructive Operations - Delegate Context")
|
||||
struct DestructiveOperationContextTests {
|
||||
|
||||
@Test("Delegate receives correct context for DELETE on specific table")
|
||||
func delegateReceivesCorrectContext() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let delegate = RejectingTrackingDelegate()
|
||||
let model = TestSequentialModel(responses: [
|
||||
"DELETE FROM orders WHERE amount < 50"
|
||||
])
|
||||
|
||||
let engine = ChatEngine(
|
||||
database: db,
|
||||
model: model,
|
||||
allowlist: .unrestricted,
|
||||
delegate: delegate
|
||||
)
|
||||
|
||||
do {
|
||||
_ = try await engine.send("Delete cheap orders")
|
||||
Issue.record("Expected confirmationRequired error")
|
||||
} catch is SwiftDBAIError {
|
||||
// Expected
|
||||
}
|
||||
|
||||
#expect(delegate.confirmCalls.count == 1)
|
||||
let ctx = delegate.confirmCalls[0]
|
||||
#expect(ctx.statementKind == .delete)
|
||||
#expect(ctx.classification == .destructive(.delete))
|
||||
#expect(ctx.classification.requiresConfirmation == true)
|
||||
#expect(ctx.sql.uppercased().contains("DELETE FROM ORDERS"))
|
||||
#expect(ctx.targetTable == "orders")
|
||||
#expect(!ctx.description.isEmpty)
|
||||
}
|
||||
|
||||
@Test("Non-destructive operations do not consult delegate")
|
||||
func selectDoesNotConsultDelegate() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let delegate = ApprovingTrackingDelegate()
|
||||
let model = TestSequentialModel(responses: [
|
||||
"SELECT COUNT(*) FROM users",
|
||||
"There are 3 users."
|
||||
])
|
||||
|
||||
let engine = ChatEngine(
|
||||
database: db,
|
||||
model: model,
|
||||
allowlist: .unrestricted,
|
||||
delegate: delegate
|
||||
)
|
||||
|
||||
_ = try await engine.send("How many users?")
|
||||
|
||||
// Delegate should NOT have been asked to confirm (SELECT is not destructive)
|
||||
#expect(delegate.confirmCalls.isEmpty)
|
||||
|
||||
// But willExecute/didExecute should still be called (observation hooks)
|
||||
#expect(delegate.willExecuteCalls.count == 1)
|
||||
#expect(delegate.didExecuteCalls.count == 1)
|
||||
}
|
||||
|
||||
@Test("INSERT does not require confirmation even with delegate")
|
||||
func insertDoesNotRequireConfirmation() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let delegate = RejectingTrackingDelegate()
|
||||
let model = TestSequentialModel(responses: [
|
||||
"INSERT INTO users (name, email) VALUES ('Dave', 'dave@example.com')",
|
||||
"Inserted 1 row."
|
||||
])
|
||||
|
||||
let engine = ChatEngine(
|
||||
database: db,
|
||||
model: model,
|
||||
allowlist: .standard,
|
||||
delegate: delegate
|
||||
)
|
||||
|
||||
let response = try await engine.send("Add user Dave")
|
||||
|
||||
// No confirmation needed for INSERT
|
||||
#expect(delegate.confirmCalls.isEmpty)
|
||||
#expect(response.sql?.uppercased().contains("INSERT") == true)
|
||||
|
||||
// Verify the insert happened
|
||||
let count = try await db.read { db in
|
||||
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE name = 'Dave'")
|
||||
}
|
||||
#expect(count == 1)
|
||||
}
|
||||
|
||||
@Test("UPDATE does not require confirmation even with delegate")
|
||||
func updateDoesNotRequireConfirmation() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let delegate = RejectingTrackingDelegate()
|
||||
let model = TestSequentialModel(responses: [
|
||||
"UPDATE users SET email = 'alice-new@example.com' WHERE id = 1",
|
||||
"Updated 1 row."
|
||||
])
|
||||
|
||||
let engine = ChatEngine(
|
||||
database: db,
|
||||
model: model,
|
||||
allowlist: .standard,
|
||||
delegate: delegate
|
||||
)
|
||||
|
||||
let response = try await engine.send("Update Alice's email")
|
||||
|
||||
// No confirmation needed for UPDATE
|
||||
#expect(delegate.confirmCalls.isEmpty)
|
||||
#expect(response.sql?.uppercased().contains("UPDATE") == true)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Tests: MutationPolicy Confirmation Flag
|
||||
|
||||
@Suite("Destructive Operations - MutationPolicy Confirmation Control")
|
||||
struct MutationPolicyConfirmationTests {
|
||||
|
||||
@Test("DELETE skips confirmation when requiresDestructiveConfirmation is false")
|
||||
func deleteSkipsConfirmationWhenDisabled() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let delegate = RejectingTrackingDelegate()
|
||||
let model = TestSequentialModel(responses: [
|
||||
"DELETE FROM users WHERE id = 1",
|
||||
"Deleted 1 user."
|
||||
])
|
||||
|
||||
let policy = MutationPolicy(
|
||||
allowedOperations: [.insert, .update, .delete],
|
||||
requiresDestructiveConfirmation: false // Explicitly disabled
|
||||
)
|
||||
|
||||
let engine = ChatEngine(
|
||||
database: db,
|
||||
model: model,
|
||||
mutationPolicy: policy,
|
||||
delegate: delegate
|
||||
)
|
||||
|
||||
// Should succeed without confirmation since the policy disables it
|
||||
let response = try await engine.send("Delete user 1")
|
||||
|
||||
// Delegate should NOT have been consulted for confirmation
|
||||
#expect(delegate.confirmCalls.isEmpty)
|
||||
|
||||
// But the SQL should have executed
|
||||
#expect(response.sql?.uppercased().contains("DELETE") == true)
|
||||
|
||||
let count = try await db.read { db in
|
||||
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 1")
|
||||
}
|
||||
#expect(count == 0, "User should have been deleted without confirmation")
|
||||
}
|
||||
|
||||
@Test("MutationPolicy.requiresConfirmation only triggers for DELETE")
|
||||
func requiresConfirmationOnlyForDelete() {
|
||||
let policy = MutationPolicy(
|
||||
allowedOperations: [.insert, .update, .delete],
|
||||
requiresDestructiveConfirmation: true
|
||||
)
|
||||
|
||||
#expect(policy.requiresConfirmation(for: .delete) == true)
|
||||
#expect(policy.requiresConfirmation(for: .select) == false)
|
||||
#expect(policy.requiresConfirmation(for: .insert) == false)
|
||||
#expect(policy.requiresConfirmation(for: .update) == false)
|
||||
}
|
||||
|
||||
@Test("MutationPolicy.readOnly never requires confirmation (no delete allowed)")
|
||||
func readOnlyNeverRequiresConfirmation() {
|
||||
let policy = MutationPolicy.readOnly
|
||||
|
||||
#expect(policy.requiresConfirmation(for: .select) == false)
|
||||
#expect(policy.requiresConfirmation(for: .delete) == true) // Would require confirmation IF allowed
|
||||
#expect(policy.isOperationAllowed(.delete) == false) // But it's not allowed at all
|
||||
}
|
||||
|
||||
@Test("Table-restricted DELETE is blocked for disallowed tables")
|
||||
func tableRestrictedDeleteBlocked() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let model = TestSequentialModel(responses: [
|
||||
"DELETE FROM users WHERE id = 1"
|
||||
])
|
||||
|
||||
let policy = MutationPolicy(
|
||||
allowedOperations: [.insert, .update, .delete],
|
||||
allowedTables: ["orders"], // Only orders, NOT users
|
||||
requiresDestructiveConfirmation: true
|
||||
)
|
||||
|
||||
let engine = ChatEngine(
|
||||
database: db,
|
||||
model: model,
|
||||
mutationPolicy: policy
|
||||
)
|
||||
|
||||
do {
|
||||
_ = try await engine.send("Delete user 1")
|
||||
Issue.record("Expected tableNotAllowedForMutation error")
|
||||
} catch let error as SwiftDBAIError {
|
||||
guard case .tableNotAllowedForMutation(let tableName, let operation) = error else {
|
||||
Issue.record("Expected tableNotAllowedForMutation, got: \(error)")
|
||||
return
|
||||
}
|
||||
#expect(tableName == "users")
|
||||
#expect(operation == "delete")
|
||||
}
|
||||
|
||||
// User was not deleted
|
||||
let count = try await db.read { db in
|
||||
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 1")
|
||||
}
|
||||
#expect(count == 1)
|
||||
}
|
||||
}
|
||||
49
Tests/SwiftDBAITests/Helpers/MockLanguageModel.swift
Normal file
49
Tests/SwiftDBAITests/Helpers/MockLanguageModel.swift
Normal file
@@ -0,0 +1,49 @@
|
||||
// MockLanguageModel.swift
|
||||
// SwiftDBAI Tests
|
||||
//
|
||||
// A mock LanguageModel for unit tests that returns canned responses.
|
||||
|
||||
import AnyLanguageModel
|
||||
import Foundation
|
||||
|
||||
/// A mock language model that returns a configurable canned response.
|
||||
///
|
||||
/// Used in tests to avoid hitting a real LLM provider.
|
||||
struct MockLanguageModel: LanguageModel {
|
||||
typealias UnavailableReason = Never
|
||||
|
||||
/// The text the mock will return from `respond(...)`.
|
||||
let responseText: String
|
||||
|
||||
init(responseText: String = "Mock summary response.") {
|
||||
self.responseText = responseText
|
||||
}
|
||||
|
||||
func respond<Content>(
|
||||
within session: LanguageModelSession,
|
||||
to prompt: Prompt,
|
||||
generating type: Content.Type,
|
||||
includeSchemaInPrompt: Bool,
|
||||
options: GenerationOptions
|
||||
) async throws -> LanguageModelSession.Response<Content> where Content: Generable {
|
||||
let rawContent = GeneratedContent(kind: .string(responseText))
|
||||
let content = try Content(rawContent)
|
||||
return LanguageModelSession.Response(
|
||||
content: content,
|
||||
rawContent: rawContent,
|
||||
transcriptEntries: [][...]
|
||||
)
|
||||
}
|
||||
|
||||
func streamResponse<Content>(
|
||||
within session: LanguageModelSession,
|
||||
to prompt: Prompt,
|
||||
generating type: Content.Type,
|
||||
includeSchemaInPrompt: Bool,
|
||||
options: GenerationOptions
|
||||
) -> sending LanguageModelSession.ResponseStream<Content> where Content: Generable {
|
||||
let rawContent = GeneratedContent(kind: .string(responseText))
|
||||
let content = try! Content(rawContent)
|
||||
return LanguageModelSession.ResponseStream(content: content, rawContent: rawContent)
|
||||
}
|
||||
}
|
||||
337
Tests/SwiftDBAITests/LocalProviderConfigurationTests.swift
Normal file
337
Tests/SwiftDBAITests/LocalProviderConfigurationTests.swift
Normal file
@@ -0,0 +1,337 @@
|
||||
// LocalProviderConfigurationTests.swift
|
||||
// SwiftDBAI Tests
|
||||
//
|
||||
// Tests for local/self-hosted provider configurations (Ollama, llama.cpp):
|
||||
// factory methods, endpoint discovery, connection handling, and model creation.
|
||||
|
||||
import AnyLanguageModel
|
||||
import Foundation
|
||||
import GRDB
|
||||
@testable import SwiftDBAI
|
||||
import Testing
|
||||
|
||||
@Suite("Local Provider Configuration")
|
||||
struct LocalProviderConfigurationTests {
|
||||
|
||||
// MARK: - Ollama Configuration
|
||||
|
||||
@Test("Ollama configuration stores provider and model")
|
||||
func ollamaBasicConfiguration() {
|
||||
let config = ProviderConfiguration.ollama(model: "llama3.2")
|
||||
|
||||
#expect(config.provider == .ollama)
|
||||
#expect(config.model == "llama3.2")
|
||||
#expect(config.baseURL == OllamaLanguageModel.defaultBaseURL)
|
||||
}
|
||||
|
||||
@Test("Ollama configuration produces OllamaLanguageModel")
|
||||
func ollamaMakeModel() {
|
||||
let config = ProviderConfiguration.ollama(model: "qwen2.5")
|
||||
|
||||
let model = config.makeModel()
|
||||
#expect(model is OllamaLanguageModel)
|
||||
}
|
||||
|
||||
@Test("Ollama with custom base URL for remote instance")
|
||||
func ollamaCustomBaseURL() {
|
||||
let remoteURL = URL(string: "http://192.168.1.100:11434")!
|
||||
let config = ProviderConfiguration.ollama(
|
||||
model: "mistral",
|
||||
baseURL: remoteURL
|
||||
)
|
||||
|
||||
#expect(config.baseURL == remoteURL)
|
||||
#expect(config.provider == .ollama)
|
||||
let model = config.makeModel()
|
||||
#expect(model is OllamaLanguageModel)
|
||||
}
|
||||
|
||||
@Test("Ollama does not require an API key")
|
||||
func ollamaNoAPIKey() {
|
||||
let config = ProviderConfiguration.ollama(model: "llama3.2")
|
||||
|
||||
// Ollama doesn't need an API key, so the key is empty
|
||||
#expect(config.apiKey == "")
|
||||
// hasValidAPIKey returns false because key is empty, but that's expected
|
||||
// for local providers — they don't need authentication
|
||||
#expect(!config.hasValidAPIKey)
|
||||
}
|
||||
|
||||
@Test("Ollama model is available without API key")
|
||||
func ollamaModelAvailable() {
|
||||
let config = ProviderConfiguration.ollama(model: "llama3.2")
|
||||
let model = config.makeModel()
|
||||
#expect(model.isAvailable)
|
||||
}
|
||||
|
||||
// MARK: - llama.cpp Configuration
|
||||
|
||||
@Test("llama.cpp configuration stores provider and model")
|
||||
func llamaCppBasicConfiguration() {
|
||||
let config = ProviderConfiguration.llamaCpp(model: "my-model")
|
||||
|
||||
#expect(config.provider == .llamaCpp)
|
||||
#expect(config.model == "my-model")
|
||||
#expect(config.baseURL == LocalProviderDiscovery.defaultLlamaCppURL)
|
||||
}
|
||||
|
||||
@Test("llama.cpp uses 'default' model name by default")
|
||||
func llamaCppDefaultModel() {
|
||||
let config = ProviderConfiguration.llamaCpp()
|
||||
|
||||
#expect(config.model == "default")
|
||||
}
|
||||
|
||||
@Test("llama.cpp configuration produces OpenAILanguageModel (compatible API)")
|
||||
func llamaCppMakeModel() {
|
||||
let config = ProviderConfiguration.llamaCpp(model: "my-gguf")
|
||||
|
||||
let model = config.makeModel()
|
||||
// llama.cpp uses OpenAI-compatible API
|
||||
#expect(model is OpenAILanguageModel)
|
||||
}
|
||||
|
||||
@Test("llama.cpp with custom base URL")
|
||||
func llamaCppCustomBaseURL() {
|
||||
let customURL = URL(string: "http://localhost:9090")!
|
||||
let config = ProviderConfiguration.llamaCpp(
|
||||
model: "custom-model",
|
||||
baseURL: customURL
|
||||
)
|
||||
|
||||
#expect(config.baseURL == customURL)
|
||||
let model = config.makeModel()
|
||||
#expect(model is OpenAILanguageModel)
|
||||
}
|
||||
|
||||
@Test("llama.cpp with API key authentication")
|
||||
func llamaCppWithAPIKey() {
|
||||
let config = ProviderConfiguration.llamaCpp(
|
||||
model: "secured-model",
|
||||
apiKey: "my-secret-key"
|
||||
)
|
||||
|
||||
#expect(config.apiKey == "my-secret-key")
|
||||
#expect(config.hasValidAPIKey)
|
||||
}
|
||||
|
||||
@Test("llama.cpp without API key")
|
||||
func llamaCppNoAPIKey() {
|
||||
let config = ProviderConfiguration.llamaCpp(model: "open-model")
|
||||
|
||||
#expect(config.apiKey == "")
|
||||
}
|
||||
|
||||
// MARK: - Provider Enum
|
||||
|
||||
@Test("Provider enum includes ollama and llamaCpp cases")
|
||||
func providerEnumHasLocalCases() {
|
||||
let cases = ProviderConfiguration.Provider.allCases
|
||||
#expect(cases.contains(.ollama))
|
||||
#expect(cases.contains(.llamaCpp))
|
||||
// Total: openAI, anthropic, gemini, openAICompatible, ollama, llamaCpp
|
||||
#expect(cases.count == 6)
|
||||
}
|
||||
|
||||
// MARK: - fromEnvironment
|
||||
|
||||
@Test("fromEnvironment creates Ollama configuration")
|
||||
func fromEnvironmentOllama() {
|
||||
let config = ProviderConfiguration.fromEnvironment(
|
||||
provider: .ollama,
|
||||
environmentVariable: "NONEXISTENT_OLLAMA_KEY",
|
||||
model: "llama3.2"
|
||||
)
|
||||
|
||||
#expect(config.provider == .ollama)
|
||||
#expect(config.model == "llama3.2")
|
||||
}
|
||||
|
||||
@Test("fromEnvironment creates llama.cpp configuration")
|
||||
func fromEnvironmentLlamaCpp() {
|
||||
let config = ProviderConfiguration.fromEnvironment(
|
||||
provider: .llamaCpp,
|
||||
environmentVariable: "NONEXISTENT_LLAMACPP_KEY",
|
||||
model: "default"
|
||||
)
|
||||
|
||||
#expect(config.provider == .llamaCpp)
|
||||
#expect(config.model == "default")
|
||||
}
|
||||
|
||||
// MARK: - ChatEngine Convenience Init with Local Providers
|
||||
|
||||
@Test("ChatEngine can be created with Ollama provider")
|
||||
func chatEngineWithOllama() throws {
|
||||
let dbQueue = try GRDB.DatabaseQueue()
|
||||
let config = ProviderConfiguration.ollama(model: "llama3.2")
|
||||
let engine = ChatEngine(database: dbQueue, provider: config)
|
||||
|
||||
#expect(engine.tableCount == nil) // schema not yet introspected
|
||||
}
|
||||
|
||||
@Test("ChatEngine can be created with llama.cpp provider")
|
||||
func chatEngineWithLlamaCpp() throws {
|
||||
let dbQueue = try GRDB.DatabaseQueue()
|
||||
let config = ProviderConfiguration.llamaCpp()
|
||||
let engine = ChatEngine(database: dbQueue, provider: config)
|
||||
|
||||
#expect(engine.tableCount == nil)
|
||||
}
|
||||
|
||||
// MARK: - LocalProviderType
|
||||
|
||||
@Test("LocalProviderType has expected raw values")
|
||||
func localProviderTypeRawValues() {
|
||||
#expect(LocalProviderType.ollama.rawValue == "ollama")
|
||||
#expect(LocalProviderType.llamaCpp.rawValue == "llama.cpp")
|
||||
}
|
||||
|
||||
@Test("LocalProviderType CaseIterable includes both cases")
|
||||
func localProviderTypeCases() {
|
||||
let cases = LocalProviderType.allCases
|
||||
#expect(cases.count == 2)
|
||||
#expect(cases.contains(.ollama))
|
||||
#expect(cases.contains(.llamaCpp))
|
||||
}
|
||||
|
||||
// MARK: - LocalProviderEndpoint
|
||||
|
||||
@Test("LocalProviderEndpoint description includes status and model count")
|
||||
func endpointDescription() {
|
||||
let endpoint = LocalProviderEndpoint(
|
||||
baseURL: URL(string: "http://localhost:11434")!,
|
||||
providerType: .ollama,
|
||||
isReachable: true,
|
||||
availableModels: ["llama3.2", "qwen2.5"]
|
||||
)
|
||||
|
||||
#expect(endpoint.description.contains("ollama"))
|
||||
#expect(endpoint.description.contains("reachable"))
|
||||
#expect(endpoint.description.contains("2 models"))
|
||||
}
|
||||
|
||||
@Test("LocalProviderEndpoint shows unreachable when not connected")
|
||||
func endpointUnreachableDescription() {
|
||||
let endpoint = LocalProviderEndpoint(
|
||||
baseURL: URL(string: "http://localhost:8080")!,
|
||||
providerType: .llamaCpp,
|
||||
isReachable: false,
|
||||
availableModels: []
|
||||
)
|
||||
|
||||
#expect(endpoint.description.contains("unreachable"))
|
||||
#expect(endpoint.description.contains("0 models"))
|
||||
}
|
||||
|
||||
@Test("LocalProviderEndpoint equality works correctly")
|
||||
func endpointEquality() {
|
||||
let a = LocalProviderEndpoint(
|
||||
baseURL: URL(string: "http://localhost:11434")!,
|
||||
providerType: .ollama,
|
||||
isReachable: true,
|
||||
availableModels: ["llama3.2"]
|
||||
)
|
||||
let b = LocalProviderEndpoint(
|
||||
baseURL: URL(string: "http://localhost:11434")!,
|
||||
providerType: .ollama,
|
||||
isReachable: true,
|
||||
availableModels: ["llama3.2"]
|
||||
)
|
||||
let c = LocalProviderEndpoint(
|
||||
baseURL: URL(string: "http://localhost:11434")!,
|
||||
providerType: .ollama,
|
||||
isReachable: false,
|
||||
availableModels: []
|
||||
)
|
||||
|
||||
#expect(a == b)
|
||||
#expect(a != c)
|
||||
}
|
||||
|
||||
// MARK: - Discovery (No Local Server Running)
|
||||
|
||||
@Test("Discovery returns unreachable when no server is running")
|
||||
func discoveryUnreachableEndpoint() async {
|
||||
// Use a port that's almost certainly not running anything
|
||||
let endpoint = await LocalProviderDiscovery.discover(
|
||||
providerType: .ollama,
|
||||
host: "127.0.0.1",
|
||||
port: 59999,
|
||||
timeout: 1
|
||||
)
|
||||
|
||||
#expect(!endpoint.isReachable)
|
||||
#expect(endpoint.availableModels.isEmpty)
|
||||
#expect(endpoint.providerType == .ollama)
|
||||
}
|
||||
|
||||
@Test("isOllamaRunning returns false for unreachable endpoint")
|
||||
func ollamaNotRunning() async {
|
||||
let unreachableURL = URL(string: "http://127.0.0.1:59998")!
|
||||
let running = await LocalProviderDiscovery.isOllamaRunning(
|
||||
at: unreachableURL,
|
||||
timeout: 1
|
||||
)
|
||||
|
||||
#expect(!running)
|
||||
}
|
||||
|
||||
@Test("isLlamaCppRunning returns false for unreachable endpoint")
|
||||
func llamaCppNotRunning() async {
|
||||
let unreachableURL = URL(string: "http://127.0.0.1:59997")!
|
||||
let running = await LocalProviderDiscovery.isLlamaCppRunning(
|
||||
at: unreachableURL,
|
||||
timeout: 1
|
||||
)
|
||||
|
||||
#expect(!running)
|
||||
}
|
||||
|
||||
@Test("listOllamaModels returns empty for unreachable endpoint")
|
||||
func ollamaModelsUnreachable() async {
|
||||
let unreachableURL = URL(string: "http://127.0.0.1:59996")!
|
||||
let models = await LocalProviderDiscovery.listOllamaModels(
|
||||
at: unreachableURL,
|
||||
timeout: 1
|
||||
)
|
||||
|
||||
#expect(models.isEmpty)
|
||||
}
|
||||
|
||||
@Test("listLlamaCppModels returns empty for unreachable endpoint")
|
||||
func llamaCppModelsUnreachable() async {
|
||||
let unreachableURL = URL(string: "http://127.0.0.1:59995")!
|
||||
let models = await LocalProviderDiscovery.listLlamaCppModels(
|
||||
at: unreachableURL,
|
||||
timeout: 1
|
||||
)
|
||||
|
||||
#expect(models.isEmpty)
|
||||
}
|
||||
|
||||
@Test("discoverAll returns endpoints for both provider types")
|
||||
func discoverAllReturnsAllProviders() async {
|
||||
// Use very short timeout since we likely don't have servers running
|
||||
let endpoints = await LocalProviderDiscovery.discoverAll(timeout: 0.5)
|
||||
|
||||
// Should return exactly 2 endpoints (one per well-known provider)
|
||||
#expect(endpoints.count == 2)
|
||||
|
||||
let types = Set(endpoints.map(\.providerType))
|
||||
#expect(types.contains(.ollama))
|
||||
#expect(types.contains(.llamaCpp))
|
||||
}
|
||||
|
||||
// MARK: - Default URLs
|
||||
|
||||
@Test("Default Ollama URL is correct")
|
||||
func defaultOllamaURL() {
|
||||
#expect(LocalProviderDiscovery.defaultOllamaURL.absoluteString == "http://localhost:11434")
|
||||
}
|
||||
|
||||
@Test("Default llama.cpp URL is correct")
|
||||
func defaultLlamaCppURL() {
|
||||
#expect(LocalProviderDiscovery.defaultLlamaCppURL.absoluteString == "http://localhost:8080")
|
||||
}
|
||||
}
|
||||
363
Tests/SwiftDBAITests/MultiTurnContextTests.swift
Normal file
363
Tests/SwiftDBAITests/MultiTurnContextTests.swift
Normal file
@@ -0,0 +1,363 @@
|
||||
// MultiTurnContextTests.swift
|
||||
// SwiftDBAI Tests
|
||||
//
|
||||
// Tests verifying multi-turn conversation context — follow-up queries
|
||||
// correctly reference the prior query's table, columns, and results.
|
||||
|
||||
import AnyLanguageModel
|
||||
import Foundation
|
||||
import GRDB
|
||||
import Testing
|
||||
|
||||
@testable import SwiftDBAI
|
||||
|
||||
@Suite("Multi-Turn Context Tests")
|
||||
struct MultiTurnContextTests {
|
||||
|
||||
// MARK: - Test Database Setup
|
||||
|
||||
/// Creates an in-memory database with users (including age) and orders.
|
||||
private func makeTestDatabase() throws -> DatabaseQueue {
|
||||
let db = try DatabaseQueue(path: ":memory:")
|
||||
try db.write { db in
|
||||
try db.execute(sql: """
|
||||
CREATE TABLE users (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
age INTEGER NOT NULL,
|
||||
email TEXT NOT NULL,
|
||||
city TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
try db.execute(sql: """
|
||||
INSERT INTO users (name, age, email, city) VALUES
|
||||
('Alice', 25, 'alice@example.com', 'New York'),
|
||||
('Bob', 35, 'bob@example.com', 'San Francisco'),
|
||||
('Charlie', 42, 'charlie@example.com', 'New York'),
|
||||
('Diana', 28, 'diana@example.com', 'Chicago'),
|
||||
('Eve', 55, 'eve@example.com', 'San Francisco')
|
||||
""")
|
||||
try db.execute(sql: """
|
||||
CREATE TABLE orders (
|
||||
id INTEGER PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
amount REAL NOT NULL,
|
||||
status TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL,
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
)
|
||||
""")
|
||||
try db.execute(sql: """
|
||||
INSERT INTO orders (user_id, amount, status, created_at) VALUES
|
||||
(1, 99.99, 'completed', '2024-01-15'),
|
||||
(1, 49.50, 'pending', '2024-02-20'),
|
||||
(2, 150.00, 'completed', '2024-01-10'),
|
||||
(3, 200.00, 'completed', '2024-03-01'),
|
||||
(4, 75.00, 'cancelled', '2024-02-05')
|
||||
""")
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// MARK: - Multi-Turn Context Tests
|
||||
|
||||
@Test("Follow-up 'filter those by age > 30' references prior 'show all users' context")
|
||||
func followUpFilterReferencesUsersTable() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
|
||||
// Turn 1: "show all users" → SELECT * FROM users (returns 5 rows, LLM summary needed)
|
||||
// Turn 2: "filter those by age > 30" → should reference users table from context
|
||||
let mock = PromptCapturingMockModel(responses: [
|
||||
"SELECT * FROM users",
|
||||
"Here are all 5 users in the database.",
|
||||
"SELECT * FROM users WHERE age > 30",
|
||||
"Found 3 users over 30: Bob (35), Charlie (42), and Eve (55)."
|
||||
])
|
||||
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
|
||||
// First turn: show all users
|
||||
let response1 = try await engine.send("show all users")
|
||||
#expect(response1.sql == "SELECT * FROM users")
|
||||
#expect(response1.queryResult?.rowCount == 5)
|
||||
|
||||
// Second turn: follow-up with implicit reference
|
||||
let response2 = try await engine.send("filter those by age > 30")
|
||||
#expect(response2.sql == "SELECT * FROM users WHERE age > 30")
|
||||
#expect(response2.queryResult?.rowCount == 3)
|
||||
|
||||
// Verify the follow-up prompt includes conversation history
|
||||
let prompts = mock.capturedPrompts
|
||||
// Find the prompt for the second SQL generation (skip summary prompts)
|
||||
let followUpSQLPrompt = prompts.first { prompt in
|
||||
prompt.contains("filter those by age > 30") && prompt.contains("CONVERSATION HISTORY")
|
||||
}
|
||||
#expect(followUpSQLPrompt != nil, "Follow-up prompt should contain CONVERSATION HISTORY")
|
||||
|
||||
// The conversation history should include the prior query and its SQL
|
||||
if let prompt = followUpSQLPrompt {
|
||||
#expect(prompt.contains("show all users"), "History should contain prior user message")
|
||||
#expect(prompt.contains("SELECT * FROM users"), "History should contain prior SQL")
|
||||
#expect(prompt.contains("filter those by age > 30"), "Prompt should contain current question")
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Follow-up correctly inherits table context across multiple turns")
|
||||
func multipleFollowUpsInheritContext() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
|
||||
// 3-turn conversation narrowing down results
|
||||
let mock = PromptCapturingMockModel(responses: [
|
||||
"SELECT * FROM users",
|
||||
"Here are all 5 users.",
|
||||
"SELECT * FROM users WHERE city = 'New York'",
|
||||
"Found 2 users in New York: Alice and Charlie.",
|
||||
"SELECT * FROM users WHERE city = 'New York' AND age > 30",
|
||||
"Charlie (42) is the only New York user over 30."
|
||||
])
|
||||
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
|
||||
// Turn 1
|
||||
_ = try await engine.send("show all users")
|
||||
|
||||
// Turn 2 — narrows by city
|
||||
let response2 = try await engine.send("only those in New York")
|
||||
#expect(response2.sql == "SELECT * FROM users WHERE city = 'New York'")
|
||||
#expect(response2.queryResult?.rowCount == 2)
|
||||
|
||||
// Turn 3 — further narrows by age
|
||||
let response3 = try await engine.send("now filter by age over 30")
|
||||
#expect(response3.sql == "SELECT * FROM users WHERE city = 'New York' AND age > 30")
|
||||
#expect(response3.queryResult?.rowCount == 1)
|
||||
|
||||
// Verify third turn's prompt includes the full conversation history
|
||||
let prompts = mock.capturedPrompts
|
||||
let thirdTurnPrompt = prompts.last { prompt in
|
||||
prompt.contains("now filter by age over 30") && prompt.contains("CONVERSATION HISTORY")
|
||||
}
|
||||
#expect(thirdTurnPrompt != nil)
|
||||
|
||||
if let prompt = thirdTurnPrompt {
|
||||
// Should include both prior user messages
|
||||
#expect(prompt.contains("show all users"))
|
||||
#expect(prompt.contains("only those in New York"))
|
||||
// Should include prior SQL
|
||||
#expect(prompt.contains("SELECT * FROM users"))
|
||||
#expect(prompt.contains("SELECT * FROM users WHERE city = 'New York'"))
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Follow-up switching tables preserves cross-table context")
|
||||
func followUpSwitchesTableWithContext() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
|
||||
// Turn 1: query users, Turn 2: ask about their orders
|
||||
let mock = PromptCapturingMockModel(responses: [
|
||||
"SELECT name, age FROM users WHERE age > 30",
|
||||
"Found 3 users over 30.",
|
||||
"SELECT o.id, u.name, o.amount, o.status FROM orders o JOIN users u ON o.user_id = u.id WHERE u.age > 30",
|
||||
"Bob has a $150 completed order, Charlie has a $200 completed order."
|
||||
])
|
||||
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
|
||||
// Turn 1: users over 30
|
||||
let response1 = try await engine.send("show users over 30")
|
||||
#expect(response1.queryResult?.rowCount == 3)
|
||||
|
||||
// Turn 2: their orders — references the previous result context
|
||||
let response2 = try await engine.send("show their orders")
|
||||
#expect(response2.sql?.contains("JOIN") == true)
|
||||
|
||||
// Verify the follow-up prompt contains the users context
|
||||
let prompts = mock.capturedPrompts
|
||||
let orderPrompt = prompts.first { prompt in
|
||||
prompt.contains("show their orders") && prompt.contains("CONVERSATION HISTORY")
|
||||
}
|
||||
#expect(orderPrompt != nil)
|
||||
|
||||
if let prompt = orderPrompt {
|
||||
#expect(prompt.contains("show users over 30"), "Should contain prior user message")
|
||||
#expect(prompt.contains("age > 30"), "Should contain prior SQL context for table reference")
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Conversation history includes SQL from prior turns for context")
|
||||
func historyIncludesSQLFromPriorTurns() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
|
||||
// Both queries are aggregates → no LLM summarization needed
|
||||
let mock = PromptCapturingMockModel(responses: [
|
||||
"SELECT COUNT(*) FROM users",
|
||||
"SELECT COUNT(*) FROM users WHERE age > 30",
|
||||
])
|
||||
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
|
||||
// Turn 1
|
||||
let r1 = try await engine.send("how many users are there?")
|
||||
#expect(r1.sql == "SELECT COUNT(*) FROM users")
|
||||
|
||||
// Turn 2 — references "those" implicitly
|
||||
let r2 = try await engine.send("how many of those are over 30?")
|
||||
#expect(r2.sql == "SELECT COUNT(*) FROM users WHERE age > 30")
|
||||
|
||||
// Verify engine history has all 4 messages (2 user + 2 assistant)
|
||||
let messages = engine.messages
|
||||
#expect(messages.count == 4)
|
||||
#expect(messages[0].role == .user)
|
||||
#expect(messages[0].content == "how many users are there?")
|
||||
#expect(messages[1].role == .assistant)
|
||||
#expect(messages[1].sql == "SELECT COUNT(*) FROM users")
|
||||
#expect(messages[2].role == .user)
|
||||
#expect(messages[2].content == "how many of those are over 30?")
|
||||
#expect(messages[3].role == .assistant)
|
||||
#expect(messages[3].sql == "SELECT COUNT(*) FROM users WHERE age > 30")
|
||||
|
||||
// The second prompt should reference the first query SQL
|
||||
let prompts = mock.capturedPrompts
|
||||
#expect(prompts.count >= 2)
|
||||
let secondPrompt = prompts[1]
|
||||
#expect(secondPrompt.contains("CONVERSATION HISTORY"))
|
||||
#expect(secondPrompt.contains("SELECT COUNT(*) FROM users"))
|
||||
#expect(secondPrompt.contains("how many users are there?"))
|
||||
}
|
||||
|
||||
@Test("Follow-up after aggregate uses prior table context")
|
||||
func followUpAfterAggregateUsesTableContext() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
|
||||
// Turn 1: aggregate (no LLM summary needed)
|
||||
// Turn 2: follow-up referencing "those"
|
||||
let mock = PromptCapturingMockModel(responses: [
|
||||
"SELECT AVG(age) FROM users",
|
||||
"SELECT name, age FROM users WHERE age > 35",
|
||||
"Charlie (42) and Eve (55) are older than average."
|
||||
])
|
||||
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
|
||||
// Turn 1: average age → aggregate, template summary
|
||||
let r1 = try await engine.send("what is the average age of users?")
|
||||
#expect(r1.sql == "SELECT AVG(age) FROM users")
|
||||
|
||||
// Turn 2: "who is above that?" — needs the avg context
|
||||
let r2 = try await engine.send("who is above average?")
|
||||
#expect(r2.queryResult?.rowCount == 2)
|
||||
|
||||
// Verify context passed
|
||||
let prompts = mock.capturedPrompts
|
||||
let followUp = prompts.first { prompt in
|
||||
prompt.contains("who is above average?") && prompt.contains("CONVERSATION HISTORY")
|
||||
}
|
||||
#expect(followUp != nil)
|
||||
if let prompt = followUp {
|
||||
#expect(prompt.contains("AVG(age)"), "Should include prior aggregate SQL for context")
|
||||
#expect(prompt.contains("users"), "Should include table reference from prior turn")
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Context window limits how much history is visible in follow-ups")
|
||||
func contextWindowLimitsHistoryInFollowUps() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
|
||||
// 3 turns, but context window of 2 messages
|
||||
let mock = PromptCapturingMockModel(responses: [
|
||||
"SELECT COUNT(*) FROM users",
|
||||
"SELECT COUNT(*) FROM orders",
|
||||
"SELECT COUNT(*) FROM users WHERE age > 30",
|
||||
])
|
||||
|
||||
let config = ChatEngineConfiguration(
|
||||
queryTimeout: nil,
|
||||
contextWindowSize: 2
|
||||
)
|
||||
|
||||
let engine = ChatEngine(
|
||||
database: db,
|
||||
model: mock,
|
||||
configuration: config
|
||||
)
|
||||
|
||||
_ = try await engine.send("how many users?")
|
||||
_ = try await engine.send("how many orders?")
|
||||
_ = try await engine.send("how many users over 30?")
|
||||
|
||||
// The third prompt should only have the last 2 messages from turn 2
|
||||
let prompts = mock.capturedPrompts
|
||||
#expect(prompts.count >= 3)
|
||||
|
||||
let thirdPrompt = prompts[2]
|
||||
#expect(thirdPrompt.contains("CONVERSATION HISTORY"))
|
||||
// Turn 2 context should be present
|
||||
#expect(thirdPrompt.contains("how many orders?"))
|
||||
#expect(thirdPrompt.contains("SELECT COUNT(*) FROM orders"))
|
||||
// Turn 1 context should be trimmed (window=2 means last 2 messages)
|
||||
#expect(!thirdPrompt.contains("how many users?\n"), "First turn should be trimmed from context window")
|
||||
}
|
||||
|
||||
@Test("clearHistory resets context so follow-ups have no prior history")
|
||||
func clearHistoryResetsFollowUpContext() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
|
||||
let mock = PromptCapturingMockModel(responses: [
|
||||
"SELECT * FROM users",
|
||||
"Here are the 5 users.",
|
||||
"SELECT COUNT(*) FROM users",
|
||||
])
|
||||
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
|
||||
// Turn 1
|
||||
_ = try await engine.send("show all users")
|
||||
#expect(engine.messages.count == 2)
|
||||
|
||||
// Clear history
|
||||
engine.clearHistory()
|
||||
#expect(engine.messages.isEmpty)
|
||||
|
||||
// Turn 2 after clear — should NOT have conversation history
|
||||
_ = try await engine.send("count all users")
|
||||
|
||||
let prompts = mock.capturedPrompts
|
||||
let lastPrompt = prompts.last!
|
||||
// After clearing, the prompt should NOT contain conversation history
|
||||
#expect(!lastPrompt.contains("CONVERSATION HISTORY"),
|
||||
"After clearHistory(), follow-up should not have prior context")
|
||||
#expect(!lastPrompt.contains("show all users"),
|
||||
"After clearHistory(), prior messages should be gone")
|
||||
}
|
||||
|
||||
@Test("Multi-turn with result data in context enables informed follow-ups")
|
||||
func resultDataInContextEnablesInformedFollowUps() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
|
||||
// Turn 1: list users → multi-row result, LLM summarizes
|
||||
// Turn 2: "sort those by age" → references same table
|
||||
let mock = PromptCapturingMockModel(responses: [
|
||||
"SELECT name, age, city FROM users",
|
||||
"Found 5 users: Alice (25, NY), Bob (35, SF), Charlie (42, NY), Diana (28, Chicago), Eve (55, SF).",
|
||||
"SELECT name, age, city FROM users ORDER BY age DESC",
|
||||
"Users sorted by age: Eve (55), Charlie (42), Bob (35), Diana (28), Alice (25)."
|
||||
])
|
||||
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
|
||||
let r1 = try await engine.send("list all users with their age and city")
|
||||
#expect(r1.queryResult?.rowCount == 5)
|
||||
#expect(r1.queryResult?.columns.contains("age") == true)
|
||||
#expect(r1.queryResult?.columns.contains("city") == true)
|
||||
|
||||
let r2 = try await engine.send("sort those by age descending")
|
||||
#expect(r2.sql == "SELECT name, age, city FROM users ORDER BY age DESC")
|
||||
|
||||
// Verify the assistant message in history includes the SQL
|
||||
let messages = engine.messages
|
||||
#expect(messages.count == 4)
|
||||
// First assistant message should have the SQL recorded
|
||||
#expect(messages[1].sql == "SELECT name, age, city FROM users")
|
||||
// Second assistant should have the sorted SQL
|
||||
#expect(messages[3].sql == "SELECT name, age, city FROM users ORDER BY age DESC")
|
||||
}
|
||||
}
|
||||
508
Tests/SwiftDBAITests/OnDeviceProviderConfigurationTests.swift
Normal file
508
Tests/SwiftDBAITests/OnDeviceProviderConfigurationTests.swift
Normal file
@@ -0,0 +1,508 @@
|
||||
// OnDeviceProviderConfigurationTests.swift
|
||||
// SwiftDBAI Tests
|
||||
//
|
||||
// Tests for on-device provider configurations (CoreML, MLX) including
|
||||
// configuration validation, inference pipeline setup, and system readiness.
|
||||
|
||||
import AnyLanguageModel
|
||||
import Foundation
|
||||
@testable import SwiftDBAI
|
||||
import Testing
|
||||
|
||||
@Suite("OnDeviceProviderConfiguration")
|
||||
struct OnDeviceProviderConfigurationTests {
|
||||
|
||||
// MARK: - OnDeviceProviderType
|
||||
|
||||
@Test("OnDeviceProviderType has CoreML and MLX cases")
|
||||
func providerTypeCases() {
|
||||
let cases = OnDeviceProviderType.allCases
|
||||
#expect(cases.count == 2)
|
||||
#expect(cases.contains(.coreML))
|
||||
#expect(cases.contains(.mlx))
|
||||
}
|
||||
|
||||
@Test("OnDeviceProviderType raw values are descriptive")
|
||||
func providerTypeRawValues() {
|
||||
#expect(OnDeviceProviderType.coreML.rawValue == "coreML")
|
||||
#expect(OnDeviceProviderType.mlx.rawValue == "mlx")
|
||||
}
|
||||
|
||||
// MARK: - CoreML Configuration
|
||||
|
||||
@Test("CoreML configuration stores all properties")
|
||||
func coreMLBasicConfiguration() {
|
||||
let url = URL(fileURLWithPath: "/tmp/TestModel.mlmodelc")
|
||||
let config = CoreMLProviderConfiguration(
|
||||
modelURL: url,
|
||||
computeUnits: .cpuAndGPU,
|
||||
maxResponseTokens: 1024,
|
||||
useSampling: true,
|
||||
temperature: 0.3
|
||||
)
|
||||
|
||||
#expect(config.modelURL == url)
|
||||
#expect(config.computeUnits == .cpuAndGPU)
|
||||
#expect(config.maxResponseTokens == 1024)
|
||||
#expect(config.useSampling == true)
|
||||
#expect(config.temperature == 0.3)
|
||||
}
|
||||
|
||||
@Test("CoreML configuration uses sensible defaults")
|
||||
func coreMLDefaultConfiguration() {
|
||||
let url = URL(fileURLWithPath: "/tmp/TestModel.mlmodelc")
|
||||
let config = CoreMLProviderConfiguration(modelURL: url)
|
||||
|
||||
#expect(config.computeUnits == .all)
|
||||
#expect(config.maxResponseTokens == 2048)
|
||||
#expect(config.useSampling == false)
|
||||
#expect(config.temperature == 0.1)
|
||||
}
|
||||
|
||||
@Test("CoreML validation fails for non-mlmodelc extension")
|
||||
func coreMLValidateWrongExtension() {
|
||||
let url = URL(fileURLWithPath: "/tmp/TestModel.onnx")
|
||||
let config = CoreMLProviderConfiguration(modelURL: url)
|
||||
|
||||
#expect(throws: OnDeviceProviderError.self) {
|
||||
try config.validate()
|
||||
}
|
||||
}
|
||||
|
||||
@Test("CoreML validation fails for missing model file")
|
||||
func coreMLValidateMissingFile() {
|
||||
let url = URL(fileURLWithPath: "/nonexistent/path/Model.mlmodelc")
|
||||
let config = CoreMLProviderConfiguration(modelURL: url)
|
||||
|
||||
#expect(throws: OnDeviceProviderError.self) {
|
||||
try config.validate()
|
||||
}
|
||||
}
|
||||
|
||||
@Test("CoreML configuration is Equatable")
|
||||
func coreMLEquatable() {
|
||||
let url = URL(fileURLWithPath: "/tmp/TestModel.mlmodelc")
|
||||
let a = CoreMLProviderConfiguration(modelURL: url, computeUnits: .all)
|
||||
let b = CoreMLProviderConfiguration(modelURL: url, computeUnits: .all)
|
||||
let c = CoreMLProviderConfiguration(modelURL: url, computeUnits: .cpuOnly)
|
||||
|
||||
#expect(a == b)
|
||||
#expect(a != c)
|
||||
}
|
||||
|
||||
// MARK: - ComputeUnitPreference
|
||||
|
||||
@Test("ComputeUnitPreference has all expected cases")
|
||||
func computeUnitCases() {
|
||||
let cases = ComputeUnitPreference.allCases
|
||||
#expect(cases.count == 4)
|
||||
#expect(cases.contains(.all))
|
||||
#expect(cases.contains(.cpuOnly))
|
||||
#expect(cases.contains(.cpuAndGPU))
|
||||
#expect(cases.contains(.cpuAndNeuralEngine))
|
||||
}
|
||||
|
||||
// MARK: - MLX Configuration
|
||||
|
||||
@Test("MLX configuration stores all properties")
|
||||
func mlxBasicConfiguration() {
|
||||
let dir = URL(fileURLWithPath: "/tmp/models/my-model")
|
||||
let config = MLXProviderConfiguration(
|
||||
modelId: "mlx-community/Test-Model-4bit",
|
||||
localDirectory: dir,
|
||||
gpuMemory: .minimal,
|
||||
maxResponseTokens: 512,
|
||||
temperature: 0.2,
|
||||
topP: 0.9,
|
||||
repetitionPenalty: 1.2
|
||||
)
|
||||
|
||||
#expect(config.modelId == "mlx-community/Test-Model-4bit")
|
||||
#expect(config.localDirectory == dir)
|
||||
#expect(config.gpuMemory == .minimal)
|
||||
#expect(config.maxResponseTokens == 512)
|
||||
#expect(config.temperature == 0.2)
|
||||
#expect(config.topP == 0.9)
|
||||
#expect(config.repetitionPenalty == 1.2)
|
||||
}
|
||||
|
||||
@Test("MLX configuration uses sensible defaults")
|
||||
func mlxDefaultConfiguration() {
|
||||
let config = MLXProviderConfiguration(modelId: "test-model")
|
||||
|
||||
#expect(config.localDirectory == nil)
|
||||
#expect(config.gpuMemory == .automatic)
|
||||
#expect(config.maxResponseTokens == 2048)
|
||||
#expect(config.temperature == 0.1)
|
||||
#expect(config.topP == 0.95)
|
||||
#expect(config.repetitionPenalty == 1.1)
|
||||
}
|
||||
|
||||
@Test("MLX validation fails for empty model ID")
|
||||
func mlxValidateEmptyModelId() {
|
||||
let config = MLXProviderConfiguration(modelId: "")
|
||||
|
||||
#expect(throws: OnDeviceProviderError.self) {
|
||||
try config.validate()
|
||||
}
|
||||
}
|
||||
|
||||
@Test("MLX validation fails for nonexistent local directory")
|
||||
func mlxValidateMissingDirectory() {
|
||||
let config = MLXProviderConfiguration(
|
||||
modelId: "test-model",
|
||||
localDirectory: URL(fileURLWithPath: "/nonexistent/directory")
|
||||
)
|
||||
|
||||
#expect(throws: OnDeviceProviderError.self) {
|
||||
try config.validate()
|
||||
}
|
||||
}
|
||||
|
||||
@Test("MLX validation fails for negative temperature")
|
||||
func mlxValidateNegativeTemperature() {
|
||||
let config = MLXProviderConfiguration(
|
||||
modelId: "test-model",
|
||||
temperature: -0.5
|
||||
)
|
||||
|
||||
#expect(throws: OnDeviceProviderError.self) {
|
||||
try config.validate()
|
||||
}
|
||||
}
|
||||
|
||||
@Test("MLX validation fails for topP out of range")
|
||||
func mlxValidateInvalidTopP() {
|
||||
let configZero = MLXProviderConfiguration(
|
||||
modelId: "test-model",
|
||||
topP: 0.0
|
||||
)
|
||||
|
||||
#expect(throws: OnDeviceProviderError.self) {
|
||||
try configZero.validate()
|
||||
}
|
||||
|
||||
let configOver = MLXProviderConfiguration(
|
||||
modelId: "test-model",
|
||||
topP: 1.5
|
||||
)
|
||||
|
||||
#expect(throws: OnDeviceProviderError.self) {
|
||||
try configOver.validate()
|
||||
}
|
||||
}
|
||||
|
||||
@Test("MLX validation fails for zero repetition penalty")
|
||||
func mlxValidateInvalidRepetitionPenalty() {
|
||||
let config = MLXProviderConfiguration(
|
||||
modelId: "test-model",
|
||||
repetitionPenalty: 0.0
|
||||
)
|
||||
|
||||
#expect(throws: OnDeviceProviderError.self) {
|
||||
try config.validate()
|
||||
}
|
||||
}
|
||||
|
||||
@Test("MLX validation succeeds for valid configuration")
|
||||
func mlxValidateSuccess() throws {
|
||||
let config = MLXProviderConfiguration(modelId: "test-model")
|
||||
// Should not throw (no local directory set, model ID is non-empty)
|
||||
try config.validate()
|
||||
}
|
||||
|
||||
@Test("MLX configuration is Equatable")
|
||||
func mlxEquatable() {
|
||||
let a = MLXProviderConfiguration(modelId: "model-a")
|
||||
let b = MLXProviderConfiguration(modelId: "model-a")
|
||||
let c = MLXProviderConfiguration(modelId: "model-b")
|
||||
|
||||
#expect(a == b)
|
||||
#expect(a != c)
|
||||
}
|
||||
|
||||
// MARK: - Well-Known MLX Models
|
||||
|
||||
@Test("Llama 3.2 3B preset has correct model ID")
|
||||
func llama3_2_3BPreset() {
|
||||
let config = MLXProviderConfiguration.llama3_2_3B()
|
||||
#expect(config.modelId == "mlx-community/Llama-3.2-3B-Instruct-4bit")
|
||||
#expect(config.temperature == 0.1)
|
||||
#expect(config.maxResponseTokens == 2048)
|
||||
}
|
||||
|
||||
@Test("Qwen 2.5 Coder 3B preset has correct model ID")
|
||||
func qwen2_5_coder3BPreset() {
|
||||
let config = MLXProviderConfiguration.qwen2_5_coder_3B()
|
||||
#expect(config.modelId == "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit")
|
||||
#expect(config.temperature == 0.05)
|
||||
}
|
||||
|
||||
@Test("Phi 3.5 Mini preset has correct model ID")
|
||||
func phi3_5_miniPreset() {
|
||||
let config = MLXProviderConfiguration.phi3_5_mini()
|
||||
#expect(config.modelId == "mlx-community/Phi-3.5-mini-instruct-4bit")
|
||||
#expect(config.temperature == 0.1)
|
||||
}
|
||||
|
||||
@Test("Well-known models accept custom GPU memory config")
|
||||
func wellKnownModelsCustomGPU() {
|
||||
let config = MLXProviderConfiguration.llama3_2_3B(
|
||||
gpuMemory: .minimal
|
||||
)
|
||||
#expect(config.gpuMemory == .minimal)
|
||||
}
|
||||
|
||||
// MARK: - GPU Memory Configuration
|
||||
|
||||
@Test("Automatic GPU memory config scales with RAM")
|
||||
func automaticGPUMemory() {
|
||||
let config = MLXGPUMemoryConfig.automatic
|
||||
#expect(config.activeCacheLimit > 0)
|
||||
#expect(config.idleCacheLimit == 50_000_000)
|
||||
#expect(config.clearCacheOnEviction == true)
|
||||
}
|
||||
|
||||
@Test("Minimal GPU memory config is conservative")
|
||||
func minimalGPUMemory() {
|
||||
let config = MLXGPUMemoryConfig.minimal
|
||||
#expect(config.activeCacheLimit == 64_000_000)
|
||||
#expect(config.idleCacheLimit == 16_000_000)
|
||||
#expect(config.clearCacheOnEviction == true)
|
||||
}
|
||||
|
||||
@Test("Unconstrained GPU memory config uses max values")
|
||||
func unconstrainedGPUMemory() {
|
||||
let config = MLXGPUMemoryConfig.unconstrained
|
||||
#expect(config.activeCacheLimit == Int.max)
|
||||
#expect(config.idleCacheLimit == Int.max)
|
||||
#expect(config.clearCacheOnEviction == false)
|
||||
}
|
||||
|
||||
@Test("GPU memory config is Equatable")
|
||||
func gpuMemoryEquatable() {
|
||||
#expect(MLXGPUMemoryConfig.minimal == MLXGPUMemoryConfig.minimal)
|
||||
#expect(MLXGPUMemoryConfig.minimal != MLXGPUMemoryConfig.unconstrained)
|
||||
}
|
||||
|
||||
// MARK: - On-Device Provider Errors
|
||||
|
||||
@Test("OnDeviceProviderError has descriptive messages")
|
||||
func errorDescriptions() {
|
||||
let errors: [OnDeviceProviderError] = [
|
||||
.modelNotFound(URL(fileURLWithPath: "/tmp/model")),
|
||||
.invalidModelFormat(expected: ".mlmodelc", actual: ".onnx"),
|
||||
.emptyModelId,
|
||||
.invalidParameter(name: "temperature", value: "-1", reason: "Must be non-negative"),
|
||||
.providerUnavailable(.mlx, reason: "MLX build flag not enabled"),
|
||||
.modelLoadFailed(reason: "Out of memory"),
|
||||
.inferenceFailed(reason: "Token limit exceeded"),
|
||||
]
|
||||
|
||||
for error in errors {
|
||||
#expect(error.errorDescription != nil)
|
||||
#expect(!error.errorDescription!.isEmpty)
|
||||
}
|
||||
}
|
||||
|
||||
@Test("OnDeviceProviderError is Equatable")
|
||||
func errorEquatable() {
|
||||
let a = OnDeviceProviderError.emptyModelId
|
||||
let b = OnDeviceProviderError.emptyModelId
|
||||
let c = OnDeviceProviderError.modelLoadFailed(reason: "test")
|
||||
|
||||
#expect(a == b)
|
||||
#expect(a != c)
|
||||
}
|
||||
|
||||
// MARK: - Inference Pipeline
|
||||
|
||||
@Test("MLX inference pipeline initializes with correct type")
|
||||
func mlxPipelineInit() {
|
||||
let config = MLXProviderConfiguration.llama3_2_3B()
|
||||
let pipeline = OnDeviceInferencePipeline(mlxConfiguration: config)
|
||||
|
||||
#expect(pipeline.providerType == .mlx)
|
||||
#expect(pipeline.mlxConfiguration != nil)
|
||||
#expect(pipeline.coreMLConfiguration == nil)
|
||||
#expect(pipeline.status == .notLoaded)
|
||||
}
|
||||
|
||||
@Test("CoreML inference pipeline initializes with correct type")
|
||||
func coreMLPipelineInit() {
|
||||
let url = URL(fileURLWithPath: "/tmp/TestModel.mlmodelc")
|
||||
let config = CoreMLProviderConfiguration(modelURL: url)
|
||||
let pipeline = OnDeviceInferencePipeline(coreMLConfiguration: config)
|
||||
|
||||
#expect(pipeline.providerType == .coreML)
|
||||
#expect(pipeline.coreMLConfiguration != nil)
|
||||
#expect(pipeline.mlxConfiguration == nil)
|
||||
#expect(pipeline.status == .notLoaded)
|
||||
}
|
||||
|
||||
@Test("Pipeline validates MLX configuration")
|
||||
func pipelineValidatesMLX() throws {
|
||||
let validConfig = MLXProviderConfiguration(modelId: "test-model")
|
||||
let pipeline = OnDeviceInferencePipeline(mlxConfiguration: validConfig)
|
||||
try pipeline.validateConfiguration()
|
||||
|
||||
let invalidConfig = MLXProviderConfiguration(modelId: "")
|
||||
let invalidPipeline = OnDeviceInferencePipeline(mlxConfiguration: invalidConfig)
|
||||
#expect(throws: OnDeviceProviderError.self) {
|
||||
try invalidPipeline.validateConfiguration()
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Pipeline validates CoreML configuration")
|
||||
func pipelineValidatesCoreML() {
|
||||
let url = URL(fileURLWithPath: "/tmp/TestModel.onnx")
|
||||
let config = CoreMLProviderConfiguration(modelURL: url)
|
||||
let pipeline = OnDeviceInferencePipeline(coreMLConfiguration: config)
|
||||
|
||||
#expect(throws: OnDeviceProviderError.self) {
|
||||
try pipeline.validateConfiguration()
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Pipeline provides SQL generation hints for MLX")
|
||||
func mlxSQLHints() {
|
||||
let config = MLXProviderConfiguration(
|
||||
modelId: "test-model",
|
||||
maxResponseTokens: 512,
|
||||
temperature: 0.2
|
||||
)
|
||||
let pipeline = OnDeviceInferencePipeline(mlxConfiguration: config)
|
||||
let hints = pipeline.recommendedSQLGenerationHints
|
||||
|
||||
#expect(hints.maxTokens == 512)
|
||||
#expect(hints.temperature == 0.2)
|
||||
#expect(hints.useSampling == true)
|
||||
#expect(hints.systemPromptSuffix.contains("MLX"))
|
||||
}
|
||||
|
||||
@Test("Pipeline provides SQL generation hints for CoreML")
|
||||
func coreMLSQLHints() {
|
||||
let url = URL(fileURLWithPath: "/tmp/TestModel.mlmodelc")
|
||||
let config = CoreMLProviderConfiguration(
|
||||
modelURL: url,
|
||||
maxResponseTokens: 1024,
|
||||
useSampling: false,
|
||||
temperature: 0.05
|
||||
)
|
||||
let pipeline = OnDeviceInferencePipeline(coreMLConfiguration: config)
|
||||
let hints = pipeline.recommendedSQLGenerationHints
|
||||
|
||||
#expect(hints.maxTokens == 1024)
|
||||
#expect(hints.temperature == 0.05)
|
||||
#expect(hints.useSampling == false)
|
||||
#expect(hints.systemPromptSuffix.contains("SQL"))
|
||||
}
|
||||
|
||||
// MARK: - System Readiness
|
||||
|
||||
@Test("System capability check returns valid data")
|
||||
func systemCapability() {
|
||||
let capability = OnDeviceModelReadiness.checkSystemCapability()
|
||||
|
||||
#expect(capability.totalRAM > 0)
|
||||
// On any modern test machine, we should have at least some RAM
|
||||
#expect(capability.totalRAM > 1024 * 1024 * 1024) // > 1GB
|
||||
|
||||
// On Apple silicon Macs, this should be true
|
||||
#if arch(arm64)
|
||||
#expect(capability.hasNeuralEngine == true)
|
||||
#endif
|
||||
}
|
||||
|
||||
@Test("Suggested MLX model returns a valid configuration")
|
||||
func suggestedMLXModel() {
|
||||
let config = OnDeviceModelReadiness.suggestedMLXModel()
|
||||
#expect(!config.modelId.isEmpty)
|
||||
#expect(config.temperature >= 0)
|
||||
#expect(config.maxResponseTokens > 0)
|
||||
}
|
||||
|
||||
@Test("Recommended model size enum has correct raw values")
|
||||
func recommendedModelSizeRawValues() {
|
||||
#expect(OnDeviceModelReadiness.RecommendedModelSize.small.rawValue == "small")
|
||||
#expect(OnDeviceModelReadiness.RecommendedModelSize.medium.rawValue == "medium")
|
||||
#expect(OnDeviceModelReadiness.RecommendedModelSize.large.rawValue == "large")
|
||||
}
|
||||
|
||||
// MARK: - ProviderConfiguration Integration
|
||||
|
||||
@Test("onDeviceMLX creates a ProviderConfiguration")
|
||||
func onDeviceMLXProviderConfig() {
|
||||
let mlxConfig = MLXProviderConfiguration.llama3_2_3B()
|
||||
let providerConfig = ProviderConfiguration.onDeviceMLX(mlxConfig)
|
||||
|
||||
#expect(providerConfig.model == mlxConfig.modelId)
|
||||
#expect(!providerConfig.hasValidAPIKey) // No API key needed for on-device
|
||||
}
|
||||
|
||||
@Test("onDeviceCoreML creates a ProviderConfiguration")
|
||||
func onDeviceCoreMLProviderConfig() {
|
||||
let url = URL(fileURLWithPath: "/tmp/SQLModel.mlmodelc")
|
||||
let coreMLConfig = CoreMLProviderConfiguration(modelURL: url)
|
||||
let providerConfig = ProviderConfiguration.onDeviceCoreML(coreMLConfig)
|
||||
|
||||
#expect(providerConfig.model == "SQLModel.mlmodelc")
|
||||
#expect(!providerConfig.hasValidAPIKey)
|
||||
}
|
||||
|
||||
// MARK: - Pipeline Status
|
||||
|
||||
@Test("Pipeline status transitions")
|
||||
func pipelineStatusTransitions() {
|
||||
let config = MLXProviderConfiguration(modelId: "test-model")
|
||||
let pipeline = OnDeviceInferencePipeline(mlxConfiguration: config)
|
||||
|
||||
#expect(pipeline.status == .notLoaded)
|
||||
|
||||
pipeline.setStatus(.loading)
|
||||
#expect(pipeline.status == .loading)
|
||||
|
||||
pipeline.setStatus(.ready)
|
||||
#expect(pipeline.status == .ready)
|
||||
|
||||
pipeline.setStatus(.failed("Out of memory"))
|
||||
#expect(pipeline.status == .failed("Out of memory"))
|
||||
}
|
||||
|
||||
@Test("Pipeline Status is Equatable")
|
||||
func pipelineStatusEquatable() {
|
||||
#expect(OnDeviceInferencePipeline.Status.notLoaded == .notLoaded)
|
||||
#expect(OnDeviceInferencePipeline.Status.loading == .loading)
|
||||
#expect(OnDeviceInferencePipeline.Status.ready == .ready)
|
||||
#expect(OnDeviceInferencePipeline.Status.failed("a") == .failed("a"))
|
||||
#expect(OnDeviceInferencePipeline.Status.failed("a") != .failed("b"))
|
||||
#expect(OnDeviceInferencePipeline.Status.notLoaded != .ready)
|
||||
}
|
||||
|
||||
// MARK: - SQL Generation Hints
|
||||
|
||||
@Test("SQL generation hints are Equatable")
|
||||
func sqlHintsEquatable() {
|
||||
let a = OnDeviceSQLGenerationHints(
|
||||
maxTokens: 512,
|
||||
temperature: 0.1,
|
||||
systemPromptSuffix: "test",
|
||||
useSampling: true
|
||||
)
|
||||
let b = OnDeviceSQLGenerationHints(
|
||||
maxTokens: 512,
|
||||
temperature: 0.1,
|
||||
systemPromptSuffix: "test",
|
||||
useSampling: true
|
||||
)
|
||||
let c = OnDeviceSQLGenerationHints(
|
||||
maxTokens: 1024,
|
||||
temperature: 0.1,
|
||||
systemPromptSuffix: "test",
|
||||
useSampling: true
|
||||
)
|
||||
|
||||
#expect(a == b)
|
||||
#expect(a != c)
|
||||
}
|
||||
}
|
||||
247
Tests/SwiftDBAITests/PresentationTests.swift
Normal file
247
Tests/SwiftDBAITests/PresentationTests.swift
Normal file
@@ -0,0 +1,247 @@
|
||||
// PresentationTests.swift
|
||||
// SwiftDBAITests
|
||||
//
|
||||
// Tests for presentation modalities: DataChatSheet, DataChatViewController,
|
||||
// and view modifier helpers.
|
||||
|
||||
import SwiftUI
|
||||
import Testing
|
||||
import ViewInspector
|
||||
import GRDB
|
||||
@testable import SwiftDBAI
|
||||
|
||||
// MARK: - Helpers
|
||||
|
||||
private func makeSampleDatabase() throws -> DatabaseQueue {
|
||||
let db = try DatabaseQueue()
|
||||
try db.write { db in
|
||||
try db.execute(sql: """
|
||||
CREATE TABLE items (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO items (name) VALUES ('Alpha');
|
||||
""")
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// MARK: - DataChatSheet Tests
|
||||
|
||||
@Suite("DataChatSheet Tests")
|
||||
struct DataChatSheetTests {
|
||||
|
||||
@Test("DataChatSheet renders NavigationStack with title")
|
||||
@MainActor
|
||||
func sheetRendersNavigationStackWithTitle() throws {
|
||||
let db = try makeSampleDatabase()
|
||||
let sheet = DataChatSheet(
|
||||
database: db,
|
||||
model: MockLanguageModel(),
|
||||
title: "Test Chat"
|
||||
)
|
||||
|
||||
let view = try sheet.inspect()
|
||||
// NavigationStack should be the root
|
||||
let navStack = try view.navigationStack()
|
||||
#expect(navStack != nil)
|
||||
}
|
||||
|
||||
@Test("DataChatSheet has Done button")
|
||||
@MainActor
|
||||
func sheetHasDoneButton() throws {
|
||||
let db = try makeSampleDatabase()
|
||||
let sheet = DataChatSheet(
|
||||
database: db,
|
||||
model: MockLanguageModel()
|
||||
)
|
||||
|
||||
let view = try sheet.inspect()
|
||||
// Find the Done button in the toolbar
|
||||
let button = try view.find(button: "Done")
|
||||
#expect(button != nil)
|
||||
}
|
||||
|
||||
@Test("DataChatSheet renders DataChatView inside")
|
||||
@MainActor
|
||||
func sheetContainsDataChatView() throws {
|
||||
let db = try makeSampleDatabase()
|
||||
let sheet = DataChatSheet(
|
||||
database: db,
|
||||
model: MockLanguageModel()
|
||||
)
|
||||
|
||||
let view = try sheet.inspect()
|
||||
// DataChatView should be present within the NavigationStack
|
||||
let dataChatView = try view.find(DataChatView.self)
|
||||
#expect(dataChatView != nil)
|
||||
}
|
||||
|
||||
@Test("DataChatSheet path-based init works")
|
||||
@MainActor
|
||||
func sheetPathInit() throws {
|
||||
let tempDir = FileManager.default.temporaryDirectory
|
||||
let dbPath = tempDir.appendingPathComponent("sheet_test_\(UUID().uuidString).sqlite").path
|
||||
let db = try DatabaseQueue(path: dbPath)
|
||||
try db.write { db in
|
||||
try db.execute(sql: "CREATE TABLE t (id INTEGER PRIMARY KEY)")
|
||||
}
|
||||
|
||||
let sheet = DataChatSheet(
|
||||
databasePath: dbPath,
|
||||
model: MockLanguageModel(),
|
||||
title: "Path Chat"
|
||||
)
|
||||
|
||||
let view = try sheet.inspect()
|
||||
let navStack = try view.navigationStack()
|
||||
#expect(navStack != nil)
|
||||
|
||||
try? FileManager.default.removeItem(atPath: dbPath)
|
||||
}
|
||||
|
||||
@Test("DataChatSheet uses custom title")
|
||||
@MainActor
|
||||
func sheetCustomTitle() throws {
|
||||
let db = try makeSampleDatabase()
|
||||
let sheet = DataChatSheet(
|
||||
database: db,
|
||||
model: MockLanguageModel(),
|
||||
title: "My Custom Title"
|
||||
)
|
||||
|
||||
// Verify the title property is set correctly
|
||||
#expect(sheet.title == "My Custom Title")
|
||||
}
|
||||
|
||||
@Test("DataChatSheet defaults to AI Chat title")
|
||||
@MainActor
|
||||
func sheetDefaultTitle() throws {
|
||||
let db = try makeSampleDatabase()
|
||||
let sheet = DataChatSheet(
|
||||
database: db,
|
||||
model: MockLanguageModel()
|
||||
)
|
||||
|
||||
#expect(sheet.title == "AI Chat")
|
||||
}
|
||||
|
||||
@Test("DataChatSheet defaults to read-only allowlist")
|
||||
@MainActor
|
||||
func sheetDefaultAllowlist() throws {
|
||||
let db = try makeSampleDatabase()
|
||||
let sheet = DataChatSheet(
|
||||
database: db,
|
||||
model: MockLanguageModel()
|
||||
)
|
||||
|
||||
#expect(sheet.allowlist == .readOnly)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - DataChatViewController Tests
|
||||
|
||||
#if canImport(UIKit) && !os(watchOS)
|
||||
@Suite("DataChatViewController Tests")
|
||||
struct DataChatViewControllerTests {
|
||||
|
||||
@Test("DataChatViewController can be instantiated with database path")
|
||||
@MainActor
|
||||
func viewControllerPathInit() throws {
|
||||
let tempDir = FileManager.default.temporaryDirectory
|
||||
let dbPath = tempDir.appendingPathComponent("vc_test_\(UUID().uuidString).sqlite").path
|
||||
let db = try DatabaseQueue(path: dbPath)
|
||||
try db.write { db in
|
||||
try db.execute(sql: "CREATE TABLE t (id INTEGER PRIMARY KEY)")
|
||||
}
|
||||
|
||||
let vc = DataChatViewController(
|
||||
databasePath: dbPath,
|
||||
model: MockLanguageModel()
|
||||
)
|
||||
|
||||
#expect(vc.modalPresentationStyle == .formSheet)
|
||||
|
||||
try? FileManager.default.removeItem(atPath: dbPath)
|
||||
}
|
||||
|
||||
@Test("DataChatViewController can be instantiated with database connection")
|
||||
@MainActor
|
||||
func viewControllerDatabaseInit() throws {
|
||||
let db = try makeSampleDatabase()
|
||||
|
||||
let vc = DataChatViewController(
|
||||
database: db,
|
||||
model: MockLanguageModel(),
|
||||
title: "VC Chat"
|
||||
)
|
||||
|
||||
#expect(vc.modalPresentationStyle == .formSheet)
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// MARK: - View Modifier Tests
|
||||
|
||||
@Suite("DataChatSheet Modifier Tests")
|
||||
struct DataChatSheetModifierTests {
|
||||
|
||||
@Test("dataChatSheet modifier creates sheet correctly")
|
||||
@MainActor
|
||||
func sheetModifierCreatesSheet() throws {
|
||||
let db = try makeSampleDatabase()
|
||||
|
||||
struct TestHost: View {
|
||||
@State var showChat = false
|
||||
let db: DatabaseQueue
|
||||
|
||||
var body: some View {
|
||||
Text("Hello")
|
||||
.dataChatSheet(
|
||||
isPresented: $showChat,
|
||||
database: db,
|
||||
model: MockLanguageModel(),
|
||||
title: "Modifier Chat"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
let host = TestHost(db: db)
|
||||
// Verify it compiles and can be inspected
|
||||
let view = try host.inspect()
|
||||
let text = try view.find(text: "Hello")
|
||||
#expect(text != nil)
|
||||
}
|
||||
|
||||
@Test("dataChatSheet path modifier creates sheet correctly")
|
||||
@MainActor
|
||||
func sheetPathModifierCreatesSheet() throws {
|
||||
let tempDir = FileManager.default.temporaryDirectory
|
||||
let dbPath = tempDir.appendingPathComponent("mod_test_\(UUID().uuidString).sqlite").path
|
||||
let db = try DatabaseQueue(path: dbPath)
|
||||
try db.write { db in
|
||||
try db.execute(sql: "CREATE TABLE t (id INTEGER PRIMARY KEY)")
|
||||
}
|
||||
|
||||
struct TestHost: View {
|
||||
@State var showChat = false
|
||||
let dbPath: String
|
||||
|
||||
var body: some View {
|
||||
Text("World")
|
||||
.dataChatSheet(
|
||||
isPresented: $showChat,
|
||||
databasePath: dbPath,
|
||||
model: MockLanguageModel()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
let host = TestHost(dbPath: dbPath)
|
||||
let view = try host.inspect()
|
||||
let text = try view.find(text: "World")
|
||||
#expect(text != nil)
|
||||
|
||||
try? FileManager.default.removeItem(atPath: dbPath)
|
||||
}
|
||||
}
|
||||
254
Tests/SwiftDBAITests/PromptBuilderTests.swift
Normal file
254
Tests/SwiftDBAITests/PromptBuilderTests.swift
Normal file
@@ -0,0 +1,254 @@
|
||||
// PromptBuilderTests.swift
|
||||
// SwiftDBAI
|
||||
|
||||
import Testing
|
||||
@testable import SwiftDBAI
|
||||
|
||||
@Suite("PromptBuilder")
|
||||
struct PromptBuilderTests {
|
||||
|
||||
// MARK: - Helpers
|
||||
|
||||
/// Creates a sample schema for testing.
|
||||
private func makeSampleSchema() -> DatabaseSchema {
|
||||
let usersTable = TableSchema(
|
||||
name: "users",
|
||||
columns: [
|
||||
ColumnSchema(cid: 0, name: "id", type: "INTEGER", isNotNull: true, defaultValue: nil, isPrimaryKey: true),
|
||||
ColumnSchema(cid: 1, name: "name", type: "TEXT", isNotNull: true, defaultValue: nil, isPrimaryKey: false),
|
||||
ColumnSchema(cid: 2, name: "email", type: "TEXT", isNotNull: false, defaultValue: nil, isPrimaryKey: false),
|
||||
ColumnSchema(cid: 3, name: "created_at", type: "TEXT", isNotNull: false, defaultValue: "CURRENT_TIMESTAMP", isPrimaryKey: false),
|
||||
],
|
||||
primaryKey: ["id"],
|
||||
foreignKeys: [],
|
||||
indexes: [
|
||||
IndexSchema(name: "idx_users_email", isUnique: true, columns: ["email"])
|
||||
]
|
||||
)
|
||||
|
||||
let ordersTable = TableSchema(
|
||||
name: "orders",
|
||||
columns: [
|
||||
ColumnSchema(cid: 0, name: "id", type: "INTEGER", isNotNull: true, defaultValue: nil, isPrimaryKey: true),
|
||||
ColumnSchema(cid: 1, name: "user_id", type: "INTEGER", isNotNull: true, defaultValue: nil, isPrimaryKey: false),
|
||||
ColumnSchema(cid: 2, name: "total", type: "REAL", isNotNull: true, defaultValue: nil, isPrimaryKey: false),
|
||||
ColumnSchema(cid: 3, name: "status", type: "TEXT", isNotNull: true, defaultValue: "'pending'", isPrimaryKey: false),
|
||||
],
|
||||
primaryKey: ["id"],
|
||||
foreignKeys: [
|
||||
ForeignKeySchema(fromColumn: "user_id", toTable: "users", toColumn: "id", onUpdate: "NO ACTION", onDelete: "CASCADE")
|
||||
],
|
||||
indexes: []
|
||||
)
|
||||
|
||||
return DatabaseSchema(
|
||||
tables: ["users": usersTable, "orders": ordersTable],
|
||||
tableNames: ["users", "orders"]
|
||||
)
|
||||
}
|
||||
|
||||
private func makeEmptySchema() -> DatabaseSchema {
|
||||
DatabaseSchema(tables: [:], tableNames: [])
|
||||
}
|
||||
|
||||
// MARK: - System Instructions Tests
|
||||
|
||||
@Test("System instructions contain role section")
|
||||
func systemInstructionsContainRole() {
|
||||
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||
let instructions = builder.buildSystemInstructions()
|
||||
|
||||
#expect(instructions.contains("ROLE"))
|
||||
#expect(instructions.contains("SQL assistant"))
|
||||
#expect(instructions.contains("SQLite database"))
|
||||
}
|
||||
|
||||
@Test("System instructions contain schema")
|
||||
func systemInstructionsContainSchema() {
|
||||
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||
let instructions = builder.buildSystemInstructions()
|
||||
|
||||
#expect(instructions.contains("DATABASE SCHEMA"))
|
||||
#expect(instructions.contains("TABLE users"))
|
||||
#expect(instructions.contains("TABLE orders"))
|
||||
#expect(instructions.contains("name TEXT"))
|
||||
#expect(instructions.contains("email TEXT"))
|
||||
}
|
||||
|
||||
@Test("System instructions contain foreign keys from schema")
|
||||
func systemInstructionsContainForeignKeys() {
|
||||
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||
let instructions = builder.buildSystemInstructions()
|
||||
|
||||
#expect(instructions.contains("FOREIGN KEY"))
|
||||
#expect(instructions.contains("REFERENCES users(id)"))
|
||||
}
|
||||
|
||||
@Test("System instructions contain SQL generation rules")
|
||||
func systemInstructionsContainRules() {
|
||||
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||
let instructions = builder.buildSystemInstructions()
|
||||
|
||||
#expect(instructions.contains("SQL GENERATION RULES"))
|
||||
#expect(instructions.contains("Use ONLY the tables and columns"))
|
||||
#expect(instructions.contains("Never generate DDL"))
|
||||
}
|
||||
|
||||
@Test("System instructions contain output format section")
|
||||
func systemInstructionsContainOutputFormat() {
|
||||
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||
let instructions = builder.buildSystemInstructions()
|
||||
|
||||
#expect(instructions.contains("OUTPUT FORMAT"))
|
||||
}
|
||||
|
||||
@Test("Default allowlist is read-only")
|
||||
func defaultAllowlistIsReadOnly() {
|
||||
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||
let instructions = builder.buildSystemInstructions()
|
||||
|
||||
#expect(instructions.contains("ONLY generate SELECT queries"))
|
||||
#expect(instructions.contains("No data modifications"))
|
||||
}
|
||||
|
||||
@Test("Standard allowlist shows correct operations")
|
||||
func standardAllowlistInstructions() {
|
||||
let builder = PromptBuilder(schema: makeSampleSchema(), allowlist: .standard)
|
||||
let instructions = builder.buildSystemInstructions()
|
||||
|
||||
#expect(instructions.contains("INSERT"))
|
||||
#expect(instructions.contains("SELECT"))
|
||||
#expect(instructions.contains("UPDATE"))
|
||||
}
|
||||
|
||||
@Test("Unrestricted allowlist warns about DELETE")
|
||||
func unrestrictedAllowlistWarnsAboutDelete() {
|
||||
let builder = PromptBuilder(schema: makeSampleSchema(), allowlist: .unrestricted)
|
||||
let instructions = builder.buildSystemInstructions()
|
||||
|
||||
#expect(instructions.contains("DELETE"))
|
||||
#expect(instructions.contains("destructive"))
|
||||
#expect(instructions.contains("confirmation"))
|
||||
}
|
||||
|
||||
@Test("Additional context is appended")
|
||||
func additionalContextAppended() {
|
||||
let builder = PromptBuilder(
|
||||
schema: makeSampleSchema(),
|
||||
additionalContext: "All dates are stored in ISO 8601 format."
|
||||
)
|
||||
let instructions = builder.buildSystemInstructions()
|
||||
|
||||
#expect(instructions.contains("ADDITIONAL CONTEXT"))
|
||||
#expect(instructions.contains("ISO 8601"))
|
||||
}
|
||||
|
||||
@Test("No additional context section when nil")
|
||||
func noAdditionalContextWhenNil() {
|
||||
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||
let instructions = builder.buildSystemInstructions()
|
||||
|
||||
#expect(!instructions.contains("ADDITIONAL CONTEXT"))
|
||||
}
|
||||
|
||||
@Test("No additional context section when empty string")
|
||||
func noAdditionalContextWhenEmpty() {
|
||||
let builder = PromptBuilder(schema: makeSampleSchema(), additionalContext: "")
|
||||
let instructions = builder.buildSystemInstructions()
|
||||
|
||||
#expect(!instructions.contains("ADDITIONAL CONTEXT"))
|
||||
}
|
||||
|
||||
@Test("Empty schema produces valid instructions")
|
||||
func emptySchemaProducesValidInstructions() {
|
||||
let builder = PromptBuilder(schema: makeEmptySchema())
|
||||
let instructions = builder.buildSystemInstructions()
|
||||
|
||||
#expect(instructions.contains("ROLE"))
|
||||
#expect(instructions.contains("SQL GENERATION RULES"))
|
||||
// Schema section should still be present, just empty
|
||||
#expect(instructions.contains("DATABASE SCHEMA"))
|
||||
}
|
||||
|
||||
// MARK: - User Prompt Tests
|
||||
|
||||
@Test("User prompt passes through question directly")
|
||||
func userPromptPassesThrough() {
|
||||
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||
let prompt = builder.buildUserPrompt("How many users signed up this week?")
|
||||
|
||||
#expect(prompt == "How many users signed up this week?")
|
||||
}
|
||||
|
||||
// MARK: - Follow-up Prompt Tests
|
||||
|
||||
@Test("Follow-up prompt includes previous context")
|
||||
func followUpPromptIncludesPreviousContext() {
|
||||
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||
let prompt = builder.buildFollowUpPrompt(
|
||||
"Now sort them by name",
|
||||
previousSQL: "SELECT * FROM users WHERE created_at > date('now', '-7 days')",
|
||||
previousResultSummary: "Found 42 users who signed up this week"
|
||||
)
|
||||
|
||||
#expect(prompt.contains("Previous query:"))
|
||||
#expect(prompt.contains("SELECT * FROM users"))
|
||||
#expect(prompt.contains("Previous result:"))
|
||||
#expect(prompt.contains("42 users"))
|
||||
#expect(prompt.contains("Follow-up question:"))
|
||||
#expect(prompt.contains("sort them by name"))
|
||||
}
|
||||
|
||||
// MARK: - Schema Description Quality
|
||||
|
||||
@Test("Schema includes column types and constraints")
|
||||
func schemaIncludesColumnDetails() {
|
||||
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||
let instructions = builder.buildSystemInstructions()
|
||||
|
||||
// Should include type info
|
||||
#expect(instructions.contains("INTEGER"))
|
||||
#expect(instructions.contains("TEXT"))
|
||||
#expect(instructions.contains("REAL"))
|
||||
|
||||
// Should include constraints
|
||||
#expect(instructions.contains("NOT NULL"))
|
||||
#expect(instructions.contains("PRIMARY KEY"))
|
||||
}
|
||||
|
||||
@Test("Schema includes index information")
|
||||
func schemaIncludesIndexes() {
|
||||
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||
let instructions = builder.buildSystemInstructions()
|
||||
|
||||
#expect(instructions.contains("INDEX"))
|
||||
#expect(instructions.contains("idx_users_email"))
|
||||
}
|
||||
|
||||
// MARK: - Sendable Conformance
|
||||
|
||||
@Test("PromptBuilder is Sendable")
|
||||
func promptBuilderIsSendable() async {
|
||||
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||
|
||||
// Verify it can be sent across concurrency boundaries
|
||||
let instructions = await Task.detached {
|
||||
builder.buildSystemInstructions()
|
||||
}.value
|
||||
|
||||
#expect(instructions.contains("ROLE"))
|
||||
}
|
||||
|
||||
// MARK: - Custom Allowlist
|
||||
|
||||
@Test("Custom allowlist with select and delete only")
|
||||
func customAllowlist() {
|
||||
let allowlist = OperationAllowlist([.select, .delete])
|
||||
let builder = PromptBuilder(schema: makeSampleSchema(), allowlist: allowlist)
|
||||
let instructions = builder.buildSystemInstructions()
|
||||
|
||||
#expect(instructions.contains("DELETE"))
|
||||
#expect(instructions.contains("SELECT"))
|
||||
#expect(instructions.contains("destructive"))
|
||||
}
|
||||
}
|
||||
325
Tests/SwiftDBAITests/ProviderConfigurationTests.swift
Normal file
325
Tests/SwiftDBAITests/ProviderConfigurationTests.swift
Normal file
@@ -0,0 +1,325 @@
|
||||
// ProviderConfigurationTests.swift
|
||||
// SwiftDBAI Tests
|
||||
//
|
||||
// Tests for ProviderConfiguration — verifying all cloud provider configurations
|
||||
// produce valid LanguageModel instances with correct settings.
|
||||
|
||||
import AnyLanguageModel
|
||||
import Foundation
|
||||
@testable import SwiftDBAI
|
||||
import Testing
|
||||
|
||||
@Suite("ProviderConfiguration")
|
||||
struct ProviderConfigurationTests {
|
||||
|
||||
// MARK: - OpenAI Configuration
|
||||
|
||||
@Test("OpenAI configuration stores provider and model")
|
||||
func openAIBasicConfiguration() {
|
||||
let config = ProviderConfiguration.openAI(
|
||||
apiKey: "sk-test-key-123",
|
||||
model: "gpt-4o"
|
||||
)
|
||||
|
||||
#expect(config.provider == .openAI)
|
||||
#expect(config.model == "gpt-4o")
|
||||
#expect(config.apiKey == "sk-test-key-123")
|
||||
#expect(config.hasValidAPIKey)
|
||||
}
|
||||
|
||||
@Test("OpenAI configuration produces a valid LanguageModel")
|
||||
func openAIMakeModel() {
|
||||
let config = ProviderConfiguration.openAI(
|
||||
apiKey: "sk-test-key",
|
||||
model: "gpt-4o-mini"
|
||||
)
|
||||
|
||||
let model = config.makeModel()
|
||||
#expect(model is OpenAILanguageModel)
|
||||
}
|
||||
|
||||
@Test("OpenAI with custom base URL for compatible services")
|
||||
func openAICustomBaseURL() {
|
||||
let customURL = URL(string: "https://my-proxy.example.com/v1/")!
|
||||
let config = ProviderConfiguration.openAI(
|
||||
apiKey: "sk-proxy-key",
|
||||
model: "gpt-4o",
|
||||
baseURL: customURL
|
||||
)
|
||||
|
||||
#expect(config.baseURL == customURL)
|
||||
let model = config.makeModel()
|
||||
#expect(model is OpenAILanguageModel)
|
||||
}
|
||||
|
||||
@Test("OpenAI with Responses API variant")
|
||||
func openAIResponsesVariant() {
|
||||
let config = ProviderConfiguration.openAI(
|
||||
apiKey: "sk-test",
|
||||
model: "gpt-4o",
|
||||
variant: .responses
|
||||
)
|
||||
|
||||
#expect(config.openAIVariant == .responses)
|
||||
let model = config.makeModel()
|
||||
#expect(model is OpenAILanguageModel)
|
||||
}
|
||||
|
||||
@Test("OpenAI with dynamic key provider captures key by reference")
|
||||
func openAIDynamicKeyProvider() {
|
||||
nonisolated(unsafe) var currentKey = "sk-initial"
|
||||
let config = ProviderConfiguration.openAI(
|
||||
apiKeyProvider: { currentKey },
|
||||
model: "gpt-4o"
|
||||
)
|
||||
|
||||
#expect(config.apiKey == "sk-initial")
|
||||
currentKey = "sk-rotated"
|
||||
#expect(config.apiKey == "sk-rotated")
|
||||
}
|
||||
|
||||
// MARK: - Anthropic Configuration
|
||||
|
||||
@Test("Anthropic configuration stores provider and model")
|
||||
func anthropicBasicConfiguration() {
|
||||
let config = ProviderConfiguration.anthropic(
|
||||
apiKey: "sk-ant-test-key",
|
||||
model: "claude-sonnet-4-20250514"
|
||||
)
|
||||
|
||||
#expect(config.provider == .anthropic)
|
||||
#expect(config.model == "claude-sonnet-4-20250514")
|
||||
#expect(config.apiKey == "sk-ant-test-key")
|
||||
#expect(config.hasValidAPIKey)
|
||||
}
|
||||
|
||||
@Test("Anthropic configuration produces a valid LanguageModel")
|
||||
func anthropicMakeModel() {
|
||||
let config = ProviderConfiguration.anthropic(
|
||||
apiKey: "sk-ant-test",
|
||||
model: "claude-sonnet-4-20250514"
|
||||
)
|
||||
|
||||
let model = config.makeModel()
|
||||
#expect(model is AnthropicLanguageModel)
|
||||
}
|
||||
|
||||
@Test("Anthropic with API version and betas")
|
||||
func anthropicWithVersionAndBetas() {
|
||||
let config = ProviderConfiguration.anthropic(
|
||||
apiKey: "sk-ant-test",
|
||||
model: "claude-sonnet-4-20250514",
|
||||
apiVersion: "2024-01-01",
|
||||
betas: ["computer-use"]
|
||||
)
|
||||
|
||||
#expect(config.apiVersion == "2024-01-01")
|
||||
#expect(config.betas == ["computer-use"])
|
||||
let model = config.makeModel()
|
||||
#expect(model is AnthropicLanguageModel)
|
||||
}
|
||||
|
||||
@Test("Anthropic with dynamic key provider captures key by reference")
|
||||
func anthropicDynamicKeyProvider() {
|
||||
nonisolated(unsafe) var currentKey = "sk-ant-initial"
|
||||
let config = ProviderConfiguration.anthropic(
|
||||
apiKeyProvider: { currentKey },
|
||||
model: "claude-sonnet-4-20250514"
|
||||
)
|
||||
|
||||
#expect(config.apiKey == "sk-ant-initial")
|
||||
currentKey = "sk-ant-rotated"
|
||||
#expect(config.apiKey == "sk-ant-rotated")
|
||||
}
|
||||
|
||||
// MARK: - Gemini Configuration
|
||||
|
||||
@Test("Gemini configuration stores provider and model")
|
||||
func geminiBasicConfiguration() {
|
||||
let config = ProviderConfiguration.gemini(
|
||||
apiKey: "AIzaSyTest123",
|
||||
model: "gemini-2.0-flash"
|
||||
)
|
||||
|
||||
#expect(config.provider == .gemini)
|
||||
#expect(config.model == "gemini-2.0-flash")
|
||||
#expect(config.apiKey == "AIzaSyTest123")
|
||||
#expect(config.hasValidAPIKey)
|
||||
}
|
||||
|
||||
@Test("Gemini configuration produces a valid LanguageModel")
|
||||
func geminiMakeModel() {
|
||||
let config = ProviderConfiguration.gemini(
|
||||
apiKey: "AIzaSyTest",
|
||||
model: "gemini-2.0-flash"
|
||||
)
|
||||
|
||||
let model = config.makeModel()
|
||||
#expect(model is GeminiLanguageModel)
|
||||
}
|
||||
|
||||
@Test("Gemini with custom API version")
|
||||
func geminiCustomVersion() {
|
||||
let config = ProviderConfiguration.gemini(
|
||||
apiKey: "AIzaSyTest",
|
||||
model: "gemini-2.0-flash",
|
||||
apiVersion: "v1"
|
||||
)
|
||||
|
||||
#expect(config.apiVersion == "v1")
|
||||
let model = config.makeModel()
|
||||
#expect(model is GeminiLanguageModel)
|
||||
}
|
||||
|
||||
@Test("Gemini with dynamic key provider captures key by reference")
|
||||
func geminiDynamicKeyProvider() {
|
||||
nonisolated(unsafe) var currentKey = "AIza-initial"
|
||||
let config = ProviderConfiguration.gemini(
|
||||
apiKeyProvider: { currentKey },
|
||||
model: "gemini-2.0-flash"
|
||||
)
|
||||
|
||||
#expect(config.apiKey == "AIza-initial")
|
||||
currentKey = "AIza-rotated"
|
||||
#expect(config.apiKey == "AIza-rotated")
|
||||
}
|
||||
|
||||
// MARK: - OpenAI-Compatible Configuration
|
||||
|
||||
@Test("OpenAI-compatible configuration with custom base URL")
|
||||
func openAICompatibleConfiguration() {
|
||||
let baseURL = URL(string: "https://api.together.xyz/v1/")!
|
||||
let config = ProviderConfiguration.openAICompatible(
|
||||
apiKey: "together-key",
|
||||
model: "meta-llama/Llama-3.1-70B",
|
||||
baseURL: baseURL
|
||||
)
|
||||
|
||||
#expect(config.provider == .openAICompatible)
|
||||
#expect(config.model == "meta-llama/Llama-3.1-70B")
|
||||
#expect(config.baseURL == baseURL)
|
||||
let model = config.makeModel()
|
||||
#expect(model is OpenAILanguageModel)
|
||||
}
|
||||
|
||||
@Test("OpenAI-compatible with dynamic key provider")
|
||||
func openAICompatibleDynamicKey() {
|
||||
let baseURL = URL(string: "http://localhost:1234/v1/")!
|
||||
nonisolated(unsafe) var currentKey = "local-key"
|
||||
let config = ProviderConfiguration.openAICompatible(
|
||||
apiKeyProvider: { currentKey },
|
||||
model: "local-model",
|
||||
baseURL: baseURL
|
||||
)
|
||||
|
||||
#expect(config.apiKey == "local-key")
|
||||
currentKey = "new-local-key"
|
||||
#expect(config.apiKey == "new-local-key")
|
||||
}
|
||||
|
||||
// MARK: - API Key Validation
|
||||
|
||||
@Test("Empty API key reports invalid")
|
||||
func emptyAPIKeyInvalid() {
|
||||
let config = ProviderConfiguration.openAI(
|
||||
apiKey: "",
|
||||
model: "gpt-4o"
|
||||
)
|
||||
|
||||
#expect(!config.hasValidAPIKey)
|
||||
}
|
||||
|
||||
@Test("Whitespace-only API key reports invalid")
|
||||
func whitespaceAPIKeyInvalid() {
|
||||
let config = ProviderConfiguration.openAI(
|
||||
apiKey: " \n\t ",
|
||||
model: "gpt-4o"
|
||||
)
|
||||
|
||||
#expect(!config.hasValidAPIKey)
|
||||
}
|
||||
|
||||
@Test("Non-empty API key reports valid")
|
||||
func nonEmptyAPIKeyValid() {
|
||||
let config = ProviderConfiguration.openAI(
|
||||
apiKey: "x",
|
||||
model: "gpt-4o"
|
||||
)
|
||||
|
||||
#expect(config.hasValidAPIKey)
|
||||
}
|
||||
|
||||
// MARK: - Environment Variable Configuration
|
||||
|
||||
@Test("fromEnvironment creates configuration for each provider")
|
||||
func fromEnvironmentCreatesConfig() {
|
||||
let openAI = ProviderConfiguration.fromEnvironment(
|
||||
provider: .openAI,
|
||||
environmentVariable: "SWIFTDAI_TEST_OPENAI_KEY",
|
||||
model: "gpt-4o"
|
||||
)
|
||||
#expect(openAI.provider == .openAI)
|
||||
#expect(openAI.model == "gpt-4o")
|
||||
|
||||
let anthropic = ProviderConfiguration.fromEnvironment(
|
||||
provider: .anthropic,
|
||||
environmentVariable: "SWIFTDAI_TEST_ANTHROPIC_KEY",
|
||||
model: "claude-sonnet-4-20250514"
|
||||
)
|
||||
#expect(anthropic.provider == .anthropic)
|
||||
|
||||
let gemini = ProviderConfiguration.fromEnvironment(
|
||||
provider: .gemini,
|
||||
environmentVariable: "SWIFTDAI_TEST_GEMINI_KEY",
|
||||
model: "gemini-2.0-flash"
|
||||
)
|
||||
#expect(gemini.provider == .gemini)
|
||||
}
|
||||
|
||||
@Test("fromEnvironment returns empty key when variable not set")
|
||||
func fromEnvironmentMissingVariable() {
|
||||
let config = ProviderConfiguration.fromEnvironment(
|
||||
provider: .openAI,
|
||||
environmentVariable: "NONEXISTENT_KEY_VAR_SWIFTDBAI_TEST",
|
||||
model: "gpt-4o"
|
||||
)
|
||||
|
||||
#expect(!config.hasValidAPIKey)
|
||||
#expect(config.apiKey == "")
|
||||
}
|
||||
|
||||
// MARK: - Provider Enum
|
||||
|
||||
@Test("Provider enum has all expected cases")
|
||||
func providerCases() {
|
||||
let cases = ProviderConfiguration.Provider.allCases
|
||||
#expect(cases.count == 6)
|
||||
#expect(cases.contains(.openAI))
|
||||
#expect(cases.contains(.anthropic))
|
||||
#expect(cases.contains(.gemini))
|
||||
#expect(cases.contains(.openAICompatible))
|
||||
#expect(cases.contains(.ollama))
|
||||
#expect(cases.contains(.llamaCpp))
|
||||
}
|
||||
|
||||
// MARK: - Cross-Provider Model Creation
|
||||
|
||||
@Test("All providers produce available models")
|
||||
func allProvidersCreateAvailableModels() {
|
||||
let configs: [ProviderConfiguration] = [
|
||||
.openAI(apiKey: "test", model: "gpt-4o"),
|
||||
.anthropic(apiKey: "test", model: "claude-sonnet-4-20250514"),
|
||||
.gemini(apiKey: "test", model: "gemini-2.0-flash"),
|
||||
.openAICompatible(
|
||||
apiKey: "test",
|
||||
model: "local",
|
||||
baseURL: URL(string: "http://localhost:8080/v1/")!
|
||||
),
|
||||
]
|
||||
|
||||
for config in configs {
|
||||
let model = config.makeModel()
|
||||
#expect(model.isAvailable, "Model for \(config.provider) should be available")
|
||||
}
|
||||
}
|
||||
}
|
||||
629
Tests/SwiftDBAITests/SQLQueryParserTests.swift
Normal file
629
Tests/SwiftDBAITests/SQLQueryParserTests.swift
Normal file
@@ -0,0 +1,629 @@
|
||||
// SQLQueryParserTests.swift
|
||||
// SwiftDBAITests
|
||||
|
||||
import Testing
|
||||
@testable import SwiftDBAI
|
||||
|
||||
@Suite("SQLQueryParser")
|
||||
struct SQLQueryParserTests {
|
||||
|
||||
let readOnlyParser = SQLQueryParser(allowlist: .readOnly)
|
||||
let standardParser = SQLQueryParser(allowlist: .standard)
|
||||
let unrestrictedParser = SQLQueryParser(allowlist: .unrestricted)
|
||||
|
||||
// MARK: - Extraction from code blocks
|
||||
|
||||
@Test("Extracts SQL from markdown sql code block")
|
||||
func extractFromSQLCodeBlock() throws {
|
||||
let text = """
|
||||
Here's the query to find the top users:
|
||||
|
||||
```sql
|
||||
SELECT name, COUNT(*) as count FROM users GROUP BY name ORDER BY count DESC
|
||||
```
|
||||
|
||||
This will give you the results.
|
||||
"""
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "SELECT name, COUNT(*) as count FROM users GROUP BY name ORDER BY count DESC")
|
||||
#expect(result.operation == .select)
|
||||
#expect(result.requiresConfirmation == false)
|
||||
}
|
||||
|
||||
@Test("Extracts SQL from generic code block")
|
||||
func extractFromGenericCodeBlock() throws {
|
||||
let text = """
|
||||
Here you go:
|
||||
|
||||
```
|
||||
SELECT * FROM products WHERE price > 100
|
||||
```
|
||||
"""
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "SELECT * FROM products WHERE price > 100")
|
||||
}
|
||||
|
||||
@Test("Extracts SQL from labeled text")
|
||||
func extractFromLabel() throws {
|
||||
let text = """
|
||||
I can help with that.
|
||||
SQL: SELECT id, name FROM categories WHERE active = 1
|
||||
That should work.
|
||||
"""
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "SELECT id, name FROM categories WHERE active = 1")
|
||||
}
|
||||
|
||||
@Test("Extracts direct SQL from plain text")
|
||||
func extractDirectSQL() throws {
|
||||
let text = "SELECT COUNT(*) FROM orders WHERE status = 'shipped'"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "SELECT COUNT(*) FROM orders WHERE status = 'shipped'")
|
||||
}
|
||||
|
||||
@Test("Handles SQL with trailing semicolons")
|
||||
func trailingSemicolon() throws {
|
||||
let text = "```sql\nSELECT * FROM users;\n```"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "SELECT * FROM users")
|
||||
}
|
||||
|
||||
@Test("Handles multiline SQL in code block")
|
||||
func multilineSQL() throws {
|
||||
let text = """
|
||||
```sql
|
||||
SELECT u.name, COUNT(o.id) as order_count
|
||||
FROM users u
|
||||
JOIN orders o ON u.id = o.user_id
|
||||
GROUP BY u.name
|
||||
ORDER BY order_count DESC
|
||||
LIMIT 10
|
||||
```
|
||||
"""
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql.contains("SELECT u.name"))
|
||||
#expect(result.sql.contains("LIMIT 10"))
|
||||
}
|
||||
|
||||
@Test("Handles WITH (CTE) queries as SELECT")
|
||||
func cteQuery() throws {
|
||||
let text = """
|
||||
```sql
|
||||
WITH top_users AS (
|
||||
SELECT user_id, COUNT(*) as cnt FROM orders GROUP BY user_id
|
||||
)
|
||||
SELECT * FROM top_users WHERE cnt > 5
|
||||
```
|
||||
"""
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.operation == .select)
|
||||
}
|
||||
|
||||
// MARK: - No SQL found
|
||||
|
||||
@Test("Throws noSQLFound for text without SQL")
|
||||
func noSQLFound() throws {
|
||||
let text = "I'm sorry, I can't help with that request."
|
||||
#expect(throws: SQLParsingError.noSQLFound) {
|
||||
try readOnlyParser.parse(text)
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Throws noSQLFound for empty input")
|
||||
func emptyInput() throws {
|
||||
#expect(throws: SQLParsingError.noSQLFound) {
|
||||
try readOnlyParser.parse("")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Operation detection
|
||||
|
||||
@Test("Detects INSERT operation")
|
||||
func detectInsert() throws {
|
||||
let text = "```sql\nINSERT INTO users (name) VALUES ('Alice')\n```"
|
||||
let result = try standardParser.parse(text)
|
||||
#expect(result.operation == .insert)
|
||||
}
|
||||
|
||||
@Test("Detects UPDATE operation")
|
||||
func detectUpdate() throws {
|
||||
let text = "```sql\nUPDATE users SET name = 'Bob' WHERE id = 1\n```"
|
||||
let result = try standardParser.parse(text)
|
||||
#expect(result.operation == .update)
|
||||
}
|
||||
|
||||
@Test("Detects DELETE operation and requires confirmation")
|
||||
func detectDeleteRequiresConfirmation() throws {
|
||||
let text = "```sql\nDELETE FROM users WHERE id = 99\n```"
|
||||
let result = try unrestrictedParser.parse(text)
|
||||
#expect(result.operation == .delete)
|
||||
#expect(result.requiresConfirmation == true)
|
||||
}
|
||||
|
||||
// MARK: - Allowlist enforcement
|
||||
|
||||
@Test("Rejects INSERT on read-only allowlist")
|
||||
func rejectInsertOnReadOnly() throws {
|
||||
let text = "```sql\nINSERT INTO users (name) VALUES ('Mallory')\n```"
|
||||
#expect(throws: SQLParsingError.operationNotAllowed(.insert)) {
|
||||
try readOnlyParser.parse(text)
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Rejects UPDATE on read-only allowlist")
|
||||
func rejectUpdateOnReadOnly() {
|
||||
let text = "```sql\nUPDATE users SET name = 'Eve' WHERE id = 1\n```"
|
||||
#expect(throws: SQLParsingError.operationNotAllowed(.update)) {
|
||||
try readOnlyParser.parse(text)
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Rejects DELETE on standard allowlist")
|
||||
func rejectDeleteOnStandard() {
|
||||
let text = "```sql\nDELETE FROM users WHERE id = 1\n```"
|
||||
#expect(throws: SQLParsingError.operationNotAllowed(.delete)) {
|
||||
try standardParser.parse(text)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Dangerous operations
|
||||
|
||||
@Test("Rejects DROP TABLE")
|
||||
func rejectDrop() {
|
||||
let text = "```sql\nDROP TABLE users\n```"
|
||||
#expect(throws: SQLParsingError.dangerousOperation("DROP")) {
|
||||
try unrestrictedParser.parse(text)
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Rejects ALTER TABLE")
|
||||
func rejectAlter() {
|
||||
let text = "```sql\nALTER TABLE users ADD COLUMN age INTEGER\n```"
|
||||
#expect(throws: SQLParsingError.dangerousOperation("ALTER")) {
|
||||
try unrestrictedParser.parse(text)
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Rejects PRAGMA")
|
||||
func rejectPragma() {
|
||||
let text = "```sql\nPRAGMA table_info(users)\n```"
|
||||
#expect(throws: SQLParsingError.dangerousOperation("PRAGMA")) {
|
||||
try unrestrictedParser.parse(text)
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Does not match dangerous keywords inside identifiers")
|
||||
func noFalsePositiveOnSubstring() throws {
|
||||
// "DROPDOWN" contains "DROP" as substring but is not the keyword
|
||||
let text = "SELECT dropdown_value FROM settings"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql.contains("dropdown_value"))
|
||||
}
|
||||
|
||||
// MARK: - Multiple statements
|
||||
|
||||
@Test("Rejects multiple statements separated by semicolons")
|
||||
func rejectMultipleStatements() {
|
||||
let text = "```sql\nSELECT * FROM users; SELECT * FROM orders\n```"
|
||||
#expect(throws: SQLParsingError.multipleStatements) {
|
||||
try readOnlyParser.parse(text)
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Allows semicolons inside string literals")
|
||||
func allowSemicolonInString() throws {
|
||||
let text = "SELECT * FROM users WHERE bio = 'hello; world'"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql.contains("hello; world"))
|
||||
}
|
||||
|
||||
// MARK: - ParsedSQL equality
|
||||
|
||||
@Test("ParsedSQL equality works")
|
||||
func parsedSQLEquality() {
|
||||
let a = ParsedSQL(sql: "SELECT 1", operation: .select)
|
||||
let b = ParsedSQL(sql: "SELECT 1", operation: .select)
|
||||
#expect(a == b)
|
||||
}
|
||||
|
||||
// MARK: - Error descriptions
|
||||
|
||||
@Test("Error descriptions are meaningful")
|
||||
func errorDescriptions() {
|
||||
#expect(SQLParsingError.noSQLFound.description.contains("No SQL"))
|
||||
#expect(SQLParsingError.operationNotAllowed(.insert).description.contains("INSERT"))
|
||||
#expect(SQLParsingError.dangerousOperation("DROP").description.contains("DROP"))
|
||||
#expect(SQLParsingError.multipleStatements.description.contains("single"))
|
||||
}
|
||||
|
||||
// MARK: - MutationPolicy integration
|
||||
|
||||
@Test("MutationPolicy allows INSERT on permitted table")
|
||||
func mutationPolicyAllowsInsertOnPermittedTable() throws {
|
||||
let policy = MutationPolicy(
|
||||
allowedOperations: [.insert, .update],
|
||||
allowedTables: ["orders", "order_items"]
|
||||
)
|
||||
let parser = SQLQueryParser(mutationPolicy: policy)
|
||||
let text = "```sql\nINSERT INTO orders (product, qty) VALUES ('Widget', 3)\n```"
|
||||
let result = try parser.parse(text)
|
||||
#expect(result.operation == .insert)
|
||||
#expect(result.requiresConfirmation == false)
|
||||
}
|
||||
|
||||
@Test("MutationPolicy rejects INSERT on non-permitted table")
|
||||
func mutationPolicyRejectsInsertOnForbiddenTable() {
|
||||
let policy = MutationPolicy(
|
||||
allowedOperations: [.insert, .update],
|
||||
allowedTables: ["orders"]
|
||||
)
|
||||
let parser = SQLQueryParser(mutationPolicy: policy)
|
||||
let text = "```sql\nINSERT INTO users (name) VALUES ('Alice')\n```"
|
||||
#expect(throws: SQLParsingError.tableNotAllowed(table: "users", operation: .insert)) {
|
||||
try parser.parse(text)
|
||||
}
|
||||
}
|
||||
|
||||
@Test("MutationPolicy rejects UPDATE on non-permitted table")
|
||||
func mutationPolicyRejectsUpdateOnForbiddenTable() {
|
||||
let policy = MutationPolicy(
|
||||
allowedOperations: [.insert, .update],
|
||||
allowedTables: ["orders"]
|
||||
)
|
||||
let parser = SQLQueryParser(mutationPolicy: policy)
|
||||
let text = "```sql\nUPDATE users SET name = 'Bob' WHERE id = 1\n```"
|
||||
#expect(throws: SQLParsingError.tableNotAllowed(table: "users", operation: .update)) {
|
||||
try parser.parse(text)
|
||||
}
|
||||
}
|
||||
|
||||
@Test("MutationPolicy rejects DELETE on non-permitted table")
|
||||
func mutationPolicyRejectsDeleteOnForbiddenTable() {
|
||||
let policy = MutationPolicy(
|
||||
allowedOperations: [.insert, .update, .delete],
|
||||
allowedTables: ["temp_data"]
|
||||
)
|
||||
let parser = SQLQueryParser(mutationPolicy: policy)
|
||||
let text = "```sql\nDELETE FROM users WHERE id = 99\n```"
|
||||
#expect(throws: SQLParsingError.tableNotAllowed(table: "users", operation: .delete)) {
|
||||
try parser.parse(text)
|
||||
}
|
||||
}
|
||||
|
||||
@Test("MutationPolicy allows mutation on any table when allowedTables is nil")
|
||||
func mutationPolicyAllowsAllTablesWhenNil() throws {
|
||||
let policy = MutationPolicy(allowedOperations: [.insert, .update])
|
||||
let parser = SQLQueryParser(mutationPolicy: policy)
|
||||
let text = "```sql\nINSERT INTO any_table (col) VALUES ('val')\n```"
|
||||
let result = try parser.parse(text)
|
||||
#expect(result.operation == .insert)
|
||||
}
|
||||
|
||||
@Test("MutationPolicy SELECT is never restricted by table allowlist")
|
||||
func mutationPolicySelectIgnoresTableRestrictions() throws {
|
||||
let policy = MutationPolicy(
|
||||
allowedOperations: [.insert],
|
||||
allowedTables: ["orders"]
|
||||
)
|
||||
let parser = SQLQueryParser(mutationPolicy: policy)
|
||||
// SELECT from a table NOT in allowedTables — should still work
|
||||
let text = "```sql\nSELECT * FROM users\n```"
|
||||
let result = try parser.parse(text)
|
||||
#expect(result.operation == .select)
|
||||
#expect(result.requiresConfirmation == false)
|
||||
}
|
||||
|
||||
@Test("MutationPolicy DELETE requires confirmation by default")
|
||||
func mutationPolicyDeleteRequiresConfirmation() throws {
|
||||
let policy = MutationPolicy(allowedOperations: [.delete])
|
||||
let parser = SQLQueryParser(mutationPolicy: policy)
|
||||
let text = "```sql\nDELETE FROM users WHERE id = 1\n```"
|
||||
let result = try parser.parse(text)
|
||||
#expect(result.operation == .delete)
|
||||
#expect(result.requiresConfirmation == true)
|
||||
}
|
||||
|
||||
@Test("MutationPolicy DELETE skips confirmation when configured")
|
||||
func mutationPolicyDeleteNoConfirmation() throws {
|
||||
let policy = MutationPolicy(
|
||||
allowedOperations: [.delete],
|
||||
requiresDestructiveConfirmation: false
|
||||
)
|
||||
let parser = SQLQueryParser(mutationPolicy: policy)
|
||||
let text = "```sql\nDELETE FROM users WHERE id = 1\n```"
|
||||
let result = try parser.parse(text)
|
||||
#expect(result.operation == .delete)
|
||||
#expect(result.requiresConfirmation == false)
|
||||
}
|
||||
|
||||
@Test("MutationPolicy readOnly preset rejects all mutations")
|
||||
func mutationPolicyReadOnlyRejectsAll() {
|
||||
let parser = SQLQueryParser(mutationPolicy: .readOnly)
|
||||
#expect(throws: SQLParsingError.operationNotAllowed(.insert)) {
|
||||
try parser.parse("INSERT INTO t (a) VALUES (1)")
|
||||
}
|
||||
#expect(throws: SQLParsingError.operationNotAllowed(.update)) {
|
||||
try parser.parse("UPDATE t SET a = 1")
|
||||
}
|
||||
#expect(throws: SQLParsingError.operationNotAllowed(.delete)) {
|
||||
try parser.parse("DELETE FROM t WHERE id = 1")
|
||||
}
|
||||
}
|
||||
|
||||
@Test("MutationPolicy table matching is case-insensitive")
|
||||
func mutationPolicyTableCaseInsensitive() throws {
|
||||
let policy = MutationPolicy(
|
||||
allowedOperations: [.insert],
|
||||
allowedTables: ["Orders"]
|
||||
)
|
||||
let parser = SQLQueryParser(mutationPolicy: policy)
|
||||
let text = "INSERT INTO orders (product) VALUES ('Widget')"
|
||||
let result = try parser.parse(text)
|
||||
#expect(result.operation == .insert)
|
||||
}
|
||||
|
||||
@Test("MutationPolicy handles quoted table names")
|
||||
func mutationPolicyQuotedTableNames() throws {
|
||||
let policy = MutationPolicy(
|
||||
allowedOperations: [.insert, .update],
|
||||
allowedTables: ["order_items"]
|
||||
)
|
||||
let parser = SQLQueryParser(mutationPolicy: policy)
|
||||
|
||||
// Backtick-quoted
|
||||
let backtick = "INSERT INTO `order_items` (qty) VALUES (5)"
|
||||
let r1 = try parser.parse(backtick)
|
||||
#expect(r1.operation == .insert)
|
||||
|
||||
// Double-quote-quoted
|
||||
let doubleQuote = "UPDATE \"order_items\" SET qty = 10 WHERE id = 1"
|
||||
let r2 = try parser.parse(doubleQuote)
|
||||
#expect(r2.operation == .update)
|
||||
}
|
||||
|
||||
@Test("Error description for tableNotAllowed is meaningful")
|
||||
func tableNotAllowedDescription() {
|
||||
let error = SQLParsingError.tableNotAllowed(table: "secret", operation: .delete)
|
||||
#expect(error.description.contains("secret"))
|
||||
#expect(error.description.contains("DELETE"))
|
||||
}
|
||||
|
||||
@Test("Error description for confirmationRequired is meaningful")
|
||||
func confirmationRequiredDescription() {
|
||||
let error = SQLParsingError.confirmationRequired(sql: "DELETE FROM x", operation: .delete)
|
||||
#expect(error.description.contains("DELETE"))
|
||||
#expect(error.description.contains("confirmation"))
|
||||
}
|
||||
|
||||
// MARK: - Robust extraction edge cases
|
||||
|
||||
@Test("Extracts plain SQL without any wrapping")
|
||||
func plainSQL() throws {
|
||||
let text = "SELECT * FROM users"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "SELECT * FROM users")
|
||||
}
|
||||
|
||||
@Test("Extracts SQL from markdown sql code block")
|
||||
func markdownSQLBlock() throws {
|
||||
let text = "```sql\nSELECT * FROM users\n```"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "SELECT * FROM users")
|
||||
}
|
||||
|
||||
@Test("Extracts SQL from generic code block")
|
||||
func genericCodeBlock() throws {
|
||||
let text = "```\nSELECT * FROM users\n```"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "SELECT * FROM users")
|
||||
}
|
||||
|
||||
@Test("Strips trailing semicolons")
|
||||
func trailingSemicolonEdge() throws {
|
||||
let text = "SELECT * FROM users;"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "SELECT * FROM users")
|
||||
}
|
||||
|
||||
@Test("Extracts SQL with preamble text")
|
||||
func preambleText() throws {
|
||||
let text = "Here's the query:\nSELECT * FROM users"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "SELECT * FROM users")
|
||||
}
|
||||
|
||||
@Test("Handles trailing backticks only (no opening fence)")
|
||||
func trailingBackticksOnly() throws {
|
||||
let text = "SELECT * FROM users\n```"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "SELECT * FROM users")
|
||||
}
|
||||
|
||||
@Test("Extracts SQL from single-line code block")
|
||||
func singleLineCodeBlock() throws {
|
||||
let text = "```sql SELECT * FROM users ```"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "SELECT * FROM users")
|
||||
}
|
||||
|
||||
@Test("Handles no newline before closing fence")
|
||||
func noNewlineBeforeClosingFence() throws {
|
||||
let text = "```sql\nSELECT * FROM users```"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "SELECT * FROM users")
|
||||
}
|
||||
|
||||
@Test("Extracts SQL inline with text prefix")
|
||||
func inlineWithText() throws {
|
||||
let text = "The SQL query is: SELECT * FROM users"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "SELECT * FROM users")
|
||||
}
|
||||
|
||||
@Test("Handles extra whitespace around SQL")
|
||||
func extraWhitespace() throws {
|
||||
let text = "\n\nSELECT * FROM users\n\n"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "SELECT * FROM users")
|
||||
}
|
||||
|
||||
@Test("Extracts SQL from chatty LLM response with preamble and postamble")
|
||||
func chattyLLMResponse() throws {
|
||||
let text = "Sure! Here's the SQL:\n\n```sql\nSELECT * FROM users\n```\n\nThis will return all users."
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "SELECT * FROM users")
|
||||
}
|
||||
|
||||
@Test("Preserves SQL comments")
|
||||
func sqlWithComments() throws {
|
||||
let text = "SELECT * FROM users -- get all users"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql.contains("-- get all users"))
|
||||
}
|
||||
|
||||
@Test("Preserves backtick-quoted identifiers in SQL")
|
||||
func backtickQuotedIdentifiers() throws {
|
||||
let text = "SELECT `column name` FROM users"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql.contains("`column name`"))
|
||||
}
|
||||
|
||||
@Test("Strips think tags from Qwen-style models")
|
||||
func thinkTags() throws {
|
||||
let text = "<think>I need to query the users table</think>\nSELECT * FROM users"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "SELECT * FROM users")
|
||||
#expect(!result.sql.contains("think"))
|
||||
}
|
||||
|
||||
@Test("Handles 4 or 5 backtick fences")
|
||||
func extraBacktickFences() throws {
|
||||
let text4 = "````sql\nSELECT * FROM users\n````"
|
||||
let result4 = try readOnlyParser.parse(text4)
|
||||
#expect(result4.sql == "SELECT * FROM users")
|
||||
|
||||
let text5 = "`````\nSELECT * FROM users\n`````"
|
||||
let result5 = try readOnlyParser.parse(text5)
|
||||
#expect(result5.sql == "SELECT * FROM users")
|
||||
}
|
||||
|
||||
@Test("Handles mixed case SQL keywords")
|
||||
func mixedCaseSQL() throws {
|
||||
let text = "select * from USERS"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "select * from USERS")
|
||||
}
|
||||
|
||||
@Test("Handles WITH clause (CTE) queries")
|
||||
func withClause() throws {
|
||||
let text = "WITH cte AS (SELECT id FROM orders) SELECT * FROM cte"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql.hasPrefix("WITH"))
|
||||
#expect(result.operation == .select)
|
||||
}
|
||||
|
||||
@Test("Handles WITH clause in code block")
|
||||
func withClauseInCodeBlock() throws {
|
||||
let text = "```sql\nWITH top AS (\n SELECT user_id, COUNT(*) as cnt FROM orders GROUP BY user_id\n)\nSELECT * FROM top WHERE cnt > 5\n```"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql.hasPrefix("WITH"))
|
||||
#expect(result.operation == .select)
|
||||
}
|
||||
|
||||
@Test("Multi-line SQL with JOINs and subqueries in code block")
|
||||
func multiLineJoinsAndSubqueries() throws {
|
||||
let text = """
|
||||
```sql
|
||||
SELECT u.name, o.total
|
||||
FROM users u
|
||||
INNER JOIN orders o ON u.id = o.user_id
|
||||
WHERE o.total > (SELECT AVG(total) FROM orders)
|
||||
ORDER BY o.total DESC
|
||||
```
|
||||
"""
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql.contains("INNER JOIN"))
|
||||
#expect(result.sql.contains("SELECT AVG(total)"))
|
||||
#expect(result.sql.contains("ORDER BY"))
|
||||
}
|
||||
|
||||
@Test("Handles response with both explanation text and SQL")
|
||||
func explanationAndSQL() throws {
|
||||
let text = """
|
||||
To find all active users, we need to query the users table
|
||||
and filter by the active column. Here's the query:
|
||||
|
||||
SELECT * FROM users WHERE active = 1
|
||||
|
||||
This should give you the results you're looking for.
|
||||
"""
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "SELECT * FROM users WHERE active = 1")
|
||||
}
|
||||
|
||||
@Test("Throws noSQLFound for empty response")
|
||||
func emptyResponse() throws {
|
||||
#expect(throws: SQLParsingError.noSQLFound) {
|
||||
try readOnlyParser.parse("")
|
||||
}
|
||||
#expect(throws: SQLParsingError.noSQLFound) {
|
||||
try readOnlyParser.parse(" \n\n ")
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Throws noSQLFound for response with no SQL at all")
|
||||
func noSQLAtAll() throws {
|
||||
#expect(throws: SQLParsingError.noSQLFound) {
|
||||
try readOnlyParser.parse("I cannot help with that question. Please try asking about your data.")
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Handles response with multiple SQL statements in code block (rejects them)")
|
||||
func multipleStatementsInCodeBlock() throws {
|
||||
// When multiple statements are in a code block, the parser sees both and rejects
|
||||
let text = "```sql\nSELECT * FROM users; SELECT * FROM orders\n```"
|
||||
#expect(throws: SQLParsingError.multipleStatements) {
|
||||
try readOnlyParser.parse(text)
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Extracts first SQL statement from plain text with multiple statements")
|
||||
func multipleStatementsPlainText() throws {
|
||||
// In plain text, the direct extraction stops at the semicolon and extracts the first statement
|
||||
let text = "SELECT * FROM users; SELECT * FROM orders"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "SELECT * FROM users")
|
||||
}
|
||||
|
||||
@Test("Preserves backtick identifiers inside code blocks")
|
||||
func backtickIdentifiersInCodeBlock() throws {
|
||||
let text = "```sql\nSELECT `first name`, `last name` FROM `user data`\n```"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql.contains("`first name`"))
|
||||
#expect(result.sql.contains("`last name`"))
|
||||
#expect(result.sql.contains("`user data`"))
|
||||
}
|
||||
|
||||
@Test("Strips think tags with multiline reasoning content")
|
||||
func multilineThinkTags() throws {
|
||||
let text = """
|
||||
<think>
|
||||
The user wants to find all users.
|
||||
I should use SELECT * FROM users.
|
||||
Let me think about which columns to include...
|
||||
</think>
|
||||
SELECT * FROM users
|
||||
"""
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql == "SELECT * FROM users")
|
||||
}
|
||||
|
||||
@Test("Handles mixed backtick styles in response")
|
||||
func mixedBacktickStyles() throws {
|
||||
// Code fences + backtick-quoted identifiers inside
|
||||
let text = "```sql\nSELECT `user name` FROM users WHERE `is active` = 1\n```"
|
||||
let result = try readOnlyParser.parse(text)
|
||||
#expect(result.sql.contains("`user name`"))
|
||||
#expect(result.sql.contains("`is active`"))
|
||||
}
|
||||
}
|
||||
234
Tests/SwiftDBAITests/SchemaIntrospectorTests.swift
Normal file
234
Tests/SwiftDBAITests/SchemaIntrospectorTests.swift
Normal file
@@ -0,0 +1,234 @@
|
||||
// SchemaIntrospectorTests.swift
|
||||
// SwiftDBAI
|
||||
|
||||
import Testing
|
||||
import GRDB
|
||||
@testable import SwiftDBAI
|
||||
|
||||
@Suite("SchemaIntrospector")
|
||||
struct SchemaIntrospectorTests {
|
||||
|
||||
// MARK: - Helper
|
||||
|
||||
/// Creates an in-memory database with a sample schema for testing.
|
||||
private func makeTestDatabase() throws -> DatabaseQueue {
|
||||
let db = try DatabaseQueue(configuration: {
|
||||
var config = Configuration()
|
||||
config.foreignKeysEnabled = true
|
||||
return config
|
||||
}())
|
||||
|
||||
try db.write { db in
|
||||
try db.execute(sql: """
|
||||
CREATE TABLE authors (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
email TEXT UNIQUE
|
||||
);
|
||||
""")
|
||||
|
||||
try db.execute(sql: """
|
||||
CREATE TABLE books (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
title TEXT NOT NULL,
|
||||
author_id INTEGER NOT NULL REFERENCES authors(id) ON DELETE CASCADE,
|
||||
published_date TEXT,
|
||||
price REAL DEFAULT 9.99
|
||||
);
|
||||
""")
|
||||
|
||||
try db.execute(sql: """
|
||||
CREATE INDEX idx_books_author ON books(author_id);
|
||||
""")
|
||||
|
||||
try db.execute(sql: """
|
||||
CREATE INDEX idx_books_title ON books(title);
|
||||
""")
|
||||
|
||||
try db.execute(sql: """
|
||||
CREATE TABLE reviews (
|
||||
id INTEGER PRIMARY KEY,
|
||||
book_id INTEGER NOT NULL REFERENCES books(id),
|
||||
rating INTEGER NOT NULL,
|
||||
comment TEXT
|
||||
);
|
||||
""")
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// MARK: - Tests
|
||||
|
||||
@Test("Discovers all user tables")
|
||||
func discoversAllTables() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||
|
||||
#expect(schema.tableNames.count == 3)
|
||||
#expect(schema.tableNames.contains("authors"))
|
||||
#expect(schema.tableNames.contains("books"))
|
||||
#expect(schema.tableNames.contains("reviews"))
|
||||
}
|
||||
|
||||
@Test("Excludes sqlite_ internal tables")
|
||||
func excludesInternalTables() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||
|
||||
for name in schema.tableNames {
|
||||
#expect(!name.hasPrefix("sqlite_"))
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Introspects column names and types")
|
||||
func introspectsColumns() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||
|
||||
let books = try #require(schema.tables["books"])
|
||||
#expect(books.columns.count == 5)
|
||||
|
||||
let titleCol = try #require(books.columns.first { $0.name == "title" })
|
||||
#expect(titleCol.type == "TEXT")
|
||||
#expect(titleCol.isNotNull == true)
|
||||
#expect(titleCol.isPrimaryKey == false)
|
||||
|
||||
let priceCol = try #require(books.columns.first { $0.name == "price" })
|
||||
#expect(priceCol.type == "REAL")
|
||||
#expect(priceCol.defaultValue == "9.99")
|
||||
}
|
||||
|
||||
@Test("Detects primary keys")
|
||||
func detectsPrimaryKeys() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||
|
||||
let authors = try #require(schema.tables["authors"])
|
||||
#expect(authors.primaryKey == ["id"])
|
||||
|
||||
let idCol = try #require(authors.columns.first { $0.name == "id" })
|
||||
#expect(idCol.isPrimaryKey == true)
|
||||
}
|
||||
|
||||
@Test("Detects foreign keys")
|
||||
func detectsForeignKeys() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||
|
||||
let books = try #require(schema.tables["books"])
|
||||
#expect(books.foreignKeys.count == 1)
|
||||
|
||||
let fk = books.foreignKeys[0]
|
||||
#expect(fk.fromColumn == "author_id")
|
||||
#expect(fk.toTable == "authors")
|
||||
#expect(fk.toColumn == "id")
|
||||
#expect(fk.onDelete == "CASCADE")
|
||||
}
|
||||
|
||||
@Test("Detects indexes")
|
||||
func detectsIndexes() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||
|
||||
let books = try #require(schema.tables["books"])
|
||||
let indexNames = books.indexes.map(\.name)
|
||||
#expect(indexNames.contains("idx_books_author"))
|
||||
#expect(indexNames.contains("idx_books_title"))
|
||||
}
|
||||
|
||||
@Test("Detects NOT NULL constraints")
|
||||
func detectsNotNull() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||
|
||||
let reviews = try #require(schema.tables["reviews"])
|
||||
let ratingCol = try #require(reviews.columns.first { $0.name == "rating" })
|
||||
#expect(ratingCol.isNotNull == true)
|
||||
|
||||
let commentCol = try #require(reviews.columns.first { $0.name == "comment" })
|
||||
#expect(commentCol.isNotNull == false)
|
||||
}
|
||||
|
||||
@Test("Generates LLM-friendly schema description")
|
||||
func generatesSchemaDescription() async throws {
|
||||
let db = try makeTestDatabase()
|
||||
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||
|
||||
let description = schema.schemaDescription
|
||||
#expect(description.contains("TABLE authors"))
|
||||
#expect(description.contains("TABLE books"))
|
||||
#expect(description.contains("FOREIGN KEY"))
|
||||
#expect(description.contains("REFERENCES authors(id)"))
|
||||
#expect(description.contains("INDEX idx_books_author"))
|
||||
}
|
||||
|
||||
@Test("Handles empty database")
|
||||
func handlesEmptyDatabase() async throws {
|
||||
let db = try DatabaseQueue()
|
||||
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||
|
||||
#expect(schema.tables.isEmpty)
|
||||
#expect(schema.tableNames.isEmpty)
|
||||
#expect(schema.schemaDescription.isEmpty)
|
||||
}
|
||||
|
||||
@Test("Handles composite primary keys")
|
||||
func handlesCompositePrimaryKey() async throws {
|
||||
let db = try DatabaseQueue()
|
||||
try await db.write { db in
|
||||
try db.execute(sql: """
|
||||
CREATE TABLE book_tags (
|
||||
book_id INTEGER NOT NULL,
|
||||
tag_id INTEGER NOT NULL,
|
||||
PRIMARY KEY (book_id, tag_id)
|
||||
);
|
||||
""")
|
||||
}
|
||||
|
||||
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||
let bookTags = try #require(schema.tables["book_tags"])
|
||||
#expect(bookTags.primaryKey.count == 2)
|
||||
#expect(bookTags.primaryKey.contains("book_id"))
|
||||
#expect(bookTags.primaryKey.contains("tag_id"))
|
||||
}
|
||||
|
||||
@Test("Handles tables with no explicit types (SQLite dynamic typing)")
|
||||
func handlesDynamicTyping() async throws {
|
||||
let db = try DatabaseQueue()
|
||||
try await db.write { db in
|
||||
try db.execute(sql: """
|
||||
CREATE TABLE flexible (
|
||||
id INTEGER PRIMARY KEY,
|
||||
data,
|
||||
info BLOB
|
||||
);
|
||||
""")
|
||||
}
|
||||
|
||||
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||
let flexible = try #require(schema.tables["flexible"])
|
||||
|
||||
let dataCol = try #require(flexible.columns.first { $0.name == "data" })
|
||||
#expect(dataCol.type == "") // No declared type
|
||||
|
||||
let infoCol = try #require(flexible.columns.first { $0.name == "info" })
|
||||
#expect(infoCol.type == "BLOB")
|
||||
}
|
||||
|
||||
@Test("Synchronous introspection works within database access")
|
||||
func synchronousIntrospection() async throws {
|
||||
let db = try DatabaseQueue()
|
||||
try await db.write { db in
|
||||
try db.execute(sql: "CREATE TABLE test (id INTEGER PRIMARY KEY, val TEXT);")
|
||||
}
|
||||
|
||||
let schema = try await db.read { db in
|
||||
try SchemaIntrospector.introspect(db: db)
|
||||
}
|
||||
|
||||
#expect(schema.tableNames == ["test"])
|
||||
let table = try #require(schema.tables["test"])
|
||||
#expect(table.columns.count == 2)
|
||||
}
|
||||
}
|
||||
133
Tests/SwiftDBAITests/ScrollableDataTableViewTests.swift
Normal file
133
Tests/SwiftDBAITests/ScrollableDataTableViewTests.swift
Normal file
@@ -0,0 +1,133 @@
|
||||
// ScrollableDataTableViewTests.swift
|
||||
// SwiftDBAITests
|
||||
//
|
||||
// Tests for the ScrollableDataTableView component.
|
||||
|
||||
import Foundation
|
||||
import Testing
|
||||
@testable import SwiftDBAI
|
||||
|
||||
@Suite("ScrollableDataTableView")
|
||||
@MainActor
|
||||
struct ScrollableDataTableViewTests {
|
||||
|
||||
// MARK: - Test Helpers
|
||||
|
||||
private func makeDataTable(
|
||||
columnNames: [String] = ["id", "name", "score"],
|
||||
inferredTypes: [DataTable.InferredType] = [.integer, .text, .real],
|
||||
rowCount: Int = 5
|
||||
) -> DataTable {
|
||||
let columns = columnNames.enumerated().map { idx, name in
|
||||
DataTable.Column(name: name, index: idx, inferredType: inferredTypes[idx])
|
||||
}
|
||||
let rows = (0..<rowCount).map { i in
|
||||
DataTable.Row(
|
||||
id: i,
|
||||
values: [
|
||||
.integer(Int64(i + 1)),
|
||||
.text("Item \(i + 1)"),
|
||||
.real(Double(i) * 10.5),
|
||||
],
|
||||
columnNames: columnNames
|
||||
)
|
||||
}
|
||||
return DataTable(columns: columns, rows: rows, sql: "SELECT * FROM test", executionTime: 0.015)
|
||||
}
|
||||
|
||||
private func makeEmptyDataTable() -> DataTable {
|
||||
DataTable(columns: [], rows: [], sql: "", executionTime: 0)
|
||||
}
|
||||
|
||||
// MARK: - Initialization Tests
|
||||
|
||||
@Test("Initializes with default parameters")
|
||||
func initWithDefaults() {
|
||||
let table = makeDataTable()
|
||||
let view = ScrollableDataTableView(dataTable: table)
|
||||
|
||||
#expect(view.minimumColumnWidth == 80)
|
||||
#expect(view.maximumColumnWidth == 250)
|
||||
#expect(view.showAlternatingRows == true)
|
||||
#expect(view.showFooter == true)
|
||||
}
|
||||
|
||||
@Test("Initializes with custom parameters")
|
||||
func initWithCustomParams() {
|
||||
let table = makeDataTable()
|
||||
let view = ScrollableDataTableView(
|
||||
dataTable: table,
|
||||
minimumColumnWidth: 100,
|
||||
maximumColumnWidth: 300,
|
||||
showAlternatingRows: false,
|
||||
showFooter: false
|
||||
)
|
||||
|
||||
#expect(view.minimumColumnWidth == 100)
|
||||
#expect(view.maximumColumnWidth == 300)
|
||||
#expect(view.showAlternatingRows == false)
|
||||
#expect(view.showFooter == false)
|
||||
}
|
||||
|
||||
@Test("Handles empty data table")
|
||||
func handlesEmptyTable() {
|
||||
let table = makeEmptyDataTable()
|
||||
let view = ScrollableDataTableView(dataTable: table)
|
||||
#expect(view.dataTable.isEmpty)
|
||||
}
|
||||
|
||||
@Test("Handles single row table")
|
||||
func handlesSingleRow() {
|
||||
let table = makeDataTable(rowCount: 1)
|
||||
let view = ScrollableDataTableView(dataTable: table)
|
||||
#expect(view.dataTable.rowCount == 1)
|
||||
#expect(view.dataTable.columnCount == 3)
|
||||
}
|
||||
|
||||
@Test("Handles single column table")
|
||||
func handlesSingleColumn() {
|
||||
let columns = [DataTable.Column(name: "count", index: 0, inferredType: .integer)]
|
||||
let rows = [
|
||||
DataTable.Row(id: 0, values: [.integer(42)], columnNames: ["count"])
|
||||
]
|
||||
let table = DataTable(columns: columns, rows: rows, sql: "SELECT count(*) FROM t", executionTime: 0.001)
|
||||
let view = ScrollableDataTableView(dataTable: table)
|
||||
#expect(view.dataTable.columnCount == 1)
|
||||
#expect(view.dataTable.rowCount == 1)
|
||||
}
|
||||
|
||||
@Test("Handles large number of rows")
|
||||
func handlesLargeRowCount() {
|
||||
let table = makeDataTable(rowCount: 1000)
|
||||
let view = ScrollableDataTableView(dataTable: table)
|
||||
#expect(view.dataTable.rowCount == 1000)
|
||||
}
|
||||
|
||||
@Test("Handles null values in cells")
|
||||
func handlesNullValues() {
|
||||
let columns = [
|
||||
DataTable.Column(name: "name", index: 0, inferredType: .text),
|
||||
DataTable.Column(name: "value", index: 1, inferredType: .null),
|
||||
]
|
||||
let rows = [
|
||||
DataTable.Row(id: 0, values: [.text("test"), .null], columnNames: ["name", "value"])
|
||||
]
|
||||
let table = DataTable(columns: columns, rows: rows)
|
||||
let view = ScrollableDataTableView(dataTable: table)
|
||||
#expect(view.dataTable.rows[0][1] == .null)
|
||||
}
|
||||
|
||||
@Test("Handles blob values in cells")
|
||||
func handlesBlobValues() {
|
||||
let columns = [
|
||||
DataTable.Column(name: "data", index: 0, inferredType: .blob),
|
||||
]
|
||||
let blobData = Data([0x00, 0xFF, 0xAB])
|
||||
let rows = [
|
||||
DataTable.Row(id: 0, values: [.blob(blobData)], columnNames: ["data"])
|
||||
]
|
||||
let table = DataTable(columns: columns, rows: rows)
|
||||
let view = ScrollableDataTableView(dataTable: table)
|
||||
#expect(view.dataTable.rows[0][0] == QueryResult.Value.blob(blobData))
|
||||
}
|
||||
}
|
||||
301
Tests/SwiftDBAITests/TextSummaryRendererTests.swift
Normal file
301
Tests/SwiftDBAITests/TextSummaryRendererTests.swift
Normal file
@@ -0,0 +1,301 @@
|
||||
// TextSummaryRendererTests.swift
|
||||
// SwiftDBAI
|
||||
|
||||
import AnyLanguageModel
|
||||
import Testing
|
||||
import Foundation
|
||||
@testable import SwiftDBAI
|
||||
|
||||
@Suite("TextSummaryRenderer")
|
||||
struct TextSummaryRendererTests {
|
||||
|
||||
// MARK: - QueryResult.Value Tests
|
||||
|
||||
@Test("Value description renders correctly")
|
||||
func valueDescriptions() {
|
||||
#expect(QueryResult.Value.text("hello").description == "hello")
|
||||
#expect(QueryResult.Value.integer(42).description == "42")
|
||||
#expect(QueryResult.Value.real(3.14).description == "3.14")
|
||||
#expect(QueryResult.Value.null.description == "NULL")
|
||||
#expect(QueryResult.Value.blob(Data([0x01, 0x02])).description == "<2 bytes>")
|
||||
}
|
||||
|
||||
@Test("Value doubleValue extracts numeric values")
|
||||
func valueDoubleValues() {
|
||||
#expect(QueryResult.Value.integer(42).doubleValue == 42.0)
|
||||
#expect(QueryResult.Value.real(3.14).doubleValue == 3.14)
|
||||
#expect(QueryResult.Value.text("100").doubleValue == 100.0)
|
||||
#expect(QueryResult.Value.text("not a number").doubleValue == nil)
|
||||
#expect(QueryResult.Value.null.doubleValue == nil)
|
||||
#expect(QueryResult.Value.blob(Data()).doubleValue == nil)
|
||||
}
|
||||
|
||||
@Test("Value isNull works correctly")
|
||||
func valueIsNull() {
|
||||
#expect(QueryResult.Value.null.isNull == true)
|
||||
#expect(QueryResult.Value.text("").isNull == false)
|
||||
#expect(QueryResult.Value.integer(0).isNull == false)
|
||||
}
|
||||
|
||||
// MARK: - QueryResult Tests
|
||||
|
||||
@Test("Empty result has correct properties")
|
||||
func emptyResult() {
|
||||
let result = QueryResult(
|
||||
columns: ["id", "name"],
|
||||
rows: [],
|
||||
sql: "SELECT id, name FROM users",
|
||||
executionTime: 0.01
|
||||
)
|
||||
#expect(result.rowCount == 0)
|
||||
#expect(result.isAggregate == false)
|
||||
#expect(result.tabularDescription == "(empty result set)")
|
||||
}
|
||||
|
||||
@Test("Single aggregate result is detected")
|
||||
func aggregateDetection() {
|
||||
let result = QueryResult(
|
||||
columns: ["COUNT(*)"],
|
||||
rows: [["COUNT(*)": .integer(42)]],
|
||||
sql: "SELECT COUNT(*) FROM users",
|
||||
executionTime: 0.01
|
||||
)
|
||||
#expect(result.isAggregate == true)
|
||||
}
|
||||
|
||||
@Test("Multi-row result is not aggregate")
|
||||
func nonAggregateDetection() {
|
||||
let result = QueryResult(
|
||||
columns: ["name"],
|
||||
rows: [
|
||||
["name": .text("Alice")],
|
||||
["name": .text("Bob")],
|
||||
],
|
||||
sql: "SELECT name FROM users",
|
||||
executionTime: 0.01
|
||||
)
|
||||
#expect(result.isAggregate == false)
|
||||
}
|
||||
|
||||
@Test("Tabular description formats correctly")
|
||||
func tabularDescription() {
|
||||
let result = QueryResult(
|
||||
columns: ["id", "name"],
|
||||
rows: [
|
||||
["id": .integer(1), "name": .text("Alice")],
|
||||
["id": .integer(2), "name": .text("Bob")],
|
||||
],
|
||||
sql: "SELECT id, name FROM users",
|
||||
executionTime: 0.01
|
||||
)
|
||||
let desc = result.tabularDescription
|
||||
#expect(desc.contains("id | name"))
|
||||
#expect(desc.contains("1 | Alice"))
|
||||
#expect(desc.contains("2 | Bob"))
|
||||
}
|
||||
|
||||
@Test("values(forColumn:) extracts column values")
|
||||
func valuesForColumn() {
|
||||
let result = QueryResult(
|
||||
columns: ["name"],
|
||||
rows: [
|
||||
["name": .text("Alice")],
|
||||
["name": .text("Bob")],
|
||||
],
|
||||
sql: "SELECT name FROM users",
|
||||
executionTime: 0.01
|
||||
)
|
||||
let values = result.values(forColumn: "name")
|
||||
#expect(values.count == 2)
|
||||
#expect(values[0] == .text("Alice"))
|
||||
}
|
||||
|
||||
// MARK: - Local Summary Tests (no LLM required)
|
||||
|
||||
@Test("Local summary for empty result")
|
||||
func localSummaryEmpty() {
|
||||
let result = makeResult(columns: ["id"], rows: [])
|
||||
let renderer = makeMockRenderer()
|
||||
let summary = renderer.localSummary(result: result, userQuestion: "Any users?")
|
||||
#expect(summary == "No results found for your query.")
|
||||
}
|
||||
|
||||
@Test("Local summary for single aggregate")
|
||||
func localSummarySingleAggregate() {
|
||||
let result = makeResult(
|
||||
columns: ["COUNT(*)"],
|
||||
rows: [["COUNT(*)": .integer(42)]]
|
||||
)
|
||||
let renderer = makeMockRenderer()
|
||||
let summary = renderer.localSummary(result: result, userQuestion: "How many?")
|
||||
#expect(summary.contains("42"))
|
||||
}
|
||||
|
||||
@Test("Local summary for multiple aggregates")
|
||||
func localSummaryMultipleAggregates() {
|
||||
let result = makeResult(
|
||||
columns: ["COUNT(*)", "AVG(price)"],
|
||||
rows: [["COUNT(*)": .integer(10), "AVG(price)": .real(25.5)]]
|
||||
)
|
||||
let renderer = makeMockRenderer()
|
||||
let summary = renderer.localSummary(result: result, userQuestion: "Stats?")
|
||||
#expect(summary.contains("count"))
|
||||
#expect(summary.contains("average price"))
|
||||
}
|
||||
|
||||
@Test("Local summary for single record")
|
||||
func localSummarySingleRecord() {
|
||||
let result = makeResult(
|
||||
columns: ["name", "email"],
|
||||
rows: [["name": .text("Alice"), "email": .text("alice@example.com")]]
|
||||
)
|
||||
let renderer = makeMockRenderer()
|
||||
let summary = renderer.localSummary(result: result, userQuestion: "Who?")
|
||||
#expect(summary.contains("1 result"))
|
||||
#expect(summary.contains("Alice"))
|
||||
}
|
||||
|
||||
@Test("Local summary for multiple records with name column")
|
||||
func localSummaryMultipleWithNames() {
|
||||
let result = makeResult(
|
||||
columns: ["name", "age"],
|
||||
rows: [
|
||||
["name": .text("Alice"), "age": .integer(30)],
|
||||
["name": .text("Bob"), "age": .integer(25)],
|
||||
["name": .text("Charlie"), "age": .integer(35)],
|
||||
["name": .text("Diana"), "age": .integer(28)],
|
||||
]
|
||||
)
|
||||
let renderer = makeMockRenderer()
|
||||
let summary = renderer.localSummary(result: result, userQuestion: "List users")
|
||||
#expect(summary.contains("4 results"))
|
||||
#expect(summary.contains("Alice"))
|
||||
#expect(summary.contains("1 more"))
|
||||
}
|
||||
|
||||
@Test("Local summary for mutation result")
|
||||
func localSummaryMutation() {
|
||||
let result = QueryResult(
|
||||
columns: [],
|
||||
rows: [],
|
||||
sql: "INSERT INTO users (name) VALUES ('Test')",
|
||||
executionTime: 0.01,
|
||||
rowsAffected: 1
|
||||
)
|
||||
let renderer = makeMockRenderer()
|
||||
let summary = renderer.localSummary(result: result, userQuestion: "Add user")
|
||||
#expect(summary == "Successfully inserted 1 row.")
|
||||
}
|
||||
|
||||
@Test("Local summary for delete mutation")
|
||||
func localSummaryDelete() {
|
||||
let result = QueryResult(
|
||||
columns: [],
|
||||
rows: [],
|
||||
sql: "DELETE FROM users WHERE id = 5",
|
||||
executionTime: 0.01,
|
||||
rowsAffected: 3
|
||||
)
|
||||
let renderer = makeMockRenderer()
|
||||
let summary = renderer.localSummary(result: result, userQuestion: "Delete old users")
|
||||
#expect(summary == "Successfully deleted 3 rows.")
|
||||
}
|
||||
|
||||
@Test("Local summary for update mutation")
|
||||
func localSummaryUpdate() {
|
||||
let result = QueryResult(
|
||||
columns: [],
|
||||
rows: [],
|
||||
sql: "UPDATE users SET active = 0 WHERE id = 1",
|
||||
executionTime: 0.01,
|
||||
rowsAffected: 1
|
||||
)
|
||||
let renderer = makeMockRenderer()
|
||||
let summary = renderer.localSummary(result: result, userQuestion: "Deactivate user")
|
||||
#expect(summary == "Successfully updated 1 row.")
|
||||
}
|
||||
|
||||
// MARK: - LLM-based Summary Tests (using MockLanguageModel)
|
||||
|
||||
@Test("Summarize with LLM returns mock response for multi-row results")
|
||||
func summarizeWithLLM() async throws {
|
||||
let result = makeResult(
|
||||
columns: ["name", "age"],
|
||||
rows: [
|
||||
["name": .text("Alice"), "age": .integer(30)],
|
||||
["name": .text("Bob"), "age": .integer(25)],
|
||||
]
|
||||
)
|
||||
let mockModel = MockLanguageModel(responseText: "There are 2 users: Alice (30) and Bob (25).")
|
||||
let renderer = TextSummaryRenderer(model: mockModel)
|
||||
let summary = try await renderer.summarize(result: result, userQuestion: "List all users")
|
||||
#expect(summary == "There are 2 users: Alice (30) and Bob (25).")
|
||||
}
|
||||
|
||||
@Test("Summarize returns empty result message without calling LLM")
|
||||
func summarizeEmptyResult() async throws {
|
||||
let result = makeResult(columns: ["id"], rows: [])
|
||||
let renderer = makeMockRenderer()
|
||||
let summary = try await renderer.summarize(result: result, userQuestion: "Find users")
|
||||
#expect(summary == "No results found for your query.")
|
||||
}
|
||||
|
||||
@Test("Summarize returns direct aggregate without calling LLM")
|
||||
func summarizeAggregate() async throws {
|
||||
let result = makeResult(
|
||||
columns: ["COUNT(*)"],
|
||||
rows: [["COUNT(*)": .integer(42)]]
|
||||
)
|
||||
let renderer = makeMockRenderer()
|
||||
let summary = try await renderer.summarize(result: result, userQuestion: "How many?")
|
||||
#expect(summary.contains("42"))
|
||||
}
|
||||
|
||||
@Test("Summarize mutation returns template without calling LLM")
|
||||
func summarizeMutation() async throws {
|
||||
let result = QueryResult(
|
||||
columns: [],
|
||||
rows: [],
|
||||
sql: "UPDATE users SET name = 'Test' WHERE id = 1",
|
||||
executionTime: 0.01,
|
||||
rowsAffected: 1
|
||||
)
|
||||
let renderer = makeMockRenderer()
|
||||
let summary = try await renderer.summarize(result: result, userQuestion: "Update user")
|
||||
#expect(summary == "Successfully updated 1 row.")
|
||||
}
|
||||
|
||||
@Test("Summarize passes context to LLM prompt")
|
||||
func summarizeWithContext() async throws {
|
||||
let result = makeResult(
|
||||
columns: ["total"],
|
||||
rows: [
|
||||
["total": .real(100.0)],
|
||||
["total": .real(200.0)],
|
||||
]
|
||||
)
|
||||
let mockModel = MockLanguageModel(responseText: "The totals are 100 and 200.")
|
||||
let renderer = TextSummaryRenderer(model: mockModel)
|
||||
let summary = try await renderer.summarize(
|
||||
result: result,
|
||||
userQuestion: "Show totals",
|
||||
context: "Amounts are in USD"
|
||||
)
|
||||
#expect(summary == "The totals are 100 and 200.")
|
||||
}
|
||||
|
||||
// MARK: - Helpers
|
||||
|
||||
private func makeResult(
|
||||
columns: [String],
|
||||
rows: [[String: QueryResult.Value]],
|
||||
sql: String = "SELECT * FROM test"
|
||||
) -> QueryResult {
|
||||
QueryResult(columns: columns, rows: rows, sql: sql, executionTime: 0.01)
|
||||
}
|
||||
|
||||
/// Creates a renderer with a mock model (for localSummary tests that don't hit the LLM).
|
||||
private func makeMockRenderer() -> TextSummaryRenderer {
|
||||
TextSummaryRenderer(model: MockLanguageModel())
|
||||
}
|
||||
}
|
||||
246
Tests/SwiftDBAITests/ToolExecutionDelegateTests.swift
Normal file
246
Tests/SwiftDBAITests/ToolExecutionDelegateTests.swift
Normal file
@@ -0,0 +1,246 @@
|
||||
// ToolExecutionDelegateTests.swift
|
||||
// SwiftDBAITests
|
||||
|
||||
import Foundation
|
||||
import Testing
|
||||
@testable import SwiftDBAI
|
||||
|
||||
@Suite("DestructiveClassification")
|
||||
struct DestructiveClassificationTests {
|
||||
|
||||
// MARK: - Safe statements
|
||||
|
||||
@Test("SELECT is classified as safe")
|
||||
func selectIsSafe() {
|
||||
let result = classifySQL("SELECT * FROM users")
|
||||
#expect(result == .safe)
|
||||
#expect(!result.requiresConfirmation)
|
||||
#expect(!result.isMutating)
|
||||
}
|
||||
|
||||
@Test("WITH (CTE) is classified as safe")
|
||||
func withIsSafe() {
|
||||
let result = classifySQL("WITH cte AS (SELECT 1) SELECT * FROM cte")
|
||||
#expect(result == .safe)
|
||||
}
|
||||
|
||||
// MARK: - Mutation statements
|
||||
|
||||
@Test("INSERT is classified as mutation")
|
||||
func insertIsMutation() {
|
||||
let result = classifySQL("INSERT INTO users (name) VALUES ('Alice')")
|
||||
#expect(result == .mutation(.insert))
|
||||
#expect(!result.requiresConfirmation)
|
||||
#expect(result.isMutating)
|
||||
}
|
||||
|
||||
@Test("UPDATE is classified as mutation")
|
||||
func updateIsMutation() {
|
||||
let result = classifySQL("UPDATE users SET name = 'Bob' WHERE id = 1")
|
||||
#expect(result == .mutation(.update))
|
||||
#expect(!result.requiresConfirmation)
|
||||
#expect(result.isMutating)
|
||||
}
|
||||
|
||||
// MARK: - Destructive statements
|
||||
|
||||
@Test("DELETE is classified as destructive")
|
||||
func deleteIsDestructive() {
|
||||
let result = classifySQL("DELETE FROM users WHERE id = 1")
|
||||
#expect(result == .destructive(.delete))
|
||||
#expect(result.requiresConfirmation)
|
||||
#expect(result.isMutating)
|
||||
}
|
||||
|
||||
@Test("DROP is classified as destructive")
|
||||
func dropIsDestructive() {
|
||||
let result = classifySQL("DROP TABLE users")
|
||||
#expect(result == .destructive(.drop))
|
||||
#expect(result.requiresConfirmation)
|
||||
}
|
||||
|
||||
@Test("ALTER is classified as destructive")
|
||||
func alterIsDestructive() {
|
||||
let result = classifySQL("ALTER TABLE users ADD COLUMN age INTEGER")
|
||||
#expect(result == .destructive(.alter))
|
||||
#expect(result.requiresConfirmation)
|
||||
}
|
||||
|
||||
@Test("TRUNCATE is classified as destructive")
|
||||
func truncateIsDestructive() {
|
||||
let result = classifySQL("TRUNCATE TABLE users")
|
||||
#expect(result == .destructive(.truncate))
|
||||
#expect(result.requiresConfirmation)
|
||||
}
|
||||
|
||||
// MARK: - Case insensitivity
|
||||
|
||||
@Test("Classification is case-insensitive")
|
||||
func caseInsensitive() {
|
||||
#expect(classifySQL("delete from users") == .destructive(.delete))
|
||||
#expect(classifySQL("Drop Table foo") == .destructive(.drop))
|
||||
#expect(classifySQL("select 1") == .safe)
|
||||
#expect(classifySQL("INSERT into t values (1)") == .mutation(.insert))
|
||||
}
|
||||
|
||||
// MARK: - Leading whitespace
|
||||
|
||||
@Test("Classification ignores leading whitespace")
|
||||
func leadingWhitespace() {
|
||||
#expect(classifySQL(" \n DELETE FROM users") == .destructive(.delete))
|
||||
#expect(classifySQL("\t SELECT 1") == .safe)
|
||||
}
|
||||
|
||||
// MARK: - SQLStatementKind
|
||||
|
||||
@Test("Destructive kinds are correct")
|
||||
func destructiveKinds() {
|
||||
#expect(SQLStatementKind.delete.isDestructive)
|
||||
#expect(SQLStatementKind.drop.isDestructive)
|
||||
#expect(SQLStatementKind.alter.isDestructive)
|
||||
#expect(SQLStatementKind.truncate.isDestructive)
|
||||
#expect(!SQLStatementKind.select.isDestructive)
|
||||
#expect(!SQLStatementKind.insert.isDestructive)
|
||||
#expect(!SQLStatementKind.update.isDestructive)
|
||||
}
|
||||
|
||||
@Test("Mutation kinds are correct")
|
||||
func mutationKinds() {
|
||||
#expect(SQLStatementKind.insert.isMutation)
|
||||
#expect(SQLStatementKind.update.isMutation)
|
||||
#expect(!SQLStatementKind.select.isMutation)
|
||||
#expect(!SQLStatementKind.delete.isMutation)
|
||||
}
|
||||
}
|
||||
|
||||
@Suite("ToolExecutionDelegate")
|
||||
struct ToolExecutionDelegateProtocolTests {
|
||||
|
||||
@Test("AutoApproveDelegate approves all operations")
|
||||
func autoApprove() async {
|
||||
let delegate = AutoApproveDelegate()
|
||||
let context = DestructiveOperationContext(
|
||||
sql: "DELETE FROM users",
|
||||
statementKind: .delete,
|
||||
classification: .destructive(.delete),
|
||||
description: "Delete all rows from users"
|
||||
)
|
||||
let result = await delegate.confirmDestructiveOperation(context)
|
||||
#expect(result == true)
|
||||
}
|
||||
|
||||
@Test("RejectAllDelegate rejects all operations")
|
||||
func rejectAll() async {
|
||||
let delegate = RejectAllDelegate()
|
||||
let context = DestructiveOperationContext(
|
||||
sql: "DROP TABLE users",
|
||||
statementKind: .drop,
|
||||
classification: .destructive(.drop),
|
||||
description: "Drop the users table"
|
||||
)
|
||||
let result = await delegate.confirmDestructiveOperation(context)
|
||||
#expect(result == false)
|
||||
}
|
||||
|
||||
@Test("Default delegate implementation rejects destructive operations")
|
||||
func defaultRejects() async {
|
||||
struct EmptyDelegate: ToolExecutionDelegate {}
|
||||
let delegate = EmptyDelegate()
|
||||
let context = DestructiveOperationContext(
|
||||
sql: "DELETE FROM users",
|
||||
statementKind: .delete,
|
||||
classification: .destructive(.delete),
|
||||
description: "Delete rows"
|
||||
)
|
||||
let result = await delegate.confirmDestructiveOperation(context)
|
||||
#expect(result == false)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Tracking Delegate for Integration Tests
|
||||
|
||||
/// A delegate that records all calls for verification in tests.
|
||||
private final class TrackingDelegate: ToolExecutionDelegate, @unchecked Sendable {
|
||||
private let lock = NSLock()
|
||||
|
||||
private var _confirmCalls: [DestructiveOperationContext] = []
|
||||
private var _willExecuteCalls: [(sql: String, classification: DestructiveClassification)] = []
|
||||
private var _didExecuteCalls: [(sql: String, success: Bool)] = []
|
||||
private var _confirmResult: Bool
|
||||
|
||||
var confirmCalls: [DestructiveOperationContext] {
|
||||
lock.withLock { _confirmCalls }
|
||||
}
|
||||
|
||||
var willExecuteCalls: [(sql: String, classification: DestructiveClassification)] {
|
||||
lock.withLock { _willExecuteCalls }
|
||||
}
|
||||
|
||||
var didExecuteCalls: [(sql: String, success: Bool)] {
|
||||
lock.withLock { _didExecuteCalls }
|
||||
}
|
||||
|
||||
init(confirmResult: Bool) {
|
||||
self._confirmResult = confirmResult
|
||||
}
|
||||
|
||||
func confirmDestructiveOperation(_ context: DestructiveOperationContext) async -> Bool {
|
||||
lock.withLock { _confirmCalls.append(context) }
|
||||
return _confirmResult
|
||||
}
|
||||
|
||||
func willExecuteSQL(_ sql: String, classification: DestructiveClassification) async {
|
||||
lock.withLock { _willExecuteCalls.append((sql: sql, classification: classification)) }
|
||||
}
|
||||
|
||||
func didExecuteSQL(_ sql: String, success: Bool) async {
|
||||
lock.withLock { _didExecuteCalls.append((sql: sql, success: success)) }
|
||||
}
|
||||
}
|
||||
|
||||
@Suite("ToolExecutionDelegate - ChatEngine Integration")
|
||||
struct DelegateIntegrationTests {
|
||||
|
||||
@Test("DestructiveOperationContext captures target table")
|
||||
func contextCapturesTable() {
|
||||
let context = DestructiveOperationContext(
|
||||
sql: "DELETE FROM users WHERE id = 1",
|
||||
statementKind: .delete,
|
||||
classification: .destructive(.delete),
|
||||
description: "Delete from users",
|
||||
targetTable: "users"
|
||||
)
|
||||
#expect(context.targetTable == "users")
|
||||
#expect(context.statementKind == .delete)
|
||||
#expect(context.classification.requiresConfirmation)
|
||||
}
|
||||
|
||||
@Test("classifySQL returns destructive for DELETE")
|
||||
func classifySQLDestructive() {
|
||||
let result = classifySQL("DELETE FROM orders WHERE id = 5")
|
||||
#expect(result == .destructive(.delete))
|
||||
#expect(result.requiresConfirmation)
|
||||
}
|
||||
|
||||
@Test("classifySQL returns safe for SELECT")
|
||||
func classifySQLSafe() {
|
||||
let result = classifySQL("SELECT * FROM users")
|
||||
#expect(result == .safe)
|
||||
#expect(!result.requiresConfirmation)
|
||||
}
|
||||
|
||||
@Test("classifySQL returns mutation for INSERT")
|
||||
func classifySQLMutation() {
|
||||
let result = classifySQL("INSERT INTO users (name) VALUES ('test')")
|
||||
#expect(result == .mutation(.insert))
|
||||
#expect(!result.requiresConfirmation)
|
||||
}
|
||||
|
||||
@Test("DestructiveClassification.isMutating is true for mutations and destructive")
|
||||
func isMutatingCovers() {
|
||||
#expect(DestructiveClassification.mutation(.insert).isMutating)
|
||||
#expect(DestructiveClassification.mutation(.update).isMutating)
|
||||
#expect(DestructiveClassification.destructive(.delete).isMutating)
|
||||
#expect(!DestructiveClassification.safe.isMutating)
|
||||
}
|
||||
}
|
||||
617
Tests/SwiftDBAITests/UnifiedProviderTestHarness.swift
Normal file
617
Tests/SwiftDBAITests/UnifiedProviderTestHarness.swift
Normal file
@@ -0,0 +1,617 @@
|
||||
// UnifiedProviderTestHarness.swift
|
||||
// SwiftDBAI Tests
|
||||
//
|
||||
// A unified test harness that validates all seven provider types
|
||||
// conform to the AnyLanguageModel protocol and produce consistent
|
||||
// ChatEngine-compatible output. Covers: OpenAI, Anthropic, Gemini,
|
||||
// OpenAI-Compatible, Ollama, llama.cpp, and on-device (MLX/CoreML).
|
||||
|
||||
import AnyLanguageModel
|
||||
import Foundation
|
||||
import GRDB
|
||||
import Testing
|
||||
|
||||
@testable import SwiftDBAI
|
||||
|
||||
// MARK: - Provider-Simulating Mock Models
|
||||
|
||||
/// A mock that records which LanguageModel protocol methods were called,
|
||||
/// the arguments passed, and returns configurable responses.
|
||||
/// Used to validate that every provider path through ChatEngine
|
||||
/// exercises the same protocol surface.
|
||||
final class ProviderConformanceMock: LanguageModel, @unchecked Sendable {
|
||||
typealias UnavailableReason = Never
|
||||
|
||||
/// Track calls to verify protocol conformance exercised fully.
|
||||
struct CallRecord: Sendable {
|
||||
let method: String
|
||||
let promptDescription: String
|
||||
let timestamp: Date
|
||||
}
|
||||
|
||||
private let lock = NSLock()
|
||||
private var _calls: [CallRecord] = []
|
||||
private let _responses: [String]
|
||||
private var _callIndex = 0
|
||||
|
||||
/// Label for diagnostics.
|
||||
let providerName: String
|
||||
|
||||
var calls: [CallRecord] {
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
return _calls
|
||||
}
|
||||
|
||||
init(providerName: String, responses: [String]) {
|
||||
self.providerName = providerName
|
||||
self._responses = responses
|
||||
}
|
||||
|
||||
private func nextResponse() -> String {
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
let idx = _callIndex
|
||||
_callIndex += 1
|
||||
return idx < _responses.count ? _responses[idx] : "fallback response"
|
||||
}
|
||||
|
||||
private func recordCall(method: String, prompt: String) {
|
||||
lock.lock()
|
||||
_calls.append(CallRecord(method: method, promptDescription: prompt, timestamp: Date()))
|
||||
lock.unlock()
|
||||
}
|
||||
|
||||
func respond<Content>(
|
||||
within session: LanguageModelSession,
|
||||
to prompt: Prompt,
|
||||
generating type: Content.Type,
|
||||
includeSchemaInPrompt: Bool,
|
||||
options: GenerationOptions
|
||||
) async throws -> LanguageModelSession.Response<Content> where Content: Generable {
|
||||
recordCall(method: "respond", prompt: prompt.description)
|
||||
let text = nextResponse()
|
||||
let rawContent = GeneratedContent(kind: .string(text))
|
||||
let content = try Content(rawContent)
|
||||
return LanguageModelSession.Response(
|
||||
content: content,
|
||||
rawContent: rawContent,
|
||||
transcriptEntries: [][...]
|
||||
)
|
||||
}
|
||||
|
||||
func streamResponse<Content>(
|
||||
within session: LanguageModelSession,
|
||||
to prompt: Prompt,
|
||||
generating type: Content.Type,
|
||||
includeSchemaInPrompt: Bool,
|
||||
options: GenerationOptions
|
||||
) -> sending LanguageModelSession.ResponseStream<Content> where Content: Generable {
|
||||
recordCall(method: "streamResponse", prompt: prompt.description)
|
||||
let text = nextResponse()
|
||||
let rawContent = GeneratedContent(kind: .string(text))
|
||||
let content = try! Content(rawContent)
|
||||
return LanguageModelSession.ResponseStream(content: content, rawContent: rawContent)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Test Database Helper
|
||||
|
||||
/// Creates a minimal in-memory database for provider integration tests.
|
||||
private func makeProviderTestDatabase() throws -> DatabaseQueue {
|
||||
let db = try DatabaseQueue(path: ":memory:")
|
||||
try db.write { db in
|
||||
try db.execute(sql: """
|
||||
CREATE TABLE products (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
price REAL NOT NULL,
|
||||
category TEXT NOT NULL
|
||||
)
|
||||
""")
|
||||
try db.execute(sql: """
|
||||
INSERT INTO products (name, price, category) VALUES
|
||||
('Widget', 9.99, 'tools'),
|
||||
('Gadget', 24.99, 'electronics'),
|
||||
('Doohickey', 4.50, 'tools')
|
||||
""")
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// MARK: - Unified Provider Test Harness
|
||||
|
||||
@Suite("Unified Provider Test Harness")
|
||||
struct UnifiedProviderTestHarness {
|
||||
|
||||
// MARK: - Provider Configuration Enumeration
|
||||
|
||||
/// All seven provider types that SwiftDBAI supports.
|
||||
enum TestedProvider: String, CaseIterable {
|
||||
case openAI
|
||||
case anthropic
|
||||
case gemini
|
||||
case openAICompatible
|
||||
case ollama
|
||||
case llamaCpp
|
||||
case onDevice
|
||||
}
|
||||
|
||||
/// Creates a ProviderConformanceMock simulating each provider type.
|
||||
private func makeMock(for provider: TestedProvider, responses: [String]) -> ProviderConformanceMock {
|
||||
ProviderConformanceMock(providerName: provider.rawValue, responses: responses)
|
||||
}
|
||||
|
||||
// MARK: - 1. Protocol Conformance — All Providers Are LanguageModel
|
||||
|
||||
@Test("All provider types produce instances conforming to LanguageModel protocol")
|
||||
func allProvidersConformToLanguageModel() {
|
||||
// Cloud providers via ProviderConfiguration.makeModel()
|
||||
let openAI = ProviderConfiguration.openAI(apiKey: "test-key", model: "gpt-4o").makeModel()
|
||||
let anthropic = ProviderConfiguration.anthropic(apiKey: "test-key", model: "claude-sonnet-4-20250514").makeModel()
|
||||
let gemini = ProviderConfiguration.gemini(apiKey: "test-key", model: "gemini-2.0-flash").makeModel()
|
||||
let openAICompatible = ProviderConfiguration.openAICompatible(
|
||||
apiKey: "test-key",
|
||||
model: "local-model",
|
||||
baseURL: URL(string: "http://localhost:8080/v1/")!
|
||||
).makeModel()
|
||||
let ollama = ProviderConfiguration.ollama(model: "llama3.2").makeModel()
|
||||
let llamaCpp = ProviderConfiguration.llamaCpp(model: "default").makeModel()
|
||||
// On-device MLX (wraps as openAICompatible internally)
|
||||
let onDeviceMLX = ProviderConfiguration.onDeviceMLX(
|
||||
MLXProviderConfiguration(modelId: "test-model")
|
||||
).makeModel()
|
||||
|
||||
// Verify all are LanguageModel
|
||||
let models: [(String, any LanguageModel)] = [
|
||||
("OpenAI", openAI),
|
||||
("Anthropic", anthropic),
|
||||
("Gemini", gemini),
|
||||
("OpenAI-Compatible", openAICompatible),
|
||||
("Ollama", ollama),
|
||||
("llama.cpp", llamaCpp),
|
||||
("On-Device MLX", onDeviceMLX),
|
||||
]
|
||||
|
||||
for (name, model) in models {
|
||||
// Protocol conformance is compile-time, but we verify isAvailable works
|
||||
#expect(model.isAvailable, "\(name) model should report as available")
|
||||
}
|
||||
}
|
||||
|
||||
@Test("All provider configurations produce correct concrete model types")
|
||||
func providerConfigurationsProduceCorrectTypes() {
|
||||
let openAI = ProviderConfiguration.openAI(apiKey: "k", model: "m").makeModel()
|
||||
#expect(openAI is OpenAILanguageModel, "OpenAI config should produce OpenAILanguageModel")
|
||||
|
||||
let anthropic = ProviderConfiguration.anthropic(apiKey: "k", model: "m").makeModel()
|
||||
#expect(anthropic is AnthropicLanguageModel, "Anthropic config should produce AnthropicLanguageModel")
|
||||
|
||||
let gemini = ProviderConfiguration.gemini(apiKey: "k", model: "m").makeModel()
|
||||
#expect(gemini is GeminiLanguageModel, "Gemini config should produce GeminiLanguageModel")
|
||||
|
||||
let openAICompat = ProviderConfiguration.openAICompatible(
|
||||
apiKey: "k", model: "m", baseURL: URL(string: "http://localhost:1234")!
|
||||
).makeModel()
|
||||
#expect(openAICompat is OpenAILanguageModel, "OpenAI-Compatible config should produce OpenAILanguageModel")
|
||||
|
||||
let ollama = ProviderConfiguration.ollama(model: "m").makeModel()
|
||||
#expect(ollama is OllamaLanguageModel, "Ollama config should produce OllamaLanguageModel")
|
||||
|
||||
let llamaCpp = ProviderConfiguration.llamaCpp(model: "m").makeModel()
|
||||
#expect(llamaCpp is OpenAILanguageModel, "llama.cpp config should produce OpenAILanguageModel (OpenAI-compatible)")
|
||||
|
||||
// On-device uses OpenAILanguageModel internally as a wrapper
|
||||
let onDevice = ProviderConfiguration.onDeviceMLX(
|
||||
MLXProviderConfiguration(modelId: "test")
|
||||
).makeModel()
|
||||
#expect(onDevice is OpenAILanguageModel, "On-device MLX config should produce OpenAILanguageModel wrapper")
|
||||
}
|
||||
|
||||
// MARK: - 2. Consistent ChatEngine-Compatible Output
|
||||
|
||||
@Test("Every provider mock produces valid ChatEngine responses for SELECT queries",
|
||||
arguments: TestedProvider.allCases)
|
||||
func providerProducesValidChatEngineResponse(provider: TestedProvider) async throws {
|
||||
let db = try makeProviderTestDatabase()
|
||||
let mock = makeMock(for: provider, responses: [
|
||||
"SELECT COUNT(*) FROM products", // SQL generation
|
||||
"There are 3 products in the database.", // Summary (fallback)
|
||||
])
|
||||
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
let response = try await engine.send("How many products are there?")
|
||||
|
||||
// All providers must produce:
|
||||
// 1. Non-empty summary
|
||||
#expect(!response.summary.isEmpty, "\(provider.rawValue): summary must not be empty")
|
||||
|
||||
// 2. Valid SQL that was executed
|
||||
#expect(response.sql == "SELECT COUNT(*) FROM products",
|
||||
"\(provider.rawValue): SQL must match generated query")
|
||||
|
||||
// 3. A QueryResult with data
|
||||
#expect(response.queryResult != nil, "\(provider.rawValue): queryResult must exist")
|
||||
#expect(response.queryResult?.rowCount == 1, "\(provider.rawValue): should have 1 row for COUNT")
|
||||
}
|
||||
|
||||
@Test("Every provider mock produces valid ChatEngine responses for multi-row SELECT",
|
||||
arguments: TestedProvider.allCases)
|
||||
func providerProducesMultiRowResponse(provider: TestedProvider) async throws {
|
||||
let db = try makeProviderTestDatabase()
|
||||
let mock = makeMock(for: provider, responses: [
|
||||
"SELECT name, price FROM products ORDER BY price DESC",
|
||||
"Here are the products sorted by price.",
|
||||
])
|
||||
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
let response = try await engine.send("List products by price")
|
||||
|
||||
#expect(response.queryResult != nil, "\(provider.rawValue): queryResult must exist")
|
||||
#expect(response.queryResult?.rowCount == 3, "\(provider.rawValue): should return all 3 products")
|
||||
#expect(response.queryResult?.columns.contains("name") == true,
|
||||
"\(provider.rawValue): columns must include 'name'")
|
||||
#expect(response.queryResult?.columns.contains("price") == true,
|
||||
"\(provider.rawValue): columns must include 'price'")
|
||||
}
|
||||
|
||||
// MARK: - 3. Consistent LanguageModelSession Integration
|
||||
|
||||
@Test("Every provider mock works through LanguageModelSession.respond(to:)",
|
||||
arguments: TestedProvider.allCases)
|
||||
func providerWorksWithSession(provider: TestedProvider) async throws {
|
||||
let mock = makeMock(for: provider, responses: [
|
||||
"SELECT 1 AS test",
|
||||
])
|
||||
|
||||
let session = LanguageModelSession(
|
||||
model: mock,
|
||||
instructions: "You are a SQL assistant."
|
||||
)
|
||||
|
||||
let response = try await session.respond(to: "Generate a test query")
|
||||
|
||||
// Verify the response content is the expected string
|
||||
#expect(response.content == "SELECT 1 AS test",
|
||||
"\(provider.rawValue): session response should match mock output")
|
||||
|
||||
// Verify the mock received the call
|
||||
#expect(mock.calls.count == 1, "\(provider.rawValue): should have exactly 1 call")
|
||||
#expect(mock.calls.first?.method == "respond",
|
||||
"\(provider.rawValue): should call respond method")
|
||||
}
|
||||
|
||||
@Test("Every provider mock works through LanguageModelSession.streamResponse(to:)",
|
||||
arguments: TestedProvider.allCases)
|
||||
func providerWorksWithStreamSession(provider: TestedProvider) async throws {
|
||||
let mock = makeMock(for: provider, responses: [
|
||||
"SELECT 42 AS answer",
|
||||
])
|
||||
|
||||
let session = LanguageModelSession(
|
||||
model: mock,
|
||||
instructions: "You are a SQL assistant."
|
||||
)
|
||||
|
||||
let stream = session.streamResponse(to: "Give me a number")
|
||||
let collected = try await stream.collect()
|
||||
|
||||
#expect(collected.content == "SELECT 42 AS answer",
|
||||
"\(provider.rawValue): stream collected response should match mock output")
|
||||
#expect(mock.calls.count == 1, "\(provider.rawValue): should have exactly 1 call")
|
||||
#expect(mock.calls.first?.method == "streamResponse",
|
||||
"\(provider.rawValue): should call streamResponse method")
|
||||
}
|
||||
|
||||
// MARK: - 4. Schema Introspection Works Identically Across Providers
|
||||
|
||||
@Test("Schema introspection returns same schema regardless of provider",
|
||||
arguments: TestedProvider.allCases)
|
||||
func schemaIntrospectionIsProviderAgnostic(provider: TestedProvider) async throws {
|
||||
let db = try makeProviderTestDatabase()
|
||||
let mock = makeMock(for: provider, responses: ["SELECT 1"])
|
||||
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
let schema = try await engine.prepareSchema()
|
||||
|
||||
#expect(schema.tableNames.contains("products"),
|
||||
"\(provider.rawValue): schema must include 'products' table")
|
||||
#expect(schema.tableNames.count == 1,
|
||||
"\(provider.rawValue): should have exactly 1 table")
|
||||
|
||||
let table = schema.tables["products"]
|
||||
#expect(table != nil, "\(provider.rawValue): must find products table")
|
||||
#expect(table?.columns.count == 4,
|
||||
"\(provider.rawValue): products table must have 4 columns")
|
||||
}
|
||||
|
||||
// MARK: - 5. Error Handling Consistency
|
||||
|
||||
@Test("All providers handle empty schema consistently",
|
||||
arguments: TestedProvider.allCases)
|
||||
func emptySchemaHandledConsistently(provider: TestedProvider) async throws {
|
||||
let db = try DatabaseQueue(path: ":memory:")
|
||||
let mock = makeMock(for: provider, responses: ["SELECT 1"])
|
||||
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
|
||||
do {
|
||||
_ = try await engine.send("Show me data")
|
||||
Issue.record("\(provider.rawValue): should throw for empty schema")
|
||||
} catch let error as SwiftDBAIError {
|
||||
#expect(error == .emptySchema,
|
||||
"\(provider.rawValue): must throw .emptySchema for database with no tables")
|
||||
}
|
||||
}
|
||||
|
||||
@Test("All providers reject disallowed SQL operations consistently",
|
||||
arguments: TestedProvider.allCases)
|
||||
func disallowedSQLRejectedConsistently(provider: TestedProvider) async throws {
|
||||
let db = try makeProviderTestDatabase()
|
||||
let mock = makeMock(for: provider, responses: [
|
||||
"DELETE FROM products WHERE id = 1",
|
||||
])
|
||||
|
||||
// Default allowlist is readOnly (SELECT only)
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
|
||||
do {
|
||||
_ = try await engine.send("Delete the first product")
|
||||
Issue.record("\(provider.rawValue): should reject DELETE when allowlist is readOnly")
|
||||
} catch {
|
||||
// All providers must trigger the same error path for disallowed operations
|
||||
#expect(error is SwiftDBAIError,
|
||||
"\(provider.rawValue): error must be SwiftDBAIError")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - 6. Conversation History Consistency
|
||||
|
||||
@Test("Conversation history works identically for all providers",
|
||||
arguments: TestedProvider.allCases)
|
||||
func conversationHistoryConsistent(provider: TestedProvider) async throws {
|
||||
let db = try makeProviderTestDatabase()
|
||||
// ChatEngine calls LLM for SQL generation, then TextSummaryRenderer
|
||||
// may call LLM for summarization. For aggregate queries (COUNT, AVG),
|
||||
// TextSummaryRenderer uses a template and skips the LLM call.
|
||||
// So the mock sequence is: SQL1, SQL2 (each followed by template summary).
|
||||
let mock = makeMock(for: provider, responses: [
|
||||
"SELECT COUNT(*) FROM products",
|
||||
"SELECT AVG(price) FROM products",
|
||||
])
|
||||
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
|
||||
_ = try await engine.send("How many products?")
|
||||
_ = try await engine.send("What is the average price?")
|
||||
|
||||
let messages = engine.messages
|
||||
#expect(messages.count == 4,
|
||||
"\(provider.rawValue): should have 4 messages (2 user + 2 assistant)")
|
||||
#expect(messages[0].role == .user, "\(provider.rawValue): first message should be user")
|
||||
#expect(messages[1].role == .assistant, "\(provider.rawValue): second message should be assistant")
|
||||
#expect(messages[2].role == .user, "\(provider.rawValue): third message should be user")
|
||||
#expect(messages[3].role == .assistant, "\(provider.rawValue): fourth message should be assistant")
|
||||
|
||||
// Both assistant messages must have SQL
|
||||
#expect(messages[1].sql != nil, "\(provider.rawValue): first response must have SQL")
|
||||
#expect(messages[3].sql != nil, "\(provider.rawValue): second response must have SQL")
|
||||
}
|
||||
|
||||
// MARK: - 7. ProviderConfiguration Roundtrip
|
||||
|
||||
@Test("All cloud provider configurations roundtrip through makeModel()")
|
||||
func allCloudProvidersRoundtrip() {
|
||||
let configs: [(String, ProviderConfiguration)] = [
|
||||
("OpenAI", .openAI(apiKey: "sk-test", model: "gpt-4o")),
|
||||
("OpenAI Responses", .openAI(apiKey: "sk-test", model: "gpt-4o", variant: .responses)),
|
||||
("Anthropic", .anthropic(apiKey: "sk-ant-test", model: "claude-sonnet-4-20250514")),
|
||||
("Anthropic+version", .anthropic(apiKey: "sk-ant-test", model: "claude-sonnet-4-20250514", apiVersion: "2024-01-01")),
|
||||
("Anthropic+betas", .anthropic(apiKey: "sk-ant-test", model: "claude-sonnet-4-20250514", betas: ["computer-use"])),
|
||||
("Gemini", .gemini(apiKey: "AIza-test", model: "gemini-2.0-flash")),
|
||||
("Gemini+version", .gemini(apiKey: "AIza-test", model: "gemini-2.0-flash", apiVersion: "v1")),
|
||||
("OpenAI-Compatible", .openAICompatible(
|
||||
apiKey: "key", model: "model", baseURL: URL(string: "http://localhost:1234")!
|
||||
)),
|
||||
("Ollama", .ollama(model: "llama3.2")),
|
||||
("Ollama+custom URL", .ollama(model: "qwen2.5", baseURL: URL(string: "http://192.168.1.100:11434")!)),
|
||||
("llama.cpp", .llamaCpp(model: "default")),
|
||||
("llama.cpp+custom", .llamaCpp(model: "my-model", baseURL: URL(string: "http://localhost:9090")!)),
|
||||
]
|
||||
|
||||
for (name, config) in configs {
|
||||
let model = config.makeModel()
|
||||
#expect(model.isAvailable, "\(name): model must be available after makeModel()")
|
||||
}
|
||||
}
|
||||
|
||||
@Test("On-device provider configurations produce valid models")
|
||||
func onDeviceProvidersRoundtrip() {
|
||||
let mlxConfigs: [MLXProviderConfiguration] = [
|
||||
.llama3_2_3B(),
|
||||
.qwen2_5_coder_3B(),
|
||||
.phi3_5_mini(),
|
||||
MLXProviderConfiguration(modelId: "custom-model", temperature: 0.2),
|
||||
]
|
||||
|
||||
for mlxConfig in mlxConfigs {
|
||||
let providerConfig = ProviderConfiguration.onDeviceMLX(mlxConfig)
|
||||
let model = providerConfig.makeModel()
|
||||
#expect(model.isAvailable, "MLX model '\(mlxConfig.modelId)' must be available")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - 8. Write Operation Allowlist Consistency
|
||||
|
||||
@Test("Write operations require explicit opt-in for all providers",
|
||||
arguments: TestedProvider.allCases)
|
||||
func writeOperationsRequireOptIn(provider: TestedProvider) async throws {
|
||||
let db = try makeProviderTestDatabase()
|
||||
|
||||
// Mock returns an INSERT statement
|
||||
let mock = makeMock(for: provider, responses: [
|
||||
"INSERT INTO products (name, price, category) VALUES ('New', 1.00, 'misc')",
|
||||
])
|
||||
|
||||
// readOnly allowlist (default)
|
||||
let readOnlyEngine = ChatEngine(database: db, model: mock)
|
||||
|
||||
do {
|
||||
_ = try await readOnlyEngine.send("Add a new product")
|
||||
Issue.record("\(provider.rawValue): INSERT should be rejected with readOnly allowlist")
|
||||
} catch {
|
||||
#expect(error is SwiftDBAIError,
|
||||
"\(provider.rawValue): must throw SwiftDBAIError for disallowed INSERT")
|
||||
}
|
||||
}
|
||||
|
||||
@Test("Allowed write operations work for all providers",
|
||||
arguments: TestedProvider.allCases)
|
||||
func allowedWriteOperationsWork(provider: TestedProvider) async throws {
|
||||
let db = try makeProviderTestDatabase()
|
||||
|
||||
let mock = makeMock(for: provider, responses: [
|
||||
"INSERT INTO products (name, price, category) VALUES ('NewItem', 1.00, 'misc')",
|
||||
"Successfully added 1 product.",
|
||||
])
|
||||
|
||||
let engine = ChatEngine(
|
||||
database: db,
|
||||
model: mock,
|
||||
allowlist: .standard
|
||||
)
|
||||
|
||||
let response = try await engine.send("Add a product called NewItem")
|
||||
#expect(response.sql?.uppercased().hasPrefix("INSERT") == true,
|
||||
"\(provider.rawValue): SQL should be an INSERT")
|
||||
}
|
||||
|
||||
// MARK: - 9. Response Format Consistency
|
||||
|
||||
@Test("ChatResponse structure is identical regardless of provider",
|
||||
arguments: TestedProvider.allCases)
|
||||
func responseStructureConsistent(provider: TestedProvider) async throws {
|
||||
let db = try makeProviderTestDatabase()
|
||||
let mock = makeMock(for: provider, responses: [
|
||||
"SELECT name, price, category FROM products",
|
||||
"Found 3 products across 2 categories.",
|
||||
])
|
||||
|
||||
let engine = ChatEngine(database: db, model: mock)
|
||||
let response = try await engine.send("Show all products")
|
||||
|
||||
// ChatResponse must always have these properties populated
|
||||
#expect(response.summary.count > 0,
|
||||
"\(provider.rawValue): summary must be non-empty")
|
||||
#expect(response.sql != nil,
|
||||
"\(provider.rawValue): sql must be present")
|
||||
#expect(response.queryResult != nil,
|
||||
"\(provider.rawValue): queryResult must be present")
|
||||
|
||||
// QueryResult structure must match the query
|
||||
let qr = response.queryResult!
|
||||
#expect(qr.columns == ["name", "price", "category"],
|
||||
"\(provider.rawValue): columns must match SELECT clause")
|
||||
#expect(qr.rowCount == 3,
|
||||
"\(provider.rawValue): must return all rows")
|
||||
#expect(qr.sql == "SELECT name, price, category FROM products",
|
||||
"\(provider.rawValue): QueryResult.sql must match executed SQL")
|
||||
#expect(qr.executionTime >= 0,
|
||||
"\(provider.rawValue): execution time must be non-negative")
|
||||
}
|
||||
|
||||
// MARK: - 10. Provider Enum Completeness
|
||||
|
||||
@Test("TestedProvider covers all ProviderConfiguration.Provider cases plus on-device")
|
||||
func testedProviderCoversAllCases() {
|
||||
// ProviderConfiguration.Provider has 6 cases
|
||||
let configProviderCount = ProviderConfiguration.Provider.allCases.count
|
||||
#expect(configProviderCount == 6, "ProviderConfiguration.Provider should have 6 cases")
|
||||
|
||||
// TestedProvider adds on-device for 7 total
|
||||
#expect(TestedProvider.allCases.count == 7, "TestedProvider should cover all 7 provider types")
|
||||
|
||||
// Verify 1:1 mapping for the config providers
|
||||
let configNames = Set(ProviderConfiguration.Provider.allCases.map(\.rawValue))
|
||||
for tested in TestedProvider.allCases where tested != .onDevice {
|
||||
#expect(configNames.contains(tested.rawValue),
|
||||
"\(tested.rawValue) must map to a ProviderConfiguration.Provider case")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - 11. ChatEngine Convenience Init Consistency
|
||||
|
||||
@Test("ChatEngine convenience init with ProviderConfiguration works for all cloud providers")
|
||||
func chatEngineConvenienceInitWorks() throws {
|
||||
let db = try makeProviderTestDatabase()
|
||||
|
||||
let configs: [ProviderConfiguration] = [
|
||||
.openAI(apiKey: "test", model: "gpt-4o"),
|
||||
.anthropic(apiKey: "test", model: "claude-sonnet-4-20250514"),
|
||||
.gemini(apiKey: "test", model: "gemini-2.0-flash"),
|
||||
.openAICompatible(apiKey: "test", model: "m", baseURL: URL(string: "http://localhost:1234")!),
|
||||
.ollama(model: "llama3.2"),
|
||||
.llamaCpp(model: "default"),
|
||||
]
|
||||
|
||||
for config in configs {
|
||||
// This should not throw — it only creates the engine, doesn't call the LLM
|
||||
let engine = ChatEngine(database: db, provider: config)
|
||||
#expect(engine.tableCount == nil, "tableCount should be nil before first query")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - 12. Availability Reporting
|
||||
|
||||
@Test("All real provider models report available by default")
|
||||
func allModelsReportAvailable() {
|
||||
let models: [(String, any LanguageModel)] = [
|
||||
("OpenAI", OpenAILanguageModel(apiKey: "k", model: "m")),
|
||||
("Anthropic", AnthropicLanguageModel(apiKey: "k", model: "m")),
|
||||
("Gemini", GeminiLanguageModel(apiKey: "k", model: "m")),
|
||||
("Ollama", OllamaLanguageModel(model: "m")),
|
||||
]
|
||||
|
||||
for (name, model) in models {
|
||||
#expect(model.isAvailable, "\(name) should be available by default")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - 13. On-Device Pipeline Status
|
||||
|
||||
@Test("On-device inference pipeline starts in notLoaded state")
|
||||
func onDevicePipelineInitialState() {
|
||||
let mlxPipeline = OnDeviceInferencePipeline(
|
||||
mlxConfiguration: .llama3_2_3B()
|
||||
)
|
||||
#expect(mlxPipeline.status == .notLoaded)
|
||||
#expect(mlxPipeline.providerType == .mlx)
|
||||
|
||||
let coreMLPipeline = OnDeviceInferencePipeline(
|
||||
coreMLConfiguration: CoreMLProviderConfiguration(
|
||||
modelURL: URL(fileURLWithPath: "/tmp/test.mlmodelc")
|
||||
)
|
||||
)
|
||||
#expect(coreMLPipeline.status == .notLoaded)
|
||||
#expect(coreMLPipeline.providerType == .coreML)
|
||||
}
|
||||
|
||||
@Test("On-device SQL generation hints are populated for both provider types")
|
||||
func onDeviceSQLHints() {
|
||||
let mlxPipeline = OnDeviceInferencePipeline(mlxConfiguration: .llama3_2_3B())
|
||||
let mlxHints = mlxPipeline.recommendedSQLGenerationHints
|
||||
#expect(mlxHints.maxTokens > 0)
|
||||
#expect(mlxHints.temperature >= 0)
|
||||
#expect(!mlxHints.systemPromptSuffix.isEmpty)
|
||||
|
||||
let coreMLPipeline = OnDeviceInferencePipeline(
|
||||
coreMLConfiguration: CoreMLProviderConfiguration(
|
||||
modelURL: URL(fileURLWithPath: "/tmp/test.mlmodelc")
|
||||
)
|
||||
)
|
||||
let coreMLHints = coreMLPipeline.recommendedSQLGenerationHints
|
||||
#expect(coreMLHints.maxTokens > 0)
|
||||
#expect(coreMLHints.temperature >= 0)
|
||||
#expect(!coreMLHints.systemPromptSuffix.isEmpty)
|
||||
}
|
||||
}
|
||||
489
Tests/SwiftDBAITests/ViewInspectorTests.swift
Normal file
489
Tests/SwiftDBAITests/ViewInspectorTests.swift
Normal file
@@ -0,0 +1,489 @@
|
||||
// ViewInspectorTests.swift
|
||||
// SwiftDBAITests
|
||||
//
|
||||
// ViewInspector-based tests for SwiftDBAI's SwiftUI views.
|
||||
// Tests content and structure of MessageBubbleView, ErrorMessageView,
|
||||
// ScrollableDataTableView, ChatViewConfiguration, and BarChartView.
|
||||
|
||||
import SwiftUI
|
||||
import Testing
|
||||
import ViewInspector
|
||||
@testable import SwiftDBAI
|
||||
|
||||
// MARK: - Test Helpers
|
||||
|
||||
/// Helper to build a DataTable for tests.
|
||||
private func makeDataTable(
|
||||
columnNames: [String] = ["id", "name", "score"],
|
||||
inferredTypes: [DataTable.InferredType] = [.integer, .text, .real],
|
||||
rowCount: Int = 3
|
||||
) -> DataTable {
|
||||
let columns = columnNames.enumerated().map { idx, name in
|
||||
DataTable.Column(name: name, index: idx, inferredType: inferredTypes[idx])
|
||||
}
|
||||
let rows = (0..<rowCount).map { i in
|
||||
DataTable.Row(
|
||||
id: i,
|
||||
values: [
|
||||
.integer(Int64(i + 1)),
|
||||
.text("Item \(i + 1)"),
|
||||
.real(Double(i) * 10.5),
|
||||
],
|
||||
columnNames: columnNames
|
||||
)
|
||||
}
|
||||
return DataTable(columns: columns, rows: rows, sql: "SELECT * FROM test", executionTime: 0.015)
|
||||
}
|
||||
|
||||
/// Helper to build a QueryResult for tests.
|
||||
private func makeQueryResult(
|
||||
columns: [String] = ["id", "name"],
|
||||
rowCount: Int = 2
|
||||
) -> QueryResult {
|
||||
let rows: [[String: QueryResult.Value]] = (0..<rowCount).map { i in
|
||||
["id": .integer(Int64(i + 1)), "name": .text("User \(i + 1)")]
|
||||
}
|
||||
return QueryResult(
|
||||
columns: columns,
|
||||
rows: rows,
|
||||
sql: "SELECT id, name FROM users",
|
||||
executionTime: 0.01
|
||||
)
|
||||
}
|
||||
|
||||
// MARK: - MessageBubbleView Tests
|
||||
|
||||
@Suite("MessageBubbleView - ViewInspector")
|
||||
struct MessageBubbleViewInspectorTests {
|
||||
|
||||
@Test("User message bubble renders the user text")
|
||||
@MainActor
|
||||
func userMessageShowsText() throws {
|
||||
let message = ChatMessage(role: .user, content: "Show me all users")
|
||||
let view = MessageBubbleView(message: message)
|
||||
let inspected = try view.inspect()
|
||||
let found = try inspected.find(text: "Show me all users")
|
||||
#expect(try found.string() == "Show me all users")
|
||||
}
|
||||
|
||||
@Test("Assistant message renders summary text")
|
||||
@MainActor
|
||||
func assistantMessageShowsSummary() throws {
|
||||
let message = ChatMessage(
|
||||
role: .assistant,
|
||||
content: "Found 42 users in the database."
|
||||
)
|
||||
let view = MessageBubbleView(message: message)
|
||||
let inspected = try view.inspect()
|
||||
let found = try inspected.find(text: "Found 42 users in the database.")
|
||||
#expect(try found.string() == "Found 42 users in the database.")
|
||||
}
|
||||
|
||||
@Test("Assistant message with SQL shows disclosure group")
|
||||
@MainActor
|
||||
func assistantMessageWithSQLShowsDisclosure() throws {
|
||||
let message = ChatMessage(
|
||||
role: .assistant,
|
||||
content: "Here are the results.",
|
||||
sql: "SELECT * FROM users"
|
||||
)
|
||||
let view = MessageBubbleView(message: message)
|
||||
let inspected = try view.inspect()
|
||||
// The SQL disclosure contains "SQL Query" label text
|
||||
let sqlLabel = try inspected.find(text: "SQL Query")
|
||||
#expect(try sqlLabel.string() == "SQL Query")
|
||||
}
|
||||
|
||||
@Test("Error message renders error text")
|
||||
@MainActor
|
||||
func errorMessageShowsText() throws {
|
||||
let error = SwiftDBAIError.databaseError(reason: "connection lost")
|
||||
let message = ChatMessage(
|
||||
role: .error,
|
||||
content: error.localizedDescription,
|
||||
error: error
|
||||
)
|
||||
let view = MessageBubbleView(message: message)
|
||||
let inspected = try view.inspect()
|
||||
// The error message text should be present
|
||||
let found = try inspected.find(text: error.localizedDescription)
|
||||
#expect(try found.string() == error.localizedDescription)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - ErrorMessageView Tests
|
||||
|
||||
@Suite("ErrorMessageView - ViewInspector")
|
||||
struct ErrorMessageViewInspectorTests {
|
||||
|
||||
@Test("Safety error shows Operation Blocked title")
|
||||
@MainActor
|
||||
func safetyErrorShowsTitle() throws {
|
||||
let error = SwiftDBAIError.dangerousOperationBlocked(keyword: "DROP")
|
||||
let view = ErrorMessageView(error: error)
|
||||
let inspected = try view.inspect()
|
||||
let title = try inspected.find(text: "Operation Blocked")
|
||||
#expect(try title.string() == "Operation Blocked")
|
||||
}
|
||||
|
||||
@Test("Safety error shows error message")
|
||||
@MainActor
|
||||
func safetyErrorShowsMessage() throws {
|
||||
let error = SwiftDBAIError.operationNotAllowed(operation: "DELETE")
|
||||
let view = ErrorMessageView(error: error)
|
||||
let inspected = try view.inspect()
|
||||
let msg = try inspected.find(text: error.localizedDescription)
|
||||
#expect(try msg.string() == error.localizedDescription)
|
||||
}
|
||||
|
||||
@Test("LLM response unparseable error shows recovery hint")
|
||||
@MainActor
|
||||
func parsingErrorShowsRecoveryHint() throws {
|
||||
let error = SwiftDBAIError.llmResponseUnparseable(response: "gibberish")
|
||||
let view = ErrorMessageView(error: error)
|
||||
let inspected = try view.inspect()
|
||||
let hint = try inspected.find(text: "Try rephrasing your question.")
|
||||
#expect(try hint.string() == "Try rephrasing your question.")
|
||||
}
|
||||
|
||||
@Test("Database error shows Database Error title")
|
||||
@MainActor
|
||||
func databaseErrorShowsTitle() throws {
|
||||
let error = SwiftDBAIError.databaseError(reason: "disk full")
|
||||
let view = ErrorMessageView(error: error)
|
||||
let inspected = try view.inspect()
|
||||
let title = try inspected.find(text: "Database Error")
|
||||
#expect(try title.string() == "Database Error")
|
||||
}
|
||||
|
||||
@Test("LLM timeout shows AI Provider Error title and recovery hint")
|
||||
@MainActor
|
||||
func timeoutErrorShowsTitleAndHint() throws {
|
||||
let error = SwiftDBAIError.llmTimeout(seconds: 30)
|
||||
let view = ErrorMessageView(error: error)
|
||||
let inspected = try view.inspect()
|
||||
let title = try inspected.find(text: "AI Provider Error")
|
||||
#expect(try title.string() == "AI Provider Error")
|
||||
let hint = try inspected.find(text: "The AI took too long. Try a simpler question.")
|
||||
#expect(try hint.string() == "The AI took too long. Try a simpler question.")
|
||||
}
|
||||
|
||||
@Test("LLM failure error shows AI Provider Error title")
|
||||
@MainActor
|
||||
func llmFailureShowsTitle() throws {
|
||||
let error = SwiftDBAIError.llmFailure(reason: "rate limited")
|
||||
let view = ErrorMessageView(error: error)
|
||||
let inspected = try view.inspect()
|
||||
let title = try inspected.find(text: "AI Provider Error")
|
||||
#expect(try title.string() == "AI Provider Error")
|
||||
}
|
||||
|
||||
@Test("Generic error from plain string shows message text")
|
||||
@MainActor
|
||||
func genericStringErrorShowsMessage() throws {
|
||||
let view = ErrorMessageView(message: "Something went wrong")
|
||||
let inspected = try view.inspect()
|
||||
let msg = try inspected.find(text: "Something went wrong")
|
||||
#expect(try msg.string() == "Something went wrong")
|
||||
}
|
||||
|
||||
@Test("Recoverable error with retry shows retry button")
|
||||
@MainActor
|
||||
func recoverableErrorShowsRetryButton() throws {
|
||||
let error = SwiftDBAIError.noSQLGenerated
|
||||
let view = ErrorMessageView(error: error, onRetry: { })
|
||||
let inspected = try view.inspect()
|
||||
let button = try inspected.find(text: "Try Again")
|
||||
#expect(try button.string() == "Try Again")
|
||||
}
|
||||
|
||||
@Test("LLM error with retry shows Retry button")
|
||||
@MainActor
|
||||
func llmErrorShowsRetryButton() throws {
|
||||
let error = SwiftDBAIError.llmFailure(reason: "timeout")
|
||||
let view = ErrorMessageView(error: error, onRetry: { })
|
||||
let inspected = try view.inspect()
|
||||
let button = try inspected.find(text: "Retry")
|
||||
#expect(try button.string() == "Retry")
|
||||
}
|
||||
|
||||
@Test("Query timed out shows Database Error title and recovery hint")
|
||||
@MainActor
|
||||
func queryTimedOutShowsTitleAndHint() throws {
|
||||
let error = SwiftDBAIError.queryTimedOut(seconds: 10)
|
||||
let view = ErrorMessageView(error: error)
|
||||
let inspected = try view.inspect()
|
||||
let title = try inspected.find(text: "Database Error")
|
||||
#expect(try title.string() == "Database Error")
|
||||
let hint = try inspected.find(text: "Try a simpler query or add database indexes.")
|
||||
#expect(try hint.string() == "Try a simpler query or add database indexes.")
|
||||
}
|
||||
|
||||
@Test("Empty schema error shows Database Error title and recovery hint")
|
||||
@MainActor
|
||||
func emptySchemaShowsTitleAndHint() throws {
|
||||
let error = SwiftDBAIError.emptySchema
|
||||
let view = ErrorMessageView(error: error)
|
||||
let inspected = try view.inspect()
|
||||
let title = try inspected.find(text: "Database Error")
|
||||
#expect(try title.string() == "Database Error")
|
||||
let hint = try inspected.find(text: "Add some tables to your database first.")
|
||||
#expect(try hint.string() == "Add some tables to your database first.")
|
||||
}
|
||||
|
||||
@Test("Configuration error shows Configuration Error title")
|
||||
@MainActor
|
||||
func configurationErrorShowsTitle() throws {
|
||||
let error = SwiftDBAIError.configurationError(reason: "missing API key")
|
||||
let view = ErrorMessageView(error: error)
|
||||
let inspected = try view.inspect()
|
||||
let title = try inspected.find(text: "Configuration Error")
|
||||
#expect(try title.string() == "Configuration Error")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - ChatViewConfiguration Tests
|
||||
|
||||
@Suite("ChatViewConfiguration - ViewInspector")
|
||||
struct ChatViewConfigurationInspectorTests {
|
||||
|
||||
@Test("Dark configuration has expected color values")
|
||||
func darkConfigHasCorrectColors() {
|
||||
let dark = ChatViewConfiguration.dark
|
||||
#expect(dark.userTextColor == .white)
|
||||
#expect(dark.backgroundColor == .black)
|
||||
#expect(dark.accentColor == .blue)
|
||||
}
|
||||
|
||||
@Test("Default configuration has expected placeholder and empty state text")
|
||||
func defaultConfigHasExpectedText() {
|
||||
let config = ChatViewConfiguration.default
|
||||
#expect(config.inputPlaceholder == "Ask about your data\u{2026}")
|
||||
#expect(config.emptyStateTitle == "Ask a question about your data")
|
||||
#expect(config.emptyStateSubtitle == "Try something like \"How many records are in the database?\"")
|
||||
}
|
||||
|
||||
@Test("Custom inputPlaceholder propagates through environment")
|
||||
@MainActor
|
||||
func customPlaceholderInEnvironment() throws {
|
||||
var config = ChatViewConfiguration.default
|
||||
config.inputPlaceholder = "Ask about recipes..."
|
||||
config.emptyStateTitle = "Recipe Search"
|
||||
|
||||
// Verify the configuration values are set correctly
|
||||
#expect(config.inputPlaceholder == "Ask about recipes...")
|
||||
#expect(config.emptyStateTitle == "Recipe Search")
|
||||
}
|
||||
|
||||
@Test("Compact configuration has smaller padding and hidden SQL disclosure")
|
||||
func compactConfigProperties() {
|
||||
let compact = ChatViewConfiguration.compact
|
||||
#expect(compact.messagePadding == 8)
|
||||
#expect(compact.bubbleCornerRadius == 10)
|
||||
#expect(compact.showSQLDisclosure == false)
|
||||
#expect(compact.showTimestamps == false)
|
||||
}
|
||||
|
||||
@Test("Dark configuration userBubbleColor is dark gray")
|
||||
func darkConfigUserBubble() {
|
||||
let dark = ChatViewConfiguration.dark
|
||||
// Dark config uses Color(white: 0.25) for user bubble
|
||||
#expect(dark.userBubbleColor == Color(white: 0.25))
|
||||
#expect(dark.assistantBubbleColor == Color(white: 0.15))
|
||||
#expect(dark.inputBarBackgroundColor == Color(white: 0.1))
|
||||
}
|
||||
|
||||
@Test("ErrorMessageView uses environment config for database error color")
|
||||
@MainActor
|
||||
func errorViewUsesDarkConfig() throws {
|
||||
let error = SwiftDBAIError.databaseError(reason: "test error")
|
||||
let view = ErrorMessageView(error: error)
|
||||
.chatViewConfiguration(.dark)
|
||||
let inspected = try view.inspect()
|
||||
// Should still render the error message text
|
||||
let msg = try inspected.find(text: error.localizedDescription)
|
||||
#expect(try msg.string() == error.localizedDescription)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - ScrollableDataTableView Tests
|
||||
|
||||
@Suite("ScrollableDataTableView - ViewInspector")
|
||||
struct ScrollableDataTableViewInspectorTests {
|
||||
|
||||
@Test("Column headers appear in the view")
|
||||
@MainActor
|
||||
func columnHeadersAppear() throws {
|
||||
let table = makeDataTable()
|
||||
let view = ScrollableDataTableView(dataTable: table)
|
||||
let inspected = try view.inspect()
|
||||
|
||||
// Each column header should be present
|
||||
let idHeader = try inspected.find(text: "id")
|
||||
#expect(try idHeader.string() == "id")
|
||||
|
||||
let nameHeader = try inspected.find(text: "name")
|
||||
#expect(try nameHeader.string() == "name")
|
||||
|
||||
let scoreHeader = try inspected.find(text: "score")
|
||||
#expect(try scoreHeader.string() == "score")
|
||||
}
|
||||
|
||||
@Test("Row count text appears in footer")
|
||||
@MainActor
|
||||
func rowCountFooterAppears() throws {
|
||||
let table = makeDataTable(rowCount: 5)
|
||||
let view = ScrollableDataTableView(dataTable: table, showFooter: true)
|
||||
let inspected = try view.inspect()
|
||||
|
||||
let footer = try inspected.find(text: "5 rows")
|
||||
#expect(try footer.string() == "5 rows")
|
||||
}
|
||||
|
||||
@Test("Single row shows singular 'row' text")
|
||||
@MainActor
|
||||
func singleRowFooter() throws {
|
||||
let table = makeDataTable(rowCount: 1)
|
||||
let view = ScrollableDataTableView(dataTable: table, showFooter: true)
|
||||
let inspected = try view.inspect()
|
||||
|
||||
let footer = try inspected.find(text: "1 row")
|
||||
#expect(try footer.string() == "1 row")
|
||||
}
|
||||
|
||||
@Test("Empty table shows No results text")
|
||||
@MainActor
|
||||
func emptyTableShowsNoResults() throws {
|
||||
let table = DataTable(columns: [], rows: [], sql: "", executionTime: 0)
|
||||
let view = ScrollableDataTableView(dataTable: table)
|
||||
let inspected = try view.inspect()
|
||||
|
||||
let empty = try inspected.find(text: "No results")
|
||||
#expect(try empty.string() == "No results")
|
||||
}
|
||||
|
||||
@Test("Execution time appears in footer when > 0")
|
||||
@MainActor
|
||||
func executionTimeAppearsInFooter() throws {
|
||||
let columns = [DataTable.Column(name: "val", index: 0, inferredType: .integer)]
|
||||
let rows = [DataTable.Row(id: 0, values: [.integer(1)], columnNames: ["val"])]
|
||||
let table = DataTable(columns: columns, rows: rows, sql: "SELECT 1", executionTime: 0.023)
|
||||
let view = ScrollableDataTableView(dataTable: table, showFooter: true)
|
||||
let inspected = try view.inspect()
|
||||
|
||||
let timing = try inspected.find(text: "23.0 ms")
|
||||
#expect(try timing.string() == "23.0 ms")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - BarChartView Tests
|
||||
|
||||
@Suite("BarChartView - ViewInspector")
|
||||
struct BarChartViewInspectorTests {
|
||||
|
||||
@Test("BarChartView with title renders the title text")
|
||||
@MainActor
|
||||
func barChartShowsTitle() throws {
|
||||
let columns: [DataTable.Column] = [
|
||||
.init(name: "dept", index: 0, inferredType: .text),
|
||||
.init(name: "revenue", index: 1, inferredType: .real),
|
||||
]
|
||||
let rows: [DataTable.Row] = [
|
||||
.init(id: 0, values: [.text("Sales"), .real(100.0)], columnNames: ["dept", "revenue"]),
|
||||
.init(id: 1, values: [.text("Eng"), .real(200.0)], columnNames: ["dept", "revenue"]),
|
||||
]
|
||||
let table = DataTable(columns: columns, rows: rows)
|
||||
|
||||
let view = BarChartView(
|
||||
dataTable: table,
|
||||
categoryColumn: "dept",
|
||||
valueColumn: "revenue",
|
||||
title: "Revenue by Department"
|
||||
)
|
||||
let inspected = try view.inspect()
|
||||
let title = try inspected.find(text: "Revenue by Department")
|
||||
#expect(try title.string() == "Revenue by Department")
|
||||
}
|
||||
|
||||
@Test("BarChartView with empty data shows empty state")
|
||||
@MainActor
|
||||
func barChartEmptyState() throws {
|
||||
let table = DataTable(columns: [], rows: [])
|
||||
let view = BarChartView(
|
||||
dataTable: table,
|
||||
categoryColumn: "x",
|
||||
valueColumn: "y"
|
||||
)
|
||||
let inspected = try view.inspect()
|
||||
let empty = try inspected.find(text: "No chartable data")
|
||||
#expect(try empty.string() == "No chartable data")
|
||||
}
|
||||
|
||||
@Test("BarChartView with truncated data shows truncation notice")
|
||||
@MainActor
|
||||
func barChartTruncationNotice() throws {
|
||||
let columns: [DataTable.Column] = [
|
||||
.init(name: "cat", index: 0, inferredType: .text),
|
||||
.init(name: "val", index: 1, inferredType: .real),
|
||||
]
|
||||
// Create 10 rows but set maxBars to 3
|
||||
let rows: [DataTable.Row] = (0..<10).map { i in
|
||||
.init(id: i, values: [.text("Cat \(i)"), .real(Double(i) * 10)], columnNames: ["cat", "val"])
|
||||
}
|
||||
let table = DataTable(columns: columns, rows: rows)
|
||||
|
||||
let view = BarChartView(
|
||||
dataTable: table,
|
||||
categoryColumn: "cat",
|
||||
valueColumn: "val",
|
||||
maxBars: 3
|
||||
)
|
||||
let inspected = try view.inspect()
|
||||
let notice = try inspected.find(text: "Showing 3 of 10 categories")
|
||||
#expect(try notice.string() == "Showing 3 of 10 categories")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - PieChartView Tests
|
||||
|
||||
@Suite("PieChartView - ViewInspector")
|
||||
struct PieChartViewInspectorTests {
|
||||
|
||||
@Test("PieChartView with title renders the title text")
|
||||
@MainActor
|
||||
func pieChartShowsTitle() throws {
|
||||
let columns: [DataTable.Column] = [
|
||||
.init(name: "status", index: 0, inferredType: .text),
|
||||
.init(name: "count", index: 1, inferredType: .integer),
|
||||
]
|
||||
let rows: [DataTable.Row] = [
|
||||
.init(id: 0, values: [.text("Active"), .integer(40)], columnNames: ["status", "count"]),
|
||||
.init(id: 1, values: [.text("Inactive"), .integer(10)], columnNames: ["status", "count"]),
|
||||
]
|
||||
let table = DataTable(columns: columns, rows: rows)
|
||||
|
||||
let view = PieChartView(
|
||||
dataTable: table,
|
||||
categoryColumn: "status",
|
||||
valueColumn: "count",
|
||||
title: "Users by Status"
|
||||
)
|
||||
let inspected = try view.inspect()
|
||||
let title = try inspected.find(text: "Users by Status")
|
||||
#expect(try title.string() == "Users by Status")
|
||||
}
|
||||
|
||||
@Test("PieChartView with empty data shows empty state")
|
||||
@MainActor
|
||||
func pieChartEmptyState() throws {
|
||||
let table = DataTable(columns: [], rows: [])
|
||||
let view = PieChartView(
|
||||
dataTable: table,
|
||||
categoryColumn: "x",
|
||||
valueColumn: "y"
|
||||
)
|
||||
let inspected = try view.inspect()
|
||||
let empty = try inspected.find(text: "No chartable data")
|
||||
#expect(try empty.string() == "No chartable data")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user