Chat with any SQLite database using natural language. Built on AnyLanguageModel (HuggingFace) for LLM-agnostic provider support and GRDB for SQLite access. Core features: - Auto schema introspection from sqlite_master (zero config) - NL → SQL generation via any AnyLanguageModel provider - Three rendering modes: text summary, data table, Swift Charts - Drop-in DataChatView (SwiftUI) and headless ChatEngine - Operation allowlist with read-only default - Mutation policy with per-table control - ToolExecutionDelegate for destructive operation confirmation - Multi-turn conversation context - 352 tests across 24 suites, all passing Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
618 lines
26 KiB
Swift
618 lines
26 KiB
Swift
// UnifiedProviderTestHarness.swift
|
|
// SwiftDBAI Tests
|
|
//
|
|
// A unified test harness that validates all seven provider types
|
|
// conform to the AnyLanguageModel protocol and produce consistent
|
|
// ChatEngine-compatible output. Covers: OpenAI, Anthropic, Gemini,
|
|
// OpenAI-Compatible, Ollama, llama.cpp, and on-device (MLX/CoreML).
|
|
|
|
import AnyLanguageModel
|
|
import Foundation
|
|
import GRDB
|
|
import Testing
|
|
|
|
@testable import SwiftDBAI
|
|
|
|
// MARK: - Provider-Simulating Mock Models
|
|
|
|
/// A mock that records which LanguageModel protocol methods were called,
|
|
/// the arguments passed, and returns configurable responses.
|
|
/// Used to validate that every provider path through ChatEngine
|
|
/// exercises the same protocol surface.
|
|
final class ProviderConformanceMock: LanguageModel, @unchecked Sendable {
|
|
typealias UnavailableReason = Never
|
|
|
|
/// Track calls to verify protocol conformance exercised fully.
|
|
struct CallRecord: Sendable {
|
|
let method: String
|
|
let promptDescription: String
|
|
let timestamp: Date
|
|
}
|
|
|
|
private let lock = NSLock()
|
|
private var _calls: [CallRecord] = []
|
|
private let _responses: [String]
|
|
private var _callIndex = 0
|
|
|
|
/// Label for diagnostics.
|
|
let providerName: String
|
|
|
|
var calls: [CallRecord] {
|
|
lock.lock()
|
|
defer { lock.unlock() }
|
|
return _calls
|
|
}
|
|
|
|
init(providerName: String, responses: [String]) {
|
|
self.providerName = providerName
|
|
self._responses = responses
|
|
}
|
|
|
|
private func nextResponse() -> String {
|
|
lock.lock()
|
|
defer { lock.unlock() }
|
|
let idx = _callIndex
|
|
_callIndex += 1
|
|
return idx < _responses.count ? _responses[idx] : "fallback response"
|
|
}
|
|
|
|
private func recordCall(method: String, prompt: String) {
|
|
lock.lock()
|
|
_calls.append(CallRecord(method: method, promptDescription: prompt, timestamp: Date()))
|
|
lock.unlock()
|
|
}
|
|
|
|
func respond<Content>(
|
|
within session: LanguageModelSession,
|
|
to prompt: Prompt,
|
|
generating type: Content.Type,
|
|
includeSchemaInPrompt: Bool,
|
|
options: GenerationOptions
|
|
) async throws -> LanguageModelSession.Response<Content> where Content: Generable {
|
|
recordCall(method: "respond", prompt: prompt.description)
|
|
let text = nextResponse()
|
|
let rawContent = GeneratedContent(kind: .string(text))
|
|
let content = try Content(rawContent)
|
|
return LanguageModelSession.Response(
|
|
content: content,
|
|
rawContent: rawContent,
|
|
transcriptEntries: [][...]
|
|
)
|
|
}
|
|
|
|
func streamResponse<Content>(
|
|
within session: LanguageModelSession,
|
|
to prompt: Prompt,
|
|
generating type: Content.Type,
|
|
includeSchemaInPrompt: Bool,
|
|
options: GenerationOptions
|
|
) -> sending LanguageModelSession.ResponseStream<Content> where Content: Generable {
|
|
recordCall(method: "streamResponse", prompt: prompt.description)
|
|
let text = nextResponse()
|
|
let rawContent = GeneratedContent(kind: .string(text))
|
|
let content = try! Content(rawContent)
|
|
return LanguageModelSession.ResponseStream(content: content, rawContent: rawContent)
|
|
}
|
|
}
|
|
|
|
// MARK: - Test Database Helper
|
|
|
|
/// Creates a minimal in-memory database for provider integration tests.
|
|
private func makeProviderTestDatabase() throws -> DatabaseQueue {
|
|
let db = try DatabaseQueue(path: ":memory:")
|
|
try db.write { db in
|
|
try db.execute(sql: """
|
|
CREATE TABLE products (
|
|
id INTEGER PRIMARY KEY,
|
|
name TEXT NOT NULL,
|
|
price REAL NOT NULL,
|
|
category TEXT NOT NULL
|
|
)
|
|
""")
|
|
try db.execute(sql: """
|
|
INSERT INTO products (name, price, category) VALUES
|
|
('Widget', 9.99, 'tools'),
|
|
('Gadget', 24.99, 'electronics'),
|
|
('Doohickey', 4.50, 'tools')
|
|
""")
|
|
}
|
|
return db
|
|
}
|
|
|
|
// MARK: - Unified Provider Test Harness
|
|
|
|
@Suite("Unified Provider Test Harness")
|
|
struct UnifiedProviderTestHarness {
|
|
|
|
// MARK: - Provider Configuration Enumeration
|
|
|
|
/// All seven provider types that SwiftDBAI supports.
|
|
enum TestedProvider: String, CaseIterable {
|
|
case openAI
|
|
case anthropic
|
|
case gemini
|
|
case openAICompatible
|
|
case ollama
|
|
case llamaCpp
|
|
case onDevice
|
|
}
|
|
|
|
/// Creates a ProviderConformanceMock simulating each provider type.
|
|
private func makeMock(for provider: TestedProvider, responses: [String]) -> ProviderConformanceMock {
|
|
ProviderConformanceMock(providerName: provider.rawValue, responses: responses)
|
|
}
|
|
|
|
// MARK: - 1. Protocol Conformance — All Providers Are LanguageModel
|
|
|
|
@Test("All provider types produce instances conforming to LanguageModel protocol")
|
|
func allProvidersConformToLanguageModel() {
|
|
// Cloud providers via ProviderConfiguration.makeModel()
|
|
let openAI = ProviderConfiguration.openAI(apiKey: "test-key", model: "gpt-4o").makeModel()
|
|
let anthropic = ProviderConfiguration.anthropic(apiKey: "test-key", model: "claude-sonnet-4-20250514").makeModel()
|
|
let gemini = ProviderConfiguration.gemini(apiKey: "test-key", model: "gemini-2.0-flash").makeModel()
|
|
let openAICompatible = ProviderConfiguration.openAICompatible(
|
|
apiKey: "test-key",
|
|
model: "local-model",
|
|
baseURL: URL(string: "http://localhost:8080/v1/")!
|
|
).makeModel()
|
|
let ollama = ProviderConfiguration.ollama(model: "llama3.2").makeModel()
|
|
let llamaCpp = ProviderConfiguration.llamaCpp(model: "default").makeModel()
|
|
// On-device MLX (wraps as openAICompatible internally)
|
|
let onDeviceMLX = ProviderConfiguration.onDeviceMLX(
|
|
MLXProviderConfiguration(modelId: "test-model")
|
|
).makeModel()
|
|
|
|
// Verify all are LanguageModel
|
|
let models: [(String, any LanguageModel)] = [
|
|
("OpenAI", openAI),
|
|
("Anthropic", anthropic),
|
|
("Gemini", gemini),
|
|
("OpenAI-Compatible", openAICompatible),
|
|
("Ollama", ollama),
|
|
("llama.cpp", llamaCpp),
|
|
("On-Device MLX", onDeviceMLX),
|
|
]
|
|
|
|
for (name, model) in models {
|
|
// Protocol conformance is compile-time, but we verify isAvailable works
|
|
#expect(model.isAvailable, "\(name) model should report as available")
|
|
}
|
|
}
|
|
|
|
@Test("All provider configurations produce correct concrete model types")
|
|
func providerConfigurationsProduceCorrectTypes() {
|
|
let openAI = ProviderConfiguration.openAI(apiKey: "k", model: "m").makeModel()
|
|
#expect(openAI is OpenAILanguageModel, "OpenAI config should produce OpenAILanguageModel")
|
|
|
|
let anthropic = ProviderConfiguration.anthropic(apiKey: "k", model: "m").makeModel()
|
|
#expect(anthropic is AnthropicLanguageModel, "Anthropic config should produce AnthropicLanguageModel")
|
|
|
|
let gemini = ProviderConfiguration.gemini(apiKey: "k", model: "m").makeModel()
|
|
#expect(gemini is GeminiLanguageModel, "Gemini config should produce GeminiLanguageModel")
|
|
|
|
let openAICompat = ProviderConfiguration.openAICompatible(
|
|
apiKey: "k", model: "m", baseURL: URL(string: "http://localhost:1234")!
|
|
).makeModel()
|
|
#expect(openAICompat is OpenAILanguageModel, "OpenAI-Compatible config should produce OpenAILanguageModel")
|
|
|
|
let ollama = ProviderConfiguration.ollama(model: "m").makeModel()
|
|
#expect(ollama is OllamaLanguageModel, "Ollama config should produce OllamaLanguageModel")
|
|
|
|
let llamaCpp = ProviderConfiguration.llamaCpp(model: "m").makeModel()
|
|
#expect(llamaCpp is OpenAILanguageModel, "llama.cpp config should produce OpenAILanguageModel (OpenAI-compatible)")
|
|
|
|
// On-device uses OpenAILanguageModel internally as a wrapper
|
|
let onDevice = ProviderConfiguration.onDeviceMLX(
|
|
MLXProviderConfiguration(modelId: "test")
|
|
).makeModel()
|
|
#expect(onDevice is OpenAILanguageModel, "On-device MLX config should produce OpenAILanguageModel wrapper")
|
|
}
|
|
|
|
// MARK: - 2. Consistent ChatEngine-Compatible Output
|
|
|
|
@Test("Every provider mock produces valid ChatEngine responses for SELECT queries",
|
|
arguments: TestedProvider.allCases)
|
|
func providerProducesValidChatEngineResponse(provider: TestedProvider) async throws {
|
|
let db = try makeProviderTestDatabase()
|
|
let mock = makeMock(for: provider, responses: [
|
|
"SELECT COUNT(*) FROM products", // SQL generation
|
|
"There are 3 products in the database.", // Summary (fallback)
|
|
])
|
|
|
|
let engine = ChatEngine(database: db, model: mock)
|
|
let response = try await engine.send("How many products are there?")
|
|
|
|
// All providers must produce:
|
|
// 1. Non-empty summary
|
|
#expect(!response.summary.isEmpty, "\(provider.rawValue): summary must not be empty")
|
|
|
|
// 2. Valid SQL that was executed
|
|
#expect(response.sql == "SELECT COUNT(*) FROM products",
|
|
"\(provider.rawValue): SQL must match generated query")
|
|
|
|
// 3. A QueryResult with data
|
|
#expect(response.queryResult != nil, "\(provider.rawValue): queryResult must exist")
|
|
#expect(response.queryResult?.rowCount == 1, "\(provider.rawValue): should have 1 row for COUNT")
|
|
}
|
|
|
|
@Test("Every provider mock produces valid ChatEngine responses for multi-row SELECT",
|
|
arguments: TestedProvider.allCases)
|
|
func providerProducesMultiRowResponse(provider: TestedProvider) async throws {
|
|
let db = try makeProviderTestDatabase()
|
|
let mock = makeMock(for: provider, responses: [
|
|
"SELECT name, price FROM products ORDER BY price DESC",
|
|
"Here are the products sorted by price.",
|
|
])
|
|
|
|
let engine = ChatEngine(database: db, model: mock)
|
|
let response = try await engine.send("List products by price")
|
|
|
|
#expect(response.queryResult != nil, "\(provider.rawValue): queryResult must exist")
|
|
#expect(response.queryResult?.rowCount == 3, "\(provider.rawValue): should return all 3 products")
|
|
#expect(response.queryResult?.columns.contains("name") == true,
|
|
"\(provider.rawValue): columns must include 'name'")
|
|
#expect(response.queryResult?.columns.contains("price") == true,
|
|
"\(provider.rawValue): columns must include 'price'")
|
|
}
|
|
|
|
// MARK: - 3. Consistent LanguageModelSession Integration
|
|
|
|
@Test("Every provider mock works through LanguageModelSession.respond(to:)",
|
|
arguments: TestedProvider.allCases)
|
|
func providerWorksWithSession(provider: TestedProvider) async throws {
|
|
let mock = makeMock(for: provider, responses: [
|
|
"SELECT 1 AS test",
|
|
])
|
|
|
|
let session = LanguageModelSession(
|
|
model: mock,
|
|
instructions: "You are a SQL assistant."
|
|
)
|
|
|
|
let response = try await session.respond(to: "Generate a test query")
|
|
|
|
// Verify the response content is the expected string
|
|
#expect(response.content == "SELECT 1 AS test",
|
|
"\(provider.rawValue): session response should match mock output")
|
|
|
|
// Verify the mock received the call
|
|
#expect(mock.calls.count == 1, "\(provider.rawValue): should have exactly 1 call")
|
|
#expect(mock.calls.first?.method == "respond",
|
|
"\(provider.rawValue): should call respond method")
|
|
}
|
|
|
|
@Test("Every provider mock works through LanguageModelSession.streamResponse(to:)",
|
|
arguments: TestedProvider.allCases)
|
|
func providerWorksWithStreamSession(provider: TestedProvider) async throws {
|
|
let mock = makeMock(for: provider, responses: [
|
|
"SELECT 42 AS answer",
|
|
])
|
|
|
|
let session = LanguageModelSession(
|
|
model: mock,
|
|
instructions: "You are a SQL assistant."
|
|
)
|
|
|
|
let stream = session.streamResponse(to: "Give me a number")
|
|
let collected = try await stream.collect()
|
|
|
|
#expect(collected.content == "SELECT 42 AS answer",
|
|
"\(provider.rawValue): stream collected response should match mock output")
|
|
#expect(mock.calls.count == 1, "\(provider.rawValue): should have exactly 1 call")
|
|
#expect(mock.calls.first?.method == "streamResponse",
|
|
"\(provider.rawValue): should call streamResponse method")
|
|
}
|
|
|
|
// MARK: - 4. Schema Introspection Works Identically Across Providers
|
|
|
|
@Test("Schema introspection returns same schema regardless of provider",
|
|
arguments: TestedProvider.allCases)
|
|
func schemaIntrospectionIsProviderAgnostic(provider: TestedProvider) async throws {
|
|
let db = try makeProviderTestDatabase()
|
|
let mock = makeMock(for: provider, responses: ["SELECT 1"])
|
|
|
|
let engine = ChatEngine(database: db, model: mock)
|
|
let schema = try await engine.prepareSchema()
|
|
|
|
#expect(schema.tableNames.contains("products"),
|
|
"\(provider.rawValue): schema must include 'products' table")
|
|
#expect(schema.tableNames.count == 1,
|
|
"\(provider.rawValue): should have exactly 1 table")
|
|
|
|
let table = schema.tables["products"]
|
|
#expect(table != nil, "\(provider.rawValue): must find products table")
|
|
#expect(table?.columns.count == 4,
|
|
"\(provider.rawValue): products table must have 4 columns")
|
|
}
|
|
|
|
// MARK: - 5. Error Handling Consistency
|
|
|
|
@Test("All providers handle empty schema consistently",
|
|
arguments: TestedProvider.allCases)
|
|
func emptySchemaHandledConsistently(provider: TestedProvider) async throws {
|
|
let db = try DatabaseQueue(path: ":memory:")
|
|
let mock = makeMock(for: provider, responses: ["SELECT 1"])
|
|
|
|
let engine = ChatEngine(database: db, model: mock)
|
|
|
|
do {
|
|
_ = try await engine.send("Show me data")
|
|
Issue.record("\(provider.rawValue): should throw for empty schema")
|
|
} catch let error as SwiftDBAIError {
|
|
#expect(error == .emptySchema,
|
|
"\(provider.rawValue): must throw .emptySchema for database with no tables")
|
|
}
|
|
}
|
|
|
|
@Test("All providers reject disallowed SQL operations consistently",
|
|
arguments: TestedProvider.allCases)
|
|
func disallowedSQLRejectedConsistently(provider: TestedProvider) async throws {
|
|
let db = try makeProviderTestDatabase()
|
|
let mock = makeMock(for: provider, responses: [
|
|
"DELETE FROM products WHERE id = 1",
|
|
])
|
|
|
|
// Default allowlist is readOnly (SELECT only)
|
|
let engine = ChatEngine(database: db, model: mock)
|
|
|
|
do {
|
|
_ = try await engine.send("Delete the first product")
|
|
Issue.record("\(provider.rawValue): should reject DELETE when allowlist is readOnly")
|
|
} catch {
|
|
// All providers must trigger the same error path for disallowed operations
|
|
#expect(error is SwiftDBAIError,
|
|
"\(provider.rawValue): error must be SwiftDBAIError")
|
|
}
|
|
}
|
|
|
|
// MARK: - 6. Conversation History Consistency
|
|
|
|
@Test("Conversation history works identically for all providers",
|
|
arguments: TestedProvider.allCases)
|
|
func conversationHistoryConsistent(provider: TestedProvider) async throws {
|
|
let db = try makeProviderTestDatabase()
|
|
// ChatEngine calls LLM for SQL generation, then TextSummaryRenderer
|
|
// may call LLM for summarization. For aggregate queries (COUNT, AVG),
|
|
// TextSummaryRenderer uses a template and skips the LLM call.
|
|
// So the mock sequence is: SQL1, SQL2 (each followed by template summary).
|
|
let mock = makeMock(for: provider, responses: [
|
|
"SELECT COUNT(*) FROM products",
|
|
"SELECT AVG(price) FROM products",
|
|
])
|
|
|
|
let engine = ChatEngine(database: db, model: mock)
|
|
|
|
_ = try await engine.send("How many products?")
|
|
_ = try await engine.send("What is the average price?")
|
|
|
|
let messages = engine.messages
|
|
#expect(messages.count == 4,
|
|
"\(provider.rawValue): should have 4 messages (2 user + 2 assistant)")
|
|
#expect(messages[0].role == .user, "\(provider.rawValue): first message should be user")
|
|
#expect(messages[1].role == .assistant, "\(provider.rawValue): second message should be assistant")
|
|
#expect(messages[2].role == .user, "\(provider.rawValue): third message should be user")
|
|
#expect(messages[3].role == .assistant, "\(provider.rawValue): fourth message should be assistant")
|
|
|
|
// Both assistant messages must have SQL
|
|
#expect(messages[1].sql != nil, "\(provider.rawValue): first response must have SQL")
|
|
#expect(messages[3].sql != nil, "\(provider.rawValue): second response must have SQL")
|
|
}
|
|
|
|
// MARK: - 7. ProviderConfiguration Roundtrip
|
|
|
|
@Test("All cloud provider configurations roundtrip through makeModel()")
|
|
func allCloudProvidersRoundtrip() {
|
|
let configs: [(String, ProviderConfiguration)] = [
|
|
("OpenAI", .openAI(apiKey: "sk-test", model: "gpt-4o")),
|
|
("OpenAI Responses", .openAI(apiKey: "sk-test", model: "gpt-4o", variant: .responses)),
|
|
("Anthropic", .anthropic(apiKey: "sk-ant-test", model: "claude-sonnet-4-20250514")),
|
|
("Anthropic+version", .anthropic(apiKey: "sk-ant-test", model: "claude-sonnet-4-20250514", apiVersion: "2024-01-01")),
|
|
("Anthropic+betas", .anthropic(apiKey: "sk-ant-test", model: "claude-sonnet-4-20250514", betas: ["computer-use"])),
|
|
("Gemini", .gemini(apiKey: "AIza-test", model: "gemini-2.0-flash")),
|
|
("Gemini+version", .gemini(apiKey: "AIza-test", model: "gemini-2.0-flash", apiVersion: "v1")),
|
|
("OpenAI-Compatible", .openAICompatible(
|
|
apiKey: "key", model: "model", baseURL: URL(string: "http://localhost:1234")!
|
|
)),
|
|
("Ollama", .ollama(model: "llama3.2")),
|
|
("Ollama+custom URL", .ollama(model: "qwen2.5", baseURL: URL(string: "http://192.168.1.100:11434")!)),
|
|
("llama.cpp", .llamaCpp(model: "default")),
|
|
("llama.cpp+custom", .llamaCpp(model: "my-model", baseURL: URL(string: "http://localhost:9090")!)),
|
|
]
|
|
|
|
for (name, config) in configs {
|
|
let model = config.makeModel()
|
|
#expect(model.isAvailable, "\(name): model must be available after makeModel()")
|
|
}
|
|
}
|
|
|
|
@Test("On-device provider configurations produce valid models")
|
|
func onDeviceProvidersRoundtrip() {
|
|
let mlxConfigs: [MLXProviderConfiguration] = [
|
|
.llama3_2_3B(),
|
|
.qwen2_5_coder_3B(),
|
|
.phi3_5_mini(),
|
|
MLXProviderConfiguration(modelId: "custom-model", temperature: 0.2),
|
|
]
|
|
|
|
for mlxConfig in mlxConfigs {
|
|
let providerConfig = ProviderConfiguration.onDeviceMLX(mlxConfig)
|
|
let model = providerConfig.makeModel()
|
|
#expect(model.isAvailable, "MLX model '\(mlxConfig.modelId)' must be available")
|
|
}
|
|
}
|
|
|
|
// MARK: - 8. Write Operation Allowlist Consistency
|
|
|
|
@Test("Write operations require explicit opt-in for all providers",
|
|
arguments: TestedProvider.allCases)
|
|
func writeOperationsRequireOptIn(provider: TestedProvider) async throws {
|
|
let db = try makeProviderTestDatabase()
|
|
|
|
// Mock returns an INSERT statement
|
|
let mock = makeMock(for: provider, responses: [
|
|
"INSERT INTO products (name, price, category) VALUES ('New', 1.00, 'misc')",
|
|
])
|
|
|
|
// readOnly allowlist (default)
|
|
let readOnlyEngine = ChatEngine(database: db, model: mock)
|
|
|
|
do {
|
|
_ = try await readOnlyEngine.send("Add a new product")
|
|
Issue.record("\(provider.rawValue): INSERT should be rejected with readOnly allowlist")
|
|
} catch {
|
|
#expect(error is SwiftDBAIError,
|
|
"\(provider.rawValue): must throw SwiftDBAIError for disallowed INSERT")
|
|
}
|
|
}
|
|
|
|
@Test("Allowed write operations work for all providers",
|
|
arguments: TestedProvider.allCases)
|
|
func allowedWriteOperationsWork(provider: TestedProvider) async throws {
|
|
let db = try makeProviderTestDatabase()
|
|
|
|
let mock = makeMock(for: provider, responses: [
|
|
"INSERT INTO products (name, price, category) VALUES ('NewItem', 1.00, 'misc')",
|
|
"Successfully added 1 product.",
|
|
])
|
|
|
|
let engine = ChatEngine(
|
|
database: db,
|
|
model: mock,
|
|
allowlist: .standard
|
|
)
|
|
|
|
let response = try await engine.send("Add a product called NewItem")
|
|
#expect(response.sql?.uppercased().hasPrefix("INSERT") == true,
|
|
"\(provider.rawValue): SQL should be an INSERT")
|
|
}
|
|
|
|
// MARK: - 9. Response Format Consistency
|
|
|
|
@Test("ChatResponse structure is identical regardless of provider",
|
|
arguments: TestedProvider.allCases)
|
|
func responseStructureConsistent(provider: TestedProvider) async throws {
|
|
let db = try makeProviderTestDatabase()
|
|
let mock = makeMock(for: provider, responses: [
|
|
"SELECT name, price, category FROM products",
|
|
"Found 3 products across 2 categories.",
|
|
])
|
|
|
|
let engine = ChatEngine(database: db, model: mock)
|
|
let response = try await engine.send("Show all products")
|
|
|
|
// ChatResponse must always have these properties populated
|
|
#expect(response.summary.count > 0,
|
|
"\(provider.rawValue): summary must be non-empty")
|
|
#expect(response.sql != nil,
|
|
"\(provider.rawValue): sql must be present")
|
|
#expect(response.queryResult != nil,
|
|
"\(provider.rawValue): queryResult must be present")
|
|
|
|
// QueryResult structure must match the query
|
|
let qr = response.queryResult!
|
|
#expect(qr.columns == ["name", "price", "category"],
|
|
"\(provider.rawValue): columns must match SELECT clause")
|
|
#expect(qr.rowCount == 3,
|
|
"\(provider.rawValue): must return all rows")
|
|
#expect(qr.sql == "SELECT name, price, category FROM products",
|
|
"\(provider.rawValue): QueryResult.sql must match executed SQL")
|
|
#expect(qr.executionTime >= 0,
|
|
"\(provider.rawValue): execution time must be non-negative")
|
|
}
|
|
|
|
// MARK: - 10. Provider Enum Completeness
|
|
|
|
@Test("TestedProvider covers all ProviderConfiguration.Provider cases plus on-device")
|
|
func testedProviderCoversAllCases() {
|
|
// ProviderConfiguration.Provider has 6 cases
|
|
let configProviderCount = ProviderConfiguration.Provider.allCases.count
|
|
#expect(configProviderCount == 6, "ProviderConfiguration.Provider should have 6 cases")
|
|
|
|
// TestedProvider adds on-device for 7 total
|
|
#expect(TestedProvider.allCases.count == 7, "TestedProvider should cover all 7 provider types")
|
|
|
|
// Verify 1:1 mapping for the config providers
|
|
let configNames = Set(ProviderConfiguration.Provider.allCases.map(\.rawValue))
|
|
for tested in TestedProvider.allCases where tested != .onDevice {
|
|
#expect(configNames.contains(tested.rawValue),
|
|
"\(tested.rawValue) must map to a ProviderConfiguration.Provider case")
|
|
}
|
|
}
|
|
|
|
// MARK: - 11. ChatEngine Convenience Init Consistency
|
|
|
|
@Test("ChatEngine convenience init with ProviderConfiguration works for all cloud providers")
|
|
func chatEngineConvenienceInitWorks() throws {
|
|
let db = try makeProviderTestDatabase()
|
|
|
|
let configs: [ProviderConfiguration] = [
|
|
.openAI(apiKey: "test", model: "gpt-4o"),
|
|
.anthropic(apiKey: "test", model: "claude-sonnet-4-20250514"),
|
|
.gemini(apiKey: "test", model: "gemini-2.0-flash"),
|
|
.openAICompatible(apiKey: "test", model: "m", baseURL: URL(string: "http://localhost:1234")!),
|
|
.ollama(model: "llama3.2"),
|
|
.llamaCpp(model: "default"),
|
|
]
|
|
|
|
for config in configs {
|
|
// This should not throw — it only creates the engine, doesn't call the LLM
|
|
let engine = ChatEngine(database: db, provider: config)
|
|
#expect(engine.tableCount == nil, "tableCount should be nil before first query")
|
|
}
|
|
}
|
|
|
|
// MARK: - 12. Availability Reporting
|
|
|
|
@Test("All real provider models report available by default")
|
|
func allModelsReportAvailable() {
|
|
let models: [(String, any LanguageModel)] = [
|
|
("OpenAI", OpenAILanguageModel(apiKey: "k", model: "m")),
|
|
("Anthropic", AnthropicLanguageModel(apiKey: "k", model: "m")),
|
|
("Gemini", GeminiLanguageModel(apiKey: "k", model: "m")),
|
|
("Ollama", OllamaLanguageModel(model: "m")),
|
|
]
|
|
|
|
for (name, model) in models {
|
|
#expect(model.isAvailable, "\(name) should be available by default")
|
|
}
|
|
}
|
|
|
|
// MARK: - 13. On-Device Pipeline Status
|
|
|
|
@Test("On-device inference pipeline starts in notLoaded state")
|
|
func onDevicePipelineInitialState() {
|
|
let mlxPipeline = OnDeviceInferencePipeline(
|
|
mlxConfiguration: .llama3_2_3B()
|
|
)
|
|
#expect(mlxPipeline.status == .notLoaded)
|
|
#expect(mlxPipeline.providerType == .mlx)
|
|
|
|
let coreMLPipeline = OnDeviceInferencePipeline(
|
|
coreMLConfiguration: CoreMLProviderConfiguration(
|
|
modelURL: URL(fileURLWithPath: "/tmp/test.mlmodelc")
|
|
)
|
|
)
|
|
#expect(coreMLPipeline.status == .notLoaded)
|
|
#expect(coreMLPipeline.providerType == .coreML)
|
|
}
|
|
|
|
@Test("On-device SQL generation hints are populated for both provider types")
|
|
func onDeviceSQLHints() {
|
|
let mlxPipeline = OnDeviceInferencePipeline(mlxConfiguration: .llama3_2_3B())
|
|
let mlxHints = mlxPipeline.recommendedSQLGenerationHints
|
|
#expect(mlxHints.maxTokens > 0)
|
|
#expect(mlxHints.temperature >= 0)
|
|
#expect(!mlxHints.systemPromptSuffix.isEmpty)
|
|
|
|
let coreMLPipeline = OnDeviceInferencePipeline(
|
|
coreMLConfiguration: CoreMLProviderConfiguration(
|
|
modelURL: URL(fileURLWithPath: "/tmp/test.mlmodelc")
|
|
)
|
|
)
|
|
let coreMLHints = coreMLPipeline.recommendedSQLGenerationHints
|
|
#expect(coreMLHints.maxTokens > 0)
|
|
#expect(coreMLHints.temperature >= 0)
|
|
#expect(!coreMLHints.systemPromptSuffix.isEmpty)
|
|
}
|
|
}
|