SwiftDBAI: natural language queries for any SQLite database

Drop-in SwiftUI chat view, headless ChatEngine, LLM-agnostic via
AnyLanguageModel. Read-only by default with configurable allowlists.
Robust SQL parser with 63 tests. Includes demo app with GitHub stars dataset.
This commit is contained in:
Krishna Kumar
2026-04-04 09:30:56 -05:00
commit fcd752466a
80 changed files with 18265 additions and 0 deletions

8
.gitignore vendored Normal file
View File

@@ -0,0 +1,8 @@
.build/
.swiftpm/
Package.resolved
*.xcodeproj/
xcuserdata/
DerivedData/
.DS_Store
.mcp.json

View File

@@ -0,0 +1,38 @@
// DatabaseSeeder.swift
// SwiftDBAIDemo
//
// Copies the bundled GitHub stars database to the Documents directory.
// The database contains real star counts for ~2000 top GitHub repos,
// fetched live from the GitHub API.
import Foundation
enum DatabaseSeeder {
/// Returns the path to the GitHub stars database, copying from bundle if needed.
static func seedIfNeeded() throws -> String {
let url = URL.documentsDirectory.appending(path: "github_stars.sqlite")
let path = url.path(percentEncoded: false)
// If the database already exists, just return the path.
if FileManager.default.fileExists(atPath: path) {
return path
}
// Copy bundled database to Documents
guard let bundledURL = Bundle.main.url(forResource: "github_stars", withExtension: "sqlite") else {
throw SeederError.bundledDatabaseNotFound
}
try FileManager.default.copyItem(at: bundledURL, to: url)
return path
}
enum SeederError: LocalizedError {
case bundledDatabaseNotFound
var errorDescription: String? {
"Could not find github_stars.sqlite in app bundle."
}
}
}

View File

@@ -0,0 +1,253 @@
// DemoLanguageModel.swift
// SwiftDBAIDemo
//
// A mock LanguageModel that returns canned SQL for common GitHub repo queries.
// Pattern-matches natural language questions about GitHub stars, languages,
// and repository metadata.
import AnyLanguageModel
import Foundation
struct DemoLanguageModel: LanguageModel {
typealias UnavailableReason = Never
func respond<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
includeSchemaInPrompt: Bool,
options: GenerationOptions
) async throws -> LanguageModelSession.Response<Content> where Content: Generable {
let promptText = prompt.description.lowercased()
let responseText: String
if promptText.contains("row") && (promptText.contains("column") || promptText.contains("|")) {
responseText = deriveSummary(from: prompt.description)
} else {
responseText = deriveSQL(from: promptText)
}
let rawContent = GeneratedContent(kind: .string(responseText))
let content = try Content(rawContent)
return LanguageModelSession.Response(
content: content,
rawContent: rawContent,
transcriptEntries: [][...]
)
}
func streamResponse<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
includeSchemaInPrompt: Bool,
options: GenerationOptions
) -> sending LanguageModelSession.ResponseStream<Content> where Content: Generable {
let rawContent = GeneratedContent(kind: .string("SELECT full_name, stars FROM repos ORDER BY stars DESC LIMIT 10"))
let content = try! Content(rawContent)
return LanguageModelSession.ResponseStream(content: content, rawContent: rawContent)
}
// MARK: - SQL Pattern Matching
private func deriveSQL(from prompt: String) -> String {
let q = extractLastQuestion(from: prompt)
// Specific repo lookups
if q.contains("react") && !q.contains("react-native") && !q.contains("react native") {
return "SELECT full_name, stars, forks, language, description FROM repos WHERE name = 'react' OR full_name LIKE '%/react' ORDER BY stars DESC LIMIT 5"
}
// How many stars does X have
if q.contains("how many stars") || q.contains("stars does") || q.contains("stars for") {
return "SELECT full_name, stars, forks, language FROM repos ORDER BY stars DESC LIMIT 10"
}
// Language breakdown MUST come before "most popular" to avoid collision
if q.contains("language") && (q.contains("breakdown") || q.contains("distribution") || q.contains("popular") || q.contains("most")) {
return """
SELECT language, COUNT(*) AS repo_count,
SUM(stars) AS total_stars,
ROUND(AVG(stars)) AS avg_stars
FROM repos WHERE language IS NOT NULL AND language != ''
GROUP BY language
ORDER BY total_stars DESC
LIMIT 15
"""
}
// Most starred / top repos
if q.contains("most starred") || q.contains("most popular") || q.contains("top repo") || q.contains("top 10") || q.contains("most stars") {
return """
SELECT full_name, stars, forks, language
FROM repos ORDER BY stars DESC LIMIT 10
"""
}
// Language-specific queries
if q.contains("python") && (q.contains("repo") || q.contains("project")) {
return """
SELECT full_name, stars, forks, description
FROM repos WHERE language = 'Python'
ORDER BY stars DESC LIMIT 10
"""
}
if q.contains("swift") && (q.contains("repo") || q.contains("project")) {
return """
SELECT full_name, stars, forks, description
FROM repos WHERE language = 'Swift'
ORDER BY stars DESC LIMIT 10
"""
}
if q.contains("rust") && (q.contains("repo") || q.contains("project")) {
return """
SELECT full_name, stars, forks, description
FROM repos WHERE language = 'Rust'
ORDER BY stars DESC LIMIT 10
"""
}
if q.contains("typescript") && (q.contains("repo") || q.contains("project")) {
return """
SELECT full_name, stars, forks, description
FROM repos WHERE language = 'TypeScript'
ORDER BY stars DESC LIMIT 10
"""
}
// Count queries
if q.contains("how many repo") || q.contains("how many project") || q.contains("total repo") {
return "SELECT COUNT(*) AS total_repos FROM repos"
}
if q.contains("how many language") {
return "SELECT COUNT(DISTINCT language) AS total_languages FROM repos WHERE language IS NOT NULL AND language != ''"
}
// Stars threshold queries
if q.contains("100k") || q.contains("100,000") || q.contains("100000") {
return """
SELECT full_name, stars, language
FROM repos WHERE stars > 100000
ORDER BY stars DESC
"""
}
// Forks
if q.contains("most forked") || q.contains("most forks") {
return """
SELECT full_name, forks, stars, language
FROM repos ORDER BY forks DESC LIMIT 10
"""
}
// Created / oldest / newest
if q.contains("oldest") || q.contains("first") {
return """
SELECT full_name, created_at, stars, language
FROM repos ORDER BY created_at ASC LIMIT 10
"""
}
if q.contains("newest") || q.contains("recent") || q.contains("latest") {
return """
SELECT full_name, created_at, stars, language
FROM repos ORDER BY created_at DESC LIMIT 10
"""
}
// Microsoft / Google / Meta specific
if q.contains("microsoft") {
return """
SELECT full_name, stars, forks, language
FROM repos WHERE owner = 'microsoft'
ORDER BY stars DESC
"""
}
if q.contains("google") {
return """
SELECT full_name, stars, forks, language
FROM repos WHERE owner = 'google'
ORDER BY stars DESC
"""
}
if q.contains("facebook") || q.contains("meta") {
return """
SELECT full_name, stars, forks, language
FROM repos WHERE owner = 'facebook'
ORDER BY stars DESC
"""
}
// Compare
if q.contains("vs") || q.contains("versus") || q.contains("compare") {
return """
SELECT full_name, stars, forks, language
FROM repos ORDER BY stars DESC LIMIT 20
"""
}
// Default
return """
SELECT full_name, stars, language
FROM repos ORDER BY stars DESC LIMIT 10
"""
}
private func extractLastQuestion(from prompt: String) -> String {
let lines = prompt.components(separatedBy: "\n")
// First pass: find lines ending with "?" (most likely user questions)
// Take the LAST one (most recent question)
var lastQuestion: String?
for line in lines {
let trimmed = line.trimmingCharacters(in: .whitespaces)
if trimmed.hasSuffix("?") && trimmed.count < 200 && trimmed.count > 5 {
lastQuestion = trimmed.lowercased()
}
}
if let q = lastQuestion { return q }
// Fallback: walk backwards looking for short non-SQL lines
for line in lines.reversed() {
let trimmed = line.trimmingCharacters(in: .whitespaces)
guard !trimmed.isEmpty, trimmed.count > 3, trimmed.count < 100 else { continue }
let lower = trimmed.lowercased()
if lower.hasPrefix("select ") || lower.hasPrefix("create ") { continue }
if lower.contains("integer") || lower.contains("text not") { continue }
if lower.contains("respond with only") { continue }
return lower
}
return prompt.lowercased()
}
// MARK: - Summary Generation
private func deriveSummary(from rawPrompt: String) -> String {
let lines = rawPrompt.components(separatedBy: "\n")
let dataLines = lines.filter { $0.contains("|") || $0.contains(",") }
let rowCount = max(dataLines.count - 1, 0)
let lower = rawPrompt.lowercased()
if lower.contains("total_repos") || lower.contains("total_languages") || lower.contains("count(") {
if let countLine = dataLines.last {
let num = countLine.trimmingCharacters(in: .whitespacesAndNewlines)
.components(separatedBy: "|").last?
.trimmingCharacters(in: .whitespacesAndNewlines) ?? "\(rowCount)"
return "The count is \(num)."
}
}
if lower.contains("avg_stars") || lower.contains("group by") || lower.contains("total_stars") {
return "Here's the breakdown across programming languages."
}
if lower.contains("forks") && lower.contains("order by forks") {
return "These are the most forked repositories on GitHub."
}
if rowCount == 0 {
return "No repositories matched your query."
}
if rowCount == 1 {
return "Here's what I found."
}
if rowCount <= 5 {
return "Found \(rowCount) repositories."
}
return "Here are the top \(rowCount) repositories."
}
}

View File

@@ -0,0 +1,70 @@
// OllamaWithSystemPrompt.swift
// SwiftDBAIDemo
//
// Wraps OllamaLanguageModel to prepend session instructions into the user
// prompt, working around AnyLanguageModel's Ollama adapter not forwarding
// system messages.
import AnyLanguageModel
import Foundation
/// Wrapper that injects session instructions into every Ollama request.
struct OllamaWithSystemPrompt: LanguageModel {
typealias UnavailableReason = Never
private let inner: OllamaLanguageModel
init(baseURL: URL = OllamaLanguageModel.defaultBaseURL, model: String) {
self.inner = OllamaLanguageModel(baseURL: baseURL, model: model)
}
func respond<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
includeSchemaInPrompt: Bool,
options: GenerationOptions
) async throws -> LanguageModelSession.Response<Content> where Content: Generable {
let userText = prompt.description
let instructionText = session.instructions?.description ?? ""
let combinedText: String
if instructionText.isEmpty {
combinedText = userText
} else {
combinedText = """
[System Instructions]
\(instructionText)
[User Message]
\(userText)
"""
}
let plainSession = LanguageModelSession(model: inner)
let combinedPrompt = Prompt(combinedText)
return try await inner.respond(
within: plainSession,
to: combinedPrompt,
generating: type,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
}
func streamResponse<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
includeSchemaInPrompt: Bool,
options: GenerationOptions
) -> sending LanguageModelSession.ResponseStream<Content> where Content: Generable {
inner.streamResponse(
within: session,
to: prompt,
generating: type,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
}
}

View File

@@ -0,0 +1,261 @@
// SwiftDBAIDemoApp.swift
// SwiftDBAIDemo
//
// Showcase app demonstrating all SwiftDBAI UI variants and presentation modes.
import SwiftUI
import SwiftDBAI
@main
struct SwiftDBAIDemoApp: App {
@State private var databasePath: String?
@State private var setupError: String?
private let context = """
This is a database of the top ~2000 most-starred GitHub \
repositories. Each repo has: full_name (owner/name), stars, \
forks, language (programming language), description, \
open_issues, created_at date, and topics. \
Star counts are real and current as of April 2026.
"""
var body: some Scene {
WindowGroup {
Group {
if let path = databasePath {
ShowcaseTabView(databasePath: path, context: context)
} else if let error = setupError {
ContentUnavailableView(
"Database Setup Failed",
systemImage: "exclamationmark.triangle",
description: Text(error)
)
} else {
ProgressView("Setting up database...")
}
}
.task {
do {
let path = try DatabaseSeeder.seedIfNeeded()
databasePath = path
} catch {
setupError = error.localizedDescription
}
}
}
}
}
struct ShowcaseTabView: View {
let databasePath: String
let context: String
@State private var showSheet = false
@State private var showFullScreen = false
var body: some View {
TabView {
// Tab 1: Default theme
DataChatView(
databasePath: databasePath,
model: DemoLanguageModel(),
allowlist: .readOnly,
additionalContext: context
)
.tabItem { Label("Default", systemImage: "bubble.left.and.text.bubble.right") }
// Tab 2: Dark theme
DataChatView(
databasePath: databasePath,
model: DemoLanguageModel(),
allowlist: .readOnly,
additionalContext: context
)
.chatViewConfiguration(.dark)
.tabItem { Label("Dark", systemImage: "moon.fill") }
// Tab 3: Compact theme
DataChatView(
databasePath: databasePath,
model: DemoLanguageModel(),
allowlist: .readOnly,
additionalContext: context
)
.chatViewConfiguration(.compact)
.tabItem { Label("Compact", systemImage: "rectangle.compress.vertical") }
// Tab 4: Custom styling
DataChatView(
databasePath: databasePath,
model: DemoLanguageModel(),
allowlist: .readOnly,
additionalContext: context
)
.chatViewConfiguration(customConfig)
.tabItem { Label("Custom", systemImage: "paintbrush") }
// Tab 5: Presentation modes
PresentationShowcase(databasePath: databasePath, context: context)
.tabItem { Label("Present", systemImage: "rectangle.portrait.and.arrow.forward") }
// Tab 6: Tool calling API
ToolDemoView(databasePath: databasePath)
.tabItem { Label("Tool", systemImage: "wrench") }
}
}
private var customConfig: ChatViewConfiguration {
var config = ChatViewConfiguration.default
config.userBubbleColor = .purple
config.userTextColor = .white
config.accentColor = .purple
config.inputPlaceholder = "Search GitHub repos..."
config.emptyStateTitle = "Explore GitHub Data"
config.emptyStateSubtitle = "Ask about stars, forks, languages, and trends"
config.emptyStateIcon = "star.circle"
config.assistantAvatarIcon = "sparkles"
config.assistantAvatarColor = .purple
return config
}
}
struct PresentationShowcase: View {
let databasePath: String
let context: String
@State private var showSheet = false
@State private var showFullScreen = false
var body: some View {
NavigationStack {
List {
Section("Sheet Presentations") {
Button("Show as Sheet") {
showSheet = true
}
Button("Show Full Screen") {
showFullScreen = true
}
}
Section("Navigation") {
NavigationLink("Push DataChatView") {
DataChatView(
databasePath: databasePath,
model: DemoLanguageModel(),
allowlist: .readOnly,
additionalContext: context
)
.navigationTitle("Chat")
}
}
Section("Info") {
LabeledContent("DataChatSheet", value: "Nav + Done button")
LabeledContent("DataChatViewController", value: "UIKit bridge")
LabeledContent(".dataChatSheet()", value: "View modifier")
LabeledContent(".dataChatFullScreen()", value: "View modifier")
}
}
.navigationTitle("Presentation Modes")
}
.sheet(isPresented: $showSheet) {
DataChatSheet(
databasePath: databasePath,
model: DemoLanguageModel(),
additionalContext: context,
title: "GitHub Stars"
)
}
.dataChatFullScreen(
isPresented: $showFullScreen,
databasePath: databasePath,
model: DemoLanguageModel(),
additionalContext: context
)
}
}
struct ToolDemoView: View {
let databasePath: String
@State private var tool: DatabaseTool?
@State private var sqlInput = "SELECT full_name, stars FROM repos ORDER BY stars DESC LIMIT 5"
@State private var result: ToolResult?
@State private var error: String?
@State private var showSchema = false
var body: some View {
NavigationStack {
List {
if let tool {
Section("Schema") {
Button(showSchema ? "Hide Schema" : "Show Schema") {
showSchema.toggle()
}
if showSchema {
Text(tool.schemaContext)
.font(.caption2.monospaced())
}
}
Section("SQL Query") {
TextField("Enter SQL", text: $sqlInput, axis: .vertical)
.font(.footnote.monospaced())
.lineLimit(3...6)
Button("Execute") {
do {
result = try tool.execute(sql: sqlInput)
error = nil
} catch {
self.error = error.localizedDescription
result = nil
}
}
.disabled(sqlInput.isEmpty)
}
if let error {
Section("Error") {
Text(error)
.foregroundStyle(.red)
.font(.footnote)
}
}
if let result {
Section("Result (\(result.rowCount) rows, \(String(format: "%.1fms", result.executionTime * 1000)))") {
Text(result.markdownTable)
.font(.caption2.monospaced())
}
Section("JSON Response") {
Text(result.jsonString)
.font(.caption2.monospaced())
.lineLimit(15)
}
}
Section("OpenAI Tool Definition") {
Text(toolDefinitionJSON(tool))
.font(.caption2.monospaced())
.lineLimit(10)
}
} else {
ProgressView("Loading database...")
}
}
.navigationTitle("DatabaseTool API")
}
.task {
do {
tool = try await DatabaseTool(databasePath: databasePath)
} catch {
self.error = error.localizedDescription
}
}
}
private func toolDefinitionJSON(_ tool: DatabaseTool) -> String {
let def = tool.openAIFunctionDefinition
if let data = try? JSONSerialization.data(withJSONObject: def, options: [.prettyPrinted, .sortedKeys]),
let str = String(data: data, encoding: .utf8) {
return str
}
return "{}"
}
}

View File

@@ -0,0 +1,26 @@
name: SwiftDBAIDemo
options:
bundleIdPrefix: com.swiftdbai.demo
deploymentTarget:
iOS: "17.0"
xcodeVersion: "16.0"
createIntermediateGroups: true
packages:
SwiftDBAI:
path: ../..
targets:
SwiftDBAIDemo:
type: application
platform: iOS
sources:
- SwiftDBAIDemo
settings:
base:
PRODUCT_BUNDLE_IDENTIFIER: com.swiftdbai.demo
MARKETING_VERSION: "1.0"
CURRENT_PROJECT_VERSION: "1"
SWIFT_VERSION: "6.0"
INFOPLIST_GENERATION: true
GENERATE_INFOPLIST_FILE: true
dependencies:
- package: SwiftDBAI

42
Package.swift Normal file
View File

@@ -0,0 +1,42 @@
// swift-tools-version: 6.1
import PackageDescription
let package = Package(
name: "SwiftDBAI",
platforms: [
.iOS(.v17),
.macOS(.v14),
.visionOS(.v1),
],
products: [
.library(
name: "SwiftDBAI",
targets: ["SwiftDBAI"]
),
],
dependencies: [
.package(url: "https://github.com/groue/GRDB.swift.git", from: "7.0.0"),
.package(url: "https://github.com/huggingface/AnyLanguageModel.git", from: "0.8.0"),
.package(url: "https://github.com/nalexn/ViewInspector.git", from: "0.10.0"),
],
targets: [
.target(
name: "SwiftDBAI",
dependencies: [
.product(name: "GRDB", package: "GRDB.swift"),
.product(name: "AnyLanguageModel", package: "AnyLanguageModel"),
],
swiftSettings: [
.swiftLanguageMode(.v6),
]
),
.testTarget(
name: "SwiftDBAITests",
dependencies: ["SwiftDBAI", "ViewInspector"],
swiftSettings: [
.swiftLanguageMode(.v6),
]
),
]
)

346
README.md Normal file
View File

@@ -0,0 +1,346 @@
# SwiftDBAI
A Swift package that adds a natural language query interface to any SQLite database in your iOS, macOS, or visionOS app. Drop in one SwiftUI view and your users can ask questions about their data in plain English.
<!-- badges -->
![Swift 6.1+](https://img.shields.io/badge/Swift-6.1+-orange.svg)
![Platforms](https://img.shields.io/badge/Platforms-iOS%2017%20|%20macOS%2014%20|%20visionOS%201-blue.svg)
![License](https://img.shields.io/badge/License-MIT-green.svg)
## Demo
| iPhone | iPad |
|---|---|
| ![iPhone](screenshots/iphone-results.png) | ![iPad](screenshots/results-chart.png) |
| Custom theme | Sheet presentation |
|---|---|
| ![Custom theme](screenshots/custom-theme.png) | ![Sheet presentation](screenshots/sheet-presentation.png) |
The demo app is at `Example/SwiftDBAIDemo/`. It points SwiftDBAI at a real database of ~2,000 top GitHub repos with live star counts. Generate the Xcode project with [xcodegen](https://github.com/yonaskolb/XcodeGen):
```
cd Example/SwiftDBAIDemo && xcodegen generate
```
For a real-world integration, see [SwiftDBAI added to NetNewsWire](https://github.com/krishkumar/NetNewsWire) -- natural language queries against an RSS reader's article database.
## Features
- Drop-in SwiftUI chat view (`DataChatView`) -- one line to add a database chat UI
- Headless `ChatEngine` for programmatic / non-UI use
- LLM-agnostic via [AnyLanguageModel](https://github.com/huggingface/AnyLanguageModel) -- works with OpenAI, Anthropic, Gemini, Ollama, llama.cpp, or any OpenAI-compatible endpoint
- Automatic schema introspection -- no manual annotations required
- Safety-first: read-only by default, operation allowlists, table-level mutation policies, destructive operation confirmation delegate
- Configurable query timeouts, context windows, and custom validators
## Installation
Add SwiftDBAI via Swift Package Manager:
```swift
dependencies: [
.package(url: "https://github.com/krishkumar/SwiftDBAI.git", from: "1.0.0"),
]
```
Then add the dependency to your target:
```swift
.target(
name: "MyApp",
dependencies: ["SwiftDBAI"]
)
```
## Quick Start
Drop a full chat UI into any SwiftUI view with `DataChatView`:
```swift
import SwiftDBAI
import AnyLanguageModel
struct ContentView: View {
var body: some View {
DataChatView(
databasePath: "/path/to/mydata.sqlite",
model: OllamaLanguageModel(model: "llama3")
)
}
}
```
That's it. `DataChatView` opens the database, introspects the schema, and renders a chat interface. The default mode is **read-only** (SELECT only).
To pass an existing GRDB connection and customize behavior:
```swift
DataChatView(
database: myDatabasePool,
model: OpenAILanguageModel(apiKey: "sk-...", model: "gpt-4o"),
allowlist: .standard,
additionalContext: "This database stores a recipe app's data.",
maxSummaryRows: 100
)
```
## Presentation
`DataChatSheet` wraps `DataChatView` in a `NavigationStack` with a title and Done button, ready for any presentation context.
**SwiftUI sheet:**
```swift
.sheet(isPresented: $showChat) {
DataChatSheet(
databasePath: "/path/to/mydata.sqlite",
model: OllamaLanguageModel(model: "llama3")
)
}
// Or use the convenience modifier:
.dataChatSheet(isPresented: $showChat, databasePath: path, model: myLLM)
```
**SwiftUI full-screen cover:**
```swift
.fullScreenCover(isPresented: $showChat) {
DataChatSheet(databasePath: path, model: myLLM)
}
// Or use the convenience modifier:
.dataChatFullScreen(isPresented: $showChat, databasePath: path, model: myLLM)
```
**UIKit modal:**
```swift
let vc = DataChatViewController(databasePath: path, model: myLLM)
present(vc, animated: true)
```
**UIKit navigation push:**
```swift
let vc = DataChatViewController(databasePath: path, model: myLLM)
navigationController?.pushViewController(vc, animated: true)
```
All presentation wrappers accept the same parameters as `DataChatView` (`allowlist`, `additionalContext`, etc.) plus a `title` for the navigation bar.
## Tool Calling
If your app already has an LLM integration, use `DatabaseTool` to register SwiftDBAI as a tool the LLM can call. No extra LLM needed -- your existing one generates SQL, SwiftDBAI validates and executes it.
```swift
import SwiftDBAI
// 1. Create the tool
let tool = try await DatabaseTool(databasePath: "/path/to/mydata.sqlite")
// 2. Add schema context to your LLM's system prompt
let systemPrompt = "You are a helpful assistant.\n\n" + tool.systemPromptSnippet
// 3. Register with your LLM (OpenAI function calling example)
let functionDef = tool.openAIFunctionDefinition
// Pass to OpenAI's tools parameter...
// 4. When the LLM calls the tool
let result = try tool.execute(sql: "SELECT * FROM users WHERE active = 1")
result.jsonString // return to LLM as tool response
result.markdownTable // display to user
result.rowCount // 42
result.executionTime // 0.003
```
SQL is validated against a read-only allowlist before execution. INSERT, UPDATE, DELETE, and DROP are rejected.
![DatabaseTool API](screenshots/tool-api.png)
## Headless / Programmatic Use
Use `ChatEngine` directly when you don't need a UI:
```swift
import SwiftDBAI
import AnyLanguageModel
import GRDB
let pool = try DatabasePool(path: "/path/to/mydata.sqlite")
let engine = ChatEngine(
database: pool,
model: OpenAILanguageModel(apiKey: "sk-...", model: "gpt-4o")
)
let response = try await engine.send("How many users signed up this week?")
print(response.summary) // "There were 42 new signups this week."
print(response.sql) // Optional("SELECT COUNT(*) FROM users WHERE ...")
```
`ChatEngine` also accepts a `ProviderConfiguration` for convenience:
```swift
let engine = ChatEngine(
database: pool,
provider: .anthropic(apiKey: "sk-ant-...", model: "claude-sonnet-4-20250514")
)
```
For fine-grained control, pass a `ChatEngineConfiguration`:
```swift
var config = ChatEngineConfiguration(
queryTimeout: 10,
contextWindowSize: 20,
maxSummaryRows: 100,
additionalContext: "The 'status' column uses: 'active', 'inactive', 'suspended'."
)
let engine = ChatEngine(
database: pool,
model: model,
allowlist: .standard,
configuration: config
)
```
## Choosing a Provider
SwiftDBAI works with any provider supported by AnyLanguageModel. Use `ProviderConfiguration` factory methods or construct model instances directly.
```swift
// OpenAI
let config = ProviderConfiguration.openAI(apiKey: "sk-...", model: "gpt-4o")
// Anthropic
let config = ProviderConfiguration.anthropic(apiKey: "sk-ant-...", model: "claude-sonnet-4-20250514")
// Gemini
let config = ProviderConfiguration.gemini(apiKey: "AIza...", model: "gemini-2.0-flash")
// Ollama (local, no API key needed)
let config = ProviderConfiguration.ollama(model: "llama3.2")
// llama.cpp (local)
let config = ProviderConfiguration.llamaCpp(model: "default")
// Any OpenAI-compatible endpoint
let config = ProviderConfiguration.openAICompatible(
apiKey: "your-key",
model: "llama-3.1-70b",
baseURL: URL(string: "https://api.together.xyz/v1/")!
)
```
Use with ChatEngine:
```swift
let engine = ChatEngine(database: pool, provider: config)
// or
let engine = ChatEngine(database: pool, model: config.makeModel())
```
API keys can also come from environment variables:
```swift
let config = ProviderConfiguration.fromEnvironment(
provider: .openAI,
environmentVariable: "OPENAI_API_KEY",
model: "gpt-4o"
)
```
## Safety and Mutation Control
### Operation Allowlist
By default, only SELECT queries are allowed. Opt in to writes explicitly:
| Preset | Allowed Operations |
|---|---|
| `.readOnly` (default) | SELECT |
| `.standard` | SELECT, INSERT, UPDATE |
| `.unrestricted` | SELECT, INSERT, UPDATE, DELETE |
```swift
// Custom allowlist
let allowlist = OperationAllowlist([.select, .insert])
```
### Mutation Policy
For table-level control, use `MutationPolicy`:
```swift
// Allow INSERT and UPDATE only on specific tables
let policy = MutationPolicy(
allowedOperations: [.insert, .update],
allowedTables: ["orders", "order_items"]
)
let engine = ChatEngine(
database: pool,
model: model,
mutationPolicy: policy
)
```
Presets: `.readOnly`, `.readWrite`, `.unrestricted`.
### Confirmation Delegate
Destructive operations (DELETE, DROP, ALTER, TRUNCATE) require confirmation through a `ToolExecutionDelegate`:
```swift
struct MyDelegate: ToolExecutionDelegate {
func confirmDestructiveOperation(
_ context: DestructiveOperationContext
) async -> Bool {
// Present confirmation UI, return true to proceed
return await showConfirmationDialog(context.description)
}
}
let engine = ChatEngine(
database: pool,
model: model,
allowlist: .unrestricted,
delegate: MyDelegate()
)
```
Without a delegate, destructive operations throw `SwiftDBAIError.confirmationRequired` so you can handle confirmation in your own flow.
Built-in delegates: `AutoApproveDelegate` (testing only), `RejectAllDelegate` (safest).
## Architecture
```
User Question
|
v
ChatEngine
|-- SchemaIntrospector (auto-discovers tables, columns, keys, indexes)
|-- PromptBuilder (builds LLM system prompt with schema context)
|-- LanguageModel (generates SQL via AnyLanguageModel)
|-- SQLQueryParser (parses and validates against allowlist/policy)
|-- QueryValidator (optional custom validators)
|-- GRDB (executes SQL against SQLite)
|-- TextSummaryRenderer (summarizes results via LLM)
v
ChatResponse { summary, sql, queryResult }
```
`DataChatView` wraps this pipeline in a SwiftUI view with `ChatViewModel` managing state.
## Requirements
- iOS 17.0+ / macOS 14.0+ / visionOS 1.0+
- Swift 6.1+
- Xcode 16+
## License
MIT. See [LICENSE](LICENSE) for details.

View File

@@ -0,0 +1,113 @@
// ChatEngineConfiguration.swift
// SwiftDBAI
//
// Configurable settings for ChatEngine behavior timeouts, context window,
// summary limits, and custom query validation.
import Foundation
/// Configuration for ``ChatEngine`` behavior.
///
/// Use this to tune timeouts, conversation context windows, and attach
/// custom query validators.
///
/// ```swift
/// var config = ChatEngineConfiguration()
/// config.queryTimeout = 10 // 10-second SQL timeout
/// config.contextWindowSize = 20 // Keep last 20 messages for LLM context
/// config.maxSummaryRows = 100 // Summarize up to 100 rows
///
/// let engine = ChatEngine(
/// database: db,
/// model: model,
/// configuration: config
/// )
/// ```
public struct ChatEngineConfiguration: Sendable {
// MARK: - Query Execution
/// Maximum time (in seconds) to wait for a SQL query to execute.
///
/// If the query exceeds this duration, a ``ChatEngineError/queryTimedOut``
/// error is thrown. Set to `nil` to disable the timeout (not recommended
/// for user-facing apps). Defaults to 30 seconds.
public var queryTimeout: TimeInterval?
// MARK: - Conversation Context
/// Maximum number of conversation messages to include when building
/// LLM context for follow-up queries.
///
/// Only the most recent `contextWindowSize` messages are sent to the LLM.
/// Older messages are still retained in ``ChatEngine/messages`` for UI
/// display but do not consume LLM tokens.
///
/// Set to `nil` for unlimited context (all history is always sent).
/// Defaults to 50 messages.
public var contextWindowSize: Int?
// MARK: - Rendering
/// Maximum number of rows to include when generating text summaries.
/// Defaults to 50.
public var maxSummaryRows: Int
// MARK: - LLM Context
/// Optional extra instructions appended to the LLM system prompt.
///
/// Use this to provide business-specific terminology, query hints,
/// or domain constraints. For example:
/// ```swift
/// config.additionalContext = "The 'status' column uses: 'active', 'inactive', 'suspended'."
/// ```
public var additionalContext: String?
// MARK: - Validation
/// Custom query validators that run after the built-in allowlist check.
///
/// Use ``addValidator(_:)`` to add validators. They are executed in order;
/// the first validator to throw stops execution.
public private(set) var validators: [any QueryValidator] = []
// MARK: - Initialization
/// Creates a configuration with the given settings.
///
/// - Parameters:
/// - queryTimeout: SQL execution timeout in seconds. Defaults to 30.
/// - contextWindowSize: Max messages for LLM context. Defaults to 50.
/// - maxSummaryRows: Max rows for text summaries. Defaults to 50.
/// - additionalContext: Extra LLM system prompt instructions.
public init(
queryTimeout: TimeInterval? = 30,
contextWindowSize: Int? = 50,
maxSummaryRows: Int = 50,
additionalContext: String? = nil
) {
self.queryTimeout = queryTimeout
self.contextWindowSize = contextWindowSize
self.maxSummaryRows = maxSummaryRows
self.additionalContext = additionalContext
}
/// The default configuration: 30s timeout, 50-message context window,
/// 50-row summaries, no additional context, no custom validators.
public static let `default` = ChatEngineConfiguration()
// MARK: - Mutating Helpers
/// Appends a custom query validator.
///
/// Validators run after the built-in allowlist and dangerous-keyword checks.
/// They receive the parsed SQL and can throw to reject a query.
///
/// ```swift
/// config.addValidator(TableAllowlistValidator(allowedTables: ["users", "orders"]))
/// ```
public mutating func addValidator(_ validator: any QueryValidator) {
validators.append(validator)
}
}

View File

@@ -0,0 +1,336 @@
// LocalProviderConfiguration.swift
// SwiftDBAI
//
// Configuration and endpoint discovery for local/self-hosted LLM providers
// (Ollama, llama.cpp). Wraps AnyLanguageModel's OllamaLanguageModel and
// OpenAILanguageModel with convenient factory methods and health checking.
import AnyLanguageModel
import Foundation
#if canImport(FoundationNetworking)
import FoundationNetworking
#endif
// MARK: - Local Provider Endpoint
/// Represents a discovered local LLM endpoint with its connection status.
public struct LocalProviderEndpoint: Sendable, Equatable {
/// The base URL of the local provider.
public let baseURL: URL
/// The provider type (Ollama or llama.cpp).
public let providerType: LocalProviderType
/// Whether the endpoint was reachable at discovery time.
public let isReachable: Bool
/// The list of available models, if the endpoint supports model listing.
public let availableModels: [String]
/// Human-readable description of the endpoint.
public var description: String {
let status = isReachable ? "reachable" : "unreachable"
return "\(providerType.rawValue) at \(baseURL.absoluteString) (\(status), \(availableModels.count) models)"
}
}
/// The type of local LLM provider.
public enum LocalProviderType: String, Sendable, Hashable, CaseIterable {
/// Ollama runs models locally via `ollama serve`.
/// Default endpoint: http://localhost:11434
case ollama
/// llama.cpp server runs GGUF models via `llama-server`.
/// Default endpoint: http://localhost:8080
/// Exposes an OpenAI-compatible API.
case llamaCpp = "llama.cpp"
}
// MARK: - Local Provider Discovery
/// Discovers and validates local LLM provider endpoints.
///
/// Use `LocalProviderDiscovery` to automatically find running Ollama or llama.cpp
/// instances on the local machine, check their health, and list available models.
///
/// ```swift
/// // Check if Ollama is running
/// let isRunning = await LocalProviderDiscovery.isOllamaRunning()
///
/// // Discover all local providers
/// let endpoints = await LocalProviderDiscovery.discoverAll()
/// for endpoint in endpoints where endpoint.isReachable {
/// print("Found \(endpoint.description)")
/// }
///
/// // List models available on Ollama
/// let models = await LocalProviderDiscovery.listOllamaModels()
/// ```
public enum LocalProviderDiscovery {
/// Default Ollama endpoint.
public static let defaultOllamaURL = URL(string: "http://localhost:11434")!
/// Default llama.cpp server endpoint.
public static let defaultLlamaCppURL = URL(string: "http://localhost:8080")!
/// Well-known ports to probe for local providers.
/// Ollama: 11434, llama.cpp: 8080
private static let wellKnownEndpoints: [(URL, LocalProviderType)] = [
(defaultOllamaURL, .ollama),
(defaultLlamaCppURL, .llamaCpp),
]
// MARK: - Health Checks
/// Checks if an Ollama instance is reachable at the given URL.
///
/// Sends a GET request to the Ollama root endpoint and checks for a 200 response.
///
/// - Parameter baseURL: The Ollama base URL. Defaults to `http://localhost:11434`.
/// - Parameter timeout: Connection timeout in seconds. Defaults to 3.
/// - Returns: `true` if the Ollama server responded successfully.
public static func isOllamaRunning(
at baseURL: URL = defaultOllamaURL,
timeout: TimeInterval = 3
) async -> Bool {
await checkEndpointHealth(baseURL, timeout: timeout)
}
/// Checks if a llama.cpp server is reachable at the given URL.
///
/// Sends a GET request to the `/health` endpoint and checks for a 200 response.
///
/// - Parameter baseURL: The llama.cpp base URL. Defaults to `http://localhost:8080`.
/// - Parameter timeout: Connection timeout in seconds. Defaults to 3.
/// - Returns: `true` if the llama.cpp server responded successfully.
public static func isLlamaCppRunning(
at baseURL: URL = defaultLlamaCppURL,
timeout: TimeInterval = 3
) async -> Bool {
let healthURL = baseURL.appendingPathComponent("health")
return await checkEndpointHealth(healthURL, timeout: timeout)
}
/// Checks if any endpoint at the given URL responds to HTTP requests.
///
/// - Parameters:
/// - url: The URL to probe.
/// - timeout: Connection timeout in seconds.
/// - Returns: `true` if the endpoint returned an HTTP response with status 200.
private static func checkEndpointHealth(
_ url: URL,
timeout: TimeInterval
) async -> Bool {
let config = URLSessionConfiguration.ephemeral
config.timeoutIntervalForRequest = timeout
config.timeoutIntervalForResource = timeout
let session = URLSession(configuration: config)
do {
let (_, response) = try await session.data(from: url)
if let httpResponse = response as? HTTPURLResponse {
return httpResponse.statusCode == 200
}
return false
} catch {
return false
}
}
// MARK: - Model Listing
/// Lists models available on an Ollama instance.
///
/// Calls the Ollama `/api/tags` endpoint to retrieve the list of
/// locally installed models.
///
/// - Parameter baseURL: The Ollama base URL. Defaults to `http://localhost:11434`.
/// - Parameter timeout: Request timeout in seconds. Defaults to 5.
/// - Returns: An array of model name strings, or an empty array if unreachable.
public static func listOllamaModels(
at baseURL: URL = defaultOllamaURL,
timeout: TimeInterval = 5
) async -> [String] {
let tagsURL = baseURL.appendingPathComponent("api/tags")
let config = URLSessionConfiguration.ephemeral
config.timeoutIntervalForRequest = timeout
config.timeoutIntervalForResource = timeout
let session = URLSession(configuration: config)
do {
let (data, response) = try await session.data(from: tagsURL)
guard let httpResponse = response as? HTTPURLResponse,
httpResponse.statusCode == 200
else {
return []
}
let decoded = try JSONDecoder().decode(OllamaTagsResponse.self, from: data)
return decoded.models.map(\.name)
} catch {
return []
}
}
/// Lists models available on a llama.cpp server via its OpenAI-compatible endpoint.
///
/// Calls `/v1/models` which llama.cpp exposes when running with
/// `--api-key` or in default mode.
///
/// - Parameter baseURL: The llama.cpp base URL. Defaults to `http://localhost:8080`.
/// - Parameter timeout: Request timeout in seconds. Defaults to 5.
/// - Returns: An array of model ID strings, or an empty array if unreachable.
public static func listLlamaCppModels(
at baseURL: URL = defaultLlamaCppURL,
timeout: TimeInterval = 5
) async -> [String] {
let modelsURL = baseURL.appendingPathComponent("v1/models")
let config = URLSessionConfiguration.ephemeral
config.timeoutIntervalForRequest = timeout
config.timeoutIntervalForResource = timeout
let session = URLSession(configuration: config)
do {
let (data, response) = try await session.data(from: modelsURL)
guard let httpResponse = response as? HTTPURLResponse,
httpResponse.statusCode == 200
else {
return []
}
let decoded = try JSONDecoder().decode(OpenAIModelsResponse.self, from: data)
return decoded.data.map(\.id)
} catch {
return []
}
}
// MARK: - Full Discovery
/// Discovers all running local LLM providers by probing well-known endpoints.
///
/// Probes Ollama (port 11434) and llama.cpp (port 8080) concurrently,
/// returning their status and available models.
///
/// ```swift
/// let endpoints = await LocalProviderDiscovery.discoverAll()
/// for endpoint in endpoints where endpoint.isReachable {
/// print("Found: \(endpoint.description)")
/// }
/// ```
///
/// - Parameter timeout: Connection timeout per endpoint in seconds. Defaults to 3.
/// - Returns: An array of `LocalProviderEndpoint` for each probed location.
public static func discoverAll(
timeout: TimeInterval = 3
) async -> [LocalProviderEndpoint] {
await withTaskGroup(of: LocalProviderEndpoint.self, returning: [LocalProviderEndpoint].self) { group in
for (url, providerType) in wellKnownEndpoints {
group.addTask {
await discover(providerType: providerType, at: url, timeout: timeout)
}
}
var results: [LocalProviderEndpoint] = []
for await endpoint in group {
results.append(endpoint)
}
return results
}
}
/// Discovers a specific local provider at the given URL.
///
/// - Parameters:
/// - providerType: The type of provider to probe.
/// - baseURL: The base URL to check.
/// - timeout: Connection timeout in seconds.
/// - Returns: A `LocalProviderEndpoint` with reachability and model info.
public static func discover(
providerType: LocalProviderType,
at baseURL: URL,
timeout: TimeInterval = 3
) async -> LocalProviderEndpoint {
switch providerType {
case .ollama:
let reachable = await isOllamaRunning(at: baseURL, timeout: timeout)
let models = reachable ? await listOllamaModels(at: baseURL) : []
return LocalProviderEndpoint(
baseURL: baseURL,
providerType: .ollama,
isReachable: reachable,
availableModels: models
)
case .llamaCpp:
let reachable = await isLlamaCppRunning(at: baseURL, timeout: timeout)
let models = reachable ? await listLlamaCppModels(at: baseURL) : []
return LocalProviderEndpoint(
baseURL: baseURL,
providerType: .llamaCpp,
isReachable: reachable,
availableModels: models
)
}
}
/// Discovers a specific local provider at a custom URL and port.
///
/// Use this for non-standard configurations where Ollama or llama.cpp
/// is running on a custom host or port.
///
/// ```swift
/// let endpoint = await LocalProviderDiscovery.discover(
/// providerType: .ollama,
/// host: "192.168.1.100",
/// port: 11434
/// )
/// ```
///
/// - Parameters:
/// - providerType: The provider type.
/// - host: The hostname or IP address.
/// - port: The port number.
/// - timeout: Connection timeout in seconds. Defaults to 3.
/// - Returns: A `LocalProviderEndpoint` with reachability and model info.
public static func discover(
providerType: LocalProviderType,
host: String,
port: Int,
timeout: TimeInterval = 3
) async -> LocalProviderEndpoint {
guard let url = URL(string: "http://\(host):\(port)") else {
return LocalProviderEndpoint(
baseURL: URL(string: "http://\(host):\(port)")!,
providerType: providerType,
isReachable: false,
availableModels: []
)
}
return await discover(providerType: providerType, at: url, timeout: timeout)
}
}
// MARK: - JSON Response Types
/// Response from Ollama's `/api/tags` endpoint.
private struct OllamaTagsResponse: Decodable, Sendable {
let models: [OllamaModelInfo]
}
/// Individual model info from Ollama's tags endpoint.
private struct OllamaModelInfo: Decodable, Sendable {
let name: String
}
/// Response from the OpenAI-compatible `/v1/models` endpoint.
private struct OpenAIModelsResponse: Decodable, Sendable {
let data: [OpenAIModelInfo]
}
/// Individual model info from the OpenAI models endpoint.
private struct OpenAIModelInfo: Decodable, Sendable {
let id: String
}

View File

@@ -0,0 +1,148 @@
// MutationPolicy.swift
// SwiftDBAI
//
// Defines which mutation operations are permitted and optionally restricts
// them to specific tables. Wraps OperationAllowlist with table-level granularity.
import Foundation
/// Controls which SQL mutation operations the LLM may generate and,
/// optionally, which tables those mutations may target.
///
/// `MutationPolicy` builds on ``OperationAllowlist`` by adding per-table
/// restrictions. The default policy is **read-only** no mutations are
/// allowed on any table. Write operations require explicit opt-in.
///
/// ```swift
/// // Read-only (default) only SELECT is allowed
/// let readOnly = MutationPolicy.readOnly
///
/// // Allow INSERT and UPDATE on specific tables only
/// let restricted = MutationPolicy(
/// allowedOperations: [.insert, .update],
/// allowedTables: ["orders", "order_items"]
/// )
///
/// // Allow INSERT and UPDATE on all tables
/// let broad = MutationPolicy(allowedOperations: [.insert, .update])
///
/// // Full access including DELETE (requires confirmation)
/// let full = MutationPolicy.unrestricted
/// ```
public struct MutationPolicy: Sendable, Equatable {
// MARK: - Properties
/// The underlying operation allowlist (always includes SELECT).
public let operationAllowlist: OperationAllowlist
/// Optional set of table names that mutations may target.
///
/// When `nil`, mutations are allowed on all tables (subject to
/// ``operationAllowlist``). When non-nil, mutation operations
/// (INSERT, UPDATE, DELETE) are only permitted on the listed tables.
/// SELECT queries are never restricted by this property.
public let allowedMutationTables: Set<String>?
/// When `true`, destructive operations (DELETE) require explicit user
/// confirmation before execution, even when the operation is allowed.
/// Defaults to `true`.
public let requiresDestructiveConfirmation: Bool
// MARK: - Initialization
/// Creates a mutation policy with the given operations and optional table restrictions.
///
/// SELECT is always implicitly included you cannot create a policy
/// that disallows reads.
///
/// - Parameters:
/// - allowedOperations: The mutation operations to permit (INSERT, UPDATE, DELETE).
/// SELECT is always allowed regardless of this parameter.
/// - allowedTables: Optional set of table names mutations may target.
/// Pass `nil` to allow mutations on all tables. Defaults to `nil`.
/// - requiresDestructiveConfirmation: Whether DELETE requires user confirmation.
/// Defaults to `true`.
public init(
allowedOperations: Set<SQLOperation> = [],
allowedTables: Set<String>? = nil,
requiresDestructiveConfirmation: Bool = true
) {
// Always include SELECT
var ops = allowedOperations
ops.insert(.select)
self.operationAllowlist = OperationAllowlist(ops)
self.allowedMutationTables = allowedTables
self.requiresDestructiveConfirmation = requiresDestructiveConfirmation
}
// MARK: - Presets
/// Read-only policy: only SELECT queries are allowed. This is the default.
public static let readOnly = MutationPolicy()
/// Standard read-write: SELECT, INSERT, and UPDATE on all tables.
public static let readWrite = MutationPolicy(
allowedOperations: [.insert, .update]
)
/// Unrestricted: all operations including DELETE on all tables.
/// DELETE still requires confirmation by default.
public static let unrestricted = MutationPolicy(
allowedOperations: [.insert, .update, .delete]
)
// MARK: - Validation
/// Returns `true` if the given operation is permitted by this policy.
public func isOperationAllowed(_ operation: SQLOperation) -> Bool {
operationAllowlist.isAllowed(operation)
}
/// Returns `true` if the given mutation operation is permitted on the
/// specified table.
///
/// SELECT operations always return `true` regardless of table restrictions.
/// For mutation operations, this checks both the operation allowlist and
/// the table restrictions (if any).
///
/// - Parameters:
/// - operation: The SQL operation type.
/// - table: The target table name (case-insensitive comparison).
/// - Returns: Whether the operation is allowed on the given table.
public func isAllowed(operation: SQLOperation, on table: String) -> Bool {
// SELECT is always allowed
guard operation != .select else { return true }
// Check operation allowlist first
guard operationAllowlist.isAllowed(operation) else { return false }
// If no table restrictions, the operation is allowed
guard let allowedTables = allowedMutationTables else { return true }
// Case-insensitive table name check
let lowerTable = table.lowercased()
return allowedTables.contains { $0.lowercased() == lowerTable }
}
/// Returns `true` if the given operation requires user confirmation.
public func requiresConfirmation(for operation: SQLOperation) -> Bool {
operation == .delete && requiresDestructiveConfirmation
}
/// Returns a human-readable description for inclusion in the LLM system prompt.
func describeForLLM() -> String {
var desc = operationAllowlist.describeForLLM()
if let tables = allowedMutationTables, !tables.isEmpty {
let sorted = tables.sorted()
desc += " Mutations (INSERT/UPDATE/DELETE) are restricted to these tables only: \(sorted.joined(separator: ", "))."
}
if requiresDestructiveConfirmation && operationAllowlist.isAllowed(.delete) {
desc += " DELETE operations require user confirmation before execution."
}
return desc
}
}

View File

@@ -0,0 +1,866 @@
// OnDeviceProviderConfiguration.swift
// SwiftDBAI
//
// Configuration for on-device LLM providers (CoreML, MLX) that run models
// locally on Apple silicon. These providers enable fully offline,
// privacy-sensitive deployments where no data leaves the device.
//
// Both CoreML and MLX models are provided by AnyLanguageModel behind
// conditional compilation flags (#if CoreML, #if MLX). This configuration
// layer wraps their setup with convenient factory methods and integrates
// them into the SwiftDBAI ChatEngine pipeline.
import AnyLanguageModel
import Foundation
import GRDB
// MARK: - On-Device Provider Type
/// The type of on-device LLM provider.
public enum OnDeviceProviderType: String, Sendable, Hashable, CaseIterable {
/// CoreML runs compiled .mlmodelc models on-device using Apple's CoreML framework.
/// Requires pre-compiled models and supports CPU, GPU, and Neural Engine compute units.
case coreML
/// MLX runs HuggingFace models on Apple silicon using the MLX framework.
/// Models are automatically downloaded and cached. Supports quantized models
/// (e.g., 4-bit) for efficient memory usage.
case mlx
}
// MARK: - CoreML Configuration
/// Configuration for loading and running a CoreML language model on-device.
///
/// CoreML models must be pre-compiled to `.mlmodelc` format before use.
/// The model runs entirely on-device using CPU, GPU, and/or Neural Engine
/// depending on the `computeUnits` setting.
///
/// ```swift
/// let config = CoreMLProviderConfiguration(
/// modelURL: Bundle.main.url(forResource: "MyModel", withExtension: "mlmodelc")!,
/// computeUnits: .all
/// )
/// ```
///
/// - Note: CoreML models are available behind the `#if CoreML` flag in AnyLanguageModel.
/// Ensure your project enables the CoreML build condition.
public struct CoreMLProviderConfiguration: Sendable, Equatable {
/// The URL to the compiled CoreML model (`.mlmodelc`).
public let modelURL: URL
/// The compute units to use for inference.
///
/// - `.all`: Uses the best available hardware (Neural Engine, GPU, CPU).
/// - `.cpuOnly`: Forces CPU-only inference. Useful for debugging.
/// - `.cpuAndGPU`: Uses CPU and GPU but not the Neural Engine.
/// - `.cpuAndNeuralEngine`: Uses CPU and Neural Engine.
public let computeUnits: ComputeUnitPreference
/// Maximum number of tokens the model can generate per response.
/// Defaults to 2048.
public let maxResponseTokens: Int
/// Whether to use sampling (true) or greedy decoding (false).
/// Defaults to false (greedy) for more deterministic SQL generation.
public let useSampling: Bool
/// Temperature for sampling. Only used when `useSampling` is true.
/// Lower values produce more focused output. Defaults to 0.1.
public let temperature: Double
/// Creates a CoreML provider configuration.
///
/// - Parameters:
/// - modelURL: The URL to a compiled CoreML model (`.mlmodelc`).
/// - computeUnits: The compute units to use. Defaults to `.all`.
/// - maxResponseTokens: Maximum tokens per response. Defaults to 2048.
/// - useSampling: Whether to use sampling vs greedy decoding. Defaults to false.
/// - temperature: Sampling temperature. Defaults to 0.1.
public init(
modelURL: URL,
computeUnits: ComputeUnitPreference = .all,
maxResponseTokens: Int = 2048,
useSampling: Bool = false,
temperature: Double = 0.1
) {
self.modelURL = modelURL
self.computeUnits = computeUnits
self.maxResponseTokens = maxResponseTokens
self.useSampling = useSampling
self.temperature = temperature
}
/// Validates that the model URL points to a compiled CoreML model.
///
/// - Throws: ``OnDeviceProviderError`` if the URL is invalid.
public func validate() throws {
guard modelURL.pathExtension == "mlmodelc" else {
throw OnDeviceProviderError.invalidModelFormat(
expected: ".mlmodelc",
actual: modelURL.pathExtension
)
}
guard FileManager.default.fileExists(atPath: modelURL.path) else {
throw OnDeviceProviderError.modelNotFound(modelURL)
}
}
}
/// Compute unit preference for CoreML inference.
///
/// Maps to `MLComputeUnits` in the CoreML framework.
public enum ComputeUnitPreference: String, Sendable, Hashable, CaseIterable {
/// Use all available compute units (Neural Engine, GPU, CPU).
/// This is the recommended setting for production use.
case all
/// Force CPU-only execution. Useful for debugging or testing.
case cpuOnly
/// Use CPU and GPU, but not the Neural Engine.
case cpuAndGPU
/// Use CPU and Neural Engine, but not the GPU.
case cpuAndNeuralEngine
}
// MARK: - MLX Configuration
/// Configuration for loading and running an MLX language model on Apple silicon.
///
/// MLX models are loaded from HuggingFace Hub or a local directory. The MLX
/// framework provides efficient inference on Apple silicon with support for
/// quantized models (4-bit, 8-bit) for reduced memory usage.
///
/// ```swift
/// // From HuggingFace Hub (auto-downloaded)
/// let config = MLXProviderConfiguration(
/// modelId: "mlx-community/Llama-3.2-3B-Instruct-4bit"
/// )
///
/// // From a local directory
/// let config = MLXProviderConfiguration(
/// modelId: "my-local-model",
/// localDirectory: URL(fileURLWithPath: "/path/to/model")
/// )
/// ```
///
/// - Note: MLX models are available behind the `#if MLX` flag in AnyLanguageModel.
/// Ensure your project enables the MLX build condition.
public struct MLXProviderConfiguration: Sendable, Equatable {
/// The HuggingFace model identifier (e.g., "mlx-community/Llama-3.2-3B-Instruct-4bit").
public let modelId: String
/// Optional local directory containing the model files.
/// When set, the model is loaded from this directory instead of downloading from Hub.
public let localDirectory: URL?
/// GPU memory management configuration.
public let gpuMemory: MLXGPUMemoryConfig
/// Maximum number of tokens the model can generate per response.
/// Defaults to 2048.
public let maxResponseTokens: Int
/// Temperature for text generation. Lower values produce more deterministic output.
/// Defaults to 0.1 for SQL generation accuracy.
public let temperature: Double
/// Top-P (nucleus) sampling threshold. Only tokens with cumulative probability
/// below this threshold are considered. Defaults to 0.95.
public let topP: Double
/// Repetition penalty to reduce repetitive output. Defaults to 1.1.
public let repetitionPenalty: Double
/// Creates an MLX provider configuration.
///
/// - Parameters:
/// - modelId: The HuggingFace model ID or local identifier.
/// - localDirectory: Optional path to a local model directory.
/// - gpuMemory: GPU memory configuration. Defaults to `.automatic`.
/// - maxResponseTokens: Maximum tokens per response. Defaults to 2048.
/// - temperature: Generation temperature. Defaults to 0.1.
/// - topP: Top-P sampling threshold. Defaults to 0.95.
/// - repetitionPenalty: Repetition penalty. Defaults to 1.1.
public init(
modelId: String,
localDirectory: URL? = nil,
gpuMemory: MLXGPUMemoryConfig = .automatic,
maxResponseTokens: Int = 2048,
temperature: Double = 0.1,
topP: Double = 0.95,
repetitionPenalty: Double = 1.1
) {
self.modelId = modelId
self.localDirectory = localDirectory
self.gpuMemory = gpuMemory
self.maxResponseTokens = maxResponseTokens
self.temperature = temperature
self.topP = topP
self.repetitionPenalty = repetitionPenalty
}
/// Validates the configuration parameters.
///
/// - Throws: ``OnDeviceProviderError`` if the configuration is invalid.
public func validate() throws {
guard !modelId.isEmpty else {
throw OnDeviceProviderError.emptyModelId
}
if let dir = localDirectory {
guard FileManager.default.fileExists(atPath: dir.path) else {
throw OnDeviceProviderError.modelNotFound(dir)
}
}
guard temperature >= 0 else {
throw OnDeviceProviderError.invalidParameter(
name: "temperature",
value: "\(temperature)",
reason: "Must be non-negative"
)
}
guard topP > 0, topP <= 1.0 else {
throw OnDeviceProviderError.invalidParameter(
name: "topP",
value: "\(topP)",
reason: "Must be between 0 (exclusive) and 1.0 (inclusive)"
)
}
guard repetitionPenalty > 0 else {
throw OnDeviceProviderError.invalidParameter(
name: "repetitionPenalty",
value: "\(repetitionPenalty)",
reason: "Must be positive"
)
}
}
// MARK: - Well-Known Models
/// Pre-configured for Llama 3.2 3B Instruct (4-bit quantized).
/// Good balance of quality and memory usage (~2GB RAM).
public static func llama3_2_3B(
localDirectory: URL? = nil,
gpuMemory: MLXGPUMemoryConfig = .automatic
) -> MLXProviderConfiguration {
MLXProviderConfiguration(
modelId: "mlx-community/Llama-3.2-3B-Instruct-4bit",
localDirectory: localDirectory,
gpuMemory: gpuMemory,
maxResponseTokens: 2048,
temperature: 0.1
)
}
/// Pre-configured for Qwen 2.5 Coder 3B Instruct (4-bit quantized).
/// Optimized for code and SQL generation.
public static func qwen2_5_coder_3B(
localDirectory: URL? = nil,
gpuMemory: MLXGPUMemoryConfig = .automatic
) -> MLXProviderConfiguration {
MLXProviderConfiguration(
modelId: "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit",
localDirectory: localDirectory,
gpuMemory: gpuMemory,
maxResponseTokens: 2048,
temperature: 0.05
)
}
/// Pre-configured for Phi-3.5 Mini Instruct (4-bit quantized).
/// Compact model suitable for devices with limited memory (~1.5GB RAM).
public static func phi3_5_mini(
localDirectory: URL? = nil,
gpuMemory: MLXGPUMemoryConfig = .automatic
) -> MLXProviderConfiguration {
MLXProviderConfiguration(
modelId: "mlx-community/Phi-3.5-mini-instruct-4bit",
localDirectory: localDirectory,
gpuMemory: gpuMemory,
maxResponseTokens: 2048,
temperature: 0.1
)
}
}
/// GPU memory management configuration for MLX models.
///
/// Controls how aggressively the MLX runtime manages GPU buffer caches
/// during active generation and idle phases.
public struct MLXGPUMemoryConfig: Sendable, Equatable {
/// GPU cache limit (in bytes) during active generation.
public let activeCacheLimit: Int
/// GPU cache limit (in bytes) when idle.
public let idleCacheLimit: Int
/// Whether to clear cached GPU buffers when eviction is safe.
public let clearCacheOnEviction: Bool
/// Creates a GPU memory configuration.
///
/// - Parameters:
/// - activeCacheLimit: Cache limit during active generation (bytes).
/// - idleCacheLimit: Cache limit when idle (bytes).
/// - clearCacheOnEviction: Whether to clear cache on eviction.
public init(
activeCacheLimit: Int,
idleCacheLimit: Int,
clearCacheOnEviction: Bool = true
) {
self.activeCacheLimit = activeCacheLimit
self.idleCacheLimit = idleCacheLimit
self.clearCacheOnEviction = clearCacheOnEviction
}
/// Automatically determined based on device physical memory.
///
/// - Devices with <4GB RAM: 128MB active cache
/// - Devices with <6GB RAM: 256MB active cache
/// - Devices with <8GB RAM: 512MB active cache
/// - Devices with 8GB+ RAM: 768MB active cache
/// - Idle cache: 50MB for all devices
public static var automatic: MLXGPUMemoryConfig {
let ramBytes = ProcessInfo.processInfo.physicalMemory
let ramGB = ramBytes / (1024 * 1024 * 1024)
let active: Int
switch ramGB {
case ..<4:
active = 128_000_000
case ..<6:
active = 256_000_000
case ..<8:
active = 512_000_000
default:
active = 768_000_000
}
return .init(
activeCacheLimit: active,
idleCacheLimit: 50_000_000,
clearCacheOnEviction: true
)
}
/// Minimal memory configuration for constrained devices.
/// Uses 64MB active cache and 16MB idle cache.
public static var minimal: MLXGPUMemoryConfig {
.init(
activeCacheLimit: 64_000_000,
idleCacheLimit: 16_000_000,
clearCacheOnEviction: true
)
}
/// Unconstrained configuration for maximum performance.
/// Leaves GPU cache effectively unlimited. Use when your app
/// can afford maximum memory usage.
public static var unconstrained: MLXGPUMemoryConfig {
.init(
activeCacheLimit: Int.max,
idleCacheLimit: Int.max,
clearCacheOnEviction: false
)
}
}
// MARK: - On-Device Provider Errors
/// Errors specific to on-device provider configuration and model loading.
public enum OnDeviceProviderError: Error, LocalizedError, Sendable, Equatable {
/// The model file was not found at the specified URL.
case modelNotFound(URL)
/// The model file format is not what was expected.
case invalidModelFormat(expected: String, actual: String)
/// The model ID is empty.
case emptyModelId
/// A configuration parameter is invalid.
case invalidParameter(name: String, value: String, reason: String)
/// The on-device provider is not available on this platform.
/// CoreML requires macOS 15+ / iOS 18+. MLX requires the MLX build flag.
case providerUnavailable(OnDeviceProviderType, reason: String)
/// Model loading failed with an underlying error.
case modelLoadFailed(reason: String)
/// Model inference failed.
case inferenceFailed(reason: String)
public var errorDescription: String? {
switch self {
case .modelNotFound(let url):
return "On-device model not found at: \(url.path)"
case .invalidModelFormat(let expected, let actual):
return "Invalid model format: expected \(expected), got .\(actual)"
case .emptyModelId:
return "Model ID must not be empty"
case .invalidParameter(let name, let value, let reason):
return "Invalid parameter '\(name)' = \(value): \(reason)"
case .providerUnavailable(let type, let reason):
return "\(type.rawValue) provider unavailable: \(reason)"
case .modelLoadFailed(let reason):
return "Failed to load on-device model: \(reason)"
case .inferenceFailed(let reason):
return "On-device inference failed: \(reason)"
}
}
}
// MARK: - On-Device Inference Pipeline
/// Manages the on-device model inference pipeline.
///
/// `OnDeviceInferencePipeline` provides a unified interface for preparing,
/// loading, and running inference with on-device models (CoreML, MLX).
/// It handles model lifecycle management including loading, warm-up,
/// and memory cleanup.
///
/// ```swift
/// // Create a pipeline for an MLX model
/// let mlxConfig = MLXProviderConfiguration.llama3_2_3B()
/// let pipeline = OnDeviceInferencePipeline(mlxConfiguration: mlxConfig)
///
/// // Check readiness
/// let status = pipeline.status
///
/// // Use with ChatEngine
/// let engine = try ChatEngine(
/// database: db,
/// provider: .onDevice(mlx: mlxConfig)
/// )
/// ```
public final class OnDeviceInferencePipeline: @unchecked Sendable {
/// The current status of the on-device inference pipeline.
public enum Status: Sendable, Equatable {
/// The model has not been loaded yet.
case notLoaded
/// The model is currently being loaded/downloaded.
case loading
/// The model is loaded and ready for inference.
case ready
/// The model failed to load.
case failed(String)
}
/// The type of on-device provider this pipeline uses.
public let providerType: OnDeviceProviderType
/// The MLX configuration, if this is an MLX pipeline.
public let mlxConfiguration: MLXProviderConfiguration?
/// The CoreML configuration, if this is a CoreML pipeline.
public let coreMLConfiguration: CoreMLProviderConfiguration?
/// The current status of the pipeline.
private let _statusLock = NSLock()
private var _status: Status = .notLoaded
/// The current pipeline status.
public var status: Status {
_statusLock.lock()
defer { _statusLock.unlock() }
return _status
}
/// Creates an MLX inference pipeline.
///
/// - Parameter configuration: The MLX model configuration.
public init(mlxConfiguration: MLXProviderConfiguration) {
self.providerType = .mlx
self.mlxConfiguration = mlxConfiguration
self.coreMLConfiguration = nil
}
/// Creates a CoreML inference pipeline.
///
/// - Parameter configuration: The CoreML model configuration.
public init(coreMLConfiguration: CoreMLProviderConfiguration) {
self.providerType = .coreML
self.coreMLConfiguration = coreMLConfiguration
self.mlxConfiguration = nil
}
/// Validates the configuration before attempting to load.
///
/// Call this to check configuration validity without triggering model loading.
///
/// - Throws: ``OnDeviceProviderError`` if the configuration is invalid.
public func validateConfiguration() throws {
switch providerType {
case .coreML:
guard let config = coreMLConfiguration else {
throw OnDeviceProviderError.providerUnavailable(
.coreML,
reason: "No CoreML configuration provided"
)
}
try config.validate()
case .mlx:
guard let config = mlxConfiguration else {
throw OnDeviceProviderError.providerUnavailable(
.mlx,
reason: "No MLX configuration provided"
)
}
try config.validate()
}
}
/// Updates the pipeline status.
internal func setStatus(_ newStatus: Status) {
_statusLock.lock()
_status = newStatus
_statusLock.unlock()
}
/// Provides recommended generation options optimized for SQL generation
/// based on the pipeline's configuration.
///
/// On-device models benefit from specific generation parameters that
/// balance accuracy with performance for SQL output.
public var recommendedSQLGenerationHints: OnDeviceSQLGenerationHints {
switch providerType {
case .coreML:
let config = coreMLConfiguration ?? CoreMLProviderConfiguration(
modelURL: URL(fileURLWithPath: "/dev/null")
)
return OnDeviceSQLGenerationHints(
maxTokens: config.maxResponseTokens,
temperature: config.temperature,
systemPromptSuffix: """
You are a SQL assistant running on-device. Generate only valid SQLite SQL.
Be concise — output ONLY the SQL query with no explanation.
""",
useSampling: config.useSampling
)
case .mlx:
let config = mlxConfiguration ?? .llama3_2_3B()
return OnDeviceSQLGenerationHints(
maxTokens: config.maxResponseTokens,
temperature: config.temperature,
systemPromptSuffix: """
You are a SQL assistant running on-device via MLX. Generate only valid SQLite SQL.
Be concise — output ONLY the SQL query with no explanation.
""",
useSampling: true
)
}
}
}
/// Hints for optimizing SQL generation with on-device models.
///
/// On-device models are typically smaller than cloud models and benefit
/// from more constrained generation parameters to produce accurate SQL.
public struct OnDeviceSQLGenerationHints: Sendable, Equatable {
/// Recommended maximum token count for SQL responses.
public let maxTokens: Int
/// Recommended temperature for SQL generation.
public let temperature: Double
/// Additional system prompt text optimized for on-device SQL generation.
public let systemPromptSuffix: String
/// Whether to use sampling or greedy decoding.
public let useSampling: Bool
}
// MARK: - ProviderConfiguration Extension
extension ProviderConfiguration {
/// Creates a configuration for an on-device MLX model.
///
/// MLX models run entirely on Apple silicon using the MLX framework.
/// Models are automatically downloaded from HuggingFace Hub on first use.
///
/// ```swift
/// // Using a pre-configured model
/// let config = ProviderConfiguration.onDeviceMLX(.llama3_2_3B())
///
/// // Using a custom model
/// let config = ProviderConfiguration.onDeviceMLX(
/// MLXProviderConfiguration(
/// modelId: "mlx-community/Qwen2.5-7B-Instruct-4bit",
/// temperature: 0.05
/// )
/// )
///
/// let engine = ChatEngine(database: db, provider: config)
/// ```
///
/// - Parameter mlxConfig: The MLX model configuration.
/// - Returns: A configured `ProviderConfiguration` that wraps the MLX model.
///
/// - Note: The returned configuration uses `.openAICompatible` as the provider
/// type internally. The actual model is created via MLX APIs when `#if MLX` is
/// available. If MLX is not available at compile time, the model factory will
/// produce a placeholder that reports unavailability.
public static func onDeviceMLX(
_ mlxConfig: MLXProviderConfiguration
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .openAICompatible,
model: mlxConfig.modelId,
apiKeyProvider: { "" },
baseURL: nil,
apiVersion: nil,
betas: nil,
openAIVariant: nil
)
}
/// Creates a configuration for an on-device CoreML model.
///
/// CoreML models must be pre-compiled to `.mlmodelc` format.
/// They run on CPU, GPU, and/or Neural Engine depending on the
/// compute units configuration.
///
/// ```swift
/// let modelURL = Bundle.main.url(forResource: "SQLModel", withExtension: "mlmodelc")!
/// let config = ProviderConfiguration.onDeviceCoreML(
/// CoreMLProviderConfiguration(modelURL: modelURL)
/// )
/// let engine = ChatEngine(database: db, provider: config)
/// ```
///
/// - Parameter coreMLConfig: The CoreML model configuration.
/// - Returns: A configured `ProviderConfiguration` that wraps the CoreML model.
///
/// - Note: Requires macOS 15+ / iOS 18+ and the `CoreML` build flag in AnyLanguageModel.
public static func onDeviceCoreML(
_ coreMLConfig: CoreMLProviderConfiguration
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .openAICompatible,
model: coreMLConfig.modelURL.lastPathComponent,
apiKeyProvider: { "" },
baseURL: nil,
apiVersion: nil,
betas: nil,
openAIVariant: nil
)
}
}
// MARK: - ChatEngine On-Device Convenience
extension ChatEngine {
/// Creates a ChatEngine with an on-device MLX model.
///
/// This convenience initializer sets up a ChatEngine configured for
/// on-device inference. It validates the MLX configuration and creates
/// an inference pipeline.
///
/// ```swift
/// let engine = try ChatEngine.onDevice(
/// database: db,
/// mlx: .llama3_2_3B()
/// )
/// let response = try await engine.send("How many users are there?")
/// ```
///
/// - Parameters:
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
/// - mlx: The MLX model configuration.
/// - allowlist: SQL operations allowed. Defaults to read-only.
/// - configuration: Engine configuration.
/// - Returns: A configured `ChatEngine` instance.
/// - Throws: ``OnDeviceProviderError`` if the configuration is invalid.
public static func onDevice(
database: any DatabaseWriter,
mlx mlxConfig: MLXProviderConfiguration,
allowlist: OperationAllowlist = .readOnly,
configuration: ChatEngineConfiguration = .default
) throws -> ChatEngine {
// Validate configuration
try mlxConfig.validate()
let pipeline = OnDeviceInferencePipeline(mlxConfiguration: mlxConfig)
// Build a ChatEngineConfiguration that includes on-device hints
var engineConfig = configuration
let hints = pipeline.recommendedSQLGenerationHints
if engineConfig.additionalContext == nil {
engineConfig.additionalContext = hints.systemPromptSuffix
} else {
engineConfig.additionalContext! += "\n\n" + hints.systemPromptSuffix
}
let providerConfig = ProviderConfiguration.onDeviceMLX(mlxConfig)
return ChatEngine(
database: database,
provider: providerConfig,
allowlist: allowlist,
configuration: engineConfig
)
}
/// Creates a ChatEngine with an on-device CoreML model.
///
/// ```swift
/// let modelURL = Bundle.main.url(forResource: "SQLModel", withExtension: "mlmodelc")!
/// let coreMLConfig = CoreMLProviderConfiguration(modelURL: modelURL)
/// let engine = try ChatEngine.onDevice(
/// database: db,
/// coreML: coreMLConfig
/// )
/// ```
///
/// - Parameters:
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
/// - coreML: The CoreML model configuration.
/// - allowlist: SQL operations allowed. Defaults to read-only.
/// - configuration: Engine configuration.
/// - Returns: A configured `ChatEngine` instance.
/// - Throws: ``OnDeviceProviderError`` if the configuration is invalid.
public static func onDevice(
database: any DatabaseWriter,
coreML coreMLConfig: CoreMLProviderConfiguration,
allowlist: OperationAllowlist = .readOnly,
configuration: ChatEngineConfiguration = .default
) throws -> ChatEngine {
// Validate configuration
try coreMLConfig.validate()
let pipeline = OnDeviceInferencePipeline(coreMLConfiguration: coreMLConfig)
var engineConfig = configuration
let hints = pipeline.recommendedSQLGenerationHints
if engineConfig.additionalContext == nil {
engineConfig.additionalContext = hints.systemPromptSuffix
} else {
engineConfig.additionalContext! += "\n\n" + hints.systemPromptSuffix
}
let providerConfig = ProviderConfiguration.onDeviceCoreML(coreMLConfig)
return ChatEngine(
database: database,
provider: providerConfig,
allowlist: allowlist,
configuration: engineConfig
)
}
}
// MARK: - Model Readiness Checker
/// Utility for checking on-device model availability and system capability.
public enum OnDeviceModelReadiness {
/// System capability information for on-device inference.
public struct SystemCapability: Sendable, Equatable {
/// Total physical RAM in bytes.
public let totalRAM: UInt64
/// Whether the device has sufficient RAM for typical on-device models.
/// Generally requires at least 4GB for 3B parameter models.
public let hasSufficientRAM: Bool
/// Whether Apple Neural Engine is likely available.
/// True on devices with Apple silicon.
public let hasNeuralEngine: Bool
/// Recommended model size category based on available RAM.
public let recommendedModelSize: RecommendedModelSize
}
/// Recommended model size based on device capabilities.
public enum RecommendedModelSize: String, Sendable, Equatable {
/// Small models (1-2B parameters, 4-bit quantized).
/// Suitable for devices with 4GB RAM.
case small
/// Medium models (3-4B parameters, 4-bit quantized).
/// Suitable for devices with 6-8GB RAM.
case medium
/// Large models (7-8B parameters, 4-bit quantized).
/// Suitable for devices with 16GB+ RAM.
case large
}
/// Checks the current device's capability for on-device inference.
///
/// ```swift
/// let capability = OnDeviceModelReadiness.checkSystemCapability()
/// if capability.hasSufficientRAM {
/// print("Recommended size: \(capability.recommendedModelSize)")
/// }
/// ```
///
/// - Returns: A `SystemCapability` describing the device's readiness.
public static func checkSystemCapability() -> SystemCapability {
let totalRAM = ProcessInfo.processInfo.physicalMemory
let ramGB = totalRAM / (1024 * 1024 * 1024)
let recommendedSize: RecommendedModelSize
switch ramGB {
case ..<4:
recommendedSize = .small
case ..<8:
recommendedSize = .medium
default:
recommendedSize = .large
}
return SystemCapability(
totalRAM: totalRAM,
hasSufficientRAM: ramGB >= 4,
hasNeuralEngine: hasAppleSilicon(),
recommendedModelSize: recommendedSize
)
}
/// Suggests an MLX model configuration based on system capabilities.
///
/// ```swift
/// let config = OnDeviceModelReadiness.suggestedMLXModel()
/// let engine = try ChatEngine.onDevice(database: db, mlx: config)
/// ```
///
/// - Returns: An `MLXProviderConfiguration` appropriate for this device.
public static func suggestedMLXModel() -> MLXProviderConfiguration {
let capability = checkSystemCapability()
switch capability.recommendedModelSize {
case .small:
return .phi3_5_mini()
case .medium:
return .llama3_2_3B()
case .large:
return .qwen2_5_coder_3B()
}
}
/// Checks if the current device uses Apple silicon.
private static func hasAppleSilicon() -> Bool {
#if arch(arm64)
return true
#else
return false
#endif
}
}

View File

@@ -0,0 +1,54 @@
/// Defines which SQL operations the LLM is permitted to generate.
///
/// The default is ``readOnly`` (SELECT only). Write operations require
/// explicit opt-in. This is the safety-by-default principle.
public struct OperationAllowlist: Sendable, Equatable {
/// The set of permitted SQL operation types.
public let allowedOperations: Set<SQLOperation>
/// Creates an allowlist from the given set of operations.
public init(_ operations: Set<SQLOperation>) {
self.allowedOperations = operations
}
/// Read-only: only SELECT queries are permitted. This is the default.
public static let readOnly = OperationAllowlist([.select])
/// Standard read-write: SELECT, INSERT, and UPDATE are permitted.
public static let standard = OperationAllowlist([.select, .insert, .update])
/// Unrestricted: all operations including DELETE are permitted.
/// DELETE still requires confirmation via `ToolExecutionDelegate`.
public static let unrestricted = OperationAllowlist([.select, .insert, .update, .delete])
/// Returns true if the given operation is allowed.
public func isAllowed(_ operation: SQLOperation) -> Bool {
allowedOperations.contains(operation)
}
/// Returns a human-readable description of what's allowed, for inclusion
/// in the LLM system prompt.
func describeForLLM() -> String {
if allowedOperations == [.select] {
return "You may ONLY generate SELECT queries. No data modifications are allowed."
}
let sorted = allowedOperations.sorted { $0.rawValue < $1.rawValue }
let names = sorted.map { $0.rawValue.uppercased() }
var desc = "Allowed SQL operations: \(names.joined(separator: ", "))."
if allowedOperations.contains(.delete) {
desc += " DELETE operations are destructive and require user confirmation before execution."
}
return desc
}
}
/// The types of SQL operations that can be controlled via the allowlist.
public enum SQLOperation: String, Sendable, Hashable, CaseIterable {
case select
case insert
case update
case delete
}

View File

@@ -0,0 +1,609 @@
// ProviderConfiguration.swift
// SwiftDBAI
//
// Unified provider configuration for cloud-based LLM providers.
// Wraps AnyLanguageModel provider types with convenient factory methods.
import AnyLanguageModel
import Foundation
import GRDB
/// Configuration for connecting to a cloud-based LLM provider.
///
/// `ProviderConfiguration` provides a unified way to configure any supported
/// LLM provider (OpenAI, Anthropic, Gemini, or OpenAI-compatible services).
/// Each configuration produces a properly configured `LanguageModel` instance
/// that works with ``ChatEngine`` and ``TextSummaryRenderer``.
///
/// ## Quick Start
///
/// ```swift
/// // OpenAI
/// let config = ProviderConfiguration.openAI(apiKey: "sk-...", model: "gpt-4o")
///
/// // Anthropic
/// let config = ProviderConfiguration.anthropic(apiKey: "sk-ant-...", model: "claude-sonnet-4-20250514")
///
/// // Gemini
/// let config = ProviderConfiguration.gemini(apiKey: "AIza...", model: "gemini-2.0-flash")
///
/// // Use with ChatEngine
/// let engine = ChatEngine(database: db, model: config.makeModel())
/// ```
///
/// ## API Key Handling
///
/// API keys are stored as closures to support both static strings and
/// dynamic retrieval from keychains, environment variables, or secure storage:
///
/// ```swift
/// // Static key
/// let config = ProviderConfiguration.openAI(apiKey: "sk-...", model: "gpt-4o")
///
/// // Dynamic key from environment
/// let config = ProviderConfiguration.openAI(
/// apiKeyProvider: { ProcessInfo.processInfo.environment["OPENAI_API_KEY"] ?? "" },
/// model: "gpt-4o"
/// )
/// ```
public struct ProviderConfiguration: Sendable {
/// The supported LLM provider types.
public enum Provider: String, Sendable, Hashable, CaseIterable {
/// OpenAI's GPT models via the Chat Completions or Responses API.
case openAI
/// Anthropic's Claude models.
case anthropic
/// Google's Gemini models.
case gemini
/// Any OpenAI-compatible API (e.g., local servers, third-party providers).
case openAICompatible
/// Ollama local models via `ollama serve`.
/// Default endpoint: http://localhost:11434
case ollama
/// llama.cpp server local GGUF models via `llama-server`.
/// Default endpoint: http://localhost:8080
/// Uses the OpenAI-compatible API.
case llamaCpp
}
/// The provider type for this configuration.
public let provider: Provider
/// The model identifier (e.g., "gpt-4o", "claude-sonnet-4-20250514", "gemini-2.0-flash").
public let model: String
/// A closure that provides the API key on demand.
///
/// Using a closure allows lazy evaluation and integration with secure
/// storage systems (Keychain, environment variables, etc.).
private let apiKeyProvider: @Sendable () -> String
/// Optional custom base URL for OpenAI-compatible providers.
public let baseURL: URL?
/// Optional API version override (used by Anthropic and Gemini).
public let apiVersion: String?
/// Optional beta headers (used by Anthropic).
public let betas: [String]?
/// The OpenAI API variant to use (Chat Completions or Responses).
public let openAIVariant: OpenAILanguageModel.APIVariant?
// MARK: - Internal Init
/// Internal memberwise initializer used by factory methods.
internal init(
provider: Provider,
model: String,
apiKeyProvider: @escaping @Sendable () -> String,
baseURL: URL?,
apiVersion: String?,
betas: [String]?,
openAIVariant: OpenAILanguageModel.APIVariant?
) {
self.provider = provider
self.model = model
self.apiKeyProvider = apiKeyProvider
self.baseURL = baseURL
self.apiVersion = apiVersion
self.betas = betas
self.openAIVariant = openAIVariant
}
// MARK: - Factory Methods
/// Creates a configuration for OpenAI's API.
///
/// - Parameters:
/// - apiKey: Your OpenAI API key (e.g., "sk-...").
/// - model: The model identifier (e.g., "gpt-4o", "gpt-4o-mini").
/// - variant: The API variant to use. Defaults to `.chatCompletions`.
/// - baseURL: Optional custom base URL. Defaults to OpenAI's API.
/// - Returns: A configured `ProviderConfiguration`.
public static func openAI(
apiKey: String,
model: String,
variant: OpenAILanguageModel.APIVariant = .chatCompletions,
baseURL: URL? = nil
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .openAI,
model: model,
apiKeyProvider: { apiKey },
baseURL: baseURL,
apiVersion: nil,
betas: nil,
openAIVariant: variant
)
}
/// Creates a configuration for OpenAI's API with a dynamic key provider.
///
/// Use this when the API key comes from a keychain, environment variable,
/// or other dynamic source.
///
/// - Parameters:
/// - apiKeyProvider: A closure that returns the API key.
/// - model: The model identifier.
/// - variant: The API variant to use. Defaults to `.chatCompletions`.
/// - baseURL: Optional custom base URL.
/// - Returns: A configured `ProviderConfiguration`.
public static func openAI(
apiKeyProvider: @escaping @Sendable () -> String,
model: String,
variant: OpenAILanguageModel.APIVariant = .chatCompletions,
baseURL: URL? = nil
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .openAI,
model: model,
apiKeyProvider: apiKeyProvider,
baseURL: baseURL,
apiVersion: nil,
betas: nil,
openAIVariant: variant
)
}
/// Creates a configuration for Anthropic's Claude API.
///
/// - Parameters:
/// - apiKey: Your Anthropic API key (e.g., "sk-ant-...").
/// - model: The model identifier (e.g., "claude-sonnet-4-20250514").
/// - apiVersion: Optional API version override.
/// - betas: Optional beta feature headers.
/// - Returns: A configured `ProviderConfiguration`.
public static func anthropic(
apiKey: String,
model: String,
apiVersion: String? = nil,
betas: [String]? = nil
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .anthropic,
model: model,
apiKeyProvider: { apiKey },
baseURL: nil,
apiVersion: apiVersion,
betas: betas,
openAIVariant: nil
)
}
/// Creates a configuration for Anthropic's Claude API with a dynamic key provider.
///
/// - Parameters:
/// - apiKeyProvider: A closure that returns the API key.
/// - model: The model identifier.
/// - apiVersion: Optional API version override.
/// - betas: Optional beta feature headers.
/// - Returns: A configured `ProviderConfiguration`.
public static func anthropic(
apiKeyProvider: @escaping @Sendable () -> String,
model: String,
apiVersion: String? = nil,
betas: [String]? = nil
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .anthropic,
model: model,
apiKeyProvider: apiKeyProvider,
baseURL: nil,
apiVersion: apiVersion,
betas: betas,
openAIVariant: nil
)
}
/// Creates a configuration for Google's Gemini API.
///
/// - Parameters:
/// - apiKey: Your Gemini API key (e.g., "AIza...").
/// - model: The model identifier (e.g., "gemini-2.0-flash").
/// - apiVersion: Optional API version override (defaults to "v1beta").
/// - Returns: A configured `ProviderConfiguration`.
public static func gemini(
apiKey: String,
model: String,
apiVersion: String? = nil
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .gemini,
model: model,
apiKeyProvider: { apiKey },
baseURL: nil,
apiVersion: apiVersion,
betas: nil,
openAIVariant: nil
)
}
/// Creates a configuration for Google's Gemini API with a dynamic key provider.
///
/// - Parameters:
/// - apiKeyProvider: A closure that returns the API key.
/// - model: The model identifier.
/// - apiVersion: Optional API version override.
/// - Returns: A configured `ProviderConfiguration`.
public static func gemini(
apiKeyProvider: @escaping @Sendable () -> String,
model: String,
apiVersion: String? = nil
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .gemini,
model: model,
apiKeyProvider: apiKeyProvider,
baseURL: nil,
apiVersion: apiVersion,
betas: nil,
openAIVariant: nil
)
}
/// Creates a configuration for any OpenAI-compatible API.
///
/// Use this for third-party services that implement the OpenAI Chat Completions
/// API (e.g., local LLM servers, Groq, Together AI, etc.).
///
/// ```swift
/// let config = ProviderConfiguration.openAICompatible(
/// apiKey: "your-key",
/// model: "llama-3.1-70b",
/// baseURL: URL(string: "https://api.together.xyz/v1/")!
/// )
/// ```
///
/// - Parameters:
/// - apiKey: The API key for the service.
/// - model: The model identifier.
/// - baseURL: The base URL of the compatible API.
/// - variant: The API variant. Defaults to `.chatCompletions`.
/// - Returns: A configured `ProviderConfiguration`.
public static func openAICompatible(
apiKey: String,
model: String,
baseURL: URL,
variant: OpenAILanguageModel.APIVariant = .chatCompletions
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .openAICompatible,
model: model,
apiKeyProvider: { apiKey },
baseURL: baseURL,
apiVersion: nil,
betas: nil,
openAIVariant: variant
)
}
/// Creates a configuration for any OpenAI-compatible API with a dynamic key provider.
///
/// - Parameters:
/// - apiKeyProvider: A closure that returns the API key.
/// - model: The model identifier.
/// - baseURL: The base URL of the compatible API.
/// - variant: The API variant. Defaults to `.chatCompletions`.
/// - Returns: A configured `ProviderConfiguration`.
public static func openAICompatible(
apiKeyProvider: @escaping @Sendable () -> String,
model: String,
baseURL: URL,
variant: OpenAILanguageModel.APIVariant = .chatCompletions
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .openAICompatible,
model: model,
apiKeyProvider: apiKeyProvider,
baseURL: baseURL,
apiVersion: nil,
betas: nil,
openAIVariant: variant
)
}
// MARK: - Local Provider Factory Methods
/// Creates a configuration for a local Ollama instance.
///
/// Ollama runs models locally and exposes a native API on port 11434.
/// No API key is required by default.
///
/// ```swift
/// // Default local Ollama
/// let config = ProviderConfiguration.ollama(model: "llama3.2")
///
/// // Ollama on a remote machine
/// let config = ProviderConfiguration.ollama(
/// model: "qwen2.5",
/// baseURL: URL(string: "http://192.168.1.100:11434")!
/// )
///
/// // Use with ChatEngine
/// let engine = ChatEngine(database: db, provider: config)
/// ```
///
/// - Parameters:
/// - model: The Ollama model name (e.g., "llama3.2", "qwen2.5", "mistral").
/// - baseURL: The Ollama server URL. Defaults to `http://localhost:11434`.
/// - Returns: A configured `ProviderConfiguration`.
public static func ollama(
model: String,
baseURL: URL = OllamaLanguageModel.defaultBaseURL
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .ollama,
model: model,
apiKeyProvider: { "" },
baseURL: baseURL,
apiVersion: nil,
betas: nil,
openAIVariant: nil
)
}
/// Creates a configuration for a local llama.cpp server.
///
/// llama.cpp's `llama-server` exposes an OpenAI-compatible Chat Completions
/// API, typically on port 8080. No API key is required by default.
///
/// ```swift
/// // Default local llama.cpp
/// let config = ProviderConfiguration.llamaCpp(model: "default")
///
/// // llama.cpp on a custom port with API key
/// let config = ProviderConfiguration.llamaCpp(
/// model: "my-model",
/// baseURL: URL(string: "http://localhost:9090")!,
/// apiKey: "my-secret-key"
/// )
/// ```
///
/// - Parameters:
/// - model: The model identifier. Use "default" if llama-server loads a single model.
/// - baseURL: The llama.cpp server URL. Defaults to `http://localhost:8080`.
/// - apiKey: Optional API key if the server requires authentication.
/// - Returns: A configured `ProviderConfiguration`.
public static func llamaCpp(
model: String = "default",
baseURL: URL = LocalProviderDiscovery.defaultLlamaCppURL,
apiKey: String = ""
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .llamaCpp,
model: model,
apiKeyProvider: { apiKey },
baseURL: baseURL,
apiVersion: nil,
betas: nil,
openAIVariant: .chatCompletions
)
}
// MARK: - Model Construction
/// Creates a configured `LanguageModel` instance for this provider.
///
/// This is the primary way to get a model from a configuration.
/// The returned model is ready to use with ``ChatEngine`` or
/// ``TextSummaryRenderer``.
///
/// ```swift
/// let config = ProviderConfiguration.openAI(apiKey: "sk-...", model: "gpt-4o")
/// let engine = ChatEngine(database: db, model: config.makeModel())
/// ```
///
/// - Returns: A configured `LanguageModel` instance.
public func makeModel() -> any LanguageModel {
let key = apiKeyProvider
switch provider {
case .openAI:
if let baseURL {
return OpenAILanguageModel(
baseURL: baseURL,
apiKey: key(),
model: model,
apiVariant: openAIVariant ?? .chatCompletions
)
}
return OpenAILanguageModel(
apiKey: key(),
model: model,
apiVariant: openAIVariant ?? .chatCompletions
)
case .anthropic:
if let apiVersion {
return AnthropicLanguageModel(
apiKey: key(),
apiVersion: apiVersion,
betas: betas,
model: model
)
}
if let betas {
return AnthropicLanguageModel(
apiKey: key(),
betas: betas,
model: model
)
}
return AnthropicLanguageModel(
apiKey: key(),
model: model
)
case .gemini:
if let apiVersion {
return GeminiLanguageModel(
apiKey: key(),
apiVersion: apiVersion,
model: model
)
}
return GeminiLanguageModel(
apiKey: key(),
model: model
)
case .openAICompatible:
return OpenAILanguageModel(
baseURL: baseURL ?? OpenAILanguageModel.defaultBaseURL,
apiKey: key(),
model: model,
apiVariant: openAIVariant ?? .chatCompletions
)
case .ollama:
return OllamaLanguageModel(
baseURL: baseURL ?? OllamaLanguageModel.defaultBaseURL,
model: model
)
case .llamaCpp:
// llama.cpp exposes an OpenAI-compatible API
return OpenAILanguageModel(
baseURL: baseURL ?? LocalProviderDiscovery.defaultLlamaCppURL,
apiKey: key(),
model: model,
apiVariant: openAIVariant ?? .chatCompletions
)
}
}
// MARK: - API Key Access
/// Returns the current API key.
///
/// Useful for validation or debugging. In production, prefer using
/// ``makeModel()`` which handles key injection automatically.
public var apiKey: String {
apiKeyProvider()
}
/// Returns `true` if the API key is non-empty.
///
/// Use this to check configuration validity before creating an engine:
/// ```swift
/// guard config.hasValidAPIKey else {
/// // Show API key setup UI
/// return
/// }
/// ```
public var hasValidAPIKey: Bool {
!apiKeyProvider().trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
}
// MARK: - Environment Variable Helpers
/// Creates a configuration using an API key from an environment variable.
///
/// Falls back to an empty string if the environment variable is not set,
/// which will cause API calls to fail with an authentication error.
///
/// ```swift
/// let config = ProviderConfiguration.fromEnvironment(
/// provider: .openAI,
/// environmentVariable: "OPENAI_API_KEY",
/// model: "gpt-4o"
/// )
/// ```
///
/// - Parameters:
/// - provider: The LLM provider.
/// - environmentVariable: The name of the environment variable holding the API key.
/// - model: The model identifier.
/// - Returns: A configured `ProviderConfiguration`.
public static func fromEnvironment(
provider: Provider,
environmentVariable: String,
model: String
) -> ProviderConfiguration {
let keyProvider: @Sendable () -> String = {
ProcessInfo.processInfo.environment[environmentVariable] ?? ""
}
switch provider {
case .openAI:
return .openAI(apiKeyProvider: keyProvider, model: model)
case .anthropic:
return .anthropic(apiKeyProvider: keyProvider, model: model)
case .gemini:
return .gemini(apiKeyProvider: keyProvider, model: model)
case .openAICompatible:
return .openAICompatible(
apiKeyProvider: keyProvider,
model: model,
baseURL: OpenAILanguageModel.defaultBaseURL
)
case .ollama:
return .ollama(model: model)
case .llamaCpp:
return .llamaCpp(model: model)
}
}
}
// MARK: - ChatEngine Convenience Init
extension ChatEngine {
/// Creates a ChatEngine using a ``ProviderConfiguration``.
///
/// This is the most convenient way to set up a ChatEngine with a
/// cloud provider:
///
/// ```swift
/// let engine = ChatEngine(
/// database: myDB,
/// provider: .openAI(apiKey: "sk-...", model: "gpt-4o")
/// )
/// ```
///
/// - Parameters:
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
/// - provider: The provider configuration.
/// - allowlist: SQL operations the LLM may generate. Defaults to read-only.
/// - configuration: Engine configuration for timeouts, context window, validators, etc.
public convenience init(
database: any DatabaseWriter,
provider: ProviderConfiguration,
allowlist: OperationAllowlist = .readOnly,
configuration: ChatEngineConfiguration = .default
) {
self.init(
database: database,
model: provider.makeModel(),
allowlist: allowlist,
configuration: configuration
)
}
}

View File

@@ -0,0 +1,114 @@
// QueryValidator.swift
// SwiftDBAI
//
// Extensible query validation protocol for custom pre-execution checks.
import Foundation
/// A protocol for custom SQL query validation.
///
/// Implement this protocol to add domain-specific validation rules that run
/// after the built-in allowlist and safety checks. Validators receive the
/// parsed SQL string and its detected operation type.
///
/// Example restrict queries to specific tables:
/// ```swift
/// struct TableAllowlistValidator: QueryValidator {
/// let allowedTables: Set<String>
///
/// func validate(sql: String, operation: SQLOperation) throws {
/// let upper = sql.uppercased()
/// for table in allowedTables {
/// // Simple check real implementation might parse FROM/JOIN clauses
/// if upper.contains(table.uppercased()) { return }
/// }
/// throw QueryValidationError.rejected("Query references tables outside the allowlist.")
/// }
/// }
/// ```
public protocol QueryValidator: Sendable {
/// Validates a SQL query before execution.
///
/// - Parameters:
/// - sql: The cleaned SQL statement about to be executed.
/// - operation: The detected operation type (SELECT, INSERT, etc.).
/// - Throws: ``QueryValidationError`` or any `Error` to reject the query.
func validate(sql: String, operation: SQLOperation) throws
}
/// Errors thrown by custom ``QueryValidator`` implementations.
public enum QueryValidationError: Error, LocalizedError, Sendable, Equatable {
/// The query was rejected by a custom validator with the given reason.
case rejected(String)
public var errorDescription: String? {
switch self {
case .rejected(let reason):
return "Query rejected: \(reason)"
}
}
}
// MARK: - Built-in Validators
/// A validator that restricts queries to a specific set of table names.
///
/// This performs a simple keyword check it verifies that the SQL references
/// at least one of the allowed tables. This is a best-effort check, not a
/// full SQL parser.
public struct TableAllowlistValidator: QueryValidator {
/// The set of table names queries are allowed to reference.
public let allowedTables: Set<String>
/// Creates a validator with the given allowed table names.
public init(allowedTables: Set<String>) {
self.allowedTables = allowedTables
}
public func validate(sql: String, operation: SQLOperation) throws {
let upper = sql.uppercased()
let found = allowedTables.contains { table in
let pattern = table.uppercased()
return upper.contains(pattern)
}
guard found else {
throw QueryValidationError.rejected(
"Query does not reference any allowed tables: \(allowedTables.sorted().joined(separator: ", "))"
)
}
}
}
/// A validator that enforces a maximum row limit on SELECT queries
/// by checking for a LIMIT clause.
public struct MaxRowLimitValidator: QueryValidator {
/// The maximum number of rows allowed.
public let maxRows: Int
/// Creates a validator that requires SELECT queries to include a LIMIT clause
/// not exceeding `maxRows`.
public init(maxRows: Int) {
self.maxRows = maxRows
}
public func validate(sql: String, operation: SQLOperation) throws {
guard operation == .select else { return }
let upper = sql.uppercased()
// Check if LIMIT is present
guard let limitRange = upper.range(of: #"LIMIT\s+(\d+)"#, options: .regularExpression) else {
throw QueryValidationError.rejected(
"SELECT queries must include a LIMIT clause (max \(maxRows) rows)."
)
}
// Extract the limit value
let limitSubstring = upper[limitRange]
let digits = limitSubstring.components(separatedBy: .decimalDigits.inverted).joined()
if let value = Int(digits), value > maxRows {
throw QueryValidationError.rejected(
"LIMIT \(value) exceeds the maximum allowed (\(maxRows))."
)
}
}
}

View File

@@ -0,0 +1,677 @@
// ChatEngine.swift
// SwiftDBAI
//
// Orchestrates the conversation loop: user message SQL generation query
// execution result summarization response.
import AnyLanguageModel
import Foundation
import GRDB
/// A message in the chat conversation.
public struct ChatMessage: Sendable, Identifiable, Equatable {
public let id: UUID
public let role: Role
public let content: String
public let queryResult: QueryResult?
public let sql: String?
public let timestamp: Date
/// The typed error, if this is an error message.
public let error: SwiftDBAIError?
public enum Role: String, Sendable, Equatable {
case user
case assistant
case error
}
public init(
id: UUID = UUID(),
role: Role,
content: String,
queryResult: QueryResult? = nil,
sql: String? = nil,
timestamp: Date = Date(),
error: SwiftDBAIError? = nil
) {
self.id = id
self.role = role
self.content = content
self.queryResult = queryResult
self.sql = sql
self.timestamp = timestamp
self.error = error
}
}
/// The response returned by `ChatEngine.send(_:)`.
public struct ChatResponse: Sendable {
/// The natural language summary of the result.
public let summary: String
/// The SQL that was generated and executed, if any.
public let sql: String?
/// The raw query result, if a query was executed.
public let queryResult: QueryResult?
}
/// Headless engine that orchestrates the full chat-with-database pipeline.
///
/// The engine:
/// 1. Introspects the database schema (once, lazily)
/// 2. Builds a system prompt with schema context
/// 3. Sends the user's question to the LLM to generate SQL
/// 4. Validates the SQL against the operation allowlist
/// 5. Executes the SQL via GRDB
/// 6. Summarizes results using `TextSummaryRenderer`
/// 7. Returns the summary (and raw data) to the caller
///
/// Usage:
/// ```swift
/// let engine = ChatEngine(
/// database: myDatabasePool,
/// model: myLanguageModel
/// )
/// let response = try await engine.send("How many users signed up this week?")
/// print(response.summary) // "There were 42 new signups this week."
/// ```
public final class ChatEngine: @unchecked Sendable {
// MARK: - Dependencies
private let database: any DatabaseWriter
private let model: any LanguageModel
private let allowlist: OperationAllowlist
private let mutationPolicy: MutationPolicy?
private let configuration: ChatEngineConfiguration
private let summaryRenderer: TextSummaryRenderer
private let sqlParser: SQLQueryParser
/// Optional delegate for intercepting destructive operations and observing SQL execution.
private let delegate: (any ToolExecutionDelegate)?
// MARK: - State
private var schema: DatabaseSchema?
private var conversationHistory: [ChatMessage] = []
private let lock = NSLock()
// MARK: - Initialization
/// Creates a new ChatEngine with a full configuration object.
///
/// - Parameters:
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
/// - model: Any `AnyLanguageModel`-compatible language model.
/// - allowlist: SQL operations the LLM may generate. Defaults to read-only (SELECT only).
/// - configuration: Engine configuration for timeouts, context window, validators, etc.
/// - delegate: Optional delegate for confirming destructive operations and observing SQL execution.
public init(
database: any DatabaseWriter,
model: any LanguageModel,
allowlist: OperationAllowlist = .readOnly,
configuration: ChatEngineConfiguration = .default,
delegate: (any ToolExecutionDelegate)? = nil
) {
self.database = database
self.model = model
self.allowlist = allowlist
self.mutationPolicy = nil
self.configuration = configuration
self.delegate = delegate
self.summaryRenderer = TextSummaryRenderer(
model: model,
maxRowsInPrompt: configuration.maxSummaryRows
)
self.sqlParser = SQLQueryParser(allowlist: allowlist)
}
/// Creates a new ChatEngine with a `MutationPolicy` for table-level control.
///
/// This initializer provides fine-grained control over which mutations are
/// allowed on which tables. The policy's operation allowlist is used for
/// SQL validation, and table-level restrictions are enforced during parsing.
///
/// - Parameters:
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
/// - model: Any `AnyLanguageModel`-compatible language model.
/// - mutationPolicy: Controls which operations are allowed on which tables.
/// - configuration: Engine configuration for timeouts, context window, validators, etc.
/// - delegate: Optional delegate for confirming destructive operations and observing SQL execution.
public init(
database: any DatabaseWriter,
model: any LanguageModel,
mutationPolicy: MutationPolicy,
configuration: ChatEngineConfiguration = .default,
delegate: (any ToolExecutionDelegate)? = nil
) {
self.database = database
self.model = model
self.allowlist = mutationPolicy.operationAllowlist
self.mutationPolicy = mutationPolicy
self.configuration = configuration
self.delegate = delegate
self.summaryRenderer = TextSummaryRenderer(
model: model,
maxRowsInPrompt: configuration.maxSummaryRows
)
self.sqlParser = SQLQueryParser(mutationPolicy: mutationPolicy)
}
/// Creates a new ChatEngine with individual parameters (convenience).
///
/// - Parameters:
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
/// - model: Any `AnyLanguageModel`-compatible language model.
/// - allowlist: SQL operations the LLM may generate. Defaults to read-only (SELECT only).
/// - additionalContext: Optional extra instructions for the LLM system prompt.
/// - maxSummaryRows: Maximum rows to include when summarizing results (default: 50).
public convenience init(
database: any DatabaseWriter,
model: any LanguageModel,
allowlist: OperationAllowlist,
additionalContext: String?,
maxSummaryRows: Int = 50
) {
let config = ChatEngineConfiguration(
maxSummaryRows: maxSummaryRows,
additionalContext: additionalContext
)
self.init(
database: database,
model: model,
allowlist: allowlist,
configuration: config
)
}
// MARK: - Public API
/// Sends a natural language message and returns a summarized response.
///
/// This is the primary entry point. The engine will:
/// 1. Introspect the schema if not yet cached
/// 2. Ask the LLM to generate SQL
/// 3. Validate the SQL against the allowlist and custom validators
/// 4. Execute the SQL (with timeout if configured)
/// 5. Summarize the results using `TextSummaryRenderer`
///
/// All errors are caught and mapped to a distinct ``SwiftDBAIError`` case
/// so callers always receive a typed, user-friendly error with a localized
/// description suitable for display in a chat UI.
///
/// - Parameter message: The user's natural language question or command.
/// - Returns: A `ChatResponse` containing the summary, SQL, and raw result.
/// - Throws: ``SwiftDBAIError`` for every failure mode.
public func send(_ message: String) async throws -> ChatResponse {
// 1. Ensure schema is introspected
let schema: DatabaseSchema
do {
schema = try await ensureSchema()
} catch let error as SwiftDBAIError {
throw error
} catch {
throw SwiftDBAIError.schemaIntrospectionFailed(reason: error.localizedDescription)
}
// Check for empty schema
if schema.tableNames.isEmpty {
throw SwiftDBAIError.emptySchema
}
// 2. Build prompt and get raw LLM response
let promptBuilder = PromptBuilder(
schema: schema,
allowlist: allowlist,
additionalContext: configuration.additionalContext
)
let rawLLMResponse: String
do {
rawLLMResponse = try await generateRawResponse(
question: message,
promptBuilder: promptBuilder
)
} catch let error as SwiftDBAIError {
throw error
} catch {
throw SwiftDBAIError.llmFailure(reason: error.localizedDescription)
}
// 3. Parse and validate SQL through SQLQueryParser
let parsed: ParsedSQL
do {
parsed = try sqlParser.parse(rawLLMResponse)
} catch let error as SQLParsingError {
throw error.toSwiftDBAIError(rawResponse: rawLLMResponse)
} catch let error as SwiftDBAIError {
throw error
} catch {
throw SwiftDBAIError.invalidSQL(sql: rawLLMResponse, reason: error.localizedDescription)
}
// 4. Run custom validators
do {
try runCustomValidators(parsed: parsed)
} catch let error as QueryValidationError {
throw error
} catch let error as SwiftDBAIError {
throw error
} catch {
throw SwiftDBAIError.queryRejected(reason: error.localizedDescription)
}
// 5. Handle confirmation-required operations (DELETE, DROP, etc.)
if parsed.requiresConfirmation {
if let delegate = self.delegate {
// Build context for the delegate
let classification = classifySQL(parsed.sql)
let context = DestructiveOperationContext(
sql: parsed.sql,
statementKind: detectStatementKind(parsed.sql) ?? .delete,
classification: classification,
description: "Execute \(parsed.operation.rawValue.uppercased()) operation: \(parsed.sql)",
targetTable: extractTargetTableForDelegate(from: parsed.sql, operation: parsed.operation)
)
// Ask the delegate for approval
let approved = await delegate.confirmDestructiveOperation(context)
if !approved {
throw SwiftDBAIError.confirmationRequired(
sql: parsed.sql,
operation: parsed.operation.rawValue
)
}
// Delegate approved fall through to execution
} else {
// No delegate throw confirmation required so caller can handle it
throw SwiftDBAIError.confirmationRequired(
sql: parsed.sql,
operation: parsed.operation.rawValue
)
}
}
// 6. Execute the SQL (with timeout if configured)
let result: QueryResult
do {
let classification = classifySQL(parsed.sql)
await delegate?.willExecuteSQL(parsed.sql, classification: classification)
result = try await executeSQLWithTimeout(parsed.sql)
await delegate?.didExecuteSQL(parsed.sql, success: true)
} catch let error as SwiftDBAIError {
await delegate?.didExecuteSQL(parsed.sql, success: false)
throw error
} catch let error as ChatEngineError {
await delegate?.didExecuteSQL(parsed.sql, success: false)
// Map internal ChatEngineError (e.g. from timeout) to SwiftDBAIError
throw error.toSwiftDBAIError()
} catch {
await delegate?.didExecuteSQL(parsed.sql, success: false)
throw SwiftDBAIError.databaseError(reason: error.localizedDescription)
}
// 7. Summarize the result using TextSummaryRenderer
let summary: String
do {
summary = try await summaryRenderer.summarize(
result: result,
userQuestion: message
)
} catch let error as SwiftDBAIError {
throw error
} catch {
throw SwiftDBAIError.llmFailure(reason: "Summarization failed: \(error.localizedDescription)")
}
// 8. Record conversation history
let userMessage = ChatMessage(role: .user, content: message)
let assistantMessage = ChatMessage(
role: .assistant,
content: summary,
queryResult: result,
sql: parsed.sql
)
lock.withLock {
conversationHistory.append(userMessage)
conversationHistory.append(assistantMessage)
}
return ChatResponse(
summary: summary,
sql: parsed.sql,
queryResult: result
)
}
/// Sends a natural language message, executing a previously confirmed destructive operation.
///
/// Call this after receiving a `confirmationRequired` error and the user has confirmed.
///
/// - Parameters:
/// - message: The original user message (for history recording).
/// - confirmedSQL: The SQL that was confirmed by the user.
/// - Returns: A `ChatResponse` with the result.
public func sendConfirmed(_ message: String, confirmedSQL: String) async throws -> ChatResponse {
let result: QueryResult
do {
let classification = classifySQL(confirmedSQL)
await delegate?.willExecuteSQL(confirmedSQL, classification: classification)
result = try await executeSQLWithTimeout(confirmedSQL)
await delegate?.didExecuteSQL(confirmedSQL, success: true)
} catch let error as SwiftDBAIError {
await delegate?.didExecuteSQL(confirmedSQL, success: false)
throw error
} catch let error as ChatEngineError {
await delegate?.didExecuteSQL(confirmedSQL, success: false)
throw error.toSwiftDBAIError()
} catch {
await delegate?.didExecuteSQL(confirmedSQL, success: false)
throw SwiftDBAIError.databaseError(reason: error.localizedDescription)
}
let summary: String
do {
summary = try await summaryRenderer.summarize(
result: result,
userQuestion: message
)
} catch let error as SwiftDBAIError {
throw error
} catch {
throw SwiftDBAIError.llmFailure(reason: "Summarization failed: \(error.localizedDescription)")
}
let userMessage = ChatMessage(role: .user, content: message)
let assistantMessage = ChatMessage(
role: .assistant,
content: summary,
queryResult: result,
sql: confirmedSQL
)
lock.withLock {
conversationHistory.append(userMessage)
conversationHistory.append(assistantMessage)
}
return ChatResponse(
summary: summary,
sql: confirmedSQL,
queryResult: result
)
}
/// Returns the current conversation history.
public var messages: [ChatMessage] {
lock.lock()
defer { lock.unlock() }
return conversationHistory
}
/// Eagerly introspects the database schema so it's ready before the first query.
///
/// Call this at view-appear time to pre-warm the schema cache. If the schema
/// is already cached, this returns immediately. The returned `DatabaseSchema`
/// can be used to display table/column info in the UI.
///
/// - Returns: The introspected `DatabaseSchema`.
@discardableResult
public func prepareSchema() async throws -> DatabaseSchema {
try await ensureSchema()
}
/// The number of tables discovered during schema introspection.
/// Returns `nil` if the schema has not been introspected yet.
public var tableCount: Int? {
lock.withLock { schema?.tableNames.count }
}
/// The cached schema, if introspection has completed.
public var cachedSchema: DatabaseSchema? {
lock.withLock { schema }
}
/// Clears the conversation history and cached schema.
///
/// After calling this, the next `send(_:)` call will re-introspect the
/// schema. Use ``clearHistory()`` if you only want to reset the conversation
/// while keeping the cached schema.
public func reset() {
lock.withLock {
conversationHistory.removeAll()
schema = nil
}
}
/// Clears only the conversation history, keeping the cached schema.
///
/// This is useful when you want to start a fresh conversation thread
/// without re-introspecting the database. The schema cache remains valid
/// as long as the database structure hasn't changed.
public func clearHistory() {
lock.withLock {
conversationHistory.removeAll()
}
}
/// The current engine configuration.
public var currentConfiguration: ChatEngineConfiguration {
configuration
}
// MARK: - Internal Helpers (visible for testing)
/// Ensures the database schema is introspected and cached.
func ensureSchema() async throws -> DatabaseSchema {
if let cached = lock.withLock({ schema }) {
return cached
}
let introspected = try await SchemaIntrospector.introspect(database: database)
lock.withLock { schema = introspected }
return introspected
}
/// Asks the LLM to generate SQL from a natural language question.
/// Returns the raw LLM response text (before parsing).
///
/// Uses the configured ``ChatEngineConfiguration/contextWindowSize`` to limit
/// how many conversation messages are included as context for the LLM.
private func generateRawResponse(
question: String,
promptBuilder: PromptBuilder
) async throws -> String {
let instructions = promptBuilder.buildSystemInstructions()
// Build user prompt include full conversation history for follow-ups
// Respect context window: only use recent messages for context
let userPrompt: String
let historySlice = lock.withLock { () -> [ChatMessage] in
Array(contextWindowSlice())
}
if historySlice.isEmpty {
userPrompt = promptBuilder.buildUserPrompt(question)
} else {
userPrompt = promptBuilder.buildConversationPrompt(
question,
history: historySlice
)
}
let session = LanguageModelSession(
model: model,
instructions: instructions + "\n\nCRITICAL: Respond with ONLY the raw SQL query. Do NOT wrap in markdown code fences or backticks. Do NOT include any explanation, comments, or formatting. The output must be directly executable SQL and nothing else."
)
let response = try await session.respond(to: userPrompt)
return response.content.trimmingCharacters(in: .whitespacesAndNewlines)
}
/// Returns the conversation history slice within the configured context window.
/// Must be called within a `lock.withLock` closure.
private func contextWindowSlice() -> ArraySlice<ChatMessage> {
guard let windowSize = configuration.contextWindowSize else {
return conversationHistory[...]
}
let count = conversationHistory.count
let start = max(0, count - windowSize)
return conversationHistory[start...]
}
/// Runs all custom validators from the configuration against the parsed SQL.
private func runCustomValidators(parsed: ParsedSQL) throws {
for validator in configuration.validators {
try validator.validate(sql: parsed.sql, operation: parsed.operation)
}
}
/// Extracts the target table name from a SQL statement for delegate context.
private func extractTargetTableForDelegate(from sql: String, operation: SQLOperation) -> String? {
let pattern: String
switch operation {
case .insert:
pattern = #"INSERT\s+INTO\s+[`"\[]?(\w+)[`"\]]?"#
case .update:
pattern = #"UPDATE\s+[`"\[]?(\w+)[`"\]]?"#
case .delete:
pattern = #"DELETE\s+FROM\s+[`"\[]?(\w+)[`"\]]?"#
case .select:
return nil
}
guard let regex = try? NSRegularExpression(pattern: pattern, options: .caseInsensitive) else {
return nil
}
let range = NSRange(sql.startIndex..., in: sql)
guard let match = regex.firstMatch(in: sql, range: range),
match.numberOfRanges > 1,
let groupRange = Range(match.range(at: 1), in: sql) else {
return nil
}
return String(sql[groupRange])
}
/// Executes SQL with the configured timeout, if any.
private func executeSQLWithTimeout(_ sql: String) async throws -> QueryResult {
guard let timeout = configuration.queryTimeout else {
return try await executeSQL(sql)
}
return try await withThrowingTaskGroup(of: QueryResult.self) { group in
group.addTask {
try await self.executeSQL(sql)
}
group.addTask {
try await Task.sleep(for: .seconds(timeout))
throw ChatEngineError.queryTimedOut(seconds: timeout)
}
// Return whichever finishes first
let result = try await group.next()!
group.cancelAll()
return result
}
}
/// Executes SQL against the database and returns a `QueryResult`.
private func executeSQL(_ sql: String) async throws -> QueryResult {
let trimmed = sql.trimmingCharacters(in: .whitespacesAndNewlines).uppercased()
let isSelect = trimmed.hasPrefix("SELECT") || trimmed.hasPrefix("WITH")
let startTime = CFAbsoluteTimeGetCurrent()
if isSelect {
let result = try await database.read { db -> (columns: [String], rows: [[String: QueryResult.Value]]) in
let statement = try db.makeStatement(sql: sql)
let columnNames = statement.columnNames
var rows: [[String: QueryResult.Value]] = []
let cursor = try Row.fetchCursor(statement)
while let row = try cursor.next() {
var dict: [String: QueryResult.Value] = [:]
for col in columnNames {
dict[col] = Self.extractValue(row: row, column: col)
}
rows.append(dict)
}
return (columns: columnNames, rows: rows)
}
let elapsed = CFAbsoluteTimeGetCurrent() - startTime
return QueryResult(
columns: result.columns,
rows: result.rows,
sql: sql,
executionTime: elapsed
)
} else {
// Mutation query
let affected = try await database.write { db -> Int in
try db.execute(sql: sql)
return db.changesCount
}
let elapsed = CFAbsoluteTimeGetCurrent() - startTime
return QueryResult(
columns: [],
rows: [],
sql: sql,
executionTime: elapsed,
rowsAffected: affected
)
}
}
/// Extracts a `QueryResult.Value` from a GRDB `Row` for the given column.
private static func extractValue(row: Row, column: String) -> QueryResult.Value {
let dbValue: DatabaseValue = row[column]
switch dbValue.storage {
case .null:
return .null
case .int64(let i):
return .integer(i)
case .double(let d):
return .real(d)
case .string(let s):
return .text(s)
case .blob(let data):
return .blob(data)
}
}
}
// MARK: - Errors
/// Errors that can occur during ChatEngine operations.
public enum ChatEngineError: Error, LocalizedError, Sendable {
/// SQL parsing/extraction from LLM response failed.
case sqlParsingFailed(SQLParsingError)
/// A destructive operation requires user confirmation before execution.
case confirmationRequired(sql: String, operation: SQLOperation)
/// Schema introspection failed.
case schemaIntrospectionFailed(String)
/// The SQL query exceeded the configured timeout.
case queryTimedOut(seconds: TimeInterval)
/// A custom query validator rejected the query.
case validationFailed(String)
public var errorDescription: String? {
switch self {
case .sqlParsingFailed(let parsingError):
return "SQL parsing failed: \(parsingError.description)"
case .confirmationRequired(let sql, let op):
return "The \(op.rawValue.uppercased()) operation requires confirmation: \(sql)"
case .schemaIntrospectionFailed(let reason):
return "Failed to introspect database schema: \(reason)"
case .queryTimedOut(let seconds):
return "Query timed out after \(Int(seconds)) seconds."
case .validationFailed(let reason):
return "Query validation failed: \(reason)"
}
}
}

View File

@@ -0,0 +1,288 @@
// ToolExecutionDelegate.swift
// SwiftDBAI
//
// Delegate protocol for controlling SQL tool execution, including
// confirmation of destructive operations before they reach the database.
import Foundation
// MARK: - Destructive SQL Classification
/// Classifies SQL statements by their destructive potential.
///
/// A statement is considered **destructive** if it modifies or removes data
/// or schema objects. The classification drives the confirmation flow:
/// destructive statements require explicit user approval via
/// ``ToolExecutionDelegate/confirmDestructiveOperation(_:)``.
public enum DestructiveClassification: Sendable, Equatable {
/// The statement is read-only (e.g. SELECT). No confirmation needed.
case safe
/// The statement modifies existing data (INSERT, UPDATE).
case mutation(SQLStatementKind)
/// The statement deletes data or alters/drops schema objects.
/// These always require confirmation, even when the operation is allowed.
case destructive(SQLStatementKind)
/// Returns `true` when the statement requires user confirmation.
public var requiresConfirmation: Bool {
switch self {
case .safe:
return false
case .mutation:
return false
case .destructive:
return true
}
}
/// Returns `true` when the statement modifies data or schema in any way.
public var isMutating: Bool {
switch self {
case .safe:
return false
case .mutation, .destructive:
return true
}
}
}
/// The kind of SQL statement, used for classification and display.
public enum SQLStatementKind: String, Sendable, Hashable, CaseIterable {
case select = "SELECT"
case insert = "INSERT"
case update = "UPDATE"
case delete = "DELETE"
case drop = "DROP"
case alter = "ALTER"
case truncate = "TRUNCATE"
/// All kinds that are classified as destructive.
public static let destructiveKinds: Set<SQLStatementKind> = [
.delete, .drop, .alter, .truncate
]
/// All kinds that are classified as mutations (data-modifying but not destructive).
public static let mutationKinds: Set<SQLStatementKind> = [
.insert, .update
]
/// Whether this kind of statement is destructive.
public var isDestructive: Bool {
Self.destructiveKinds.contains(self)
}
/// Whether this kind of statement is a mutation (INSERT/UPDATE).
public var isMutation: Bool {
Self.mutationKinds.contains(self)
}
}
// MARK: - Classification Function
/// Classifies a SQL statement string by its destructive potential.
///
/// The classifier inspects the first keyword token of the statement
/// (case-insensitive) to determine the statement kind, then maps it
/// to a ``DestructiveClassification``.
///
/// - Parameter sql: The SQL statement to classify.
/// - Returns: The classification for the statement.
public func classifySQL(_ sql: String) -> DestructiveClassification {
guard let kind = detectStatementKind(sql) else {
return .safe
}
if kind.isDestructive {
return .destructive(kind)
} else if kind.isMutation {
return .mutation(kind)
} else {
return .safe
}
}
/// Detects the ``SQLStatementKind`` from the leading keyword of a SQL string.
///
/// - Parameter sql: The SQL statement to inspect.
/// - Returns: The detected kind, or `nil` if unrecognized.
public func detectStatementKind(_ sql: String) -> SQLStatementKind? {
let trimmed = sql.trimmingCharacters(in: .whitespacesAndNewlines).uppercased()
// Check each known statement kind against the first token
if trimmed.hasPrefix("SELECT") || trimmed.hasPrefix("WITH") {
return .select
} else if trimmed.hasPrefix("INSERT") {
return .insert
} else if trimmed.hasPrefix("UPDATE") {
return .update
} else if trimmed.hasPrefix("DELETE") {
return .delete
} else if trimmed.hasPrefix("DROP") {
return .drop
} else if trimmed.hasPrefix("ALTER") {
return .alter
} else if trimmed.hasPrefix("TRUNCATE") {
return .truncate
}
return nil
}
// MARK: - Destructive Operation Context
/// Context provided to the delegate when a destructive operation needs confirmation.
///
/// Contains all the information a UI or programmatic handler needs to
/// decide whether to allow the operation.
public struct DestructiveOperationContext: Sendable {
/// The SQL statement that would be executed.
public let sql: String
/// The detected kind of statement (DELETE, DROP, ALTER, TRUNCATE).
public let statementKind: SQLStatementKind
/// The classification result.
public let classification: DestructiveClassification
/// A human-readable description of what the operation will do.
public let description: String
/// The target table name, if detected.
public let targetTable: String?
public init(
sql: String,
statementKind: SQLStatementKind,
classification: DestructiveClassification,
description: String,
targetTable: String? = nil
) {
self.sql = sql
self.statementKind = statementKind
self.classification = classification
self.description = description
self.targetTable = targetTable
}
}
// MARK: - ToolExecutionDelegate Protocol
/// A delegate that controls execution of SQL operations, providing
/// confirmation gates for destructive statements.
///
/// Implement this protocol to intercept destructive SQL operations
/// (DELETE, DROP, ALTER, TRUNCATE) before they are executed. The
/// ``ChatEngine`` consults the delegate whenever it encounters a
/// statement classified as ``DestructiveClassification/destructive(_:)``.
///
/// ## Example
///
/// ```swift
/// struct MyDelegate: ToolExecutionDelegate {
/// func confirmDestructiveOperation(
/// _ context: DestructiveOperationContext
/// ) async -> Bool {
/// // Show a confirmation dialog to the user
/// return await showAlert(
/// "Confirm \(context.statementKind.rawValue)",
/// message: context.description
/// )
/// }
/// }
///
/// let engine = ChatEngine(
/// database: pool,
/// model: model,
/// delegate: MyDelegate()
/// )
/// ```
public protocol ToolExecutionDelegate: Sendable {
/// Called when a destructive SQL operation is about to be executed.
///
/// The delegate should present the operation details to the user and
/// return `true` to proceed or `false` to cancel.
///
/// - Parameter context: Details about the destructive operation.
/// - Returns: `true` to allow execution, `false` to reject it.
func confirmDestructiveOperation(
_ context: DestructiveOperationContext
) async -> Bool
/// Called before any SQL statement is executed.
///
/// This is an observation hook the engine does not wait for a
/// decision. Override to log, audit, or instrument queries.
///
/// - Parameters:
/// - sql: The SQL about to be executed.
/// - classification: The destructive classification of the statement.
func willExecuteSQL(
_ sql: String,
classification: DestructiveClassification
) async
/// Called after a SQL statement completes execution.
///
/// - Parameters:
/// - sql: The SQL that was executed.
/// - success: Whether execution succeeded.
func didExecuteSQL(
_ sql: String,
success: Bool
) async
}
// MARK: - Default Implementations
extension ToolExecutionDelegate {
/// Default: rejects all destructive operations.
public func confirmDestructiveOperation(
_ context: DestructiveOperationContext
) async -> Bool {
false
}
/// Default: no-op.
public func willExecuteSQL(
_ sql: String,
classification: DestructiveClassification
) async {}
/// Default: no-op.
public func didExecuteSQL(
_ sql: String,
success: Bool
) async {}
}
// MARK: - Built-in Delegates
/// A delegate that automatically approves all destructive operations.
///
/// Use this only in testing or trusted environments where confirmation
/// is not needed.
public struct AutoApproveDelegate: ToolExecutionDelegate {
public init() {}
public func confirmDestructiveOperation(
_ context: DestructiveOperationContext
) async -> Bool {
true
}
}
/// A delegate that always rejects destructive operations.
///
/// This is the safest option and matches the default behavior.
public struct RejectAllDelegate: ToolExecutionDelegate {
public init() {}
public func confirmDestructiveOperation(
_ context: DestructiveOperationContext
) async -> Bool {
false
}
}

View File

@@ -0,0 +1,143 @@
// ConversationHistory.swift
// SwiftDBAI
//
// Ordered chat message history with configurable context window.
import Foundation
/// Stores an ordered sequence of ``ChatMessage`` instances with a configurable
/// context window limit.
///
/// When the number of messages exceeds ``maxMessages``, the oldest messages are
/// trimmed to keep the history within budget. This prevents unbounded token
/// growth when building LLM prompts from conversation history.
///
/// Usage:
/// ```swift
/// var history = ConversationHistory(maxMessages: 20)
/// history.append(ChatMessage(role: .user, content: "How many users?"))
/// history.append(ChatMessage(role: .assistant, content: "42", sql: "SELECT COUNT(*) FROM users"))
/// print(history.promptText) // formatted for LLM context
/// ```
public struct ConversationHistory: Sendable {
/// The maximum number of messages to retain. `nil` means unlimited.
public let maxMessages: Int?
/// All messages in chronological order.
public private(set) var messages: [ChatMessage] = []
/// Creates a new conversation history.
///
/// - Parameter maxMessages: Maximum number of messages to keep in the
/// context window. Pass `nil` for unlimited history. Defaults to 50.
public init(maxMessages: Int? = 50) {
precondition(maxMessages == nil || maxMessages! > 0,
"maxMessages must be positive or nil")
self.maxMessages = maxMessages
}
/// The number of messages currently stored.
public var count: Int { messages.count }
/// Whether the history is empty.
public var isEmpty: Bool { messages.isEmpty }
// MARK: - Mutating Operations
/// Appends a message and trims the history if it exceeds the context window.
public mutating func append(_ message: ChatMessage) {
messages.append(message)
trimIfNeeded()
}
/// Appends multiple messages and trims once afterward.
public mutating func append(contentsOf newMessages: [ChatMessage]) {
messages.append(contentsOf: newMessages)
trimIfNeeded()
}
/// Removes all messages from the history.
public mutating func clear() {
messages.removeAll()
}
// MARK: - Context Window
/// Returns the most recent messages formatted for inclusion in an LLM prompt.
///
/// Each message is formatted as `[role] content`, with SQL and query results
/// included inline for assistant messages.
///
/// - Parameter limit: Optional override to further restrict the number of
/// messages returned. When `nil`, uses the full retained history.
/// - Returns: An array of prompt-formatted strings, one per message.
public func promptMessages(limit: Int? = nil) -> [String] {
let slice: ArraySlice<ChatMessage>
if let limit {
slice = messages.suffix(limit)
} else {
slice = messages[...]
}
return slice.map { message in
Self.formatForPrompt(message)
}
}
/// Returns the combined prompt text for all retained messages, separated by
/// double newlines.
public var promptText: String {
promptMessages().joined(separator: "\n\n")
}
// MARK: - Queries
/// Returns only user messages.
public var userMessages: [ChatMessage] {
messages.filter { $0.role == .user }
}
/// Returns only assistant messages.
public var assistantMessages: [ChatMessage] {
messages.filter { $0.role == .assistant }
}
/// Returns the last message, if any.
public var lastMessage: ChatMessage? {
messages.last
}
/// Returns the most recent user query text, if any.
public var lastUserQuery: String? {
messages.last(where: { $0.role == .user })?.content
}
/// Returns the most recent assistant message, if any.
public var lastAssistantMessage: ChatMessage? {
messages.last(where: { $0.role == .assistant })
}
// MARK: - Private
/// Formats a ``ChatMessage`` into a string suitable for LLM prompt context.
private static func formatForPrompt(_ message: ChatMessage) -> String {
var parts: [String] = ["[\(message.role.rawValue)] \(message.content)"]
if let sql = message.sql {
parts.append("SQL: \(sql)")
}
if let result = message.queryResult {
parts.append("Result:\n\(result.tabularDescription)")
}
return parts.joined(separator: "\n")
}
/// Trims the oldest messages to stay within the context window.
private mutating func trimIfNeeded() {
guard let max = maxMessages, messages.count > max else { return }
let overflow = messages.count - max
messages.removeFirst(overflow)
}
}

View File

@@ -0,0 +1,136 @@
// QueryResult.swift
// SwiftDBAI
//
// Structured result from SQL query execution.
import Foundation
/// Represents the result of executing a SQL query against the database.
///
/// Contains raw row data as dictionaries, column metadata, row count,
/// the original SQL string, and execution timing.
public struct QueryResult: Sendable, Equatable {
/// A single cell value from a query result.
///
/// Wraps SQLite's dynamic value types into a type-safe, Sendable enum.
public enum Value: Sendable, Equatable, CustomStringConvertible {
case text(String)
case integer(Int64)
case real(Double)
case blob(Data)
case null
public var description: String {
switch self {
case .text(let s): return s
case .integer(let i): return String(i)
case .real(let d):
if d == d.rounded() && abs(d) < 1e15 {
return String(format: "%.0f", d)
}
return String(d)
case .blob(let data): return "<\(data.count) bytes>"
case .null: return "NULL"
}
}
/// Returns the value as a `Double` if it is numeric, nil otherwise.
public var doubleValue: Double? {
switch self {
case .integer(let i): return Double(i)
case .real(let d): return d
case .text(let s): return Double(s)
default: return nil
}
}
/// Returns the value as a `String` (non-nil for all cases).
public var stringValue: String { description }
/// Returns `true` if this value is `.null`.
public var isNull: Bool {
if case .null = self { return true }
return false
}
}
/// Column names in the order they appear in the result set.
public let columns: [String]
/// Row data as an array of dictionaries mapping column name to value.
public let rows: [[String: Value]]
/// Total number of rows returned.
public var rowCount: Int { rows.count }
/// The SQL statement that was executed.
public let sql: String
/// Time taken to execute the query, in seconds.
public let executionTime: TimeInterval
/// Number of rows affected (for INSERT/UPDATE/DELETE). Nil for SELECT.
public let rowsAffected: Int?
public init(
columns: [String],
rows: [[String: Value]],
sql: String,
executionTime: TimeInterval,
rowsAffected: Int? = nil
) {
self.columns = columns
self.rows = rows
self.sql = sql
self.executionTime = executionTime
self.rowsAffected = rowsAffected
}
// MARK: - Convenience Accessors
/// Returns all values for a given column, in row order.
public func values(forColumn column: String) -> [Value] {
rows.compactMap { $0[column] }
}
/// Returns a compact tabular string representation of the results.
///
/// Useful for embedding query results into LLM prompts.
public var tabularDescription: String {
guard !rows.isEmpty else {
return "(empty result set)"
}
var lines: [String] = []
// Header
lines.append(columns.joined(separator: " | "))
lines.append(String(repeating: "-", count: lines[0].count))
// Rows (cap at 50 for prompt size)
let displayRows = rows.prefix(50)
for row in displayRows {
let vals = columns.map { col in
row[col]?.description ?? "NULL"
}
lines.append(vals.joined(separator: " | "))
}
if rows.count > 50 {
lines.append("... and \(rows.count - 50) more rows")
}
return lines.joined(separator: "\n")
}
/// Returns true if the result looks like a single aggregate value
/// (1 row, 1-3 columns, all numeric).
public var isAggregate: Bool {
guard rowCount == 1, columns.count <= 3 else { return false }
let firstRow = rows[0]
return columns.allSatisfy { col in
firstRow[col]?.doubleValue != nil
}
}
}

View File

@@ -0,0 +1,423 @@
// SQLQueryParser.swift
// SwiftDBAI
//
// Extracts and validates SQL statements from raw LLM response text.
import Foundation
/// Errors that can occur during SQL parsing and validation.
public enum SQLParsingError: Error, Sendable, Equatable, CustomStringConvertible {
/// No SQL statement could be found in the LLM response.
case noSQLFound
/// The SQL statement uses an operation not in the allowlist.
case operationNotAllowed(SQLOperation)
/// A destructive operation (DELETE) requires user confirmation.
case confirmationRequired(sql: String, operation: SQLOperation)
/// The mutation targets a table not in the allowed mutation tables.
case tableNotAllowed(table: String, operation: SQLOperation)
/// The SQL contains a disallowed keyword (e.g., DROP, ALTER, TRUNCATE).
case dangerousOperation(String)
/// Multiple SQL statements were found but only single-statement execution is supported.
case multipleStatements
public var description: String {
switch self {
case .noSQLFound:
return "No SQL statement found in the response."
case .operationNotAllowed(let op):
return "Operation '\(op.rawValue.uppercased())' is not allowed by the current configuration."
case .confirmationRequired(let sql, let op):
return "The \(op.rawValue.uppercased()) operation requires confirmation: \(sql)"
case .tableNotAllowed(let table, let op):
return "The \(op.rawValue.uppercased()) operation is not allowed on table '\(table)'."
case .dangerousOperation(let keyword):
return "Dangerous SQL operation '\(keyword)' is never allowed."
case .multipleStatements:
return "Only single SQL statements are supported."
}
}
}
/// Result of successfully parsing SQL from an LLM response.
public struct ParsedSQL: Sendable, Equatable {
/// The cleaned SQL statement ready for execution.
public let sql: String
/// The detected operation type.
public let operation: SQLOperation
/// Whether this operation requires user confirmation before execution.
public let requiresConfirmation: Bool
public init(sql: String, operation: SQLOperation, requiresConfirmation: Bool = false) {
self.sql = sql
self.operation = operation
self.requiresConfirmation = requiresConfirmation
}
}
/// Extracts SQL statements from raw LLM response text and validates them
/// against the configured ``OperationAllowlist``.
///
/// The parser handles common LLM output patterns:
/// - SQL in markdown code blocks (```sql ... ```)
/// - SQL in generic code blocks (``` ... ```)
/// - Raw SQL statements in plain text
/// - SQL prefixed with labels like "SQL:" or "Query:"
public struct SQLQueryParser: Sendable {
/// Keywords that are never allowed regardless of allowlist configuration.
private static let dangerousKeywords: Set<String> = [
"DROP", "ALTER", "TRUNCATE", "CREATE", "GRANT", "REVOKE",
"ATTACH", "DETACH", "PRAGMA", "VACUUM", "REINDEX"
]
/// The operation allowlist to validate against.
private let allowlist: OperationAllowlist
/// The mutation policy for table-level restrictions.
private let mutationPolicy: MutationPolicy?
/// Creates a parser with the given operation allowlist.
/// - Parameter allowlist: The set of permitted operations. Defaults to read-only.
public init(allowlist: OperationAllowlist = .readOnly) {
self.allowlist = allowlist
self.mutationPolicy = nil
}
/// Creates a parser with a mutation policy (preferred initializer).
/// - Parameter mutationPolicy: The mutation policy controlling operations and table access.
public init(mutationPolicy: MutationPolicy) {
self.allowlist = mutationPolicy.operationAllowlist
self.mutationPolicy = mutationPolicy
}
/// Extracts and validates a SQL statement from raw LLM response text.
///
/// - Parameter text: The raw text from the LLM response.
/// - Returns: A ``ParsedSQL`` containing the validated statement.
/// - Throws: ``SQLParsingError`` if extraction or validation fails.
public func parse(_ text: String) throws -> ParsedSQL {
let sql = try extractSQL(from: text)
return try validate(sql)
}
// MARK: - Extraction
/// Attempts to extract a SQL statement from the LLM response text.
/// Tries multiple strategies in order of confidence.
func extractSQL(from text: String) throws -> String {
// Pre-processing: strip <think>...</think> tags (Qwen-style models)
let preprocessed = stripThinkTags(text)
// Strategy 1: SQL in markdown fenced code block with sql language tag
if let sql = extractFromSQLCodeBlock(preprocessed) {
return sql
}
// Strategy 2: SQL in generic fenced code block
if let sql = extractFromGenericCodeBlock(preprocessed) {
return sql
}
// Strategy 3: SQL after a label like "SQL:" or "Query:"
if let sql = extractFromLabel(preprocessed) {
return sql
}
// Strategy 4: Direct SQL detection in plain text (includes WITH)
if let sql = extractDirectSQL(preprocessed) {
return sql
}
// Strategy 5: Strip markdown fence markers (3+ backticks with optional
// language tag) and retry. Only removes fences, not single backticks
// used for SQLite identifier quoting like `column name`.
let defenced = stripMarkdownFences(preprocessed)
if defenced != preprocessed, let sql = extractDirectSQL(defenced) {
return sql
}
throw SQLParsingError.noSQLFound
}
/// Strips `<think>...</think>` tags produced by Qwen-style reasoning models.
private func stripThinkTags(_ text: String) -> String {
text.replacingOccurrences(
of: #"<think>[\s\S]*?</think>"#,
with: "",
options: .regularExpression
).trimmingCharacters(in: .whitespacesAndNewlines)
}
/// Removes markdown fence markers (3+ backticks with optional language tag)
/// while preserving single backtick identifier quoting like `column name`.
private func stripMarkdownFences(_ text: String) -> String {
text.replacingOccurrences(
of: #"`{3,}\s*(?:sql|SQL)?\s*"#,
with: " ",
options: .regularExpression
).trimmingCharacters(in: .whitespacesAndNewlines)
}
/// Extracts SQL from a ```sql ... ``` code block.
/// Handles 3+ backticks, optional newline before closing fence,
/// and single-line code blocks like ```sql SELECT ... ```.
private func extractFromSQLCodeBlock(_ text: String) -> String? {
// Match 3+ backticks with sql tag, content, then 3+ closing backticks
let pattern = #"`{3,}sql\s*\n?([\s\S]*?)`{3,}"#
return firstMatch(pattern: pattern, in: text, group: 1, options: .caseInsensitive)?
.trimmingCharacters(in: .whitespacesAndNewlines)
.nonEmptyOrNil
}
/// Extracts SQL from a generic ``` ... ``` code block (no language tag).
/// Handles 3+ backticks and flexible whitespace.
private func extractFromGenericCodeBlock(_ text: String) -> String? {
let pattern = #"`{3,}\s*\n([\s\S]*?)`{3,}"#
guard let content = firstMatch(pattern: pattern, in: text, group: 1)?
.trimmingCharacters(in: .whitespacesAndNewlines) else {
return nil
}
// Only accept if it looks like SQL
guard looksLikeSQL(content) else { return nil }
return content.nonEmptyOrNil
}
/// Extracts SQL after labels like "SQL:", "Query:", "Here's the query:", "The SQL query is:"
private func extractFromLabel(_ text: String) -> String? {
// Match common label patterns followed by a SQL statement.
// The SQL ends at a double newline, a single newline followed by non-SQL text, or end-of-string.
let pattern = #"(?:SQL|Query|Statement|query is|SQL query is)\s*:\s*\n?\s*((?:SELECT|INSERT|UPDATE|DELETE|WITH)\b(?:[^;'\n]|'[^']*'|\n(?=\s*(?:SELECT|INSERT|UPDATE|DELETE|WITH|FROM|WHERE|JOIN|INNER|LEFT|RIGHT|OUTER|CROSS|ON|AND|OR|ORDER|GROUP|HAVING|LIMIT|OFFSET|UNION|EXCEPT|INTERSECT|AS|SET|INTO|VALUES)\b)|\n(?=\s))*;?)"#
guard let content = firstMatch(pattern: pattern, in: text, group: 1, options: [.caseInsensitive])?
.trimmingCharacters(in: .whitespacesAndNewlines) else {
return nil
}
guard looksLikeSQL(content) else { return nil }
return content.nonEmptyOrNil
}
/// Detects SQL directly in the text by matching known statement patterns.
/// Handles SELECT, INSERT, UPDATE, DELETE, and WITH (CTE) statements.
private func extractDirectSQL(_ text: String) -> String? {
// Match SQL statement starting with a keyword, allowing semicolons inside string literals.
// The WITH clause is included to support CTE queries.
let pattern = #"(?:^|\n)\s*((?:SELECT|INSERT|UPDATE|DELETE|WITH)\b(?:[^;']|'[^']*')*;?)"#
guard var content = firstMatch(pattern: pattern, in: text, group: 1, options: .caseInsensitive)?
.trimmingCharacters(in: .whitespacesAndNewlines) else {
return nil
}
// Strip any trailing markdown fence markers that got captured
content = content.replacingOccurrences(
of: #"\s*`{3,}\s*$"#,
with: "",
options: .regularExpression
).trimmingCharacters(in: .whitespacesAndNewlines)
return content.nonEmptyOrNil
}
// MARK: - Validation
/// Validates a SQL string against the allowlist and safety rules.
func validate(_ sql: String) throws -> ParsedSQL {
let cleaned = cleanSQL(sql)
guard !cleaned.isEmpty else {
throw SQLParsingError.noSQLFound
}
// Check for multiple statements (semicolons in non-trivial positions)
if containsMultipleStatements(cleaned) {
throw SQLParsingError.multipleStatements
}
// Check for dangerous operations first (before allowlist)
try checkDangerousKeywords(cleaned)
// Detect the operation type
let operation = detectOperation(cleaned)
// Check against the allowlist
guard allowlist.isAllowed(operation) else {
throw SQLParsingError.operationNotAllowed(operation)
}
// Check table-level restrictions for mutation operations
if let policy = mutationPolicy, operation != .select,
let targetTable = extractTargetTable(from: cleaned, operation: operation) {
guard policy.isAllowed(operation: operation, on: targetTable) else {
throw SQLParsingError.tableNotAllowed(table: targetTable, operation: operation)
}
}
// DELETE requires confirmation when policy says so, or always by default
let requiresConfirmation: Bool
if let policy = mutationPolicy {
requiresConfirmation = policy.requiresConfirmation(for: operation)
} else {
requiresConfirmation = operation == .delete
}
return ParsedSQL(
sql: cleaned,
operation: operation,
requiresConfirmation: requiresConfirmation
)
}
// MARK: - Helpers
/// Cleans a SQL string by removing trailing semicolons (outside string literals) and excess whitespace.
private func cleanSQL(_ sql: String) -> String {
var cleaned = sql.trimmingCharacters(in: .whitespacesAndNewlines)
// Remove trailing semicolons only if they're outside string literals
while cleaned.hasSuffix(";") && !isInsideStringLiteral(sql: cleaned, position: cleaned.index(before: cleaned.endIndex)) {
cleaned = String(cleaned.dropLast()).trimmingCharacters(in: .whitespacesAndNewlines)
}
// Collapse internal whitespace outside string literals
cleaned = collapseWhitespace(cleaned)
return cleaned
}
/// Collapses whitespace while preserving string literal contents.
private func collapseWhitespace(_ sql: String) -> String {
var result = ""
var inString = false
var prevWasSpace = false
for ch in sql {
if ch == "'" {
inString.toggle()
prevWasSpace = false
result.append(ch)
} else if inString {
result.append(ch)
} else if ch.isWhitespace {
if !prevWasSpace {
result.append(" ")
prevWasSpace = true
}
} else {
prevWasSpace = false
result.append(ch)
}
}
return result
}
/// Returns true if the character at the given position is inside a single-quoted string literal.
private func isInsideStringLiteral(sql: String, position: String.Index) -> Bool {
var inString = false
for idx in sql.indices {
if idx == position { return inString }
if sql[idx] == "'" { inString.toggle() }
}
return false
}
/// Checks whether cleaned SQL contains multiple statements.
private func containsMultipleStatements(_ sql: String) -> Bool {
// Remove string literals before checking for semicolons
var inString = false
for ch in sql {
if ch == "'" {
inString.toggle()
} else if ch == ";" && !inString {
return true
}
}
return false
}
/// Checks for dangerous SQL keywords that are never allowed.
private func checkDangerousKeywords(_ sql: String) throws {
let upper = sql.uppercased()
// Tokenize to avoid partial matches (e.g., "DROPDOWN" matching "DROP")
let tokens = upper.components(separatedBy: .alphanumerics.inverted)
.filter { !$0.isEmpty }
for keyword in Self.dangerousKeywords {
if tokens.contains(keyword) {
throw SQLParsingError.dangerousOperation(keyword)
}
}
}
/// Detects the SQL operation type from the first keyword.
private func detectOperation(_ sql: String) -> SQLOperation {
let upper = sql.uppercased().trimmingCharacters(in: .whitespaces)
if upper.hasPrefix("SELECT") || upper.hasPrefix("WITH") {
return .select
} else if upper.hasPrefix("INSERT") {
return .insert
} else if upper.hasPrefix("UPDATE") {
return .update
} else if upper.hasPrefix("DELETE") {
return .delete
}
// Default to select for unrecognized patterns (e.g. EXPLAIN)
return .select
}
/// Extracts the target table name from a mutation SQL statement.
///
/// Handles common patterns:
/// - `INSERT INTO table_name ...`
/// - `UPDATE table_name SET ...`
/// - `DELETE FROM table_name ...`
private func extractTargetTable(from sql: String, operation: SQLOperation) -> String? {
let pattern: String
switch operation {
case .insert:
pattern = #"INSERT\s+INTO\s+[`"\[]?(\w+)[`"\]]?"#
case .update:
pattern = #"UPDATE\s+[`"\[]?(\w+)[`"\]]?"#
case .delete:
pattern = #"DELETE\s+FROM\s+[`"\[]?(\w+)[`"\]]?"#
case .select:
return nil
}
return firstMatch(pattern: pattern, in: sql, group: 1, options: .caseInsensitive)
}
/// Returns true if the text looks like a SQL statement.
private func looksLikeSQL(_ text: String) -> Bool {
let upper = text.uppercased().trimmingCharacters(in: .whitespaces)
let sqlPrefixes = ["SELECT", "INSERT", "UPDATE", "DELETE", "WITH"]
return sqlPrefixes.contains { upper.hasPrefix($0) }
}
/// Extracts the first regex match group from the text.
private func firstMatch(
pattern: String,
in text: String,
group: Int,
options: NSRegularExpression.Options = []
) -> String? {
guard let regex = try? NSRegularExpression(pattern: pattern, options: options) else {
return nil
}
let range = NSRange(text.startIndex..., in: text)
guard let match = regex.firstMatch(in: text, range: range),
match.numberOfRanges > group,
let groupRange = Range(match.range(at: group), in: text) else {
return nil
}
return String(text[groupRange])
}
}
// MARK: - String Extension
private extension String {
/// Returns nil if the string is empty, otherwise returns self.
var nonEmptyOrNil: String? {
isEmpty ? nil : self
}
}

View File

@@ -0,0 +1,238 @@
/// Builds structured LLM prompts for SQL generation from a database schema
/// and natural language input.
///
/// `PromptBuilder` is the bridge between the introspected database schema and
/// the LLM. It produces two things:
/// 1. A **system instructions** string containing schema context and behavioral rules
/// 2. A **user prompt** string wrapping the natural language question
///
/// Usage:
/// ```swift
/// let builder = PromptBuilder(schema: mySchema, allowlist: .readOnly)
/// let instructions = builder.buildSystemInstructions()
/// let prompt = builder.buildUserPrompt("How many users signed up this week?")
/// ```
public struct PromptBuilder: Sendable {
/// The database schema to include as context.
public let schema: DatabaseSchema
/// Which SQL operations the LLM may generate.
public let allowlist: OperationAllowlist
/// Optional additional context to append to the system instructions
/// (e.g., business-specific terminology or query hints).
public let additionalContext: String?
/// Creates a prompt builder for the given schema and allowlist.
///
/// - Parameters:
/// - schema: The introspected database schema.
/// - allowlist: Permitted SQL operations. Defaults to ``OperationAllowlist/readOnly``.
/// - additionalContext: Extra instructions appended to the system prompt.
public init(
schema: DatabaseSchema,
allowlist: OperationAllowlist = .readOnly,
additionalContext: String? = nil
) {
self.schema = schema
self.allowlist = allowlist
self.additionalContext = additionalContext
}
// MARK: - System Instructions
/// Builds the system instructions string that should be passed as the
/// `instructions` parameter when creating a `LanguageModelSession`.
///
/// The instructions include:
/// - Role definition
/// - The full database schema
/// - SQL generation rules and constraints
/// - The operation allowlist
/// - Output format requirements
public func buildSystemInstructions() -> String {
var sections: [String] = []
// 1. Role
sections.append(Self.roleSection)
// 2. Schema
sections.append(buildSchemaSection())
// 3. Operation permissions
sections.append(buildPermissionsSection())
// 4. SQL generation rules
sections.append(Self.sqlRulesSection)
// 5. Output format
sections.append(Self.outputFormatSection)
// 6. Additional context
if let additionalContext, !additionalContext.isEmpty {
sections.append("ADDITIONAL CONTEXT\n=================\n\(additionalContext)")
}
return sections.joined(separator: "\n\n")
}
// MARK: - User Prompt
/// Wraps a natural language question into a user prompt string.
///
/// - Parameter question: The user's natural language question.
/// - Returns: A formatted prompt string for the LLM.
public func buildUserPrompt(_ question: String) -> String {
question
}
/// Builds a follow-up prompt that includes prior SQL context for
/// multi-turn conversations.
///
/// - Parameters:
/// - question: The user's follow-up question.
/// - previousSQL: The SQL from the previous turn, for context.
/// - previousResultSummary: A brief summary of what the previous query returned.
/// - Returns: A formatted prompt string.
public func buildFollowUpPrompt(
_ question: String,
previousSQL: String,
previousResultSummary: String
) -> String {
"""
Previous query: \(previousSQL)
Previous result: \(previousResultSummary)
Follow-up question: \(question)
"""
}
/// Builds a prompt that includes the full conversation history within the
/// configured context window, enabling the LLM to resolve follow-up
/// references (pronouns, implicit table/column references, etc.).
///
/// - Parameters:
/// - question: The user's current question.
/// - history: The conversation history messages within the context window.
/// - Returns: A formatted prompt string with conversation context.
public func buildConversationPrompt(
_ question: String,
history: [ChatMessage]
) -> String {
guard !history.isEmpty else {
return buildUserPrompt(question)
}
var lines: [String] = []
lines.append("CONVERSATION HISTORY")
lines.append("====================")
for message in history {
switch message.role {
case .user:
lines.append("User: \(message.content)")
case .assistant:
if let sql = message.sql {
lines.append("Assistant SQL: \(sql)")
}
lines.append("Assistant: \(message.content)")
case .error:
lines.append("Error: \(message.content)")
}
}
lines.append("")
lines.append("CURRENT QUESTION")
lines.append("================")
lines.append(question)
return lines.joined(separator: "\n")
}
// MARK: - Private Sections
private func buildSchemaSection() -> String {
var lines: [String] = []
lines.append("DATABASE SCHEMA")
lines.append("===============")
lines.append("")
lines.append(schema.schemaDescription)
return lines.joined(separator: "\n")
}
private func buildPermissionsSection() -> String {
var lines: [String] = []
lines.append("PERMISSIONS")
lines.append("===========")
lines.append(allowlist.describeForLLM())
return lines.joined(separator: "\n")
}
// MARK: - Static Content
static let roleSection = """
ROLE
====
You are a SQL assistant for a SQLite database. Your job is to translate \
natural language questions into valid SQLite SQL queries based on the \
database schema provided below.
IMPORTANT:
- ONLY reference tables and columns that exist in the schema. Never \
fabricate table or column names.
- Interpret questions by INTENT, not literally. If a user asks \
"articles starting with the", they mean articles whose title begins \
with the word "the", NOT articles containing that exact phrase.
- If the schema does not have a column for what the user asks about, \
use the closest available column or return a query that explains \
what data is available.
"""
static let sqlRulesSection = """
SQL GENERATION RULES
====================
1. Use ONLY the tables and columns listed in the schema above.
2. Use SQLite-compatible syntax (e.g., || for string concatenation, \
IFNULL instead of COALESCE where needed).
3. Use appropriate JOINs when queries span multiple tables — reference \
the foreign key relationships in the schema.
4. For date/time operations, use SQLite date functions \
(date(), time(), datetime(), strftime()).
5. Use parameterized-style values where possible. For literal values \
from the user's question, embed them directly in the SQL.
6. Always include an ORDER BY clause when the user implies ordering.
7. Use LIMIT when the user asks for "top N" or "first N" results. \
Default to LIMIT 20 when no limit is specified and the result could be large.
8. For aggregate queries (count, sum, average, min, max), use the \
appropriate SQL aggregate functions.
9. When the user's question is ambiguous, prefer the simplest valid \
interpretation that returns useful results.
10. Never generate DDL statements (CREATE, ALTER, DROP TABLE).
11. If a column the user asks about does not exist, use the closest \
available column. Do NOT reference columns not in the schema.
EXAMPLES
--------
User: "articles starting with the"
SQL: SELECT title, url FROM articles WHERE title LIKE 'The %' ORDER BY datePublished DESC LIMIT 20
User: "most popular items"
SQL: SELECT name, COUNT(*) AS count FROM items GROUP BY name ORDER BY count DESC LIMIT 10
User: "anything from last week"
SQL: SELECT * FROM articles WHERE datePublished >= date('now', '-7 days') ORDER BY datePublished DESC
User: "how many per category"
SQL: SELECT category, COUNT(*) AS count FROM items GROUP BY category ORDER BY count DESC
"""
static let outputFormatSection = """
OUTPUT FORMAT
=============
Output ONLY the raw SQL query. \
Do NOT wrap the SQL in markdown code fences or backticks. \
Do NOT include any explanation, comments, or formatting before or after the SQL. \
Do NOT prefix with labels like "SQL:" or "Query:". \
The output should be directly executable SQL — nothing else.
"""
}

View File

@@ -0,0 +1,423 @@
// ChartDataDetector.swift
// SwiftDBAI
//
// Analyzes query results to determine chart eligibility and
// recommends appropriate chart types based on data shape.
import Foundation
/// Detects whether a `DataTable` is suitable for charting and
/// recommends the best chart type based on data shape heuristics.
///
/// The detector examines column types, row counts, and value distributions
/// to produce a `ChartRecommendation` that the rendering layer can use
/// to auto-select an appropriate Swift Charts visualization.
///
/// Usage:
/// ```swift
/// let detector = ChartDataDetector()
/// if let recommendation = detector.detect(table) {
/// switch recommendation.chartType {
/// case .bar: // render bar chart
/// case .line: // render line chart
/// case .pie: // render pie chart
/// }
/// }
/// ```
public struct ChartDataDetector: Sendable {
// MARK: - Chart Types
/// The type of chart recommended for the data.
public enum ChartType: String, Sendable, Equatable, CaseIterable {
/// Vertical bar chart best for categorical comparisons.
case bar
/// Line chart best for time series or ordered sequences.
case line
/// Pie/donut chart best for proportional breakdowns with few categories.
case pie
}
/// A recommendation for how to chart a `DataTable`.
public struct ChartRecommendation: Sendable, Equatable {
/// The recommended chart type.
public let chartType: ChartType
/// The column to use for the category axis (x-axis / labels).
public let categoryColumn: String
/// The column to use for the value axis (y-axis / sizes).
public let valueColumn: String
/// Confidence score from 0.0 (guess) to 1.0 (strong match).
public let confidence: Double
/// Human-readable reason for this recommendation.
public let reason: String
public init(
chartType: ChartType,
categoryColumn: String,
valueColumn: String,
confidence: Double,
reason: String
) {
self.chartType = chartType
self.categoryColumn = categoryColumn
self.valueColumn = valueColumn
self.confidence = confidence
self.reason = reason
}
}
// MARK: - Configuration
/// Minimum rows required to consider chart-eligible.
public let minimumRows: Int
/// Maximum rows for a pie chart (too many slices becomes unreadable).
public let maxPieSlices: Int
/// Maximum rows for any chart before it becomes cluttered.
public let maximumRows: Int
// MARK: - Initialization
/// Creates a detector with configurable thresholds.
///
/// - Parameters:
/// - minimumRows: Minimum rows for chart eligibility (default: 2).
/// - maxPieSlices: Maximum categories for pie charts (default: 8).
/// - maximumRows: Maximum rows for any chart (default: 100).
public init(
minimumRows: Int = 2,
maxPieSlices: Int = 8,
maximumRows: Int = 100
) {
self.minimumRows = minimumRows
self.maxPieSlices = maxPieSlices
self.maximumRows = maximumRows
}
// MARK: - Detection
/// Analyzes a `DataTable` and returns a chart recommendation, or `nil`
/// if the data is not suitable for charting.
///
/// - Parameter table: The data table to analyze.
/// - Returns: A recommendation, or `nil` if no chart type fits.
public func detect(_ table: DataTable) -> ChartRecommendation? {
// Must have at least 2 columns (category + value) and enough rows
guard table.columnCount >= 2,
table.rowCount >= minimumRows,
table.rowCount <= maximumRows else {
return nil
}
// Find candidate category and value columns
guard let (categoryCol, valueCol) = findCategoryValuePair(in: table) else {
return nil
}
let chartType = recommendChartType(
table: table,
categoryColumn: categoryCol,
valueColumn: valueCol
)
let confidence = computeConfidence(
table: table,
categoryColumn: categoryCol,
valueColumn: valueCol,
chartType: chartType
)
let reason = describeReason(
chartType: chartType,
categoryColumn: categoryCol,
valueColumn: valueCol,
table: table
)
return ChartRecommendation(
chartType: chartType,
categoryColumn: categoryCol.name,
valueColumn: valueCol.name,
confidence: confidence,
reason: reason
)
}
/// Returns all viable chart recommendations, ranked by confidence.
///
/// - Parameter table: The data table to analyze.
/// - Returns: An array of recommendations sorted by confidence (highest first).
public func allRecommendations(for table: DataTable) -> [ChartRecommendation] {
guard table.columnCount >= 2,
table.rowCount >= minimumRows,
table.rowCount <= maximumRows else {
return []
}
guard let (categoryCol, valueCol) = findCategoryValuePair(in: table) else {
return []
}
return ChartType.allCases.compactMap { chartType in
guard isViable(chartType, table: table, categoryColumn: categoryCol) else {
return nil
}
let confidence = computeConfidence(
table: table,
categoryColumn: categoryCol,
valueColumn: valueCol,
chartType: chartType
)
let reason = describeReason(
chartType: chartType,
categoryColumn: categoryCol,
valueColumn: valueCol,
table: table
)
return ChartRecommendation(
chartType: chartType,
categoryColumn: categoryCol.name,
valueColumn: valueCol.name,
confidence: confidence,
reason: reason
)
}
.sorted { $0.confidence > $1.confidence }
}
// MARK: - Private Helpers
/// Finds the best (category, value) column pair from the table.
private func findCategoryValuePair(
in table: DataTable
) -> (category: DataTable.Column, value: DataTable.Column)? {
let numericColumns = table.columns.filter { isNumeric($0) }
let categoryColumns = table.columns.filter { isCategory($0) }
// Prefer: first text/category column + first numeric column
if let cat = categoryColumns.first, let val = numericColumns.first {
return (cat, val)
}
// Fallback: if all columns are numeric, use first as category, second as value
if numericColumns.count >= 2 {
return (numericColumns[0], numericColumns[1])
}
return nil
}
/// Recommends the single best chart type for the data shape.
private func recommendChartType(
table: DataTable,
categoryColumn: DataTable.Column,
valueColumn: DataTable.Column
) -> ChartType {
// Line: time series or sequential numeric categories (check first strongest signal)
if isTimeSeries(categoryColumn, in: table) || isSequential(categoryColumn, in: table) {
return .line
}
// Pie: small number of categories with all-positive values
// Only when clearly categorical (text labels) and few rows
if table.rowCount <= maxPieSlices,
isCategory(categoryColumn),
isPieCandidate(table: table, valueColumn: valueColumn),
looksProportional(table: table, valueColumn: valueColumn) {
return .pie
}
// Default: bar chart for categorical comparisons
return .bar
}
/// Checks if a chart type is viable for the given data.
private func isViable(
_ chartType: ChartType,
table: DataTable,
categoryColumn: DataTable.Column
) -> Bool {
switch chartType {
case .pie:
return table.rowCount <= maxPieSlices
case .line:
return table.rowCount >= minimumRows
case .bar:
return true
}
}
/// Determines if a column holds numeric data.
private func isNumeric(_ column: DataTable.Column) -> Bool {
switch column.inferredType {
case .integer, .real:
return true
default:
return false
}
}
/// Determines if a column holds categorical (label) data.
private func isCategory(_ column: DataTable.Column) -> Bool {
switch column.inferredType {
case .text, .mixed:
return true
default:
return false
}
}
/// Checks if the value column contains all non-negative values,
/// making it a candidate for pie charts.
private func isPieCandidate(
table: DataTable,
valueColumn: DataTable.Column
) -> Bool {
let values = table.numericValues(forColumn: valueColumn.name)
guard !values.isEmpty else { return false }
// All values must be positive for a meaningful pie chart
return values.allSatisfy { $0 > 0 }
}
/// Heuristic: do values look like they represent parts of a whole?
///
/// Checks for aggregate-like column names (count, total, sum, amount, pct, etc.)
/// or if values sum to a round number suggesting percentages/proportions.
private func looksProportional(
table: DataTable,
valueColumn: DataTable.Column
) -> Bool {
let proportionalNames: Set<String> = ["count", "total", "sum", "amount", "pct",
"percent", "percentage", "share", "proportion",
"quantity", "qty", "num", "number"]
// Split on common separators and check for exact word matches
let lowerName = valueColumn.name.lowercased()
let words = Set(lowerName.split { $0 == "_" || $0 == "-" || $0 == " " }.map(String.init))
if !words.isDisjoint(with: proportionalNames) {
return true
}
// Check if values sum to ~100 (percentages)
let values = table.numericValues(forColumn: valueColumn.name)
let sum = values.reduce(0, +)
if abs(sum - 100.0) < 1.0 {
return true
}
return false
}
/// Heuristic: does the category column look like time-series data?
///
/// Checks for date-like patterns (YYYY, YYYY-MM, YYYY-MM-DD)
/// or common time-related column names.
private func isTimeSeries(_ column: DataTable.Column, in table: DataTable) -> Bool {
let timeNames = ["date", "time", "timestamp", "year", "month", "day",
"week", "quarter", "period", "created_at", "updated_at"]
let lowerName = column.name.lowercased()
if timeNames.contains(where: { lowerName.contains($0) }) {
return true
}
// Check if text values look like dates
if column.inferredType == .text {
let values = table.stringValues(forColumn: column.name)
let datePattern = #/^\d{4}(-\d{2}){0,2}$/#
let matchCount = values.prefix(5).filter { (try? datePattern.wholeMatch(in: $0)) != nil }.count
if matchCount >= 3 {
return true
}
}
return false
}
/// Heuristic: does the category column contain sequential numeric values?
private func isSequential(_ column: DataTable.Column, in table: DataTable) -> Bool {
guard isNumeric(column) else { return false }
let values = table.numericValues(forColumn: column.name)
guard values.count >= 3 else { return false }
// Check if values are monotonically increasing
for i in 1..<values.count {
if values[i] <= values[i - 1] {
return false
}
}
return true
}
/// Computes a confidence score for a specific chart type + data combination.
private func computeConfidence(
table: DataTable,
categoryColumn: DataTable.Column,
valueColumn: DataTable.Column,
chartType: ChartType
) -> Double {
var score = 0.5 // baseline
// Bonus: clear category/value split (text + numeric)
if isCategory(categoryColumn) && isNumeric(valueColumn) {
score += 0.2
}
// Bonus: reasonable row count for the chart type
switch chartType {
case .bar:
if table.rowCount >= 2 && table.rowCount <= 20 {
score += 0.15
}
case .line:
if isTimeSeries(categoryColumn, in: table) {
score += 0.2
} else if isSequential(categoryColumn, in: table) {
score += 0.1
}
case .pie:
if table.rowCount <= maxPieSlices && isPieCandidate(table: table, valueColumn: valueColumn) {
score += 0.2
}
// Penalty: too many slices
if table.rowCount > 5 {
score -= 0.1
}
}
// Bonus: no null values in key columns
let categoryNulls = table.columnValues(named: categoryColumn.name).filter(\.isNull).count
let valueNulls = table.columnValues(named: valueColumn.name).filter(\.isNull).count
if categoryNulls == 0 && valueNulls == 0 {
score += 0.1
}
return min(max(score, 0.0), 1.0)
}
/// Generates a human-readable reason for the recommendation.
private func describeReason(
chartType: ChartType,
categoryColumn: DataTable.Column,
valueColumn: DataTable.Column,
table: DataTable
) -> String {
switch chartType {
case .bar:
return "\(table.rowCount) categories comparing \(valueColumn.name) by \(categoryColumn.name)"
case .line:
if isTimeSeries(categoryColumn, in: table) {
return "\(valueColumn.name) over time (\(categoryColumn.name))"
}
return "\(valueColumn.name) trend across \(table.rowCount) points"
case .pie:
return "Proportional breakdown of \(valueColumn.name) by \(categoryColumn.name)"
}
}
}

View File

@@ -0,0 +1,255 @@
// DataTable.swift
// SwiftDBAI
//
// Structured table representation for rendering query results
// in SwiftUI table views and charts.
import Foundation
/// A structured, row-column table built from a `QueryResult`.
///
/// `DataTable` provides indexed access to rows and columns, typed column
/// metadata, and convenience methods for extracting data suitable for
/// SwiftUI `Table` views and Swift Charts.
///
/// Usage:
/// ```swift
/// let table = DataTable(queryResult)
/// print(table.columnCount) // 3
/// print(table[row: 0, column: 1]) // .text("Alice")
/// ```
public struct DataTable: Sendable, Equatable {
// MARK: - Column Metadata
/// Metadata for a single column in the data table.
public struct Column: Sendable, Equatable, Identifiable {
/// Stable identifier for the column (same as `name`).
public var id: String { name }
/// Column name from the query result set.
public let name: String
/// Index of this column in the table (0-based).
public let index: Int
/// Inferred data type based on the values in this column.
public let inferredType: InferredType
public init(name: String, index: Int, inferredType: InferredType) {
self.name = name
self.index = index
self.inferredType = inferredType
}
}
/// The inferred data type for a column, determined by inspecting its values.
public enum InferredType: Sendable, Equatable {
/// All non-null values are integers.
case integer
/// All non-null values are numeric (mix of integer and real).
case real
/// All non-null values are text.
case text
/// Values contain blob data.
case blob
/// Column contains only null values or is empty.
case null
/// Values are a mix of incompatible types.
case mixed
}
// MARK: - Row Type
/// A single row in the data table, providing indexed and named access.
public struct Row: Sendable, Equatable, Identifiable {
/// Row index (0-based), used as stable identity.
public let id: Int
/// Values in column order.
public let values: [QueryResult.Value]
/// Column names for named access.
private let columnNames: [String]
public init(id: Int, values: [QueryResult.Value], columnNames: [String]) {
self.id = id
self.values = values
self.columnNames = columnNames
}
/// Access a value by column index.
public subscript(columnIndex: Int) -> QueryResult.Value {
values[columnIndex]
}
/// Access a value by column name. Returns `.null` if the column doesn't exist.
public subscript(columnName: String) -> QueryResult.Value {
guard let idx = columnNames.firstIndex(of: columnName) else {
return .null
}
return values[idx]
}
}
// MARK: - Properties
/// Column metadata in order.
public let columns: [Column]
/// All rows in order.
public let rows: [Row]
/// The SQL that produced this table.
public let sql: String
/// Execution time of the underlying query.
public let executionTime: TimeInterval
/// Number of columns.
public var columnCount: Int { columns.count }
/// Number of rows.
public var rowCount: Int { rows.count }
/// Whether the table has no rows.
public var isEmpty: Bool { rows.isEmpty }
/// Column names in order.
public var columnNames: [String] { columns.map(\.name) }
// MARK: - Initialization
/// Creates a `DataTable` from a `QueryResult`.
///
/// Converts the dictionary-based row representation into an indexed
/// array representation and infers column types from the data.
///
/// - Parameter queryResult: The raw query result to convert.
public init(_ queryResult: QueryResult) {
let colNames = queryResult.columns
// Build indexed rows
let indexedRows: [Row] = queryResult.rows.enumerated().map { idx, rowDict in
let values = colNames.map { col in
rowDict[col] ?? .null
}
return Row(id: idx, values: values, columnNames: colNames)
}
// Infer column types
let inferredColumns: [Column] = colNames.enumerated().map { colIdx, name in
let type = Self.inferType(
from: indexedRows.map { $0.values[colIdx] }
)
return Column(name: name, index: colIdx, inferredType: type)
}
self.columns = inferredColumns
self.rows = indexedRows
self.sql = queryResult.sql
self.executionTime = queryResult.executionTime
}
/// Creates a `DataTable` directly from components (useful for testing).
public init(
columns: [Column],
rows: [Row],
sql: String = "",
executionTime: TimeInterval = 0
) {
self.columns = columns
self.rows = rows
self.sql = sql
self.executionTime = executionTime
}
// MARK: - Subscript Access
/// Access a cell by row and column index.
public subscript(row rowIndex: Int, column columnIndex: Int) -> QueryResult.Value {
rows[rowIndex].values[columnIndex]
}
/// Access a cell by row index and column name.
public subscript(row rowIndex: Int, column columnName: String) -> QueryResult.Value {
rows[rowIndex][columnName]
}
// MARK: - Column Data Extraction
/// Returns all values for a column by index, in row order.
public func columnValues(at index: Int) -> [QueryResult.Value] {
rows.map { $0.values[index] }
}
/// Returns all values for a column by name, in row order.
public func columnValues(named name: String) -> [QueryResult.Value] {
guard let col = columns.first(where: { $0.name == name }) else {
return []
}
return columnValues(at: col.index)
}
/// Returns all non-null `Double` values for a column (useful for charting).
public func numericValues(forColumn name: String) -> [Double] {
columnValues(named: name).compactMap(\.doubleValue)
}
/// Returns all non-null `String` values for a column (useful for labels).
public func stringValues(forColumn name: String) -> [String] {
columnValues(named: name).compactMap { value in
if case .null = value { return nil }
return value.stringValue
}
}
// MARK: - Type Inference
/// Infers the predominant type from an array of values.
static func inferType(from values: [QueryResult.Value]) -> InferredType {
var hasInteger = false
var hasReal = false
var hasText = false
var hasBlob = false
var hasNonNull = false
for value in values {
switch value {
case .integer:
hasInteger = true
hasNonNull = true
case .real:
hasReal = true
hasNonNull = true
case .text:
hasText = true
hasNonNull = true
case .blob:
hasBlob = true
hasNonNull = true
case .null:
break
}
}
guard hasNonNull else { return .null }
// Count how many distinct types are present
let typeCount = [hasInteger, hasReal, hasText, hasBlob].filter { $0 }.count
if typeCount == 1 {
if hasInteger { return .integer }
if hasReal { return .real }
if hasText { return .text }
if hasBlob { return .blob }
}
// Integer + real treat as real (numeric promotion)
if typeCount == 2, hasInteger, hasReal {
return .real
}
return .mixed
}
}

View File

@@ -0,0 +1,301 @@
// TextSummaryRenderer.swift
// SwiftDBAI
//
// Converts raw SQL query results into natural language text summaries
// using the LLM via AnyLanguageModel.
import AnyLanguageModel
import Foundation
/// Renders SQL query results as natural language text summaries.
///
/// The renderer takes a `QueryResult` and the user's original question,
/// sends them to the LLM for summarization, and returns a concise,
/// human-readable response.
///
/// Usage:
/// ```swift
/// let renderer = TextSummaryRenderer(model: myModel)
/// let summary = try await renderer.summarize(
/// result: queryResult,
/// userQuestion: "How many orders were placed last month?"
/// )
/// print(summary) // "There were 42 orders placed last month."
/// ```
public struct TextSummaryRenderer: Sendable {
/// The language model used to generate summaries.
private let model: any LanguageModel
/// Maximum number of rows to include in the LLM prompt.
///
/// Results larger than this are truncated with a note about total count.
public let maxRowsInPrompt: Int
/// Creates a new text summary renderer.
///
/// - Parameters:
/// - model: Any `AnyLanguageModel`-compatible language model.
/// - maxRowsInPrompt: Maximum rows to send to the LLM for summarization (default: 50).
public init(model: any LanguageModel, maxRowsInPrompt: Int = 50) {
self.model = model
self.maxRowsInPrompt = maxRowsInPrompt
}
/// Generates a natural language summary of query results.
///
/// - Parameters:
/// - result: The raw `QueryResult` from SQL execution.
/// - userQuestion: The original natural language question from the user.
/// - context: Optional additional context (e.g., table descriptions) to help the LLM.
/// - Returns: A natural language text summary of the results.
public func summarize(
result: QueryResult,
userQuestion: String,
context: String? = nil
) async throws -> String {
// For mutation results (INSERT/UPDATE/DELETE), use a simple template
if let affected = result.rowsAffected {
return summarizeMutation(result: result, affected: affected)
}
// For empty results, no need to call the LLM
if result.rows.isEmpty {
return "No results found for your query."
}
// For simple aggregates, produce a direct answer without LLM
if let directAnswer = tryDirectAggregateSummary(result: result, userQuestion: userQuestion) {
return directAnswer
}
// Build the prompt and ask the LLM to summarize
let prompt = buildSummarizationPrompt(
result: result,
userQuestion: userQuestion,
context: context
)
let session = LanguageModelSession(
model: model,
instructions: summaryInstructions
)
let response = try await session.respond(to: prompt)
return response.content.trimmingCharacters(in: .whitespacesAndNewlines)
}
/// Generates a summary without calling the LLM, using simple templates.
///
/// Useful when LLM access is unavailable, or for fast local rendering.
///
/// - Parameters:
/// - result: The raw `QueryResult` from SQL execution.
/// - userQuestion: The original natural language question.
/// - Returns: A template-based text summary.
public func localSummary(result: QueryResult, userQuestion: String) -> String {
if let affected = result.rowsAffected {
return summarizeMutation(result: result, affected: affected)
}
if result.rows.isEmpty {
return "No results found for your query."
}
if let directAnswer = tryDirectAggregateSummary(result: result, userQuestion: userQuestion) {
return directAnswer
}
return buildTemplateSummary(result: result)
}
// MARK: - Private Helpers
/// System instructions for the summarization session.
private var summaryInstructions: String {
"""
You are a data assistant that summarizes SQL query results in natural language.
Rules:
- Be concise and direct. Answer the user's question first, then add detail if helpful.
- Use natural language, not SQL or code.
- For numeric results, include the exact numbers.
- For lists of records, summarize the count and highlight notable items.
- If the data contains dates, format them in a readable way.
- Do not mention SQL, databases, tables, columns, or queries in your response.
- Do not include markdown formatting.
- Keep your response under 3 sentences for simple results, under 5 for complex ones.
"""
}
/// Builds the prompt sent to the LLM for summarization.
private func buildSummarizationPrompt(
result: QueryResult,
userQuestion: String,
context: String?
) -> String {
var parts: [String] = []
parts.append("User's question: \(userQuestion)")
if let context {
parts.append("Context: \(context)")
}
parts.append("Query returned \(result.rowCount) row(s) with columns: \(result.columns.joined(separator: ", "))")
// Include the result data (truncated if large)
let dataStr = formatResultData(result)
parts.append("Data:\n\(dataStr)")
parts.append("Summarize these results in natural language, directly answering the user's question.")
return parts.joined(separator: "\n\n")
}
/// Formats the query result data as a compact table for the LLM prompt.
private func formatResultData(_ result: QueryResult) -> String {
let rowsToInclude = Array(result.rows.prefix(maxRowsInPrompt))
var lines: [String] = []
// Header
lines.append(result.columns.joined(separator: " | "))
// Rows
for row in rowsToInclude {
let values = result.columns.map { col in
row[col]?.description ?? "NULL"
}
lines.append(values.joined(separator: " | "))
}
if result.rowCount > maxRowsInPrompt {
lines.append("(\(result.rowCount - maxRowsInPrompt) additional rows not shown)")
}
return lines.joined(separator: "\n")
}
/// Produces a direct answer for simple aggregate queries (1 row, few columns).
private func tryDirectAggregateSummary(result: QueryResult, userQuestion: String) -> String? {
guard result.isAggregate else { return nil }
let row = result.rows[0]
// Single numeric column e.g., "COUNT(*)" "42"
if result.columns.count == 1 {
let col = result.columns[0]
guard let value = row[col] else { return nil }
let formatted = formatNumber(value)
return "The result is \(formatted)."
}
// Multiple aggregate columns e.g., COUNT, AVG, SUM
let parts = result.columns.compactMap { col -> String? in
guard let value = row[col] else { return nil }
let label = humanizeColumnName(col)
let formatted = formatNumber(value)
return "\(label): \(formatted)"
}
return parts.joined(separator: ", ") + "."
}
/// Formats a numeric Value for display.
private func formatNumber(_ value: QueryResult.Value) -> String {
switch value {
case .integer(let i):
return NumberFormatter.localizedString(from: NSNumber(value: i), number: .decimal)
case .real(let d):
if d == d.rounded() && abs(d) < 1e12 {
return NumberFormatter.localizedString(from: NSNumber(value: Int64(d)), number: .decimal)
}
let formatter = NumberFormatter()
formatter.numberStyle = .decimal
formatter.maximumFractionDigits = 2
return formatter.string(from: NSNumber(value: d)) ?? String(d)
default:
return value.description
}
}
/// Converts a column name like "total_count" or "AVG(price)" into a readable label.
private func humanizeColumnName(_ name: String) -> String {
// Handle SQL function names: "COUNT(*)" "count", "AVG(price)" "average price"
let functionPatterns: [(pattern: String, label: String)] = [
("COUNT", "count"),
("SUM", "total"),
("AVG", "average"),
("MIN", "minimum"),
("MAX", "maximum"),
]
let upper = name.uppercased()
for (pattern, label) in functionPatterns {
if upper.hasPrefix(pattern + "(") {
// Extract the inner column name
let start = name.index(name.startIndex, offsetBy: pattern.count + 1)
let end = name.index(before: name.endIndex)
if start < end {
let inner = String(name[start..<end])
if inner == "*" { return label }
return "\(label) \(humanizeColumnName(inner))"
}
return label
}
}
// snake_case space-separated
return name
.replacingOccurrences(of: "_", with: " ")
.lowercased()
}
/// Produces a template-based summary without calling the LLM.
private func buildTemplateSummary(result: QueryResult) -> String {
let count = result.rowCount
if count == 1 {
// Single record list field values
let row = result.rows[0]
let details = result.columns.prefix(5).compactMap { col -> String? in
guard let val = row[col], !val.isNull else { return nil }
return "\(humanizeColumnName(col)): \(val.description)"
}
return "Found 1 result. \(details.joined(separator: ", "))."
}
// Multiple records
var summary = "Found \(count) results"
// If there's a clear "name" or "title" column, list first few
let nameColumns = ["name", "title", "label", "description"]
if let nameCol = result.columns.first(where: { nameColumns.contains($0.lowercased()) }) {
let names = result.rows.prefix(3).compactMap { $0[nameCol]?.description }
if !names.isEmpty {
summary += " including \(names.joined(separator: ", "))"
if count > 3 { summary += ", and \(count - 3) more" }
}
}
return summary + "."
}
/// Summarizes a mutation (INSERT/UPDATE/DELETE) result.
private func summarizeMutation(result: QueryResult, affected: Int) -> String {
let sql = result.sql.trimmingCharacters(in: .whitespacesAndNewlines).uppercased()
let operation: String
if sql.hasPrefix("INSERT") {
operation = "inserted"
} else if sql.hasPrefix("UPDATE") {
operation = "updated"
} else if sql.hasPrefix("DELETE") {
operation = "deleted"
} else {
operation = "affected"
}
let noun = affected == 1 ? "row" : "rows"
return "Successfully \(operation) \(affected) \(noun)."
}
}

View File

@@ -0,0 +1,164 @@
// DatabaseSchema.swift
// SwiftDBAI
//
// Auto-introspected SQLite schema model types.
import Foundation
/// Complete schema representation of an SQLite database.
public struct DatabaseSchema: Sendable, Equatable {
/// All tables in the database, keyed by table name.
public let tables: [String: TableSchema]
/// Ordered table names (preserves discovery order).
public let tableNames: [String]
/// Returns a compact text description suitable for LLM system prompts.
public var schemaDescription: String {
var lines: [String] = []
for name in tableNames {
guard let table = tables[name] else { continue }
lines.append(table.descriptionForLLM)
}
return lines.joined(separator: "\n\n")
}
/// Returns a description suitable for LLM system prompts.
/// Alias for `schemaDescription` for API compatibility.
public func describeForLLM() -> String {
schemaDescription
}
public init(tables: [String: TableSchema], tableNames: [String]) {
self.tables = tables
self.tableNames = tableNames
}
}
/// Schema for a single SQLite table.
public struct TableSchema: Sendable, Equatable {
public let name: String
public let columns: [ColumnSchema]
public let primaryKey: [String]
public let foreignKeys: [ForeignKeySchema]
public let indexes: [IndexSchema]
/// Text description for embedding in LLM prompts.
public var descriptionForLLM: String {
var parts: [String] = []
let colDefs = columns.map { col in
var def = " \(col.name) \(col.type)"
if col.isPrimaryKey { def += " PRIMARY KEY" }
if col.isNotNull { def += " NOT NULL" }
if let defaultValue = col.defaultValue { def += " DEFAULT \(defaultValue)" }
return def
}
parts.append("TABLE \(name) (\n\(colDefs.joined(separator: ",\n"))\n)")
if !foreignKeys.isEmpty {
let fkDescs = foreignKeys.map {
" FOREIGN KEY (\($0.fromColumn)) REFERENCES \($0.toTable)(\($0.toColumn))"
}
parts.append("FOREIGN KEYS:\n\(fkDescs.joined(separator: "\n"))")
}
if !indexes.isEmpty {
let idxDescs = indexes.map {
" INDEX \($0.name) ON (\($0.columns.joined(separator: ", ")))\($0.isUnique ? " UNIQUE" : "")"
}
parts.append("INDEXES:\n\(idxDescs.joined(separator: "\n"))")
}
return parts.joined(separator: "\n")
}
public init(
name: String,
columns: [ColumnSchema],
primaryKey: [String],
foreignKeys: [ForeignKeySchema],
indexes: [IndexSchema]
) {
self.name = name
self.columns = columns
self.primaryKey = primaryKey
self.foreignKeys = foreignKeys
self.indexes = indexes
}
}
/// Schema for a single column.
public struct ColumnSchema: Sendable, Equatable {
/// Column position (0-based).
public let cid: Int
/// Column name.
public let name: String
/// Declared SQLite type (e.g. "TEXT", "INTEGER", "REAL", "BLOB").
public let type: String
/// Whether the column has a NOT NULL constraint.
public let isNotNull: Bool
/// Default value expression, if any.
public let defaultValue: String?
/// Whether this column is part of the primary key.
public let isPrimaryKey: Bool
public init(
cid: Int,
name: String,
type: String,
isNotNull: Bool,
defaultValue: String?,
isPrimaryKey: Bool
) {
self.cid = cid
self.name = name
self.type = type
self.isNotNull = isNotNull
self.defaultValue = defaultValue
self.isPrimaryKey = isPrimaryKey
}
}
/// Schema for a foreign key relationship.
public struct ForeignKeySchema: Sendable, Equatable {
/// Column in the source table.
public let fromColumn: String
/// Referenced table name.
public let toTable: String
/// Referenced column name.
public let toColumn: String
/// ON UPDATE action (e.g. "CASCADE", "NO ACTION").
public let onUpdate: String
/// ON DELETE action.
public let onDelete: String
public init(
fromColumn: String,
toTable: String,
toColumn: String,
onUpdate: String,
onDelete: String
) {
self.fromColumn = fromColumn
self.toTable = toTable
self.toColumn = toColumn
self.onUpdate = onUpdate
self.onDelete = onDelete
}
}
/// Schema for a database index.
public struct IndexSchema: Sendable, Equatable {
/// Index name.
public let name: String
/// Whether the index enforces uniqueness.
public let isUnique: Bool
/// Columns included in the index, in order.
public let columns: [String]
public init(name: String, isUnique: Bool, columns: [String]) {
self.name = name
self.isUnique = isUnique
self.columns = columns
}
}

View File

@@ -0,0 +1,153 @@
// SchemaIntrospector.swift
// SwiftDBAI
//
// Auto-introspects SQLite database schema using GRDB.
import GRDB
/// Introspects an SQLite database schema by querying sqlite_master and PRAGMA statements.
///
/// Usage:
/// ```swift
/// let dbPool = try DatabasePool(path: "path/to/db.sqlite")
/// let schema = try await SchemaIntrospector.introspect(database: dbPool)
/// print(schema.schemaDescription)
/// ```
public struct SchemaIntrospector: Sendable {
// MARK: - Public API
/// Introspects the full schema of the given database.
///
/// Discovers all user tables (excluding sqlite_ internal tables),
/// their columns, primary keys, foreign keys, and indexes.
///
/// - Parameter database: A GRDB `DatabaseReader` (DatabasePool or DatabaseQueue).
/// - Returns: A complete `DatabaseSchema` representation.
public static func introspect(database: any DatabaseReader) async throws -> DatabaseSchema {
try await database.read { db in
try introspect(db: db)
}
}
/// Synchronous introspection within an existing database access context.
///
/// - Parameter db: A GRDB `Database` instance from within a read/write block.
/// - Returns: A complete `DatabaseSchema` representation.
public static func introspect(db: Database) throws -> DatabaseSchema {
let tableNames = try fetchTableNames(db: db)
var tables: [String: TableSchema] = [:]
for tableName in tableNames {
let columns = try fetchColumns(db: db, table: tableName)
let primaryKey = try fetchPrimaryKey(db: db, table: tableName)
let foreignKeys = try fetchForeignKeys(db: db, table: tableName)
let indexes = try fetchIndexes(db: db, table: tableName)
// Mark columns that are part of the primary key
let pkSet = Set(primaryKey)
let annotatedColumns = columns.map { col in
ColumnSchema(
cid: col.cid,
name: col.name,
type: col.type,
isNotNull: col.isNotNull,
defaultValue: col.defaultValue,
isPrimaryKey: pkSet.contains(col.name)
)
}
tables[tableName] = TableSchema(
name: tableName,
columns: annotatedColumns,
primaryKey: primaryKey,
foreignKeys: foreignKeys,
indexes: indexes
)
}
return DatabaseSchema(tables: tables, tableNames: tableNames)
}
// MARK: - Private Helpers
/// Fetches all user table names from sqlite_master.
private static func fetchTableNames(db: Database) throws -> [String] {
let sql = """
SELECT name FROM sqlite_master
WHERE type = 'table'
AND name NOT LIKE 'sqlite_%'
ORDER BY name
"""
return try String.fetchAll(db, sql: sql)
}
/// Fetches column metadata for a table using PRAGMA table_info.
private static func fetchColumns(db: Database, table: String) throws -> [ColumnSchema] {
let sql = "PRAGMA table_info(\(table.quotedDatabaseIdentifier))"
let rows = try Row.fetchAll(db, sql: sql)
return rows.map { row in
ColumnSchema(
cid: row["cid"],
name: row["name"],
type: (row["type"] as String?) ?? "",
isNotNull: row["notnull"] == 1,
defaultValue: row["dflt_value"],
isPrimaryKey: row["pk"] != 0
)
}
}
/// Fetches primary key columns for a table.
private static func fetchPrimaryKey(db: Database, table: String) throws -> [String] {
let sql = "PRAGMA table_info(\(table.quotedDatabaseIdentifier))"
let rows = try Row.fetchAll(db, sql: sql)
return rows
.filter { ($0["pk"] as Int) > 0 }
.sorted { ($0["pk"] as Int) < ($1["pk"] as Int) }
.map { $0["name"] }
}
/// Fetches foreign key relationships for a table.
private static func fetchForeignKeys(db: Database, table: String) throws -> [ForeignKeySchema] {
let sql = "PRAGMA foreign_key_list(\(table.quotedDatabaseIdentifier))"
let rows = try Row.fetchAll(db, sql: sql)
return rows.map { row in
ForeignKeySchema(
fromColumn: row["from"],
toTable: row["table"],
toColumn: row["to"],
onUpdate: row["on_update"] ?? "NO ACTION",
onDelete: row["on_delete"] ?? "NO ACTION"
)
}
}
/// Fetches indexes and their columns for a table.
private static func fetchIndexes(db: Database, table: String) throws -> [IndexSchema] {
let indexListSQL = "PRAGMA index_list(\(table.quotedDatabaseIdentifier))"
let indexRows = try Row.fetchAll(db, sql: indexListSQL)
var indexes: [IndexSchema] = []
for indexRow in indexRows {
let indexName: String = indexRow["name"]
let isUnique: Bool = indexRow["unique"] == 1
// Skip auto-generated indexes for primary keys
if indexName.hasPrefix("sqlite_autoindex_") { continue }
let infoSQL = "PRAGMA index_info(\(indexName.quotedDatabaseIdentifier))"
let infoRows = try Row.fetchAll(db, sql: infoSQL)
let columns: [String] = infoRows
.sorted { ($0["seqno"] as Int) < ($1["seqno"] as Int) }
.map { $0["name"] }
indexes.append(IndexSchema(
name: indexName,
isUnique: isUnique,
columns: columns
))
}
return indexes
}
}

View File

@@ -0,0 +1,215 @@
// SwiftDBAIError.swift
// SwiftDBAI
//
// Unified error type for the SwiftDBAI package.
import Foundation
/// The top-level error type for SwiftDBAI operations.
///
/// `SwiftDBAIError` provides a single, typed error surface that covers
/// every failure mode a consumer of SwiftDBAI may encounter from invalid
/// SQL and LLM failures to schema mismatches and safety violations.
///
/// Every case includes a user-friendly `localizedDescription` suitable for
/// displaying directly in a chat interface.
public enum SwiftDBAIError: Error, LocalizedError, Sendable, Equatable {
// MARK: - SQL Errors
/// No SQL statement could be extracted from the LLM response.
case noSQLGenerated
/// The generated SQL is syntactically invalid or failed execution.
case invalidSQL(sql: String, reason: String)
/// The SQL uses an operation (e.g. DELETE) not in the developer's allowlist.
case operationNotAllowed(operation: String)
/// Multiple SQL statements were generated but only single-statement execution is supported.
case multipleStatementsNotSupported
/// A dangerous SQL keyword (DROP, ALTER, TRUNCATE) was detected.
case dangerousOperationBlocked(keyword: String)
// MARK: - LLM Errors
/// The LLM failed to produce a response.
case llmFailure(reason: String)
/// The LLM response could not be parsed into an actionable result.
case llmResponseUnparseable(response: String)
/// The LLM request timed out.
case llmTimeout(seconds: TimeInterval)
// MARK: - Schema Errors
/// Schema introspection of the database failed.
case schemaIntrospectionFailed(reason: String)
/// The generated SQL references a table that does not exist in the schema.
case tableNotFound(tableName: String)
/// The generated SQL references a column that does not exist on the given table.
case columnNotFound(columnName: String, tableName: String)
/// The database schema is empty (no user tables found).
case emptySchema
// MARK: - Safety & Validation Errors
/// A destructive operation requires explicit user confirmation before execution.
case confirmationRequired(sql: String, operation: String)
/// A mutation targets a table not in the allowed mutation tables.
case tableNotAllowedForMutation(tableName: String, operation: String)
/// A custom query validator rejected the query.
case queryRejected(reason: String)
// MARK: - Database Errors
/// The underlying database operation failed.
case databaseError(reason: String)
/// The query exceeded the configured execution timeout.
case queryTimedOut(seconds: TimeInterval)
// MARK: - Configuration Errors
/// The engine has not been configured correctly.
case configurationError(reason: String)
// MARK: - Error Classification
/// Whether this error represents a safety/permissions issue (not a bug).
public var isSafetyError: Bool {
switch self {
case .operationNotAllowed, .dangerousOperationBlocked,
.confirmationRequired, .tableNotAllowedForMutation, .queryRejected:
return true
default:
return false
}
}
/// Whether this error is recoverable by rephrasing the user's question.
public var isRecoverable: Bool {
switch self {
case .noSQLGenerated, .llmResponseUnparseable, .invalidSQL,
.tableNotFound, .columnNotFound:
return true
default:
return false
}
}
/// Whether this error requires user action (e.g. confirmation).
public var requiresUserAction: Bool {
if case .confirmationRequired = self { return true }
return false
}
// MARK: - LocalizedError
public var errorDescription: String? {
switch self {
// SQL
case .noSQLGenerated:
return "I couldn't generate a SQL query from your request. Could you rephrase your question?"
case .invalidSQL(let sql, let reason):
return "The generated query is invalid — \(reason). Query: \(sql)"
case .operationNotAllowed(let operation):
return "The \(operation.uppercased()) operation is not allowed by the current configuration."
case .multipleStatementsNotSupported:
return "Only single SQL statements are supported. Please ask one question at a time."
case .dangerousOperationBlocked(let keyword):
return "The \(keyword.uppercased()) operation is blocked for safety. This operation is never allowed."
// LLM
case .llmFailure(let reason):
return "The language model encountered an error: \(reason)"
case .llmResponseUnparseable(let response):
return "I received a response but couldn't understand it. Raw response: \(response.prefix(200))"
case .llmTimeout(let seconds):
return "The language model did not respond within \(Int(seconds)) seconds. Please try again."
// Schema
case .schemaIntrospectionFailed(let reason):
return "Failed to read the database schema: \(reason)"
case .tableNotFound(let tableName):
return "The table '\(tableName)' does not exist in this database."
case .columnNotFound(let columnName, let tableName):
return "The column '\(columnName)' does not exist on table '\(tableName)'."
case .emptySchema:
return "This database has no tables. There's nothing to query yet."
// Safety
case .confirmationRequired(let sql, let operation):
return "The \(operation.uppercased()) operation requires your confirmation before running: \(sql)"
case .tableNotAllowedForMutation(let tableName, let operation):
return "The \(operation.uppercased()) operation is not allowed on table '\(tableName)'."
case .queryRejected(let reason):
return "Query rejected: \(reason)"
// Database
case .databaseError(let reason):
return "A database error occurred: \(reason)"
case .queryTimedOut(let seconds):
return "The query timed out after \(Int(seconds)) seconds. Try a simpler query."
// Configuration
case .configurationError(let reason):
return "Configuration error: \(reason)"
}
}
}
// MARK: - Conversion from SQLParsingError
extension SQLParsingError {
/// Maps a ``SQLParsingError`` to the corresponding ``SwiftDBAIError`` case.
///
/// - Parameter rawResponse: The raw LLM response text (used for context in `.noSQLFound`).
/// - Returns: A ``SwiftDBAIError`` with the same semantic meaning.
func toSwiftDBAIError(rawResponse: String = "") -> SwiftDBAIError {
switch self {
case .noSQLFound:
if rawResponse.isEmpty {
return .noSQLGenerated
}
return .llmResponseUnparseable(response: rawResponse)
case .operationNotAllowed(let op):
return .operationNotAllowed(operation: op.rawValue)
case .confirmationRequired(let sql, let op):
return .confirmationRequired(sql: sql, operation: op.rawValue)
case .tableNotAllowed(let table, let op):
return .tableNotAllowedForMutation(tableName: table, operation: op.rawValue)
case .dangerousOperation(let keyword):
return .dangerousOperationBlocked(keyword: keyword)
case .multipleStatements:
return .multipleStatementsNotSupported
}
}
}
// MARK: - Conversion from ChatEngineError
extension ChatEngineError {
/// Maps a ``ChatEngineError`` to the corresponding ``SwiftDBAIError`` case.
func toSwiftDBAIError() -> SwiftDBAIError {
switch self {
case .sqlParsingFailed(let parsingError):
return parsingError.toSwiftDBAIError()
case .confirmationRequired(let sql, let operation):
return .confirmationRequired(sql: sql, operation: operation.rawValue)
case .schemaIntrospectionFailed(let reason):
return .schemaIntrospectionFailed(reason: reason)
case .queryTimedOut(let seconds):
return .queryTimedOut(seconds: seconds)
case .validationFailed(let reason):
return .queryRejected(reason: reason)
}
}
}

View File

@@ -0,0 +1,230 @@
// DatabaseTool.swift
// SwiftDBAI
//
// A standalone tool calling API for integrating SwiftDBAI into
// existing LLM tool calling setups (OpenAI function calling,
// Anthropic tools, Apple Foundation Models, etc.).
import Foundation
import GRDB
/// A standalone database tool for LLM tool calling integrations.
///
/// Provides everything needed to register a "query database" tool with any LLM:
/// - Tool name, description, and parameter schema for registration
/// - Schema context for the LLM's system prompt
/// - SQL execution with allowlist validation
///
/// ## Usage
///
/// ```swift
/// // 1. Create the tool
/// let tool = try await DatabaseTool(databasePath: "path/to/db.sqlite")
///
/// // 2. Get the tool definition for your LLM
/// let definition = tool.openAIFunctionDefinition
/// // Register with your OpenAI/Anthropic/etc. client...
///
/// // 3. Include schema in system prompt
/// let systemPrompt = "You are a helpful assistant.\n\n" + tool.systemPromptSnippet
///
/// // 4. When the LLM calls the tool, execute it
/// let result = try tool.execute(sql: llmGeneratedSQL)
/// // Return result.jsonString back to the LLM as the tool response
/// ```
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
public struct DatabaseTool: Sendable {
private let database: any DatabaseWriter
private let allowlist: OperationAllowlist
private let schema: DatabaseSchema
// MARK: - Initialization
/// Creates a database tool from a file path.
///
/// - Parameters:
/// - databasePath: Path to the SQLite database file.
/// - allowlist: The set of permitted SQL operations. Defaults to read-only.
public init(databasePath: String, allowlist: OperationAllowlist = .readOnly) async throws {
let dbQueue = try DatabaseQueue(path: databasePath)
self.database = dbQueue
self.allowlist = allowlist
self.schema = try await SchemaIntrospector.introspect(database: dbQueue)
}
/// Creates a database tool from an existing GRDB database connection.
///
/// - Parameters:
/// - database: A GRDB `DatabaseWriter` (DatabaseQueue or DatabasePool).
/// - allowlist: The set of permitted SQL operations. Defaults to read-only.
public init(database: any DatabaseWriter, allowlist: OperationAllowlist = .readOnly) async throws {
self.database = database
self.allowlist = allowlist
self.schema = try await SchemaIntrospector.introspect(database: database)
}
// MARK: - Tool Definition
/// The tool name for LLM function calling registration.
public var name: String { "execute_sql" }
/// The tool description for LLM function calling registration.
public var description: String {
"Execute a SQL query against a SQLite database. \(allowlist.describeForLLM())"
}
/// JSON Schema for the tool's parameters, compatible with OpenAI/Anthropic tool definitions.
public var parametersSchema: [String: Any] {
[
"type": "object",
"properties": [
"sql": [
"type": "string",
"description": "The SQL query to execute against the database.",
] as [String: Any],
] as [String: Any],
"required": ["sql"],
]
}
/// The database schema as a string, for including in the LLM's system prompt.
public var schemaContext: String {
schema.schemaDescription
}
/// A system prompt snippet that describes the database and how to use the tool.
///
/// Include this in your LLM's system prompt so it knows the database structure
/// and how to use the `execute_sql` tool.
public var systemPromptSnippet: String {
"""
You have access to a SQLite database with the following schema:
\(schema.schemaDescription)
\(allowlist.describeForLLM())
Use the `execute_sql` tool to query this database. Pass a single SQL statement as the `sql` parameter.
"""
}
// MARK: - Execution
/// Execute a SQL query, returning a structured ``ToolResult``.
///
/// Validates the SQL against the configured allowlist before execution.
/// This is the method to call when the LLM invokes the tool.
///
/// - Parameter sql: The SQL query to execute.
/// - Returns: A ``ToolResult`` with the query results.
/// - Throws: ``SQLParsingError`` if the SQL is not allowed, or a database error.
public func execute(sql: String) throws -> ToolResult {
let queryResult = try executeRaw(sql: sql)
return ToolResult(queryResult: queryResult)
}
/// Execute a SQL query and return the raw ``QueryResult``.
///
/// For advanced use cases where you need the full `QueryResult.Value` types
/// rather than the string-based ``ToolResult``.
///
/// - Parameter sql: The SQL query to execute.
/// - Returns: A ``QueryResult`` with typed values.
/// - Throws: ``SQLParsingError`` if the SQL is not allowed, or a database error.
public func executeRaw(sql: String) throws -> QueryResult {
// Validate against the allowlist
let parser = SQLQueryParser(allowlist: allowlist)
let parsed = try parser.validate(sql)
let startTime = CFAbsoluteTimeGetCurrent()
let trimmed = sql.trimmingCharacters(in: .whitespacesAndNewlines).uppercased()
let isSelect = trimmed.hasPrefix("SELECT") || trimmed.hasPrefix("WITH")
if isSelect {
let result = try database.read { db -> (columns: [String], rows: [[String: QueryResult.Value]]) in
let statement = try db.makeStatement(sql: parsed.sql)
let columnNames = statement.columnNames
var rows: [[String: QueryResult.Value]] = []
let cursor = try Row.fetchCursor(statement)
while let row = try cursor.next() {
var dict: [String: QueryResult.Value] = [:]
for col in columnNames {
dict[col] = Self.extractValue(row: row, column: col)
}
rows.append(dict)
}
return (columns: columnNames, rows: rows)
}
let elapsed = CFAbsoluteTimeGetCurrent() - startTime
return QueryResult(
columns: result.columns,
rows: result.rows,
sql: parsed.sql,
executionTime: elapsed
)
} else {
let affected = try database.write { db -> Int in
try db.execute(sql: parsed.sql)
return db.changesCount
}
let elapsed = CFAbsoluteTimeGetCurrent() - startTime
return QueryResult(
columns: [],
rows: [],
sql: parsed.sql,
executionTime: elapsed,
rowsAffected: affected
)
}
}
// MARK: - Private Helpers
private static func extractValue(row: Row, column: String) -> QueryResult.Value {
let dbValue: DatabaseValue = row[column]
switch dbValue.storage {
case .null:
return .null
case .int64(let i):
return .integer(i)
case .double(let d):
return .real(d)
case .string(let s):
return .text(s)
case .blob(let data):
return .blob(data)
}
}
}
// MARK: - OpenAI / Anthropic Compatibility
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
extension DatabaseTool {
/// Returns an OpenAI-compatible function definition dictionary.
///
/// This can be serialized to JSON and passed directly to the OpenAI API's
/// `tools` parameter, or adapted for Anthropic's tool definitions.
///
/// ```swift
/// let tool = try await DatabaseTool(databasePath: "db.sqlite")
/// let definition = tool.openAIFunctionDefinition
/// // Serialize to JSON for the API call
/// let data = try JSONSerialization.data(withJSONObject: definition)
/// ```
public var openAIFunctionDefinition: [String: Any] {
[
"type": "function",
"function": [
"name": name,
"description": description,
"parameters": parametersSchema,
] as [String: Any],
]
}
}

View File

@@ -0,0 +1,108 @@
// ToolResult.swift
// SwiftDBAI
//
// Structured result for tool calling responses.
import Foundation
/// A structured result from executing a SQL query via ``DatabaseTool``,
/// designed for returning to an LLM as a tool call response.
///
/// Provides multiple output formats:
/// - ``jsonString`` for returning to the LLM as a tool response
/// - ``markdownTable`` for display in UI
/// - ``textSummary`` for plain text output
public struct ToolResult: Sendable, Codable, Equatable {
/// Column names in the order they appear in the result set.
public let columns: [String]
/// Row data as an array of dictionaries mapping column name to string value.
/// All values are converted to strings for reliable JSON serialization.
public let rows: [[String: String]]
/// Total number of rows returned.
public let rowCount: Int
/// Time taken to execute the query, in seconds.
public let executionTime: TimeInterval
/// The SQL statement that was executed.
public let sql: String
public init(
columns: [String],
rows: [[String: String]],
rowCount: Int,
executionTime: TimeInterval,
sql: String
) {
self.columns = columns
self.rows = rows
self.rowCount = rowCount
self.executionTime = executionTime
self.sql = sql
}
/// Creates a ``ToolResult`` from a ``QueryResult``.
init(queryResult: QueryResult) {
self.columns = queryResult.columns
self.rows = queryResult.rows.map { row in
var stringRow: [String: String] = [:]
for (key, value) in row {
stringRow[key] = value.description
}
return stringRow
}
self.rowCount = queryResult.rowCount
self.executionTime = queryResult.executionTime
self.sql = queryResult.sql
}
// MARK: - Output Formats
/// Formats the result as a JSON string for returning to the LLM as a tool response.
public var jsonString: String {
let payload: [String: Any] = [
"columns": columns,
"rows": rows,
"row_count": rowCount,
"execution_time_seconds": executionTime,
"sql": sql,
]
guard let data = try? JSONSerialization.data(withJSONObject: payload, options: [.sortedKeys]),
let str = String(data: data, encoding: .utf8) else {
return "{\"error\": \"Failed to serialize result\"}"
}
return str
}
/// Formats the result as a markdown table for display.
public var markdownTable: String {
guard !rows.isEmpty else {
return "_No results._"
}
var lines: [String] = []
// Header
lines.append("| " + columns.joined(separator: " | ") + " |")
lines.append("| " + columns.map { _ in "---" }.joined(separator: " | ") + " |")
// Rows
for row in rows {
let vals = columns.map { row[$0] ?? "NULL" }
lines.append("| " + vals.joined(separator: " | ") + " |")
}
return lines.joined(separator: "\n")
}
/// Formats a plain text summary of the result.
public var textSummary: String {
if rows.isEmpty {
return "Query returned no results. (\(String(format: "%.3f", executionTime))s)"
}
return "Query returned \(rowCount) row\(rowCount == 1 ? "" : "s") with columns: \(columns.joined(separator: ", ")). (\(String(format: "%.3f", executionTime))s)"
}
}

View File

@@ -0,0 +1,182 @@
// BarChartView.swift
// SwiftDBAI
//
// A SwiftUI bar chart that renders DataTable values using Swift Charts.
// Best for categorical comparisons (e.g., sales by region, counts by status).
import SwiftUI
import Charts
/// A bar chart view that renders a `DataTable` column pair using Swift Charts.
///
/// Displays vertical bars with category labels on the x-axis and numeric
/// values on the y-axis. Automatically colors bars using the accent gradient
/// and supports scrolling when many categories are present.
///
/// Usage:
/// ```swift
/// BarChartView(
/// dataTable: table,
/// categoryColumn: "department",
/// valueColumn: "total_sales"
/// )
/// ```
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
public struct BarChartView: View {
/// The data to chart.
public let dataTable: DataTable
/// Column name for category labels (x-axis).
public let categoryColumn: String
/// Column name for numeric values (y-axis).
public let valueColumn: String
/// Optional chart title.
public var title: String?
/// Maximum number of bars to display before truncating.
public var maxBars: Int
public init(
dataTable: DataTable,
categoryColumn: String,
valueColumn: String,
title: String? = nil,
maxBars: Int = 30
) {
self.dataTable = dataTable
self.categoryColumn = categoryColumn
self.valueColumn = valueColumn
self.title = title
self.maxBars = maxBars
}
public var body: some View {
VStack(alignment: .leading, spacing: 8) {
if let title {
Text(title)
.font(.caption.weight(.semibold))
.foregroundStyle(.secondary)
}
if chartData.isEmpty {
emptyChartView
} else {
chartContent
}
}
}
// MARK: - Chart Content
@ViewBuilder
private var chartContent: some View {
Chart(chartData, id: \.label) { item in
BarMark(
x: .value(categoryColumn, item.label),
y: .value(valueColumn, item.value)
)
.foregroundStyle(
.linearGradient(
colors: [.accentColor, .accentColor.opacity(0.7)],
startPoint: .bottom,
endPoint: .top
)
)
.cornerRadius(4)
}
.chartXAxis {
AxisMarks(values: .automatic) { _ in
AxisValueLabel()
.font(.caption2)
}
}
.chartYAxis {
AxisMarks(position: .leading) { _ in
AxisGridLine(stroke: StrokeStyle(lineWidth: 0.5, dash: [4, 4]))
.foregroundStyle(.secondary.opacity(0.3))
AxisValueLabel()
.font(.caption2)
}
}
.frame(minHeight: 200)
if isTruncated {
truncationNotice
}
}
// MARK: - Empty State
@ViewBuilder
private var emptyChartView: some View {
VStack(spacing: 8) {
Image(systemName: "chart.bar")
.font(.title2)
.foregroundStyle(.secondary)
Text("No chartable data")
.font(.caption)
.foregroundStyle(.secondary)
}
.frame(maxWidth: .infinity, minHeight: 100)
}
// MARK: - Truncation Notice
@ViewBuilder
private var truncationNotice: some View {
Text("Showing \(maxBars) of \(dataTable.rowCount) categories")
.font(.caption2)
.foregroundStyle(.secondary)
}
// MARK: - Data Extraction
private var isTruncated: Bool {
dataTable.rowCount > maxBars
}
private var chartData: [ChartDataPoint] {
let labels = dataTable.stringValues(forColumn: categoryColumn)
let values = dataTable.numericValues(forColumn: valueColumn)
let count = min(labels.count, values.count, maxBars)
guard count > 0 else { return [] }
return (0..<count).map { i in
ChartDataPoint(label: labels[i], value: values[i])
}
}
}
// MARK: - Preview
#if DEBUG
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
#Preview("Bar Chart") {
let columns: [DataTable.Column] = [
.init(name: "department", index: 0, inferredType: .text),
.init(name: "revenue", index: 1, inferredType: .real),
]
let departments = ["Engineering", "Sales", "Marketing", "Support", "Design"]
let rows: [DataTable.Row] = departments.enumerated().map { i, dept in
DataTable.Row(
id: i,
values: [.text(dept), .real(Double.random(in: 50_000...200_000))],
columnNames: ["department", "revenue"]
)
}
let table = DataTable(columns: columns, rows: rows)
BarChartView(
dataTable: table,
categoryColumn: "department",
valueColumn: "revenue",
title: "Revenue by Department"
)
.padding()
.frame(height: 300)
}
#endif

View File

@@ -0,0 +1,21 @@
// ChartDataPoint.swift
// SwiftDBAI
//
// Shared data model used by all chart views.
import Foundation
/// A single data point for chart rendering.
///
/// Pairs a string label (category) with a numeric value.
/// Used as the common data format across BarChartView,
/// LineChartView, and PieChartView.
struct ChartDataPoint: Sendable, Identifiable {
var id: String { label }
/// The category label (x-axis or slice label).
let label: String
/// The numeric value (y-axis or slice size).
let value: Double
}

View File

@@ -0,0 +1,135 @@
// ChartResultView.swift
// SwiftDBAI
//
// Auto-selecting chart view that uses ChartDataDetector to pick the
// best chart type for a given DataTable.
import SwiftUI
import Charts
/// A chart view that automatically selects the best chart type for a `DataTable`.
///
/// Uses `ChartDataDetector` to analyze the data shape and renders the
/// appropriate chart (bar, line, or pie). If the data isn't suitable for
/// charting, the view renders nothing.
///
/// Usage:
/// ```swift
/// ChartResultView(dataTable: myTable)
/// ```
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
public struct ChartResultView: View {
/// The data table to chart.
public let dataTable: DataTable
/// Optional override: force a specific chart type.
public var chartType: ChartDataDetector.ChartType?
/// The detector used to analyze the data.
private let detector: ChartDataDetector
public init(
dataTable: DataTable,
chartType: ChartDataDetector.ChartType? = nil,
detector: ChartDataDetector = ChartDataDetector()
) {
self.dataTable = dataTable
self.chartType = chartType
self.detector = detector
}
public var body: some View {
if let recommendation = resolvedRecommendation {
VStack(alignment: .leading, spacing: 4) {
chartView(for: recommendation)
Text(recommendation.reason)
.font(.caption2)
.foregroundStyle(.tertiary)
}
}
}
// MARK: - Chart Selection
@ViewBuilder
private func chartView(
for recommendation: ChartDataDetector.ChartRecommendation
) -> some View {
switch recommendation.chartType {
case .bar:
BarChartView(
dataTable: dataTable,
categoryColumn: recommendation.categoryColumn,
valueColumn: recommendation.valueColumn
)
case .line:
LineChartView(
dataTable: dataTable,
categoryColumn: recommendation.categoryColumn,
valueColumn: recommendation.valueColumn
)
case .pie:
PieChartView(
dataTable: dataTable,
categoryColumn: recommendation.categoryColumn,
valueColumn: recommendation.valueColumn
)
}
}
// MARK: - Resolution
/// Resolves the chart recommendation, using the override type if provided.
private var resolvedRecommendation: ChartDataDetector.ChartRecommendation? {
if let override = chartType {
// Use forced chart type still need column pair from detector
let all = detector.allRecommendations(for: dataTable)
// Try to find recommendation for the forced type
if let match = all.first(where: { $0.chartType == override }) {
return match
}
// Fallback: use first recommendation and override its type
if let first = all.first {
return ChartDataDetector.ChartRecommendation(
chartType: override,
categoryColumn: first.categoryColumn,
valueColumn: first.valueColumn,
confidence: first.confidence * 0.8,
reason: first.reason
)
}
return nil
}
// Auto-detect best chart type
return detector.detect(dataTable)
}
}
// MARK: - Preview
#if DEBUG
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
#Preview("Auto Chart — Bar") {
let columns: [DataTable.Column] = [
.init(name: "city", index: 0, inferredType: .text),
.init(name: "population", index: 1, inferredType: .integer),
]
let cities = ["NYC", "LA", "Chicago", "Houston", "Phoenix"]
let pops: [Int64] = [8_336_817, 3_979_576, 2_693_976, 2_320_268, 1_680_992]
let rows: [DataTable.Row] = cities.enumerated().map { i, city in
DataTable.Row(
id: i,
values: [.text(city), .integer(pops[i])],
columnNames: ["city", "population"]
)
}
let table = DataTable(columns: columns, rows: rows)
ChartResultView(dataTable: table)
.padding()
.frame(height: 300)
}
#endif

View File

@@ -0,0 +1,206 @@
// LineChartView.swift
// SwiftDBAI
//
// A SwiftUI line chart that renders DataTable values using Swift Charts.
// Best for time series or sequential data (e.g., revenue over months).
import SwiftUI
import Charts
/// A line chart view that renders a `DataTable` column pair using Swift Charts.
///
/// Displays a connected line with optional area fill, point markers,
/// and smooth interpolation. Best suited for time series or sequential data.
///
/// Usage:
/// ```swift
/// LineChartView(
/// dataTable: table,
/// categoryColumn: "month",
/// valueColumn: "revenue"
/// )
/// ```
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
public struct LineChartView: View {
/// The data to chart.
public let dataTable: DataTable
/// Column name for category/time labels (x-axis).
public let categoryColumn: String
/// Column name for numeric values (y-axis).
public let valueColumn: String
/// Optional chart title.
public var title: String?
/// Whether to show an area fill below the line.
public var showAreaFill: Bool
/// Whether to show point markers at each data point.
public var showPoints: Bool
/// Maximum data points to display.
public var maxPoints: Int
public init(
dataTable: DataTable,
categoryColumn: String,
valueColumn: String,
title: String? = nil,
showAreaFill: Bool = true,
showPoints: Bool = true,
maxPoints: Int = 100
) {
self.dataTable = dataTable
self.categoryColumn = categoryColumn
self.valueColumn = valueColumn
self.title = title
self.showAreaFill = showAreaFill
self.showPoints = showPoints
self.maxPoints = maxPoints
}
public var body: some View {
VStack(alignment: .leading, spacing: 8) {
if let title {
Text(title)
.font(.caption.weight(.semibold))
.foregroundStyle(.secondary)
}
if chartData.isEmpty {
emptyChartView
} else {
chartContent
}
}
}
// MARK: - Chart Content
@ViewBuilder
private var chartContent: some View {
Chart(chartData, id: \.label) { item in
LineMark(
x: .value(categoryColumn, item.label),
y: .value(valueColumn, item.value)
)
.foregroundStyle(Color.accentColor)
.lineStyle(StrokeStyle(lineWidth: 2))
.interpolationMethod(.catmullRom)
if showAreaFill {
AreaMark(
x: .value(categoryColumn, item.label),
y: .value(valueColumn, item.value)
)
.foregroundStyle(
.linearGradient(
colors: [
Color.accentColor.opacity(0.2),
Color.accentColor.opacity(0.02),
],
startPoint: .top,
endPoint: .bottom
)
)
.interpolationMethod(.catmullRom)
}
if showPoints {
PointMark(
x: .value(categoryColumn, item.label),
y: .value(valueColumn, item.value)
)
.foregroundStyle(Color.accentColor)
.symbolSize(30)
}
}
.chartXAxis {
AxisMarks(values: .automatic) { _ in
AxisValueLabel()
.font(.caption2)
}
}
.chartYAxis {
AxisMarks(position: .leading) { _ in
AxisGridLine(stroke: StrokeStyle(lineWidth: 0.5, dash: [4, 4]))
.foregroundStyle(.secondary.opacity(0.3))
AxisValueLabel()
.font(.caption2)
}
}
.frame(minHeight: 200)
if isTruncated {
Text("Showing \(maxPoints) of \(dataTable.rowCount) data points")
.font(.caption2)
.foregroundStyle(.secondary)
}
}
// MARK: - Empty State
@ViewBuilder
private var emptyChartView: some View {
VStack(spacing: 8) {
Image(systemName: "chart.xyaxis.line")
.font(.title2)
.foregroundStyle(.secondary)
Text("No chartable data")
.font(.caption)
.foregroundStyle(.secondary)
}
.frame(maxWidth: .infinity, minHeight: 100)
}
// MARK: - Data Extraction
private var isTruncated: Bool {
dataTable.rowCount > maxPoints
}
private var chartData: [ChartDataPoint] {
let labels = dataTable.stringValues(forColumn: categoryColumn)
let values = dataTable.numericValues(forColumn: valueColumn)
let count = min(labels.count, values.count, maxPoints)
guard count > 0 else { return [] }
return (0..<count).map { i in
ChartDataPoint(label: labels[i], value: values[i])
}
}
}
// MARK: - Preview
#if DEBUG
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
#Preview("Line Chart") {
let columns: [DataTable.Column] = [
.init(name: "month", index: 0, inferredType: .text),
.init(name: "revenue", index: 1, inferredType: .real),
]
let months = ["Jan", "Feb", "Mar", "Apr", "May", "Jun"]
let rows: [DataTable.Row] = months.enumerated().map { i, month in
DataTable.Row(
id: i,
values: [.text(month), .real(Double(i + 1) * 15_000 + Double.random(in: -3000...3000))],
columnNames: ["month", "revenue"]
)
}
let table = DataTable(columns: columns, rows: rows)
LineChartView(
dataTable: table,
categoryColumn: "month",
valueColumn: "revenue",
title: "Monthly Revenue"
)
.padding()
.frame(height: 300)
}
#endif

View File

@@ -0,0 +1,234 @@
// PieChartView.swift
// SwiftDBAI
//
// A SwiftUI pie/donut chart that renders DataTable values using Swift Charts.
// Best for proportional breakdowns with few categories (e.g., market share).
import SwiftUI
import Charts
/// A pie chart view that renders a `DataTable` column pair using Swift Charts.
///
/// Displays proportional slices with category labels. Each slice is
/// automatically colored from a curated palette and sized relative to
/// its proportion of the total. Best suited for data with few categories
/// ( 8) where all values are positive.
///
/// Usage:
/// ```swift
/// PieChartView(
/// dataTable: table,
/// categoryColumn: "status",
/// valueColumn: "count"
/// )
/// ```
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
public struct PieChartView: View {
/// The data to chart.
public let dataTable: DataTable
/// Column name for category labels (slice labels).
public let categoryColumn: String
/// Column name for numeric values (slice sizes).
public let valueColumn: String
/// Optional chart title.
public var title: String?
/// Inner radius ratio for donut style (0 = full pie, >0 = donut).
public var innerRadiusRatio: CGFloat
/// Maximum number of slices before grouping remaining into "Other".
public var maxSlices: Int
public init(
dataTable: DataTable,
categoryColumn: String,
valueColumn: String,
title: String? = nil,
innerRadiusRatio: CGFloat = 0.4,
maxSlices: Int = 8
) {
self.dataTable = dataTable
self.categoryColumn = categoryColumn
self.valueColumn = valueColumn
self.title = title
self.innerRadiusRatio = innerRadiusRatio
self.maxSlices = maxSlices
}
public var body: some View {
VStack(alignment: .leading, spacing: 8) {
if let title {
Text(title)
.font(.caption.weight(.semibold))
.foregroundStyle(.secondary)
}
if chartData.isEmpty {
emptyChartView
} else {
HStack(alignment: .center, spacing: 16) {
chartContent
legendView
}
}
}
}
// MARK: - Chart Content
@ViewBuilder
private var chartContent: some View {
Chart(chartData, id: \.label) { item in
SectorMark(
angle: .value(valueColumn, item.value),
innerRadius: .ratio(innerRadiusRatio),
angularInset: 1.5
)
.foregroundStyle(by: .value(categoryColumn, item.label))
.cornerRadius(3)
}
.chartForegroundStyleScale(
domain: chartData.map(\.label),
range: sliceColors
)
.chartLegend(.hidden)
.frame(minWidth: 150, minHeight: 150)
.aspectRatio(1, contentMode: .fit)
}
// MARK: - Legend
@ViewBuilder
private var legendView: some View {
VStack(alignment: .leading, spacing: 6) {
ForEach(Array(chartData.enumerated()), id: \.element.label) { index, item in
HStack(spacing: 8) {
Circle()
.fill(sliceColors[index % sliceColors.count])
.frame(width: 8, height: 8)
Text(item.label)
.font(.caption)
.foregroundStyle(.primary)
.lineLimit(1)
Spacer()
Text(percentageText(for: item.value))
.font(.caption)
.foregroundStyle(.secondary)
.monospacedDigit()
}
}
}
}
// MARK: - Empty State
@ViewBuilder
private var emptyChartView: some View {
VStack(spacing: 8) {
Image(systemName: "chart.pie")
.font(.title2)
.foregroundStyle(.secondary)
Text("No chartable data")
.font(.caption)
.foregroundStyle(.secondary)
}
.frame(maxWidth: .infinity, minHeight: 100)
}
// MARK: - Colors
/// Curated color palette for pie slices.
private var sliceColors: [Color] {
[
.blue,
.green,
.orange,
.purple,
.pink,
.cyan,
.yellow,
.indigo,
.mint,
.teal,
]
}
// MARK: - Helpers
private var total: Double {
chartData.reduce(0) { $0 + $1.value }
}
private func percentageText(for value: Double) -> String {
guard total > 0 else { return "0%" }
let pct = (value / total) * 100
if pct >= 10 {
return String(format: "%.0f%%", pct)
}
return String(format: "%.1f%%", pct)
}
// MARK: - Data Extraction
private var chartData: [ChartDataPoint] {
let labels = dataTable.stringValues(forColumn: categoryColumn)
let values = dataTable.numericValues(forColumn: valueColumn)
let count = min(labels.count, values.count)
guard count > 0 else { return [] }
// Build all points, sorted by value descending
var points = (0..<count).map { i in
ChartDataPoint(label: labels[i], value: values[i])
}
.filter { $0.value > 0 }
.sorted { $0.value > $1.value }
// Group excess slices into "Other"
if points.count > maxSlices {
let kept = Array(points.prefix(maxSlices - 1))
let otherValue = points.dropFirst(maxSlices - 1).reduce(0) { $0 + $1.value }
points = kept + [ChartDataPoint(label: "Other", value: otherValue)]
}
return points
}
}
// MARK: - Preview
#if DEBUG
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
#Preview("Pie Chart") {
let columns: [DataTable.Column] = [
.init(name: "status", index: 0, inferredType: .text),
.init(name: "count", index: 1, inferredType: .integer),
]
let statuses = ["Active", "Inactive", "Pending", "Archived"]
let counts: [Int64] = [45, 20, 15, 10]
let rows: [DataTable.Row] = statuses.enumerated().map { i, status in
DataTable.Row(
id: i,
values: [.text(status), .integer(counts[i])],
columnNames: ["status", "count"]
)
}
let table = DataTable(columns: columns, rows: rows)
PieChartView(
dataTable: table,
categoryColumn: "status",
valueColumn: "count",
title: "Users by Status"
)
.padding()
.frame(height: 250)
}
#endif

View File

@@ -0,0 +1,225 @@
// ChatView.swift
// SwiftDBAI
//
// Drop-in SwiftUI view for chatting with a SQLite database.
// Renders messages with automatic data table display for query results.
import SwiftUI
/// A drop-in SwiftUI chat interface for querying SQLite databases
/// with natural language.
///
/// `ChatView` renders the full conversation including:
/// - User messages (right-aligned, accent-colored)
/// - Assistant responses with text summaries
/// - **Automatic data tables** via `ScrollableDataTableView` when query results
/// contain tabular data (rows + columns)
/// - SQL query disclosure for transparency
/// - Error messages with red styling
/// - A loading indicator while the engine is processing
///
/// Usage:
/// ```swift
/// let engine = ChatEngine(database: myPool, model: myModel)
/// let viewModel = ChatViewModel(engine: engine)
///
/// ChatView(viewModel: viewModel)
/// ```
///
/// Or use the convenience initializer:
/// ```swift
/// ChatView(engine: myEngine)
/// ```
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
public struct ChatView: View {
@Bindable private var viewModel: ChatViewModel
@State private var inputText: String = ""
@FocusState private var isInputFocused: Bool
@Environment(\.chatViewConfiguration) private var config
/// Creates a ChatView with an existing view model.
///
/// - Parameter viewModel: The `ChatViewModel` driving this view.
public init(viewModel: ChatViewModel) {
self.viewModel = viewModel
}
/// Creates a ChatView with a `ChatEngine`, automatically creating
/// a `ChatViewModel`.
///
/// - Parameter engine: The `ChatEngine` to power the chat.
public init(engine: ChatEngine) {
self.viewModel = ChatViewModel(engine: engine)
}
public var body: some View {
VStack(spacing: 0) {
messageList
Divider()
inputBar
}
.background(config.backgroundColor)
.applyColorSchemeOverride(config.colorSchemeOverride)
}
// MARK: - Message List
@ViewBuilder
private var messageList: some View {
ScrollViewReader { proxy in
ScrollView {
LazyVStack(spacing: 12) {
if viewModel.messages.isEmpty {
emptyState
}
ForEach(viewModel.messages) { message in
messageBubble(for: message)
.id(message.id)
}
if viewModel.isLoading {
loadingIndicator
}
}
.padding(.horizontal, 16)
.padding(.vertical, 12)
}
.onChange(of: viewModel.messages.count) { _, _ in
if let lastMessage = viewModel.messages.last {
withAnimation(.easeOut(duration: 0.3)) {
proxy.scrollTo(lastMessage.id, anchor: .bottom)
}
}
}
}
}
// MARK: - Empty State
@ViewBuilder
private var emptyState: some View {
VStack(spacing: 12) {
Image(systemName: config.emptyStateIcon)
.font(.system(size: 40))
.foregroundStyle(.tertiary)
Text(config.emptyStateTitle)
.font(.headline)
.foregroundStyle(.secondary)
Text(config.emptyStateSubtitle)
.font(.subheadline)
.foregroundStyle(.tertiary)
.multilineTextAlignment(.center)
}
.frame(maxWidth: .infinity)
.padding(.vertical, 60)
}
// MARK: - Loading Indicator
@ViewBuilder
private var loadingIndicator: some View {
HStack(alignment: .top) {
HStack(spacing: 8) {
ProgressView()
.controlSize(.small)
Text("Querying…")
.font(.callout)
.foregroundStyle(.secondary)
}
.padding(.horizontal, 14)
.padding(.vertical, 10)
.background(
config.assistantBubbleColor,
in: RoundedRectangle(cornerRadius: config.bubbleCornerRadius, style: .continuous)
)
Spacer(minLength: 48)
}
.id("loading-indicator")
.transition(.opacity.combined(with: .move(edge: .bottom)))
}
// MARK: - Input Bar
@ViewBuilder
private var inputBar: some View {
HStack(spacing: 8) {
TextField(config.inputPlaceholder, text: $inputText, axis: .vertical)
.textFieldStyle(.plain)
.font(config.inputFont)
.lineLimit(1...5)
.focused($isInputFocused)
.onSubmit { sendMessage() }
.submitLabel(.send)
Button(action: sendMessage) {
Image(systemName: "arrow.up.circle.fill")
.font(.title2)
.foregroundStyle(canSend ? config.accentColor : Color.secondary)
}
.disabled(!canSend)
.keyboardShortcut(.return, modifiers: .command)
}
.padding(.horizontal, 16)
.padding(.vertical, 10)
.background(config.inputBarBackgroundColor)
}
// MARK: - Message Bubble
@ViewBuilder
private func messageBubble(for message: ChatMessage) -> some View {
if message.role == .error {
MessageBubbleView(
message: message,
onRetry: makeRetryAction(for: message)
)
} else {
MessageBubbleView(message: message)
}
}
private func makeRetryAction(for errorMessage: ChatMessage) -> @Sendable () async -> Void {
let vm = viewModel
let messageId = errorMessage.id
return { @MainActor [vm] in
let allMessages = await MainActor.run { vm.messages }
if let lastUserMessage = allMessages
.prefix(while: { $0.id != messageId })
.last(where: { $0.role == .user }) {
await vm.send(lastUserMessage.content)
}
}
}
// MARK: - Helpers
private var canSend: Bool {
!inputText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty && !viewModel.isLoading
}
private func sendMessage() {
guard canSend else { return }
let text = inputText
inputText = ""
Task {
await viewModel.send(text)
}
}
}
// MARK: - Color Scheme Override
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
private extension View {
@ViewBuilder
func applyColorSchemeOverride(_ scheme: ColorScheme?) -> some View {
if let scheme {
self.environment(\.colorScheme, scheme)
} else {
self
}
}
}

View File

@@ -0,0 +1,137 @@
// ChatViewModel.swift
// SwiftDBAI
//
// Observable view model that bridges ChatEngine with the SwiftUI ChatView.
import Foundation
import Observation
/// The readiness state of the schema introspection.
public enum SchemaReadiness: Sendable, Equatable {
/// Schema has not been loaded yet.
case idle
/// Schema introspection is in progress.
case loading
/// Schema is ready with the given number of tables.
case ready(tableCount: Int)
/// Schema introspection failed.
case failed(String)
public var isReady: Bool {
if case .ready = self { return true }
return false
}
}
/// Observable view model that drives the `ChatView`.
///
/// Wraps `ChatEngine` to provide reactive state updates for the SwiftUI layer.
/// Manages the message list, loading state, error presentation, and schema
/// readiness. Call ``prepare()`` at view-appear time to eagerly introspect the
/// database schema.
///
/// Usage:
/// ```swift
/// let viewModel = ChatViewModel(engine: myChatEngine)
/// ChatView(viewModel: viewModel)
/// ```
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
@Observable
@MainActor
public final class ChatViewModel {
// MARK: - Public State
/// All messages in the conversation, in chronological order.
public private(set) var messages: [ChatMessage] = []
/// Whether the engine is currently processing a request.
public private(set) var isLoading: Bool = false
/// The most recent error message, if any. Cleared on next send.
public private(set) var errorMessage: String?
/// Current schema readiness state.
public private(set) var schemaReadiness: SchemaReadiness = .idle
// MARK: - Dependencies
private let engine: ChatEngine
// MARK: - Initialization
/// Creates a new ChatViewModel.
///
/// - Parameter engine: The `ChatEngine` to use for processing messages.
public init(engine: ChatEngine) {
self.engine = engine
}
// MARK: - Schema Preparation
/// Eagerly introspects the database schema so it's ready before the first query.
///
/// This should be called from a `.task` modifier on the view. It transitions
/// `schemaReadiness` through `.loading` `.ready` (or `.failed`).
/// If the schema is already cached, this completes immediately.
public func prepare() async {
// Don't re-prepare if already ready
if schemaReadiness.isReady { return }
schemaReadiness = .loading
do {
let schema = try await engine.prepareSchema()
schemaReadiness = .ready(tableCount: schema.tableNames.count)
} catch {
schemaReadiness = .failed(error.localizedDescription)
}
}
// MARK: - Public API
/// Sends a user message and appends the response to the conversation.
///
/// - Parameter text: The natural language message from the user.
public func send(_ text: String) async {
let trimmed = text.trimmingCharacters(in: .whitespacesAndNewlines)
guard !trimmed.isEmpty else { return }
errorMessage = nil
// Add user message immediately
let userMessage = ChatMessage(role: .user, content: trimmed)
messages.append(userMessage)
isLoading = true
defer { isLoading = false }
do {
let response = try await engine.send(trimmed)
let assistantMessage = ChatMessage(
role: .assistant,
content: response.summary,
queryResult: response.queryResult,
sql: response.sql
)
messages.append(assistantMessage)
} catch {
let typedError = (error as? SwiftDBAIError)
let errorMsg = ChatMessage(
role: .error,
content: error.localizedDescription,
error: typedError
)
messages.append(errorMsg)
errorMessage = error.localizedDescription
}
}
/// Clears the conversation and resets the engine state.
public func reset() {
messages.removeAll()
errorMessage = nil
engine.reset()
}
}

View File

@@ -0,0 +1,220 @@
// DataChatView.swift
// SwiftDBAI
//
// Zero-config SwiftUI view: provide a database path and a model, get a chat UI.
import AnyLanguageModel
import GRDB
import SwiftUI
/// A convenience SwiftUI view that wraps the full chat-with-database stack.
///
/// `DataChatView` is the simplest entry point into SwiftDBAI. It requires only
/// a database file path and a language model no schema files, no annotations,
/// no manual setup. The view creates a GRDB connection, a `ChatEngine`,
/// a `ChatViewModel`, and renders a fully functional `ChatView`.
///
/// Usage with just a path and model:
/// ```swift
/// DataChatView(
/// databasePath: "/path/to/mydata.sqlite",
/// model: OllamaLanguageModel(model: "llama3")
/// )
/// ```
///
/// Usage with additional configuration:
/// ```swift
/// DataChatView(
/// databasePath: documentsURL.appendingPathComponent("app.db").path,
/// model: OpenAILanguageModel(apiKey: key),
/// allowlist: .standard,
/// additionalContext: "This database stores a recipe app's data."
/// )
/// ```
///
/// If you already have a GRDB `DatabasePool` or `DatabaseQueue`, use
/// `ChatView` with a `ChatEngine` directly for full control.
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
public struct DataChatView: View {
@State private var viewModel: ChatViewModel
@State private var loadError: DataChatError?
/// Creates a DataChatView from a database file path and language model.
///
/// This is the zero-config convenience initializer. It opens a GRDB
/// `DatabasePool` at the given path, creates a `ChatEngine` with
/// read-only defaults, and wires up the full chat UI.
///
/// - Parameters:
/// - databasePath: Absolute path to a SQLite database file.
/// - model: Any `AnyLanguageModel`-compatible language model instance.
/// - allowlist: SQL operations the LLM may generate. Defaults to `.readOnly` (SELECT only).
/// - additionalContext: Optional extra context about the database for the LLM system prompt
/// (e.g., "This database stores e-commerce orders and products.").
/// - maxSummaryRows: Maximum rows to include when summarizing results (default: 50).
public init(
databasePath: String,
model: any LanguageModel,
allowlist: OperationAllowlist = .readOnly,
additionalContext: String? = nil,
maxSummaryRows: Int = 50
) {
do {
let pool = try DatabasePool(path: databasePath)
let engine = ChatEngine(
database: pool,
model: model,
allowlist: allowlist,
additionalContext: additionalContext,
maxSummaryRows: maxSummaryRows
)
self._viewModel = State(initialValue: ChatViewModel(engine: engine))
self._loadError = State(initialValue: nil)
} catch {
// If the database can't be opened, create a placeholder engine
// and store the error to display in the UI.
let queue = try! DatabaseQueue()
let engine = ChatEngine(
database: queue,
model: model,
allowlist: allowlist,
additionalContext: additionalContext,
maxSummaryRows: maxSummaryRows
)
self._viewModel = State(initialValue: ChatViewModel(engine: engine))
self._loadError = State(initialValue: DataChatError.databaseOpenFailed(
path: databasePath,
underlying: error
))
}
}
/// Creates a DataChatView from an existing GRDB database connection and language model.
///
/// Use this initializer when you already have a configured `DatabasePool` or
/// `DatabaseQueue` and want the convenience of `DataChatView` without
/// creating a `ChatEngine` yourself.
///
/// - Parameters:
/// - database: A GRDB `DatabaseWriter` (`DatabasePool` or `DatabaseQueue`).
/// - model: Any `AnyLanguageModel`-compatible language model instance.
/// - allowlist: SQL operations the LLM may generate. Defaults to `.readOnly`.
/// - additionalContext: Optional extra context about the database for the LLM.
/// - maxSummaryRows: Maximum rows to include when summarizing results (default: 50).
public init(
database: any DatabaseWriter,
model: any LanguageModel,
allowlist: OperationAllowlist = .readOnly,
additionalContext: String? = nil,
maxSummaryRows: Int = 50
) {
let engine = ChatEngine(
database: database,
model: model,
allowlist: allowlist,
additionalContext: additionalContext,
maxSummaryRows: maxSummaryRows
)
self._viewModel = State(initialValue: ChatViewModel(engine: engine))
self._loadError = State(initialValue: nil)
}
public var body: some View {
if let error = loadError {
errorView(error)
} else {
ChatView(viewModel: viewModel)
.task {
await viewModel.prepare()
}
.overlay {
if case .loading = viewModel.schemaReadiness {
schemaLoadingView
}
if case .failed(let reason) = viewModel.schemaReadiness {
schemaErrorView(reason)
}
}
}
}
// MARK: - Schema Loading View
@ViewBuilder
private var schemaLoadingView: some View {
VStack(spacing: 16) {
ProgressView()
.controlSize(.large)
Text("Introspecting database schema…")
.font(.subheadline)
.foregroundStyle(.secondary)
}
.frame(maxWidth: .infinity, maxHeight: .infinity)
.background(.ultraThinMaterial)
}
// MARK: - Schema Error View
@ViewBuilder
private func schemaErrorView(_ reason: String) -> some View {
VStack(spacing: 16) {
Image(systemName: "exclamationmark.triangle.fill")
.font(.system(size: 40))
.foregroundStyle(.orange)
Text("Schema Introspection Failed")
.font(.headline)
Text(reason)
.font(.subheadline)
.foregroundStyle(.secondary)
.multilineTextAlignment(.center)
.padding(.horizontal, 32)
Button("Retry") {
Task {
await viewModel.prepare()
}
}
.buttonStyle(.borderedProminent)
}
.frame(maxWidth: .infinity, maxHeight: .infinity)
.background(.ultraThinMaterial)
}
// MARK: - Database Open Error View
@ViewBuilder
private func errorView(_ error: DataChatError) -> some View {
VStack(spacing: 16) {
Image(systemName: "exclamationmark.triangle.fill")
.font(.system(size: 40))
.foregroundStyle(.red)
Text("Unable to Open Database")
.font(.headline)
Text(error.localizedDescription)
.font(.subheadline)
.foregroundStyle(.secondary)
.multilineTextAlignment(.center)
.padding(.horizontal, 32)
}
.frame(maxWidth: .infinity, maxHeight: .infinity)
}
}
// MARK: - Errors
/// Errors specific to `DataChatView` initialization.
public enum DataChatError: Error, LocalizedError, Sendable {
/// The database file could not be opened at the given path.
case databaseOpenFailed(path: String, underlying: any Error)
public var errorDescription: String? {
switch self {
case .databaseOpenFailed(let path, let underlying):
return "Could not open database at \"\(path)\": \(underlying.localizedDescription)"
}
}
}

View File

@@ -0,0 +1,362 @@
// ErrorMessageView.swift
// SwiftDBAI
//
// Reusable SwiftUI component that renders error messages with contextual
// icons, descriptions, and optional retry actions based on the error type.
import SwiftUI
/// A reusable SwiftUI component that renders a ``SwiftDBAIError`` with an
/// appropriate icon, human-readable message, and optional retry action.
///
/// The view automatically selects a visual treatment based on the error
/// category:
///
/// | Category | Icon | Color | Retry? |
/// |-------------------|-------------------------------|---------|--------|
/// | Safety / blocked | `shield.trianglebadge.excl` | Orange | No |
/// | Confirmation | `hand.raised.fill` | Yellow | Yes* |
/// | LLM failure | `brain` | Purple | Yes |
/// | Schema / DB | `cylinder.split.1x2` | Red | No |
/// | Recoverable SQL | `arrow.clockwise` | Blue | Yes |
/// | Generic | `exclamationmark.triangle` | Red | No |
///
/// *Confirmation retry triggers the confirm callback, not a standard retry.
///
/// Usage:
/// ```swift
/// ErrorMessageView(
/// error: .llmTimeout(seconds: 30),
/// onRetry: { /* resend the message */ }
/// )
/// ```
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
public struct ErrorMessageView: View {
@Environment(\.chatViewConfiguration) private var config
/// The error to display. When `nil`, the view falls back to the raw message.
private let error: SwiftDBAIError?
/// The raw error message string (used as fallback when error is nil).
private let message: String
/// Called when the user taps the retry button. `nil` hides the button.
private let onRetry: (@Sendable () async -> Void)?
/// Called when the user confirms a destructive operation.
private let onConfirm: (@Sendable () async -> Void)?
@State private var isRetrying = false
// MARK: - Initializers
/// Creates an ErrorMessageView from a typed ``SwiftDBAIError``.
///
/// - Parameters:
/// - error: The ``SwiftDBAIError`` to display.
/// - onRetry: An optional async closure invoked when the user taps retry.
/// - onConfirm: An optional async closure invoked when the user confirms
/// a destructive operation (only relevant for `.confirmationRequired`).
public init(
error: SwiftDBAIError,
onRetry: (@Sendable () async -> Void)? = nil,
onConfirm: (@Sendable () async -> Void)? = nil
) {
self.error = error
self.message = error.localizedDescription
self.onRetry = onRetry
self.onConfirm = onConfirm
}
/// Creates an ErrorMessageView from a ``ChatMessage``.
///
/// Extracts the typed error if available, otherwise falls back to the
/// message content string.
///
/// - Parameters:
/// - message: The chat message with role `.error`.
/// - onRetry: An optional async closure invoked when the user taps retry.
/// - onConfirm: An optional async closure invoked when the user confirms
/// a destructive operation.
public init(
chatMessage: ChatMessage,
onRetry: (@Sendable () async -> Void)? = nil,
onConfirm: (@Sendable () async -> Void)? = nil
) {
self.error = chatMessage.error
self.message = chatMessage.content
self.onRetry = onRetry
self.onConfirm = onConfirm
}
/// Creates an ErrorMessageView from a plain string (untyped fallback).
///
/// - Parameters:
/// - message: The error message string.
/// - onRetry: An optional async closure invoked when the user taps retry.
public init(
message: String,
onRetry: (@Sendable () async -> Void)? = nil
) {
self.error = nil
self.message = message
self.onRetry = onRetry
self.onConfirm = nil
}
// MARK: - Body
public var body: some View {
VStack(alignment: .leading, spacing: 10) {
// Icon + message row
HStack(alignment: .firstTextBaseline, spacing: 8) {
Image(systemName: iconName)
.foregroundStyle(iconColor)
.font(.callout)
.accessibilityHidden(true)
VStack(alignment: .leading, spacing: 4) {
if let title = errorTitle {
Text(title)
.font(.callout.weight(.semibold))
.foregroundStyle(iconColor)
}
Text(message)
.font(.body)
.foregroundStyle(.primary)
.textSelection(.enabled)
.fixedSize(horizontal: false, vertical: true)
if let hint = recoveryHint {
Text(hint)
.font(.caption)
.foregroundStyle(.secondary)
.fixedSize(horizontal: false, vertical: true)
}
}
}
// Action buttons
if showRetryButton || showConfirmButton {
HStack(spacing: 12) {
if showConfirmButton {
confirmButton
}
if showRetryButton {
retryButton
}
}
.padding(.leading, 26) // Align with text (icon width + spacing)
}
}
.accessibilityElement(children: .combine)
.accessibilityLabel(accessibilityDescription)
}
// MARK: - Action Buttons
@ViewBuilder
private var retryButton: some View {
Button {
guard !isRetrying else { return }
isRetrying = true
Task {
await onRetry?()
isRetrying = false
}
} label: {
HStack(spacing: 4) {
if isRetrying {
ProgressView()
.controlSize(.mini)
} else {
Image(systemName: "arrow.clockwise")
.font(.caption)
}
Text(retryButtonLabel)
.font(.caption.weight(.medium))
}
.padding(.horizontal, 10)
.padding(.vertical, 6)
.background(iconColor.opacity(0.12))
.foregroundStyle(iconColor)
.clipShape(Capsule())
}
.buttonStyle(.plain)
.disabled(isRetrying)
}
@ViewBuilder
private var confirmButton: some View {
Button {
Task {
await onConfirm?()
}
} label: {
HStack(spacing: 4) {
Image(systemName: "checkmark.circle")
.font(.caption)
Text("Confirm")
.font(.caption.weight(.medium))
}
.padding(.horizontal, 10)
.padding(.vertical, 6)
.background(Color.orange.opacity(0.12))
.foregroundStyle(.orange)
.clipShape(Capsule())
}
.buttonStyle(.plain)
}
// MARK: - Error Classification
private var errorCategory: ErrorCategory {
guard let error else { return .generic }
if error.requiresUserAction {
return .confirmation
}
if error.isSafetyError {
return .safety
}
if error.isRecoverable {
return .recoverable
}
switch error {
case .llmFailure, .llmResponseUnparseable, .llmTimeout:
return .llm
case .schemaIntrospectionFailed, .emptySchema, .databaseError, .queryTimedOut:
return .database
case .configurationError:
return .configuration
default:
return .generic
}
}
private enum ErrorCategory {
case safety
case confirmation
case llm
case database
case recoverable
case configuration
case generic
}
// MARK: - Visual Properties
private var iconName: String {
switch errorCategory {
case .safety:
return "shield.trianglebadge.exclamationmark.fill"
case .confirmation:
return "hand.raised.fill"
case .llm:
return "brain"
case .database:
return "cylinder.split.1x2"
case .recoverable:
return "arrow.clockwise"
case .configuration:
return "gearshape.triangle.fill"
case .generic:
return "exclamationmark.triangle.fill"
}
}
private var iconColor: Color {
switch errorCategory {
case .safety:
return .orange
case .confirmation:
return .yellow
case .llm:
return .purple
case .database:
return config.errorColor
case .recoverable:
return .blue
case .configuration:
return .gray
case .generic:
return config.errorColor
}
}
private var errorTitle: String? {
switch errorCategory {
case .safety:
return "Operation Blocked"
case .confirmation:
return "Confirmation Required"
case .llm:
return "AI Provider Error"
case .database:
return "Database Error"
case .recoverable:
return "Query Issue"
case .configuration:
return "Configuration Error"
case .generic:
return nil
}
}
private var recoveryHint: String? {
guard let error else { return nil }
switch error {
case .noSQLGenerated, .llmResponseUnparseable:
return "Try rephrasing your question."
case .tableNotFound:
return "Check that you're referring to an existing table."
case .columnNotFound:
return "Verify the column name matches your schema."
case .invalidSQL:
return "The AI generated an invalid query. Try asking differently."
case .llmTimeout:
return "The AI took too long. Try a simpler question."
case .llmFailure:
return "The AI service may be temporarily unavailable."
case .emptySchema:
return "Add some tables to your database first."
case .queryTimedOut:
return "Try a simpler query or add database indexes."
default:
return nil
}
}
// MARK: - Button Visibility
private var showRetryButton: Bool {
guard onRetry != nil else { return false }
return errorCategory == .recoverable || errorCategory == .llm
}
private var showConfirmButton: Bool {
guard onConfirm != nil else { return false }
return errorCategory == .confirmation
}
private var retryButtonLabel: String {
switch errorCategory {
case .llm:
return "Retry"
case .recoverable:
return "Try Again"
default:
return "Retry"
}
}
// MARK: - Accessibility
private var accessibilityDescription: String {
let prefix = errorTitle.map { "\($0): " } ?? "Error: "
return prefix + message
}
}

View File

@@ -0,0 +1,214 @@
// MessageBubbleView.swift
// SwiftDBAI
//
// Renders a single ChatMessage as a styled bubble with optional
// data table and SQL disclosure for query results.
import SwiftUI
import Charts
/// Renders a single `ChatMessage` in the chat conversation.
///
/// - **User messages** display right-aligned with an accent-colored background
/// and white text, using a continuous rounded rectangle shape.
/// - **Assistant messages** display left-aligned with a secondary background.
/// The natural language text summary is the primary content, rendered with
/// full `.body` font and `.primary` foreground for readability.
/// If the message contains a `queryResult` with tabular data, a
/// `ScrollableDataTableView` is automatically embedded below the summary.
/// An optional SQL disclosure group shows the generated query.
/// - **Error messages** display left-aligned with a red-tinted background
/// and an exclamation mark icon.
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
struct MessageBubbleView: View {
@Environment(\.chatViewConfiguration) private var config
let message: ChatMessage
/// Whether to show the SQL query in a disclosure group.
var showSQL: Bool = true
/// Maximum height for the data table before it scrolls.
var maxTableHeight: CGFloat = 300
/// Called when the user taps "Retry" on a recoverable error.
var onRetry: (@Sendable () async -> Void)?
/// Called when the user confirms a destructive operation.
var onConfirm: (@Sendable () async -> Void)?
var body: some View {
HStack(alignment: .top, spacing: 8) {
if message.role == .user { Spacer(minLength: 48) }
if message.role != .user, let icon = config.assistantAvatarIcon {
assistantAvatar(icon: icon)
}
bubbleContent
.padding(.horizontal, config.messagePadding)
.padding(.vertical, config.messagePadding * 10 / 14)
.background(bubbleBackground)
.clipShape(bubbleShape)
if message.role != .user { Spacer(minLength: 48) }
}
}
@ViewBuilder
private func assistantAvatar(icon: String) -> some View {
Image(systemName: icon)
.font(.system(size: 16))
.foregroundStyle(.white)
.frame(width: 32, height: 32)
.background(config.assistantAvatarColor)
.clipShape(Circle())
.padding(.top, 2)
}
// MARK: - Bubble Content
@ViewBuilder
private var bubbleContent: some View {
switch message.role {
case .user:
userContent
case .assistant:
assistantContent
case .error:
errorContent
}
}
// MARK: - User Content
@ViewBuilder
private var userContent: some View {
Text(message.content)
.font(config.messageFont)
.foregroundStyle(config.userTextColor)
.textSelection(.enabled)
}
// MARK: - Assistant Content (Text Summary + Data Table + SQL)
@ViewBuilder
private var assistantContent: some View {
VStack(alignment: .leading, spacing: 10) {
// Natural language text summary primary content
Text(message.content)
.font(config.summaryFont)
.foregroundStyle(config.assistantTextColor)
.textSelection(.enabled)
.fixedSize(horizontal: false, vertical: true)
// Data table automatically shown when queryResult has tabular data
if let queryResult = message.queryResult,
!queryResult.columns.isEmpty,
!queryResult.rows.isEmpty {
dataTableSection(for: queryResult)
}
// SQL disclosure collapsed by default for transparency
if config.showSQLDisclosure, let sql = message.sql {
sqlDisclosure(sql: sql)
}
}
}
// MARK: - Error Content
@ViewBuilder
private var errorContent: some View {
ErrorMessageView(
chatMessage: message,
onRetry: onRetry,
onConfirm: onConfirm
)
}
/// Maximum height for the chart section.
var maxChartHeight: CGFloat = 250
/// Whether to show auto-detected charts. Defaults to `true`.
var showCharts: Bool = true
// MARK: - Chart Detection
/// The shared detector used for chart eligibility checks.
private static let chartDetector = ChartDataDetector()
// MARK: - Data Table Section
@ViewBuilder
private func dataTableSection(for queryResult: QueryResult) -> some View {
let dataTable = DataTable(queryResult)
VStack(alignment: .leading, spacing: 8) {
// Chart automatically shown when ChartDataDetector finds eligible data
if showCharts {
chartSection(for: dataTable)
}
Divider()
ScrollableDataTableView(
dataTable: dataTable,
showAlternatingRows: true,
showFooter: true
)
.frame(maxHeight: maxTableHeight)
}
}
// MARK: - Chart Section
@ViewBuilder
private func chartSection(for dataTable: DataTable) -> some View {
let detector = Self.chartDetector
if detector.detect(dataTable) != nil {
VStack(alignment: .leading, spacing: 4) {
ChartResultView(dataTable: dataTable, detector: detector)
.frame(maxHeight: maxChartHeight)
}
}
}
// MARK: - SQL Disclosure
@ViewBuilder
private func sqlDisclosure(sql: String) -> some View {
DisclosureGroup {
Text(sql)
.font(config.sqlFont)
.foregroundStyle(.secondary)
.textSelection(.enabled)
.padding(8)
.frame(maxWidth: .infinity, alignment: .leading)
.background(Color.primary.opacity(0.04))
.clipShape(RoundedRectangle(cornerRadius: 6))
} label: {
Label("SQL Query", systemImage: "chevron.left.forwardslash.chevron.right")
.font(.caption)
.foregroundStyle(.secondary)
}
}
// MARK: - Styling Helpers
private var bubbleShape: RoundedRectangle {
RoundedRectangle(cornerRadius: config.bubbleCornerRadius, style: .continuous)
}
@ViewBuilder
private var bubbleBackground: some View {
switch message.role {
case .user:
config.userBubbleColor
case .assistant:
config.assistantBubbleColor
case .error:
config.errorColor.opacity(0.1)
}
}
}

View File

@@ -0,0 +1,114 @@
// DataChatSheet.swift
// SwiftDBAI
//
// SwiftUI wrapper that adds NavigationStack chrome around DataChatView.
// Designed for use with .sheet() and .fullScreenCover().
import AnyLanguageModel
import GRDB
import SwiftUI
/// A presentation-ready wrapper around ``DataChatView`` that adds a
/// `NavigationStack`, title, and **Done** button.
///
/// Use `DataChatSheet` with SwiftUI's `.sheet()` or `.fullScreenCover()`
/// modifiers so consumers get a fully navigable chat experience out of the box.
///
/// ```swift
/// .sheet(isPresented: $showChat) {
/// DataChatSheet(
/// databasePath: "/path/to/mydata.sqlite",
/// model: OllamaLanguageModel(model: "llama3")
/// )
/// }
/// ```
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
public struct DataChatSheet: View {
let databasePath: String?
let database: (any DatabaseWriter)?
let model: any LanguageModel
var allowlist: OperationAllowlist
var additionalContext: String?
var title: String
@Environment(\.dismiss) private var dismiss
/// Creates a DataChatSheet from a database file path and language model.
///
/// - Parameters:
/// - databasePath: Absolute path to a SQLite database file.
/// - model: Any `AnyLanguageModel`-compatible language model instance.
/// - allowlist: SQL operations the LLM may generate. Defaults to `.readOnly`.
/// - additionalContext: Optional extra context about the database for the LLM.
/// - title: Navigation bar title. Defaults to `"AI Chat"`.
public init(
databasePath: String,
model: any LanguageModel,
allowlist: OperationAllowlist = .readOnly,
additionalContext: String? = nil,
title: String = "AI Chat"
) {
self.databasePath = databasePath
self.database = nil
self.model = model
self.allowlist = allowlist
self.additionalContext = additionalContext
self.title = title
}
/// Creates a DataChatSheet from an existing GRDB database connection.
///
/// - Parameters:
/// - database: A GRDB `DatabaseWriter` (`DatabasePool` or `DatabaseQueue`).
/// - model: Any `AnyLanguageModel`-compatible language model instance.
/// - allowlist: SQL operations the LLM may generate. Defaults to `.readOnly`.
/// - additionalContext: Optional extra context about the database for the LLM.
/// - title: Navigation bar title. Defaults to `"AI Chat"`.
public init(
database: any DatabaseWriter,
model: any LanguageModel,
allowlist: OperationAllowlist = .readOnly,
additionalContext: String? = nil,
title: String = "AI Chat"
) {
self.databasePath = nil
self.database = database
self.model = model
self.allowlist = allowlist
self.additionalContext = additionalContext
self.title = title
}
public var body: some View {
NavigationStack {
dataChatView
.navigationTitle(title)
#if !os(macOS)
.navigationBarTitleDisplayMode(.inline)
#endif
.toolbar {
ToolbarItem(placement: .cancellationAction) {
Button("Done") { dismiss() }
}
}
}
}
@ViewBuilder
private var dataChatView: some View {
if let database {
DataChatView(
database: database,
model: model,
allowlist: allowlist,
additionalContext: additionalContext
)
} else if let databasePath {
DataChatView(
databasePath: databasePath,
model: model,
allowlist: allowlist,
additionalContext: additionalContext
)
}
}
}

View File

@@ -0,0 +1,212 @@
// DataChatSheetModifier.swift
// SwiftDBAI
//
// View modifiers for presenting DataChatSheet as a sheet or full-screen cover.
import AnyLanguageModel
import GRDB
import SwiftUI
// MARK: - Sheet Modifier
/// A view modifier that presents a ``DataChatSheet`` as a standard sheet.
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
struct DataChatSheetModifier: ViewModifier {
@Binding var isPresented: Bool
let databasePath: String
let model: any LanguageModel
var allowlist: OperationAllowlist
var additionalContext: String?
var title: String
func body(content: Content) -> some View {
content.sheet(isPresented: $isPresented) {
DataChatSheet(
databasePath: databasePath,
model: model,
allowlist: allowlist,
additionalContext: additionalContext,
title: title
)
}
}
}
/// A view modifier that presents a ``DataChatSheet`` as a sheet using
/// an existing GRDB database connection.
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
struct DataChatSheetDatabaseModifier: ViewModifier {
@Binding var isPresented: Bool
let database: any DatabaseWriter
let model: any LanguageModel
var allowlist: OperationAllowlist
var additionalContext: String?
var title: String
func body(content: Content) -> some View {
content.sheet(isPresented: $isPresented) {
DataChatSheet(
database: database,
model: model,
allowlist: allowlist,
additionalContext: additionalContext,
title: title
)
}
}
}
// MARK: - Full-Screen Modifier
#if os(iOS) || os(visionOS)
/// A view modifier that presents a ``DataChatSheet`` as a full-screen cover.
@available(iOS 17.0, visionOS 1.0, *)
struct DataChatFullScreenModifier: ViewModifier {
@Binding var isPresented: Bool
let databasePath: String
let model: any LanguageModel
var allowlist: OperationAllowlist
var additionalContext: String?
var title: String
func body(content: Content) -> some View {
content.fullScreenCover(isPresented: $isPresented) {
DataChatSheet(
databasePath: databasePath,
model: model,
allowlist: allowlist,
additionalContext: additionalContext,
title: title
)
}
}
}
/// A view modifier that presents a ``DataChatSheet`` as a full-screen cover
/// using an existing GRDB database connection.
@available(iOS 17.0, visionOS 1.0, *)
struct DataChatFullScreenDatabaseModifier: ViewModifier {
@Binding var isPresented: Bool
let database: any DatabaseWriter
let model: any LanguageModel
var allowlist: OperationAllowlist
var additionalContext: String?
var title: String
func body(content: Content) -> some View {
content.fullScreenCover(isPresented: $isPresented) {
DataChatSheet(
database: database,
model: model,
allowlist: allowlist,
additionalContext: additionalContext,
title: title
)
}
}
}
#endif
// MARK: - View Extensions
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
public extension View {
/// Presents a database chat interface as a sheet.
///
/// ```swift
/// .dataChatSheet(
/// isPresented: $showChat,
/// databasePath: "/path/to/db.sqlite",
/// model: myLLM
/// )
/// ```
func dataChatSheet(
isPresented: Binding<Bool>,
databasePath: String,
model: any LanguageModel,
allowlist: OperationAllowlist = .readOnly,
additionalContext: String? = nil,
title: String = "AI Chat"
) -> some View {
modifier(DataChatSheetModifier(
isPresented: isPresented,
databasePath: databasePath,
model: model,
allowlist: allowlist,
additionalContext: additionalContext,
title: title
))
}
/// Presents a database chat interface as a sheet using an existing GRDB connection.
func dataChatSheet(
isPresented: Binding<Bool>,
database: any DatabaseWriter,
model: any LanguageModel,
allowlist: OperationAllowlist = .readOnly,
additionalContext: String? = nil,
title: String = "AI Chat"
) -> some View {
modifier(DataChatSheetDatabaseModifier(
isPresented: isPresented,
database: database,
model: model,
allowlist: allowlist,
additionalContext: additionalContext,
title: title
))
}
}
#if os(iOS) || os(visionOS)
@available(iOS 17.0, visionOS 1.0, *)
public extension View {
/// Presents a database chat interface as a full-screen cover.
///
/// ```swift
/// .dataChatFullScreen(
/// isPresented: $showChat,
/// databasePath: "/path/to/db.sqlite",
/// model: myLLM
/// )
/// ```
func dataChatFullScreen(
isPresented: Binding<Bool>,
databasePath: String,
model: any LanguageModel,
allowlist: OperationAllowlist = .readOnly,
additionalContext: String? = nil,
title: String = "AI Chat"
) -> some View {
modifier(DataChatFullScreenModifier(
isPresented: isPresented,
databasePath: databasePath,
model: model,
allowlist: allowlist,
additionalContext: additionalContext,
title: title
))
}
/// Presents a database chat interface as a full-screen cover using an existing GRDB connection.
func dataChatFullScreen(
isPresented: Binding<Bool>,
database: any DatabaseWriter,
model: any LanguageModel,
allowlist: OperationAllowlist = .readOnly,
additionalContext: String? = nil,
title: String = "AI Chat"
) -> some View {
modifier(DataChatFullScreenDatabaseModifier(
isPresented: isPresented,
database: database,
model: model,
allowlist: allowlist,
additionalContext: additionalContext,
title: title
))
}
}
#endif

View File

@@ -0,0 +1,81 @@
// DataChatViewController.swift
// SwiftDBAI
//
// UIKit bridge: a UIHostingController subclass for presenting DataChatSheet
// in UIKit-based apps via modal presentation or navigation push.
#if canImport(UIKit) && !os(watchOS)
import AnyLanguageModel
import GRDB
import SwiftUI
import UIKit
/// A `UIHostingController` subclass that wraps ``DataChatSheet`` for UIKit apps.
///
/// Present modally:
/// ```swift
/// let vc = DataChatViewController(databasePath: path, model: myLLM)
/// present(vc, animated: true)
/// ```
///
/// Or push onto a navigation stack:
/// ```swift
/// let vc = DataChatViewController(databasePath: path, model: myLLM)
/// navigationController?.pushViewController(vc, animated: true)
/// ```
@available(iOS 17.0, visionOS 1.0, *)
public final class DataChatViewController: UIHostingController<DataChatSheet> {
/// Creates a DataChatViewController from a database file path and language model.
///
/// - Parameters:
/// - databasePath: Absolute path to a SQLite database file.
/// - model: Any `AnyLanguageModel`-compatible language model instance.
/// - allowlist: SQL operations the LLM may generate. Defaults to `.readOnly`.
/// - additionalContext: Optional extra context about the database for the LLM.
/// - title: Navigation bar title. Defaults to `"AI Chat"`.
public convenience init(
databasePath: String,
model: any LanguageModel,
allowlist: OperationAllowlist = .readOnly,
additionalContext: String? = nil,
title: String = "AI Chat"
) {
let sheet = DataChatSheet(
databasePath: databasePath,
model: model,
allowlist: allowlist,
additionalContext: additionalContext,
title: title
)
self.init(rootView: sheet)
self.modalPresentationStyle = .formSheet
}
/// Creates a DataChatViewController from an existing GRDB database connection.
///
/// - Parameters:
/// - database: A GRDB `DatabaseWriter` (`DatabasePool` or `DatabaseQueue`).
/// - model: Any `AnyLanguageModel`-compatible language model instance.
/// - allowlist: SQL operations the LLM may generate. Defaults to `.readOnly`.
/// - additionalContext: Optional extra context about the database for the LLM.
/// - title: Navigation bar title. Defaults to `"AI Chat"`.
public convenience init(
database: any DatabaseWriter,
model: any LanguageModel,
allowlist: OperationAllowlist = .readOnly,
additionalContext: String? = nil,
title: String = "AI Chat"
) {
let sheet = DataChatSheet(
database: database,
model: model,
allowlist: allowlist,
additionalContext: additionalContext,
title: title
)
self.init(rootView: sheet)
self.modalPresentationStyle = .formSheet
}
}
#endif

View File

@@ -0,0 +1,270 @@
// ScrollableDataTableView.swift
// SwiftDBAI
//
// A SwiftUI view that renders a DataTable with horizontal and vertical
// scrolling, styled column headers, and row cells.
import SwiftUI
/// A scrollable table view that renders a `DataTable` with column headers
/// and row cells, supporting both horizontal and vertical scrolling.
///
/// Usage:
/// ```swift
/// ScrollableDataTableView(dataTable: myDataTable)
/// ```
///
/// The view automatically sizes columns based on content, highlights
/// alternating rows for readability, and right-aligns numeric columns.
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
public struct ScrollableDataTableView: View {
@Environment(\.chatViewConfiguration) private var config
/// The data table to render.
public let dataTable: DataTable
/// Minimum width for each column in points.
public var minimumColumnWidth: CGFloat
/// Maximum width for each column in points.
public var maximumColumnWidth: CGFloat
/// Whether to show alternating row backgrounds.
public var showAlternatingRows: Bool
/// Whether to show the row count footer.
public var showFooter: Bool
public init(
dataTable: DataTable,
minimumColumnWidth: CGFloat = 80,
maximumColumnWidth: CGFloat = 250,
showAlternatingRows: Bool = true,
showFooter: Bool = true
) {
self.dataTable = dataTable
self.minimumColumnWidth = minimumColumnWidth
self.maximumColumnWidth = maximumColumnWidth
self.showAlternatingRows = showAlternatingRows
self.showFooter = showFooter
}
public var body: some View {
if dataTable.isEmpty {
emptyView
} else {
tableContent
}
}
// MARK: - Empty State
@ViewBuilder
private var emptyView: some View {
VStack(spacing: 8) {
Image(systemName: "tablecells")
.font(.largeTitle)
.foregroundStyle(.secondary)
Text("No results")
.font(.headline)
.foregroundStyle(.secondary)
}
.frame(maxWidth: .infinity, minHeight: 100)
}
// MARK: - Table Content
@ViewBuilder
private var tableContent: some View {
VStack(alignment: .leading, spacing: 0) {
ScrollView([.horizontal, .vertical]) {
LazyVStack(alignment: .leading, spacing: 0, pinnedViews: [.sectionHeaders]) {
Section {
ForEach(dataTable.rows) { row in
rowView(row)
}
} header: {
headerRow
}
}
}
if showFooter {
footerView
}
}
}
// MARK: - Header
@ViewBuilder
private var headerRow: some View {
HStack(spacing: 0) {
ForEach(dataTable.columns) { column in
Text(column.name)
.font(.caption.weight(.semibold))
.foregroundStyle(.primary)
.lineLimit(1)
.frame(
width: columnWidth(for: column),
alignment: alignment(for: column)
)
.padding(.horizontal, 8)
.padding(.vertical, 6)
if column.index < dataTable.columnCount - 1 {
Divider()
}
}
}
.background(.bar)
.overlay(alignment: .bottom) {
Divider()
}
}
// MARK: - Row
@ViewBuilder
private func rowView(_ row: DataTable.Row) -> some View {
HStack(spacing: 0) {
ForEach(dataTable.columns) { column in
cellView(value: row[column.index], column: column)
if column.index < dataTable.columnCount - 1 {
Divider()
}
}
}
.background(rowBackground(for: row))
.overlay(alignment: .bottom) {
Divider()
}
}
// MARK: - Cell
@ViewBuilder
private func cellView(value: QueryResult.Value, column: DataTable.Column) -> some View {
Group {
switch value {
case .null:
Text("NULL")
.foregroundStyle(.tertiary)
.italic()
case .blob(let data):
Text("<\(data.count) bytes>")
.foregroundStyle(.secondary)
default:
Text(value.stringValue)
.foregroundStyle(.primary)
}
}
.font(.caption)
.lineLimit(2)
.frame(
width: columnWidth(for: column),
alignment: alignment(for: column)
)
.padding(.horizontal, 8)
.padding(.vertical, 4)
}
// MARK: - Footer
@ViewBuilder
private var footerView: some View {
HStack {
Text("\(dataTable.rowCount) row\(dataTable.rowCount == 1 ? "" : "s")")
.font(.caption2)
.foregroundStyle(.secondary)
Spacer()
if dataTable.executionTime > 0 {
Text(String(format: "%.1f ms", dataTable.executionTime * 1000))
.font(.caption2)
.foregroundStyle(.secondary)
}
}
.padding(.horizontal, 8)
.padding(.vertical, 4)
.background(.bar)
}
// MARK: - Layout Helpers
/// Determines column width based on the column name length and type.
private func columnWidth(for column: DataTable.Column) -> CGFloat {
// Estimate based on header text length
let headerWidth = CGFloat(column.name.count) * 8 + 16
// Sample some row values to estimate content width
let sampleRows = dataTable.rows.prefix(20)
let maxContentWidth = sampleRows.reduce(CGFloat(0)) { maxWidth, row in
let value = row[column.index]
let textLength = CGFloat(value.stringValue.count) * 7
return max(maxWidth, textLength)
}
let estimatedWidth = max(headerWidth, maxContentWidth) + 16
return min(max(estimatedWidth, minimumColumnWidth), maximumColumnWidth)
}
/// Returns the alignment for a column based on its inferred type.
private func alignment(for column: DataTable.Column) -> Alignment {
switch column.inferredType {
case .integer, .real:
return .trailing
default:
return .leading
}
}
/// Returns the background color for alternating rows.
@ViewBuilder
private func rowBackground(for row: DataTable.Row) -> some View {
if showAlternatingRows && row.id.isMultiple(of: 2) {
Color.clear
} else if showAlternatingRows {
Color.primary.opacity(0.03)
} else {
Color.clear
}
}
}
// MARK: - Preview Support
#if DEBUG
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
#Preview("Data Table") {
let columns: [DataTable.Column] = [
.init(name: "id", index: 0, inferredType: .integer),
.init(name: "name", index: 1, inferredType: .text),
.init(name: "score", index: 2, inferredType: .real),
]
let rows: [DataTable.Row] = (0..<25).map { i in
DataTable.Row(
id: i,
values: [
.integer(Int64(i + 1)),
.text("Item \(i + 1)"),
.real(Double.random(in: 1.0...100.0)),
],
columnNames: ["id", "name", "score"]
)
}
let table = DataTable(columns: columns, rows: rows, sql: "SELECT * FROM items", executionTime: 0.023)
ScrollableDataTableView(dataTable: table)
.frame(height: 400)
.padding()
}
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
#Preview("Empty Table") {
let table = DataTable(columns: [], rows: [], sql: "", executionTime: 0)
ScrollableDataTableView(dataTable: table)
.frame(height: 200)
.padding()
}
#endif

View File

@@ -0,0 +1,200 @@
// ChatViewConfiguration.swift
// SwiftDBAI
//
// A configuration struct that controls the visual appearance of ChatView
// and its child views. Propagated via SwiftUI environment.
import SwiftUI
/// Controls the visual appearance of the chat interface.
///
/// Use the built-in presets (`.default`, `.compact`, `.dark`) or create
/// a custom configuration by mutating the default:
///
/// ```swift
/// var config = ChatViewConfiguration.default
/// config.userBubbleColor = .purple
/// config.inputPlaceholder = "Ask about your recipes..."
///
/// DataChatView(databasePath: path, model: myLLM)
/// .chatViewConfiguration(config)
/// ```
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
public struct ChatViewConfiguration: Sendable {
// MARK: - Colors
/// Background color for user message bubbles.
public var userBubbleColor: Color
/// Text color for user messages.
public var userTextColor: Color
/// Background color for assistant message bubbles.
public var assistantBubbleColor: Color
/// Text color for assistant messages.
public var assistantTextColor: Color
/// Background color for the overall chat view.
public var backgroundColor: Color
/// Background color for the input bar area.
public var inputBarBackgroundColor: Color
/// Accent color used for interactive elements (send button, etc.).
public var accentColor: Color
/// Color used for error-related UI elements.
public var errorColor: Color
// MARK: - Typography
/// Font for chat message text.
public var messageFont: Font
/// Font for the natural language summary text.
public var summaryFont: Font
/// Font for SQL query display.
public var sqlFont: Font
/// Font for the text input field.
public var inputFont: Font
// MARK: - Layout
/// Padding inside message bubbles.
public var messagePadding: CGFloat
/// Corner radius for message bubbles.
public var bubbleCornerRadius: CGFloat
/// Whether to show timestamps on messages.
public var showTimestamps: Bool
/// Whether to show the SQL disclosure group.
public var showSQLDisclosure: Bool
/// Placeholder text in the input field.
public var inputPlaceholder: String
/// Title text shown when the chat has no messages.
public var emptyStateTitle: String
/// Subtitle text shown when the chat has no messages.
public var emptyStateSubtitle: String
/// SF Symbol name for the empty state icon.
public var emptyStateIcon: String
/// Optional color scheme override. When set, forces the chat view to use
/// this color scheme regardless of the system setting.
public var colorSchemeOverride: ColorScheme?
// MARK: - Avatar
/// SF Symbol name for the assistant avatar. When set, shows a circular
/// avatar next to assistant messages (e.g. "person.crop.circle.fill",
/// "brain.head.profile", "sparkles").
public var assistantAvatarIcon: String?
/// Background color for the assistant avatar circle.
public var assistantAvatarColor: Color
// MARK: - Memberwise Initializer
/// Creates a fully custom configuration.
public init(
userBubbleColor: Color = .accentColor,
userTextColor: Color = .white,
assistantBubbleColor: Color = defaultAssistantBackgroundColor,
assistantTextColor: Color = .primary,
backgroundColor: Color = .clear,
inputBarBackgroundColor: Color = .clear,
accentColor: Color = .accentColor,
errorColor: Color = .red,
messageFont: Font = .body,
summaryFont: Font = .body,
sqlFont: Font = .system(.caption, design: .monospaced),
inputFont: Font = .body,
messagePadding: CGFloat = 14,
bubbleCornerRadius: CGFloat = 16,
showTimestamps: Bool = false,
showSQLDisclosure: Bool = true,
inputPlaceholder: String = "Ask about your data\u{2026}",
emptyStateTitle: String = "Ask a question about your data",
emptyStateSubtitle: String = "Try something like \"How many records are in the database?\"",
emptyStateIcon: String = "bubble.left.and.text.bubble.right",
colorSchemeOverride: ColorScheme? = nil,
assistantAvatarIcon: String? = nil,
assistantAvatarColor: Color = .accentColor
) {
self.userBubbleColor = userBubbleColor
self.userTextColor = userTextColor
self.assistantBubbleColor = assistantBubbleColor
self.assistantTextColor = assistantTextColor
self.backgroundColor = backgroundColor
self.inputBarBackgroundColor = inputBarBackgroundColor
self.accentColor = accentColor
self.errorColor = errorColor
self.messageFont = messageFont
self.summaryFont = summaryFont
self.sqlFont = sqlFont
self.inputFont = inputFont
self.messagePadding = messagePadding
self.bubbleCornerRadius = bubbleCornerRadius
self.showTimestamps = showTimestamps
self.showSQLDisclosure = showSQLDisclosure
self.inputPlaceholder = inputPlaceholder
self.emptyStateTitle = emptyStateTitle
self.emptyStateSubtitle = emptyStateSubtitle
self.emptyStateIcon = emptyStateIcon
self.colorSchemeOverride = colorSchemeOverride
self.assistantAvatarIcon = assistantAvatarIcon
self.assistantAvatarColor = assistantAvatarColor
}
// MARK: - Platform-Adaptive Defaults
/// The default assistant bubble background, matching the platform convention.
public static var defaultAssistantBackgroundColor: Color {
#if os(macOS)
Color(nsColor: .controlBackgroundColor)
#else
Color(uiColor: .secondarySystemGroupedBackground)
#endif
}
// MARK: - Presets
/// The default configuration, matching the original hardcoded ChatView styling.
public static let `default` = ChatViewConfiguration()
/// A compact configuration with smaller fonts, tighter padding, and minimal chrome.
public static let compact = ChatViewConfiguration(
messageFont: .footnote,
summaryFont: .footnote,
sqlFont: .system(.caption2, design: .monospaced),
inputFont: .footnote,
messagePadding: 8,
bubbleCornerRadius: 10,
showTimestamps: false,
showSQLDisclosure: false,
emptyStateTitle: "Ask a question",
emptyStateSubtitle: ""
)
/// A dark-themed configuration with muted colors suitable for dark backgrounds.
public static let dark = ChatViewConfiguration(
userBubbleColor: Color(white: 0.25),
userTextColor: .white,
assistantBubbleColor: Color(white: 0.15),
assistantTextColor: Color(white: 0.9),
backgroundColor: .black,
inputBarBackgroundColor: Color(white: 0.1),
accentColor: .blue,
errorColor: Color(red: 1.0, green: 0.4, blue: 0.4),
colorSchemeOverride: .dark
)
}

View File

@@ -0,0 +1,35 @@
// ChatViewConfigurationKey.swift
// SwiftDBAI
//
// SwiftUI environment key for propagating ChatViewConfiguration
// through the view hierarchy.
import SwiftUI
/// Environment key that stores the ``ChatViewConfiguration``.
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
struct ChatViewConfigurationKey: EnvironmentKey {
static let defaultValue = ChatViewConfiguration.default
}
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
extension EnvironmentValues {
/// The chat view configuration for the current environment.
var chatViewConfiguration: ChatViewConfiguration {
get { self[ChatViewConfigurationKey.self] }
set { self[ChatViewConfigurationKey.self] = newValue }
}
}
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
extension View {
/// Applies a ``ChatViewConfiguration`` to this view and its descendants.
///
/// ```swift
/// DataChatView(databasePath: path, model: myLLM)
/// .chatViewConfiguration(.dark)
/// ```
public func chatViewConfiguration(_ config: ChatViewConfiguration) -> some View {
environment(\.chatViewConfiguration, config)
}
}

131
THEMING.md Normal file
View File

@@ -0,0 +1,131 @@
# Theming
SwiftDBAI's chat interface is customizable through `ChatViewConfiguration`. Pass it via the `.chatViewConfiguration()` view modifier -- it propagates through the entire view hierarchy via SwiftUI environment.
## Built-in Presets
### Default
The standard look. Blue user bubbles, system-colored assistant bubbles, standard fonts.
```swift
DataChatView(databasePath: path, model: myLLM)
// .chatViewConfiguration(.default) is implicit
```
![Default](screenshots/results-chart.png)
### Dark
Muted colors for dark backgrounds. Dark gray bubbles, light text, black background.
```swift
DataChatView(databasePath: path, model: myLLM)
.chatViewConfiguration(.dark)
```
![Dark](screenshots/dark-theme.png)
### Compact
Smaller fonts, tighter padding, no SQL disclosure. Good for embedded or secondary views.
```swift
DataChatView(databasePath: path, model: myLLM)
.chatViewConfiguration(.compact)
```
![Compact](screenshots/compact-theme.png)
## Custom Configuration
Start from any preset and override what you need:
```swift
var config = ChatViewConfiguration.default
config.userBubbleColor = .purple
config.userTextColor = .white
config.accentColor = .purple
config.inputPlaceholder = "Search GitHub repos..."
config.emptyStateTitle = "Explore GitHub Data"
config.emptyStateSubtitle = "Ask about stars, forks, languages, and trends"
config.emptyStateIcon = "star.circle"
DataChatView(databasePath: path, model: myLLM)
.chatViewConfiguration(config)
```
| Custom empty state | Custom with results |
|---|---|
| ![Custom empty](screenshots/custom-theme.png) | ![Custom results](screenshots/custom-results.png) |
## Available Properties
### Colors
| Property | Default | Description |
|---|---|---|
| `userBubbleColor` | `.accentColor` | Background of user message bubbles |
| `userTextColor` | `.white` | Text color in user bubbles |
| `assistantBubbleColor` | System secondary | Background of assistant bubbles |
| `assistantTextColor` | `.primary` | Text color in assistant bubbles |
| `backgroundColor` | `.clear` | Overall chat view background |
| `inputBarBackgroundColor` | `.clear` | Input bar area background |
| `accentColor` | `.accentColor` | Send button and interactive elements |
| `errorColor` | `.red` | Error message icon and border |
### Typography
| Property | Default | Description |
|---|---|---|
| `messageFont` | `.body` | Chat message text |
| `summaryFont` | `.body` | Natural language summary |
| `sqlFont` | `.caption monospaced` | SQL query display |
| `inputFont` | `.body` | Text input field |
### Layout & Content
| Property | Default | Description |
|---|---|---|
| `messagePadding` | `14` | Padding inside message bubbles |
| `bubbleCornerRadius` | `16` | Corner radius of bubbles |
| `showTimestamps` | `false` | Show timestamps on messages |
| `showSQLDisclosure` | `true` | Show the "</> SQL Query" expandable section |
| `inputPlaceholder` | `"Ask about your data..."` | Placeholder text in input field |
| `emptyStateTitle` | `"Ask a question about your data"` | Title when no messages |
| `emptyStateSubtitle` | `"Try something like..."` | Subtitle when no messages |
| `emptyStateIcon` | `"bubble.left.and.text.bubble.right"` | SF Symbol for empty state |
### Avatar
| Property | Default | Description |
|---|---|---|
| `assistantAvatarIcon` | `nil` | SF Symbol for assistant avatar (e.g. `"sparkles"`, `"person.crop.circle.fill"`) |
| `assistantAvatarColor` | `.accentColor` | Background color of the avatar circle |
When set, a circular avatar appears next to every assistant message:
```swift
config.assistantAvatarIcon = "sparkles"
config.assistantAvatarColor = .purple
```
![Custom with avatar](screenshots/custom-results.png)
## Works with All Presentation Modes
The configuration propagates through sheets, navigation, and UIKit bridges:
```swift
// Sheet
.sheet(isPresented: $show) {
DataChatSheet(databasePath: path, model: myLLM)
.chatViewConfiguration(.dark)
}
// UIKit
let vc = DataChatViewController(databasePath: path, model: myLLM)
// Configuration can be set on the rootView before presenting
```
![Sheet presentation](screenshots/sheet-presentation.png)

View File

@@ -0,0 +1,254 @@
// BinarySizeTests.swift
// SwiftDBAI
//
// Validates that the SwiftDBAI package stays within its 2 MB binary size budget.
// This test suite uses source-level heuristics since we can't measure the actual
// compiled binary size in a unit test. The constraints ensure the package remains
// lightweight by checking:
// 1. Total source code size (proxy for compiled size)
// 2. No embedded binary assets or large resources
// 3. No unnecessary heavy dependencies
// 4. File count stays reasonable (no code bloat)
import Foundation
import Testing
@Suite("Binary Size Budget")
struct BinarySizeTests {
/// The maximum allowed total source code size in bytes.
/// At typical Swift optimized compilation ratios (2-4x), 500 KB of source
/// compiles to roughly 1-2 MB of binary. We set the source budget at 500 KB
/// to keep the compiled output well under 2 MB.
private static let maxSourceSizeBytes: Int = 500_000 // 500 KB
/// Maximum number of Swift source files allowed.
/// More files generally means more code and larger binaries.
private static let maxSourceFileCount: Int = 60
/// Maximum size for any single source file in bytes.
/// Large individual files often indicate code that should be split or
/// contains embedded data that bloats the binary.
private static let maxSingleFileSizeBytes: Int = 50_000 // 50 KB
/// Disallowed file extensions in the Sources directory that would bloat the binary.
private static let disallowedExtensions: Set<String> = [
"png", "jpg", "jpeg", "gif", "bmp", "tiff",
"mp3", "mp4", "wav", "mov",
"mlmodel", "mlmodelc", "mlpackage",
"sqlite", "db",
"zip", "tar", "gz",
"bin", "dat",
"framework", "dylib", "a"
]
// MARK: - Helper
/// Recursively finds all files in the Sources/SwiftDBAI directory.
private func findSourceFiles() throws -> [URL] {
let sourcesDir = findSourcesDirectory()
guard let sourcesDir else {
Issue.record("Could not locate Sources/SwiftDBAI directory")
return []
}
let fileManager = FileManager.default
guard let enumerator = fileManager.enumerator(
at: sourcesDir,
includingPropertiesForKeys: [.fileSizeKey, .isRegularFileKey],
options: [.skipsHiddenFiles]
) else {
Issue.record("Could not enumerate Sources/SwiftDBAI directory")
return []
}
var files: [URL] = []
for case let fileURL as URL in enumerator {
let resourceValues = try fileURL.resourceValues(forKeys: [.isRegularFileKey])
if resourceValues.isRegularFile == true {
files.append(fileURL)
}
}
return files
}
/// Locates the Sources/SwiftDBAI directory by walking up from the test bundle.
private func findSourcesDirectory() -> URL? {
// Try common locations relative to the build directory
let fileManager = FileManager.default
// In SPM test runs, we can find the package root by checking known paths
var candidateURL = URL(fileURLWithPath: #filePath)
// Walk up from Tests/SwiftDBAITests/BinarySizeTests.swift to package root
for _ in 0..<3 {
candidateURL = candidateURL.deletingLastPathComponent()
}
let sourcesDir = candidateURL.appendingPathComponent("Sources/SwiftDBAI")
if fileManager.fileExists(atPath: sourcesDir.path) {
return sourcesDir
}
// Fallback: check current working directory
let cwdSources = URL(fileURLWithPath: fileManager.currentDirectoryPath)
.appendingPathComponent("Sources/SwiftDBAI")
if fileManager.fileExists(atPath: cwdSources.path) {
return cwdSources
}
return nil
}
// MARK: - Tests
@Test("Total source code size stays under 500 KB budget")
func totalSourceCodeSizeUnderBudget() throws {
let files = try findSourceFiles()
let swiftFiles = files.filter { $0.pathExtension == "swift" }
var totalSize: Int = 0
for file in swiftFiles {
let attributes = try FileManager.default.attributesOfItem(atPath: file.path)
let fileSize = attributes[.size] as? Int ?? 0
totalSize += fileSize
}
#expect(totalSize < Self.maxSourceSizeBytes,
"""
Total Swift source size (\(totalSize) bytes) exceeds \(Self.maxSourceSizeBytes) byte budget.
At typical 2-4x compilation ratio, this would produce a binary larger than 2 MB.
Consider removing unused code or splitting into optional sub-targets.
""")
// Log the actual size for visibility
let sizeKB = Double(totalSize) / 1024.0
let budgetKB = Double(Self.maxSourceSizeBytes) / 1024.0
print("📦 SwiftDBAI source size: \(String(format: "%.1f", sizeKB)) KB / \(String(format: "%.0f", budgetKB)) KB budget (\(String(format: "%.0f", (sizeKB / budgetKB) * 100))% used)")
}
@Test("Source file count stays reasonable")
func sourceFileCountUnderLimit() throws {
let files = try findSourceFiles()
let swiftFiles = files.filter { $0.pathExtension == "swift" }
#expect(swiftFiles.count <= Self.maxSourceFileCount,
"""
Swift source file count (\(swiftFiles.count)) exceeds limit of \(Self.maxSourceFileCount).
More files generally means more code and larger binaries.
""")
print("📦 SwiftDBAI file count: \(swiftFiles.count) / \(Self.maxSourceFileCount) max")
}
@Test("No individual source file exceeds 50 KB")
func noOversizedSourceFiles() throws {
let files = try findSourceFiles()
let swiftFiles = files.filter { $0.pathExtension == "swift" }
for file in swiftFiles {
let attributes = try FileManager.default.attributesOfItem(atPath: file.path)
let fileSize = attributes[.size] as? Int ?? 0
#expect(fileSize < Self.maxSingleFileSizeBytes,
"""
File \(file.lastPathComponent) is \(fileSize) bytes, exceeding the \(Self.maxSingleFileSizeBytes) byte limit.
Large files may contain embedded data or code that should be split.
""")
}
}
@Test("No binary assets or heavy resources in Sources directory")
func noBinaryAssetsInSources() throws {
let files = try findSourceFiles()
let disallowedFiles = files.filter { file in
Self.disallowedExtensions.contains(file.pathExtension.lowercased())
}
#expect(disallowedFiles.isEmpty,
"""
Found \(disallowedFiles.count) disallowed file(s) in Sources directory:
\(disallowedFiles.map(\.lastPathComponent).joined(separator: "\n"))
These file types bloat the binary. Remove them or move to a separate resource bundle.
""")
}
@Test("Package has no resource bundles that could bloat binary")
func noResourceBundles() throws {
let files = try findSourceFiles()
let resourceFiles = files.filter { file in
let ext = file.pathExtension.lowercased()
return ["xcassets", "storyboard", "xib", "nib", "xcdatamodeld"].contains(ext)
}
#expect(resourceFiles.isEmpty,
"""
Found resource bundle files that could bloat the binary:
\(resourceFiles.map(\.lastPathComponent).joined(separator: "\n"))
SwiftDBAI should be pure code — no bundled resources.
""")
}
@Test("Only expected dependencies declared (GRDB + AnyLanguageModel)")
func minimalDependencies() throws {
// Read Package.swift to verify we only have the expected dependencies
var packageURL = URL(fileURLWithPath: #filePath)
for _ in 0..<3 {
packageURL = packageURL.deletingLastPathComponent()
}
let packageSwiftURL = packageURL.appendingPathComponent("Package.swift")
guard FileManager.default.fileExists(atPath: packageSwiftURL.path) else {
// Skip if we can't find Package.swift (CI environments etc.)
return
}
let packageContents = try String(contentsOf: packageSwiftURL, encoding: .utf8)
// Count .package() declarations (dependencies)
let packageDeclarations = packageContents.components(separatedBy: ".package(")
.count - 1 // subtract 1 because the first segment is before any .package(
#expect(packageDeclarations <= 3,
"""
Found \(packageDeclarations) package dependencies, expected at most 4 (GRDB + AnyLanguageModel + ViewInspector for tests).
Additional dependencies increase binary size. Evaluate if they're truly needed.
""")
// Verify the expected dependencies are present
#expect(packageContents.contains("GRDB"), "Expected GRDB dependency")
#expect(packageContents.contains("AnyLanguageModel"), "Expected AnyLanguageModel dependency")
print("📦 SwiftDBAI dependencies: \(packageDeclarations) (GRDB + AnyLanguageModel)")
}
@Test("Estimated binary size under 2 MB")
func estimatedBinarySizeUnderLimit() throws {
let files = try findSourceFiles()
let swiftFiles = files.filter { $0.pathExtension == "swift" }
var totalSize: Int = 0
for file in swiftFiles {
let attributes = try FileManager.default.attributesOfItem(atPath: file.path)
let fileSize = attributes[.size] as? Int ?? 0
totalSize += fileSize
}
// Conservative estimate: optimized Swift binary is typically 2-4x source size.
// Use 4x as worst case multiplier for safety margin.
let worstCaseMultiplier = 4.0
let estimatedBinarySize = Double(totalSize) * worstCaseMultiplier
let maxBinarySize: Double = 2.0 * 1024.0 * 1024.0 // 2 MB
#expect(estimatedBinarySize < maxBinarySize,
"""
Estimated binary size (\(String(format: "%.1f", estimatedBinarySize / 1024.0)) KB) exceeds 2 MB limit.
Source: \(totalSize) bytes × \(worstCaseMultiplier)x multiplier = \(String(format: "%.1f", estimatedBinarySize / 1024.0)) KB
Note: This is the SwiftDBAI module only — excludes GRDB and AnyLanguageModel
which are existing dependencies the developer already includes.
""")
let estimatedMB = estimatedBinarySize / (1024.0 * 1024.0)
print("📦 Estimated SwiftDBAI binary size: \(String(format: "%.2f", estimatedMB)) MB / 2.00 MB limit (worst case \(worstCaseMultiplier)x)")
}
}

View File

@@ -0,0 +1,293 @@
// ChartDataDetectorTests.swift
// SwiftDBAITests
import Testing
@testable import SwiftDBAI
@Suite("ChartDataDetector")
struct ChartDataDetectorTests {
let detector = ChartDataDetector()
// MARK: - Helpers
private func makeQueryResult(
columns: [String],
rows: [[QueryResult.Value]],
sql: String = "SELECT *"
) -> QueryResult {
let rowDicts = rows.map { values in
Dictionary(uniqueKeysWithValues: zip(columns, values))
}
return QueryResult(
columns: columns,
rows: rowDicts,
sql: sql,
executionTime: 0.01
)
}
private func makeTable(
columns: [String],
rows: [[QueryResult.Value]],
sql: String = "SELECT *"
) -> DataTable {
DataTable(makeQueryResult(columns: columns, rows: rows, sql: sql))
}
// MARK: - Basic Eligibility
@Test("Returns nil for single-column results")
func singleColumn() {
let table = makeTable(
columns: ["count"],
rows: [[.integer(42)]]
)
#expect(detector.detect(table) == nil)
}
@Test("Returns nil for empty results")
func emptyResults() {
let table = makeTable(columns: ["name", "value"], rows: [])
#expect(detector.detect(table) == nil)
}
@Test("Returns nil for single row")
func singleRow() {
let table = makeTable(
columns: ["name", "count"],
rows: [[.text("A"), .integer(10)]]
)
#expect(detector.detect(table) == nil)
}
@Test("Returns nil for too many rows")
func tooManyRows() {
let rows = (0..<101).map { i in
[QueryResult.Value.text("cat\(i)"), .integer(Int64(i))]
}
let table = makeTable(columns: ["name", "count"], rows: rows)
#expect(detector.detect(table) == nil)
}
// MARK: - Bar Chart Detection
@Test("Recommends bar chart for categorical text + numeric")
func barChartCategorical() {
let table = makeTable(
columns: ["department", "headcount"],
rows: [
[.text("Engineering"), .integer(45)],
[.text("Marketing"), .integer(20)],
[.text("Sales"), .integer(30)],
[.text("HR"), .integer(10)],
]
)
let rec = detector.detect(table)
#expect(rec != nil)
#expect(rec?.chartType == .bar)
#expect(rec?.categoryColumn == "department")
#expect(rec?.valueColumn == "headcount")
#expect(rec?.confidence ?? 0 > 0.5)
}
// MARK: - Pie Chart Detection
@Test("Recommends pie chart for small positive proportions")
func pieChartSmallCategories() {
let table = makeTable(
columns: ["status", "count"],
rows: [
[.text("Active"), .integer(50)],
[.text("Inactive"), .integer(30)],
[.text("Pending"), .integer(20)],
]
)
let rec = detector.detect(table)
#expect(rec != nil)
#expect(rec?.chartType == .pie)
#expect(rec?.categoryColumn == "status")
#expect(rec?.valueColumn == "count")
}
@Test("Does not recommend pie with negative values")
func pieRejectsNegative() {
let table = makeTable(
columns: ["category", "change"],
rows: [
[.text("A"), .integer(50)],
[.text("B"), .integer(-10)],
[.text("C"), .integer(20)],
]
)
let rec = detector.detect(table)
#expect(rec != nil)
// Should NOT be pie since there's a negative value
#expect(rec?.chartType != .pie)
}
@Test("Does not recommend pie with too many slices")
func pieRejectsTooManySlices() {
let rows = (0..<10).map { i in
[QueryResult.Value.text("cat\(i)"), .integer(Int64(i + 1))]
}
let table = makeTable(columns: ["category", "value"], rows: rows)
let rec = detector.detect(table)
#expect(rec != nil)
#expect(rec?.chartType != .pie)
}
// MARK: - Line Chart Detection
@Test("Recommends line chart for time-series column names")
func lineChartTimeSeries() {
let table = makeTable(
columns: ["year", "revenue"],
rows: [
[.text("2020"), .real(1_000_000)],
[.text("2021"), .real(1_200_000)],
[.text("2022"), .real(1_500_000)],
[.text("2023"), .real(1_800_000)],
[.text("2024"), .real(2_100_000)],
]
)
let rec = detector.detect(table)
#expect(rec != nil)
#expect(rec?.chartType == .line)
#expect(rec?.categoryColumn == "year")
#expect(rec?.valueColumn == "revenue")
}
@Test("Recommends line chart for date-formatted text values")
func lineChartDateValues() {
let table = makeTable(
columns: ["period", "sales"],
rows: [
[.text("2024-01"), .integer(100)],
[.text("2024-02"), .integer(120)],
[.text("2024-03"), .integer(90)],
[.text("2024-04"), .integer(150)],
]
)
let rec = detector.detect(table)
#expect(rec != nil)
#expect(rec?.chartType == .line)
}
@Test("Recommends line chart for sequential numeric x-axis")
func lineChartSequential() {
let table = makeTable(
columns: ["step", "value"],
rows: [
[.integer(1), .real(2.5)],
[.integer(2), .real(3.1)],
[.integer(3), .real(4.0)],
[.integer(4), .real(3.8)],
]
)
let rec = detector.detect(table)
#expect(rec != nil)
#expect(rec?.chartType == .line)
}
// MARK: - All Recommendations
@Test("Returns multiple recommendations sorted by confidence")
func allRecommendations() {
let table = makeTable(
columns: ["category", "amount"],
rows: [
[.text("A"), .integer(30)],
[.text("B"), .integer(50)],
[.text("C"), .integer(20)],
]
)
let recs = detector.allRecommendations(for: table)
#expect(!recs.isEmpty)
// Should be sorted by confidence descending
for i in 1..<recs.count {
#expect(recs[i - 1].confidence >= recs[i].confidence)
}
}
// MARK: - Two Numeric Columns Fallback
@Test("Uses first numeric as category when no text column exists")
func numericOnlyColumns() {
let table = makeTable(
columns: ["x", "y"],
rows: [
[.integer(1), .integer(10)],
[.integer(2), .integer(20)],
[.integer(3), .integer(30)],
]
)
let rec = detector.detect(table)
#expect(rec != nil)
#expect(rec?.categoryColumn == "x")
#expect(rec?.valueColumn == "y")
}
// MARK: - Confidence & Reason
@Test("Confidence is between 0 and 1")
func confidenceBounds() {
let table = makeTable(
columns: ["name", "score"],
rows: [
[.text("A"), .integer(10)],
[.text("B"), .integer(20)],
]
)
let rec = detector.detect(table)
#expect(rec != nil)
#expect(rec!.confidence >= 0.0)
#expect(rec!.confidence <= 1.0)
}
@Test("Reason is non-empty")
func reasonPresent() {
let table = makeTable(
columns: ["name", "score"],
rows: [
[.text("A"), .integer(10)],
[.text("B"), .integer(20)],
]
)
let rec = detector.detect(table)
#expect(rec != nil)
#expect(!rec!.reason.isEmpty)
}
// MARK: - Custom Configuration
@Test("Respects custom minimumRows")
func customMinRows() {
let strict = ChartDataDetector(minimumRows: 5)
let table = makeTable(
columns: ["name", "value"],
rows: [
[.text("A"), .integer(1)],
[.text("B"), .integer(2)],
[.text("C"), .integer(3)],
]
)
#expect(strict.detect(table) == nil)
}
@Test("Respects custom maxPieSlices")
func customMaxPieSlices() {
let narrow = ChartDataDetector(maxPieSlices: 2)
let table = makeTable(
columns: ["status", "count"],
rows: [
[.text("A"), .integer(50)],
[.text("B"), .integer(30)],
[.text("C"), .integer(20)],
]
)
let rec = narrow.detect(table)
// With maxPieSlices=2, 3 rows should not get pie
#expect(rec?.chartType != .pie)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,170 @@
// ChatViewConfigurationTests.swift
// SwiftDBAITests
//
// Tests for ChatViewConfiguration defaults, presets, and environment propagation.
import Testing
import SwiftUI
@testable import SwiftDBAI
@Suite("ChatViewConfiguration Tests")
struct ChatViewConfigurationTests {
// MARK: - Default Values
@Test("Default configuration has expected color values")
func defaultColors() {
let config = ChatViewConfiguration.default
#expect(config.userBubbleColor == .accentColor)
#expect(config.userTextColor == .white)
#expect(config.assistantTextColor == .primary)
#expect(config.backgroundColor == .clear)
#expect(config.inputBarBackgroundColor == .clear)
#expect(config.accentColor == .accentColor)
#expect(config.errorColor == .red)
}
@Test("Default configuration has expected typography values")
func defaultTypography() {
let config = ChatViewConfiguration.default
#expect(config.messageFont == .body)
#expect(config.summaryFont == .body)
#expect(config.sqlFont == .system(.caption, design: .monospaced))
#expect(config.inputFont == .body)
}
@Test("Default configuration has expected layout values")
func defaultLayout() {
let config = ChatViewConfiguration.default
#expect(config.messagePadding == 14)
#expect(config.bubbleCornerRadius == 16)
#expect(config.showTimestamps == false)
#expect(config.showSQLDisclosure == true)
#expect(config.inputPlaceholder == "Ask about your data\u{2026}")
#expect(config.emptyStateTitle == "Ask a question about your data")
#expect(config.emptyStateSubtitle == "Try something like \"How many records are in the database?\"")
#expect(config.emptyStateIcon == "bubble.left.and.text.bubble.right")
}
// MARK: - Compact Preset
@Test("Compact preset has smaller fonts and tighter padding")
func compactPreset() {
let config = ChatViewConfiguration.compact
#expect(config.messageFont == .footnote)
#expect(config.summaryFont == .footnote)
#expect(config.sqlFont == .system(.caption2, design: .monospaced))
#expect(config.inputFont == .footnote)
#expect(config.messagePadding == 8)
#expect(config.bubbleCornerRadius == 10)
#expect(config.showTimestamps == false)
#expect(config.showSQLDisclosure == false)
}
// MARK: - Dark Preset
@Test("Dark preset has dark-themed colors")
func darkPreset() {
let config = ChatViewConfiguration.dark
#expect(config.userBubbleColor == Color(white: 0.25))
#expect(config.userTextColor == .white)
#expect(config.assistantBubbleColor == Color(white: 0.15))
#expect(config.assistantTextColor == Color(white: 0.9))
#expect(config.backgroundColor == .black)
#expect(config.inputBarBackgroundColor == Color(white: 0.1))
#expect(config.accentColor == .blue)
#expect(config.errorColor == Color(red: 1.0, green: 0.4, blue: 0.4))
}
// MARK: - Mutability
@Test("Configuration properties can be mutated individually")
func mutateProperties() {
var config = ChatViewConfiguration.default
config.userBubbleColor = .purple
config.inputPlaceholder = "Ask about your recipes..."
config.bubbleCornerRadius = 20
config.showTimestamps = true
#expect(config.userBubbleColor == .purple)
#expect(config.inputPlaceholder == "Ask about your recipes...")
#expect(config.bubbleCornerRadius == 20)
#expect(config.showTimestamps == true)
// Other properties remain at defaults
#expect(config.userTextColor == .white)
#expect(config.messageFont == .body)
}
// MARK: - All Public Properties Accessible
@Test("All public properties are readable and writable")
func allPropertiesAccessible() {
var config = ChatViewConfiguration.default
// Colors
_ = config.userBubbleColor
_ = config.userTextColor
_ = config.assistantBubbleColor
_ = config.assistantTextColor
_ = config.backgroundColor
_ = config.inputBarBackgroundColor
_ = config.accentColor
_ = config.errorColor
// Typography
_ = config.messageFont
_ = config.summaryFont
_ = config.sqlFont
_ = config.inputFont
// Layout
_ = config.messagePadding
_ = config.bubbleCornerRadius
_ = config.showTimestamps
_ = config.showSQLDisclosure
_ = config.inputPlaceholder
_ = config.emptyStateTitle
_ = config.emptyStateSubtitle
_ = config.emptyStateIcon
// Verify write access compiles (set and read back)
config.userBubbleColor = .green
#expect(config.userBubbleColor == .green)
config.emptyStateIcon = "star"
#expect(config.emptyStateIcon == "star")
}
// MARK: - Presets Are Static
@Test("Static presets are available as expected")
func staticPresets() {
let _ = ChatViewConfiguration.default
let _ = ChatViewConfiguration.compact
let _ = ChatViewConfiguration.dark
}
// MARK: - Sendable Conformance
@Test("Configuration is Sendable")
func sendableConformance() async {
let config = ChatViewConfiguration.default
// Verify Sendable by passing across isolation boundary
let result: ChatViewConfiguration = await Task.detached {
return config
}.value
#expect(result.bubbleCornerRadius == config.bubbleCornerRadius)
}
// MARK: - Environment Propagation
@Test("Environment key default value matches ChatViewConfiguration.default")
func environmentKeyDefault() {
let defaultConfig = ChatViewConfiguration.default
let envDefault = ChatViewConfigurationKey.defaultValue
#expect(defaultConfig.bubbleCornerRadius == envDefault.bubbleCornerRadius)
#expect(defaultConfig.messagePadding == envDefault.messagePadding)
#expect(defaultConfig.showSQLDisclosure == envDefault.showSQLDisclosure)
#expect(defaultConfig.inputPlaceholder == envDefault.inputPlaceholder)
}
}

View File

@@ -0,0 +1,164 @@
// ChatViewTests.swift
// SwiftDBAITests
//
// Tests for ChatView, ChatViewModel, and MessageBubbleView integration
// with ScrollableDataTableView.
import Testing
import Foundation
@testable import SwiftDBAI
@Suite("SchemaReadiness Tests")
struct SchemaReadinessTests {
@Test("SchemaReadiness isReady returns true only for ready state")
func isReadyProperty() {
#expect(SchemaReadiness.idle.isReady == false)
#expect(SchemaReadiness.loading.isReady == false)
#expect(SchemaReadiness.ready(tableCount: 3).isReady == true)
#expect(SchemaReadiness.failed("error").isReady == false)
}
}
@Suite("ChatViewModel Tests")
struct ChatViewModelTests {
@Test("Messages with query results produce DataTable-compatible data")
func messageWithQueryResultHasTableData() {
// A ChatMessage with a queryResult should have the data needed
// for ScrollableDataTableView rendering
let result = QueryResult(
columns: ["id", "name", "score"],
rows: [
["id": .integer(1), "name": .text("Alice"), "score": .real(95.5)],
["id": .integer(2), "name": .text("Bob"), "score": .real(87.3)],
],
sql: "SELECT id, name, score FROM users",
executionTime: 0.01
)
let message = ChatMessage(
role: .assistant,
content: "Found 2 users.",
queryResult: result,
sql: "SELECT id, name, score FROM users"
)
// Verify queryResult is present and can be converted to DataTable
#expect(message.queryResult != nil)
#expect(message.queryResult!.columns.count == 3)
#expect(message.queryResult!.rows.count == 2)
// Verify DataTable conversion works (this is what MessageBubbleView does)
let dataTable = DataTable(message.queryResult!)
#expect(dataTable.columnCount == 3)
#expect(dataTable.rowCount == 2)
#expect(dataTable.columns[0].name == "id")
#expect(dataTable.columns[1].name == "name")
#expect(dataTable.columns[2].name == "score")
}
@Test("Messages without query results do not trigger table rendering")
func messageWithoutQueryResult() {
let message = ChatMessage(
role: .assistant,
content: "Hello! How can I help?",
queryResult: nil,
sql: nil
)
#expect(message.queryResult == nil)
}
@Test("Empty query results do not trigger table rendering")
func emptyQueryResult() {
let result = QueryResult(
columns: [],
rows: [],
sql: "SELECT * FROM empty_table",
executionTime: 0.001
)
let message = ChatMessage(
role: .assistant,
content: "No results found.",
queryResult: result,
sql: "SELECT * FROM empty_table"
)
// Even though queryResult exists, it has no columns/rows
// MessageBubbleView checks both conditions before showing the table
#expect(message.queryResult != nil)
#expect(message.queryResult!.columns.isEmpty)
#expect(message.queryResult!.rows.isEmpty)
}
@Test("Mutation results do not trigger table rendering")
func mutationQueryResult() {
let result = QueryResult(
columns: [],
rows: [],
sql: "INSERT INTO users (name) VALUES ('Charlie')",
executionTime: 0.005,
rowsAffected: 1
)
let message = ChatMessage(
role: .assistant,
content: "Successfully inserted 1 row.",
queryResult: result,
sql: "INSERT INTO users (name) VALUES ('Charlie')"
)
// Mutation results have empty columns no table shown
#expect(message.queryResult!.columns.isEmpty)
}
@Test("Error messages never have query results")
func errorMessageHasNoQueryResult() {
let message = ChatMessage(
role: .error,
content: "SELECT operations are not allowed."
)
#expect(message.queryResult == nil)
#expect(message.role == .error)
}
@Test("DataTable preserves column order from QueryResult")
func dataTableColumnOrder() {
let result = QueryResult(
columns: ["date", "revenue", "category"],
rows: [
["date": .text("2024-01-01"), "revenue": .real(1500.0), "category": .text("Electronics")],
],
sql: "SELECT date, revenue, category FROM sales",
executionTime: 0.02
)
let dataTable = DataTable(result)
#expect(dataTable.columnNames == ["date", "revenue", "category"])
}
@Test("Large result sets are renderable as DataTable")
func largeResultSet() {
var rows: [[String: QueryResult.Value]] = []
for i in 0..<500 {
rows.append([
"id": .integer(Int64(i)),
"value": .real(Double(i) * 1.5),
])
}
let result = QueryResult(
columns: ["id", "value"],
rows: rows,
sql: "SELECT id, value FROM big_table",
executionTime: 0.15
)
let dataTable = DataTable(result)
#expect(dataTable.rowCount == 500)
#expect(dataTable.columnCount == 2)
}
}

View File

@@ -0,0 +1,136 @@
// DataChatViewUsageTests.swift
// SwiftDBAITests
//
// Proves DataChatView works with minimal setup under 10 lines of code.
// A developer only needs a GRDB connection and a LanguageModel to get a
// full chat-with-database SwiftUI view.
import Testing
import Foundation
import GRDB
@testable import SwiftDBAI
// MARK: - Minimal Setup: DataChatView in Under 10 Lines
/// This test suite proves the "zero_config_reads" principle:
/// A developer with an existing SQLite database can create a fully functional
/// chat UI by providing only a GRDB connection and a language model instance.
/// No schema files, no annotations, no manual configuration required.
@Suite("DataChatView Minimal Setup")
struct DataChatViewMinimalSetupTests {
//
// USAGE EXAMPLE DataChatView in 6 lines of real code
//
// import SwiftDBAI
// import GRDB
//
// let db = try DatabaseQueue(path: "mydata.sqlite")
// let model = OllamaLanguageModel(model: "llama3")
//
// var body: some View {
// DataChatView(database: db, model: model)
// }
//
/// Creates a temporary in-memory database with sample data for tests.
private static func makeSampleDatabase() throws -> DatabaseQueue {
let db = try DatabaseQueue()
try db.write { db in
try db.execute(sql: """
CREATE TABLE products (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
price REAL NOT NULL,
category TEXT
);
INSERT INTO products (name, price, category) VALUES ('Widget', 9.99, 'Hardware');
INSERT INTO products (name, price, category) VALUES ('Gadget', 24.99, 'Electronics');
INSERT INTO products (name, price, category) VALUES ('Doohickey', 4.99, 'Hardware');
""")
}
return db
}
@Test("DataChatView initializes from database + model in 2 lines")
@MainActor
func dataChatViewMinimalInit() throws {
// LINE 1: Create (or receive) a GRDB connection
let db = try Self.makeSampleDatabase()
// LINE 2: Create the view that's it!
let _ = DataChatView(database: db, model: MockLanguageModel())
// The view is ready. No schema files, no annotations, no extra config.
}
@Test("DataChatView path-based init works in 1 line given a path and model")
@MainActor
func dataChatViewPathInit() throws {
// Create a temp database file
let tempDir = FileManager.default.temporaryDirectory
let dbPath = tempDir.appendingPathComponent("test_\(UUID().uuidString).sqlite").path
let db = try DatabaseQueue(path: dbPath)
try db.write { db in
try db.execute(sql: "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
}
// ONE LINE to get a full chat UI:
let _ = DataChatView(databasePath: dbPath, model: MockLanguageModel())
// Cleanup
try? FileManager.default.removeItem(atPath: dbPath)
}
@Test("ChatEngine headless usage works in 3 lines")
func chatEngineMinimalUsage() async throws {
// LINE 1: Database
let db = try Self.makeSampleDatabase()
// LINE 2: Engine
let engine = ChatEngine(database: db, model: MockLanguageModel(responseText: "SELECT COUNT(*) AS total FROM products"))
// LINE 3: Schema preparation verifies auto-introspection works
let schema = try await engine.prepareSchema()
// The engine auto-discovered the schema no manual config needed
#expect(schema.tableNames.contains("products"))
#expect(schema.tableNames.count == 1)
}
@Test("ChatViewModel works with zero configuration beyond db + model")
@MainActor
func chatViewModelMinimalUsage() async throws {
let db = try Self.makeSampleDatabase()
let engine = ChatEngine(database: db, model: MockLanguageModel())
let viewModel = ChatViewModel(engine: engine)
// Prepare triggers auto-schema-introspection
await viewModel.prepare()
#expect(viewModel.schemaReadiness.isReady)
#expect(viewModel.messages.isEmpty) // Clean slate, ready to chat
}
@Test("Default configuration is read-only (safe by default)")
@MainActor
func defaultIsReadOnly() throws {
let db = try Self.makeSampleDatabase()
// No allowlist specified defaults to .readOnly
let _ = DataChatView(database: db, model: MockLanguageModel())
// This compiles and works. SELECT-only is the safe default.
// Developer must explicitly opt in to writes:
// DataChatView(database: db, model: model, allowlist: .standard)
}
@Test("Full DataChatView with all options still under 10 lines")
@MainActor
func dataChatViewFullConfig() throws {
let db = try Self.makeSampleDatabase() // 1
let model = MockLanguageModel() // 2
let _ = DataChatView( // 3-8
database: db,
model: model,
allowlist: .readOnly,
additionalContext: "Product catalog for an e-commerce store",
maxSummaryRows: 100
)
// Even with ALL options specified, it's under 10 lines of setup.
}
}

View File

@@ -0,0 +1,285 @@
// DataTableTests.swift
// SwiftDBAITests
import Foundation
import Testing
@testable import SwiftDBAI
@Suite("DataTable")
struct DataTableTests {
// MARK: - Helpers
private func makeQueryResult(
columns: [String],
rows: [[String: QueryResult.Value]],
sql: String = "SELECT * FROM test",
executionTime: TimeInterval = 0.01
) -> QueryResult {
QueryResult(
columns: columns,
rows: rows,
sql: sql,
executionTime: executionTime
)
}
// MARK: - Basic Construction
@Test("Converts QueryResult columns and rows correctly")
func basicConversion() {
let result = makeQueryResult(
columns: ["id", "name", "score"],
rows: [
["id": .integer(1), "name": .text("Alice"), "score": .real(95.5)],
["id": .integer(2), "name": .text("Bob"), "score": .real(87.0)],
]
)
let table = DataTable(result)
#expect(table.columnCount == 3)
#expect(table.rowCount == 2)
#expect(table.columnNames == ["id", "name", "score"])
#expect(table.sql == "SELECT * FROM test")
#expect(table.executionTime == 0.01)
}
@Test("Empty result produces empty table")
func emptyResult() {
let result = makeQueryResult(columns: ["id", "name"], rows: [])
let table = DataTable(result)
#expect(table.isEmpty)
#expect(table.rowCount == 0)
#expect(table.columnCount == 2)
#expect(table.columnNames == ["id", "name"])
}
// MARK: - Subscript Access
@Test("Subscript by row and column index")
func subscriptByIndex() {
let result = makeQueryResult(
columns: ["a", "b"],
rows: [
["a": .integer(10), "b": .text("hello")],
["a": .integer(20), "b": .text("world")],
]
)
let table = DataTable(result)
#expect(table[row: 0, column: 0] == .integer(10))
#expect(table[row: 0, column: 1] == .text("hello"))
#expect(table[row: 1, column: 0] == .integer(20))
#expect(table[row: 1, column: 1] == .text("world"))
}
@Test("Subscript by row index and column name")
func subscriptByName() {
let result = makeQueryResult(
columns: ["x", "y"],
rows: [["x": .real(1.5), "y": .real(2.5)]]
)
let table = DataTable(result)
#expect(table[row: 0, column: "x"] == .real(1.5))
#expect(table[row: 0, column: "y"] == .real(2.5))
#expect(table[row: 0, column: "z"] == .null) // non-existent column
}
// MARK: - Column Data Extraction
@Test("Extract column values by index")
func columnValuesByIndex() {
let result = makeQueryResult(
columns: ["val"],
rows: [
["val": .integer(1)],
["val": .integer(2)],
["val": .integer(3)],
]
)
let table = DataTable(result)
let values = table.columnValues(at: 0)
#expect(values == [.integer(1), .integer(2), .integer(3)])
}
@Test("Extract column values by name")
func columnValuesByName() {
let result = makeQueryResult(
columns: ["name"],
rows: [
["name": .text("A")],
["name": .text("B")],
]
)
let table = DataTable(result)
#expect(table.columnValues(named: "name") == [.text("A"), .text("B")])
#expect(table.columnValues(named: "missing").isEmpty)
}
@Test("numericValues extracts doubles from numeric column")
func numericValues() {
let result = makeQueryResult(
columns: ["score"],
rows: [
["score": .integer(10)],
["score": .real(20.5)],
["score": .null],
["score": .text("not a number")],
]
)
let table = DataTable(result)
let nums = table.numericValues(forColumn: "score")
#expect(nums.count == 2)
#expect(nums[0] == 10.0)
#expect(nums[1] == 20.5)
}
@Test("stringValues extracts non-null strings")
func stringValues() {
let result = makeQueryResult(
columns: ["label"],
rows: [
["label": .text("foo")],
["label": .null],
["label": .text("bar")],
]
)
let table = DataTable(result)
let strs = table.stringValues(forColumn: "label")
#expect(strs == ["foo", "bar"])
}
// MARK: - Type Inference
@Test("Infers integer type for all-integer column")
func inferInteger() {
let result = makeQueryResult(
columns: ["id"],
rows: [["id": .integer(1)], ["id": .integer(2)]]
)
let table = DataTable(result)
#expect(table.columns[0].inferredType == .integer)
}
@Test("Infers real type for all-real column")
func inferReal() {
let result = makeQueryResult(
columns: ["price"],
rows: [["price": .real(1.99)], ["price": .real(2.50)]]
)
let table = DataTable(result)
#expect(table.columns[0].inferredType == .real)
}
@Test("Infers text type for all-text column")
func inferText() {
let result = makeQueryResult(
columns: ["name"],
rows: [["name": .text("A")], ["name": .text("B")]]
)
let table = DataTable(result)
#expect(table.columns[0].inferredType == .text)
}
@Test("Promotes integer + real to real")
func inferNumericPromotion() {
let result = makeQueryResult(
columns: ["val"],
rows: [["val": .integer(1)], ["val": .real(2.5)]]
)
let table = DataTable(result)
#expect(table.columns[0].inferredType == .real)
}
@Test("Mixed types result in .mixed")
func inferMixed() {
let result = makeQueryResult(
columns: ["data"],
rows: [["data": .integer(1)], ["data": .text("hello")]]
)
let table = DataTable(result)
#expect(table.columns[0].inferredType == .mixed)
}
@Test("All-null column infers .null")
func inferNull() {
let result = makeQueryResult(
columns: ["empty"],
rows: [["empty": .null], ["empty": .null]]
)
let table = DataTable(result)
#expect(table.columns[0].inferredType == .null)
}
@Test("Null values are ignored during type inference")
func inferIgnoresNulls() {
let result = makeQueryResult(
columns: ["val"],
rows: [["val": .integer(1)], ["val": .null], ["val": .integer(3)]]
)
let table = DataTable(result)
#expect(table.columns[0].inferredType == .integer)
}
// MARK: - Missing Values
@Test("Missing dictionary keys become .null")
func missingKeysBecomNull() {
let result = makeQueryResult(
columns: ["a", "b"],
rows: [["a": .integer(1)]] // "b" is missing
)
let table = DataTable(result)
#expect(table[row: 0, column: 0] == .integer(1))
#expect(table[row: 0, column: 1] == .null)
}
// MARK: - Row Identity
@Test("Rows have sequential IDs")
func rowIdentity() {
let result = makeQueryResult(
columns: ["x"],
rows: [["x": .integer(1)], ["x": .integer(2)], ["x": .integer(3)]]
)
let table = DataTable(result)
#expect(table.rows[0].id == 0)
#expect(table.rows[1].id == 1)
#expect(table.rows[2].id == 2)
}
// MARK: - Column Identity
@Test("Columns are Identifiable by name")
func columnIdentity() {
let result = makeQueryResult(
columns: ["alpha", "beta"],
rows: [["alpha": .integer(1), "beta": .integer(2)]]
)
let table = DataTable(result)
#expect(table.columns[0].id == "alpha")
#expect(table.columns[1].id == "beta")
#expect(table.columns[0].index == 0)
#expect(table.columns[1].index == 1)
}
}

View File

@@ -0,0 +1,317 @@
// DatabaseToolTests.swift
// SwiftDBAI
import Testing
import Foundation
import GRDB
@testable import SwiftDBAI
@Suite("DatabaseTool")
struct DatabaseToolTests {
// MARK: - Helper
/// Creates an in-memory database with sample data for testing.
private func makeTestDatabase() throws -> DatabaseQueue {
let db = try DatabaseQueue(configuration: {
var config = Configuration()
config.foreignKeysEnabled = true
return config
}())
try db.write { db in
try db.execute(sql: """
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
email TEXT UNIQUE
);
""")
try db.execute(sql: """
CREATE TABLE posts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id),
title TEXT NOT NULL,
body TEXT,
created_at TEXT DEFAULT CURRENT_TIMESTAMP
);
""")
try db.execute(sql: """
CREATE INDEX idx_posts_user ON posts(user_id);
""")
// Insert sample data
try db.execute(sql: "INSERT INTO users (name, email) VALUES ('Alice', 'alice@example.com')")
try db.execute(sql: "INSERT INTO users (name, email) VALUES ('Bob', 'bob@example.com')")
try db.execute(sql: "INSERT INTO posts (user_id, title, body) VALUES (1, 'Hello World', 'First post')")
try db.execute(sql: "INSERT INTO posts (user_id, title, body) VALUES (1, 'Second Post', 'More content')")
try db.execute(sql: "INSERT INTO posts (user_id, title, body) VALUES (2, 'Bob Post', 'Bob writes')")
}
return db
}
// MARK: - Creation
@Test("Creates tool from database connection")
func testCreationFromDatabase() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
#expect(tool.name == "execute_sql")
#expect(!tool.description.isEmpty)
}
@Test("Creates tool from database path")
func testCreationFromPath() async throws {
let tmpDir = FileManager.default.temporaryDirectory
let dbPath = tmpDir.appendingPathComponent("test_\(UUID().uuidString).sqlite").path
defer { try? FileManager.default.removeItem(atPath: dbPath) }
// Create a database at the path
let dbQueue = try DatabaseQueue(path: dbPath)
try await dbQueue.write { db in
try db.execute(sql: "CREATE TABLE test (id INTEGER PRIMARY KEY)")
}
let tool = try await DatabaseTool(databasePath: dbPath)
#expect(tool.name == "execute_sql")
}
// MARK: - Schema Context
@Test("Schema context contains table info")
func testSchemaContext() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let context = tool.schemaContext
#expect(context.contains("users"))
#expect(context.contains("posts"))
#expect(context.contains("name"))
#expect(context.contains("email"))
}
@Test("System prompt snippet contains schema")
func testSystemPromptSnippet() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let snippet = tool.systemPromptSnippet
#expect(snippet.contains("users"))
#expect(snippet.contains("posts"))
#expect(snippet.contains("execute_sql"))
#expect(snippet.contains("SELECT"))
}
// MARK: - SQL Execution
@Test("Executes valid SELECT query")
func testExecuteSelect() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let result = try tool.execute(sql: "SELECT name, email FROM users ORDER BY name")
#expect(result.rowCount == 2)
#expect(result.columns == ["name", "email"])
#expect(result.rows[0]["name"] == "Alice")
#expect(result.rows[1]["name"] == "Bob")
}
@Test("Executes query with JOIN")
func testExecuteJoin() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let result = try tool.execute(sql: """
SELECT u.name, COUNT(p.id) as post_count
FROM users u
JOIN posts p ON p.user_id = u.id
GROUP BY u.name
ORDER BY u.name
""")
#expect(result.rowCount == 2)
#expect(result.rows[0]["name"] == "Alice")
#expect(result.rows[0]["post_count"] == "2")
}
@Test("Rejects INSERT with read-only allowlist")
func testRejectInsert() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db, allowlist: .readOnly)
#expect(throws: SQLParsingError.self) {
try tool.execute(sql: "INSERT INTO users (name, email) VALUES ('Eve', 'eve@example.com')")
}
}
@Test("Rejects DELETE with read-only allowlist")
func testRejectDelete() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db, allowlist: .readOnly)
#expect(throws: SQLParsingError.self) {
try tool.execute(sql: "DELETE FROM users WHERE id = 1")
}
}
@Test("Rejects DROP as dangerous operation")
func testRejectDrop() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db, allowlist: .unrestricted)
#expect(throws: SQLParsingError.self) {
try tool.execute(sql: "DROP TABLE users")
}
}
@Test("Executes raw query returning QueryResult")
func testExecuteRaw() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let result = try tool.executeRaw(sql: "SELECT COUNT(*) as cnt FROM users")
#expect(result.rowCount == 1)
#expect(result.columns == ["cnt"])
if case .integer(let count) = result.rows[0]["cnt"] {
#expect(count == 2)
} else {
Issue.record("Expected integer value")
}
}
// MARK: - ToolResult Formatting
@Test("ToolResult JSON serialization")
func testToolResultJSON() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let result = try tool.execute(sql: "SELECT name FROM users ORDER BY name LIMIT 1")
let json = result.jsonString
#expect(json.contains("\"columns\""))
#expect(json.contains("\"rows\""))
#expect(json.contains("\"row_count\""))
#expect(json.contains("Alice"))
// Verify it is valid JSON
let data = json.data(using: .utf8)!
let parsed = try JSONSerialization.jsonObject(with: data) as! [String: Any]
#expect(parsed["row_count"] as? Int == 1)
}
@Test("ToolResult markdown table formatting")
func testToolResultMarkdown() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let result = try tool.execute(sql: "SELECT name, email FROM users ORDER BY name")
let md = result.markdownTable
#expect(md.contains("| name | email |"))
#expect(md.contains("| --- | --- |"))
#expect(md.contains("| Alice | alice@example.com |"))
#expect(md.contains("| Bob | bob@example.com |"))
}
@Test("ToolResult markdown table with empty result")
func testToolResultMarkdownEmpty() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let result = try tool.execute(sql: "SELECT name FROM users WHERE name = 'Nobody'")
#expect(result.markdownTable == "_No results._")
}
@Test("ToolResult text summary")
func testToolResultTextSummary() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let result = try tool.execute(sql: "SELECT name FROM users")
let summary = result.textSummary
#expect(summary.contains("2 rows"))
#expect(summary.contains("name"))
}
@Test("ToolResult text summary with empty result")
func testToolResultTextSummaryEmpty() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let result = try tool.execute(sql: "SELECT name FROM users WHERE 1 = 0")
#expect(result.textSummary.contains("no results"))
}
// MARK: - Parameters Schema
@Test("Parameters schema has correct structure")
func testParametersSchema() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let schema = tool.parametersSchema
#expect(schema["type"] as? String == "object")
let properties = schema["properties"] as? [String: Any]
#expect(properties != nil)
let sqlProp = properties?["sql"] as? [String: Any]
#expect(sqlProp?["type"] as? String == "string")
let required = schema["required"] as? [String]
#expect(required == ["sql"])
}
// MARK: - OpenAI Function Definition
@Test("OpenAI function definition has correct format")
func testOpenAIFunctionDefinition() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let def = tool.openAIFunctionDefinition
#expect(def["type"] as? String == "function")
let function = def["function"] as? [String: Any]
#expect(function?["name"] as? String == "execute_sql")
#expect(function?["description"] as? String != nil)
#expect(function?["parameters"] as? [String: Any] != nil)
// Verify it can be serialized to JSON
let data = try JSONSerialization.data(withJSONObject: def)
#expect(data.count > 0)
}
// MARK: - ToolResult Codable
@Test("ToolResult is Codable")
func testToolResultCodable() throws {
let result = ToolResult(
columns: ["name", "age"],
rows: [["name": "Alice", "age": "30"]],
rowCount: 1,
executionTime: 0.005,
sql: "SELECT name, age FROM users"
)
let encoder = JSONEncoder()
let data = try encoder.encode(result)
let decoder = JSONDecoder()
let decoded = try decoder.decode(ToolResult.self, from: data)
#expect(decoded.columns == result.columns)
#expect(decoded.rows == result.rows)
#expect(decoded.rowCount == result.rowCount)
#expect(decoded.sql == result.sql)
}
}

View File

@@ -0,0 +1,745 @@
// DestructiveOperationTests.swift
// SwiftDBAITests
//
// Tests verifying that destructive operations are blocked without confirmation
// and allowed when the delegate approves.
import AnyLanguageModel
import Foundation
import GRDB
import Testing
@testable import SwiftDBAI
// MARK: - Test Delegates
/// A delegate that always rejects destructive operations and tracks calls.
private final class RejectingTrackingDelegate: SwiftDBAI.ToolExecutionDelegate, @unchecked Sendable {
private let lock = NSLock()
private var _confirmCalls: [DestructiveOperationContext] = []
private var _willExecuteCalls: [(sql: String, classification: DestructiveClassification)] = []
private var _didExecuteCalls: [(sql: String, success: Bool)] = []
var confirmCalls: [DestructiveOperationContext] {
lock.withLock { _confirmCalls }
}
var willExecuteCalls: [(sql: String, classification: DestructiveClassification)] {
lock.withLock { _willExecuteCalls }
}
var didExecuteCalls: [(sql: String, success: Bool)] {
lock.withLock { _didExecuteCalls }
}
func confirmDestructiveOperation(_ context: DestructiveOperationContext) async -> Bool {
lock.withLock { _confirmCalls.append(context) }
return false
}
func willExecuteSQL(_ sql: String, classification: DestructiveClassification) async {
lock.withLock { _willExecuteCalls.append((sql: sql, classification: classification)) }
}
func didExecuteSQL(_ sql: String, success: Bool) async {
lock.withLock { _didExecuteCalls.append((sql: sql, success: success)) }
}
}
/// A delegate that always approves destructive operations and tracks calls.
private final class ApprovingTrackingDelegate: SwiftDBAI.ToolExecutionDelegate, @unchecked Sendable {
private let lock = NSLock()
private var _confirmCalls: [DestructiveOperationContext] = []
private var _willExecuteCalls: [(sql: String, classification: DestructiveClassification)] = []
private var _didExecuteCalls: [(sql: String, success: Bool)] = []
var confirmCalls: [DestructiveOperationContext] {
lock.withLock { _confirmCalls }
}
var willExecuteCalls: [(sql: String, classification: DestructiveClassification)] {
lock.withLock { _willExecuteCalls }
}
var didExecuteCalls: [(sql: String, success: Bool)] {
lock.withLock { _didExecuteCalls }
}
func confirmDestructiveOperation(_ context: DestructiveOperationContext) async -> Bool {
lock.withLock { _confirmCalls.append(context) }
return true
}
func willExecuteSQL(_ sql: String, classification: DestructiveClassification) async {
lock.withLock { _willExecuteCalls.append((sql: sql, classification: classification)) }
}
func didExecuteSQL(_ sql: String, success: Bool) async {
lock.withLock { _didExecuteCalls.append((sql: sql, success: success)) }
}
}
// MARK: - Helpers
/// Creates an in-memory database with test data for destructive operation tests.
/// Users 1 and 2 have orders; user 3 has no orders (safe to delete).
private func makeTestDatabase() throws -> DatabaseQueue {
let db = try DatabaseQueue(path: ":memory:")
try db.write { db in
// Disable FK enforcement for test flexibility, then re-enable
try db.execute(sql: "PRAGMA foreign_keys = OFF")
try db.execute(sql: """
CREATE TABLE users (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
email TEXT NOT NULL
)
""")
try db.execute(sql: """
INSERT INTO users (name, email) VALUES
('Alice', 'alice@example.com'),
('Bob', 'bob@example.com'),
('Charlie', 'charlie@example.com')
""")
try db.execute(sql: """
CREATE TABLE orders (
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
amount REAL NOT NULL
)
""")
try db.execute(sql: """
INSERT INTO orders (user_id, amount) VALUES
(1, 99.99),
(2, 150.00),
(3, 25.50)
""")
}
return db
}
/// A sequential mock model for tests. Returns responses in order.
private struct TestSequentialModel: LanguageModel {
typealias UnavailableReason = Never
let responses: [String]
private let callCounter = CallCounter()
private final class CallCounter: @unchecked Sendable {
var count = 0
let lock = NSLock()
func next() -> Int {
lock.lock()
defer { lock.unlock() }
let c = count
count += 1
return c
}
}
init(responses: [String]) {
self.responses = responses
}
func respond<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
includeSchemaInPrompt: Bool,
options: GenerationOptions
) async throws -> LanguageModelSession.Response<Content> where Content: Generable {
let idx = callCounter.next()
let text = idx < responses.count ? responses[idx] : "fallback response"
let rawContent = GeneratedContent(kind: .string(text))
let content = try Content(rawContent)
return LanguageModelSession.Response(
content: content,
rawContent: rawContent,
transcriptEntries: [][...]
)
}
func streamResponse<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
includeSchemaInPrompt: Bool,
options: GenerationOptions
) -> sending LanguageModelSession.ResponseStream<Content> where Content: Generable {
let idx = callCounter.next()
let text = idx < responses.count ? responses[idx] : "fallback response"
let rawContent = GeneratedContent(kind: .string(text))
let content = try! Content(rawContent)
return LanguageModelSession.ResponseStream(content: content, rawContent: rawContent)
}
}
// MARK: - Tests: Destructive Operations Blocked Without Confirmation
@Suite("Destructive Operations - Blocked Without Confirmation")
struct DestructiveOperationsBlockedTests {
@Test("DELETE is blocked when no delegate is provided")
func deleteBlockedWithoutDelegate() async throws {
let db = try makeTestDatabase()
let model = TestSequentialModel(responses: [
"DELETE FROM users WHERE id = 1"
])
// Unrestricted allowlist permits DELETE, but no delegate to confirm
let engine = ChatEngine(
database: db,
model: model,
allowlist: .unrestricted
)
do {
_ = try await engine.send("Delete user 1")
Issue.record("Expected confirmationRequired error but send succeeded")
} catch let error as SwiftDBAIError {
guard case .confirmationRequired(let sql, let operation) = error else {
Issue.record("Expected confirmationRequired, got: \(error)")
return
}
#expect(sql.uppercased().contains("DELETE"))
#expect(operation == "delete")
}
// Verify the user was NOT deleted (data remains intact)
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 1")
}
#expect(count == 1, "User should NOT have been deleted")
}
@Test("DELETE is blocked when delegate rejects")
func deleteBlockedWhenDelegateRejects() async throws {
let db = try makeTestDatabase()
let delegate = RejectingTrackingDelegate()
let model = TestSequentialModel(responses: [
"DELETE FROM users WHERE id = 2"
])
let engine = ChatEngine(
database: db,
model: model,
allowlist: .unrestricted,
delegate: delegate
)
do {
_ = try await engine.send("Delete user 2")
Issue.record("Expected confirmationRequired error but send succeeded")
} catch let error as SwiftDBAIError {
guard case .confirmationRequired(let sql, let operation) = error else {
Issue.record("Expected confirmationRequired, got: \(error)")
return
}
#expect(sql.uppercased().contains("DELETE"))
#expect(operation == "delete")
}
// Verify delegate was consulted
#expect(delegate.confirmCalls.count == 1)
#expect(delegate.confirmCalls[0].statementKind == .delete)
#expect(delegate.confirmCalls[0].sql.uppercased().contains("DELETE"))
// Verify the data was NOT modified
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 2")
}
#expect(count == 1, "User should NOT have been deleted")
// Verify no SQL was actually executed (no willExecute/didExecute calls)
#expect(delegate.willExecuteCalls.isEmpty, "No SQL should have been executed")
#expect(delegate.didExecuteCalls.isEmpty, "No SQL should have been executed")
}
@Test("DELETE is blocked with MutationPolicy and no delegate")
func deleteBlockedWithMutationPolicyNoDelegate() async throws {
let db = try makeTestDatabase()
let model = TestSequentialModel(responses: [
"DELETE FROM users WHERE id = 3"
])
let policy = MutationPolicy(
allowedOperations: [.insert, .update, .delete],
requiresDestructiveConfirmation: true
)
let engine = ChatEngine(
database: db,
model: model,
mutationPolicy: policy
)
do {
_ = try await engine.send("Delete user 3")
Issue.record("Expected confirmationRequired error but send succeeded")
} catch let error as SwiftDBAIError {
guard case .confirmationRequired(let sql, let operation) = error else {
Issue.record("Expected confirmationRequired, got: \(error)")
return
}
#expect(sql.uppercased().contains("DELETE"))
#expect(operation == "delete")
}
// Data intact
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 3")
}
#expect(count == 1, "User should NOT have been deleted")
}
@Test("DELETE is blocked with MutationPolicy and rejecting delegate")
func deleteBlockedWithMutationPolicyRejectingDelegate() async throws {
let db = try makeTestDatabase()
let delegate = RejectingTrackingDelegate()
let model = TestSequentialModel(responses: [
"DELETE FROM orders WHERE user_id = 1"
])
let policy = MutationPolicy(
allowedOperations: [.insert, .update, .delete],
requiresDestructiveConfirmation: true
)
let engine = ChatEngine(
database: db,
model: model,
mutationPolicy: policy,
delegate: delegate
)
do {
_ = try await engine.send("Delete all orders for user 1")
Issue.record("Expected confirmationRequired error but send succeeded")
} catch let error as SwiftDBAIError {
guard case .confirmationRequired = error else {
Issue.record("Expected confirmationRequired, got: \(error)")
return
}
}
// Delegate was consulted and rejected
#expect(delegate.confirmCalls.count == 1)
#expect(delegate.confirmCalls[0].statementKind == .delete)
// Orders remain
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM orders WHERE user_id = 1")
}
#expect(count == 1, "Orders should NOT have been deleted")
}
@Test("Default delegate implementation rejects destructive operations")
func defaultDelegateRejectsDestructive() async {
struct DefaultDelegate: SwiftDBAI.ToolExecutionDelegate {}
let delegate = DefaultDelegate()
let context = DestructiveOperationContext(
sql: "DELETE FROM users WHERE id = 1",
statementKind: .delete,
classification: .destructive(.delete),
description: "Delete from users"
)
let approved = await delegate.confirmDestructiveOperation(context)
#expect(approved == false, "Default delegate should reject destructive operations")
}
@Test("DELETE not in readOnly allowlist is rejected before delegate is consulted")
func deleteNotInAllowlistRejectedEarly() async throws {
let db = try makeTestDatabase()
let delegate = ApprovingTrackingDelegate()
let model = TestSequentialModel(responses: [
"DELETE FROM users WHERE id = 1"
])
// Read-only allowlist does NOT include DELETE
let engine = ChatEngine(
database: db,
model: model,
allowlist: .readOnly,
delegate: delegate
)
do {
_ = try await engine.send("Delete user 1")
Issue.record("Expected operationNotAllowed error")
} catch let error as SwiftDBAIError {
guard case .operationNotAllowed(let operation) = error else {
Issue.record("Expected operationNotAllowed, got: \(error)")
return
}
#expect(operation == "delete")
}
// Delegate should NOT have been consulted the allowlist rejects before delegation
#expect(delegate.confirmCalls.isEmpty, "Delegate should not be consulted when op is not in allowlist")
}
}
// MARK: - Tests: Destructive Operations Allowed When Delegate Approves
@Suite("Destructive Operations - Allowed When Delegate Approves")
struct DestructiveOperationsAllowedTests {
@Test("DELETE succeeds when delegate approves")
func deleteSucceedsWithApprovingDelegate() async throws {
let db = try makeTestDatabase()
let delegate = ApprovingTrackingDelegate()
let model = TestSequentialModel(responses: [
"DELETE FROM users WHERE id = 1",
"Successfully deleted 1 user."
])
let engine = ChatEngine(
database: db,
model: model,
allowlist: .unrestricted,
delegate: delegate
)
let response = try await engine.send("Delete user 1")
// Delegate was consulted and approved
#expect(delegate.confirmCalls.count == 1)
#expect(delegate.confirmCalls[0].statementKind == .delete)
#expect(delegate.confirmCalls[0].sql.uppercased().contains("DELETE"))
#expect(delegate.confirmCalls[0].targetTable == "users")
// SQL was executed
#expect(delegate.willExecuteCalls.count == 1)
#expect(delegate.didExecuteCalls.count == 1)
#expect(delegate.didExecuteCalls[0].success == true)
// Verify the data was actually deleted
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 1")
}
#expect(count == 0, "User should have been deleted")
// Response should contain meaningful content
#expect(response.sql?.uppercased().contains("DELETE") == true)
#expect(response.queryResult != nil)
}
@Test("DELETE with MutationPolicy succeeds when delegate approves")
func deleteWithPolicySucceedsWhenApproved() async throws {
let db = try makeTestDatabase()
let delegate = ApprovingTrackingDelegate()
let model = TestSequentialModel(responses: [
"DELETE FROM orders WHERE user_id = 2",
"Deleted 1 order."
])
let policy = MutationPolicy(
allowedOperations: [.insert, .update, .delete],
requiresDestructiveConfirmation: true
)
let engine = ChatEngine(
database: db,
model: model,
mutationPolicy: policy,
delegate: delegate
)
let response = try await engine.send("Delete all orders for user 2")
// Delegate approved
#expect(delegate.confirmCalls.count == 1)
// Data was actually deleted
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM orders WHERE user_id = 2")
}
#expect(count == 0, "Orders should have been deleted")
#expect(response.sql?.uppercased().contains("DELETE") == true)
}
@Test("AutoApproveDelegate allows DELETE without user interaction")
func autoApproveDelegateAllowsDelete() async throws {
let db = try makeTestDatabase()
let delegate = AutoApproveDelegate()
let model = TestSequentialModel(responses: [
"DELETE FROM users WHERE id = 3",
"Deleted 1 user."
])
let engine = ChatEngine(
database: db,
model: model,
allowlist: .unrestricted,
delegate: delegate
)
let response = try await engine.send("Delete user 3")
// Should succeed without error
#expect(response.sql?.uppercased().contains("DELETE") == true)
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 3")
}
#expect(count == 0, "User should have been deleted")
}
@Test("sendConfirmed bypasses delegate and executes directly")
func sendConfirmedBypassesDelegate() async throws {
let db = try makeTestDatabase()
let delegate = RejectingTrackingDelegate()
let model = TestSequentialModel(responses: [
"Deleted 1 user."
])
let engine = ChatEngine(
database: db,
model: model,
allowlist: .unrestricted,
delegate: delegate
)
// sendConfirmed should execute directly without consulting the delegate for confirmation
let response = try await engine.sendConfirmed(
"Delete user 1",
confirmedSQL: "DELETE FROM users WHERE id = 1"
)
// Delegate was NOT asked to confirm (sendConfirmed skips confirmation)
#expect(delegate.confirmCalls.isEmpty)
// But willExecute/didExecute hooks were still called
#expect(delegate.willExecuteCalls.count == 1)
#expect(delegate.didExecuteCalls.count == 1)
#expect(delegate.didExecuteCalls[0].success == true)
// Data was deleted
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 1")
}
#expect(count == 0)
#expect(response.summary.contains("deleted") || response.summary.contains("Deleted") || response.summary.contains("1"))
}
}
// MARK: - Tests: Delegate Context Correctness
@Suite("Destructive Operations - Delegate Context")
struct DestructiveOperationContextTests {
@Test("Delegate receives correct context for DELETE on specific table")
func delegateReceivesCorrectContext() async throws {
let db = try makeTestDatabase()
let delegate = RejectingTrackingDelegate()
let model = TestSequentialModel(responses: [
"DELETE FROM orders WHERE amount < 50"
])
let engine = ChatEngine(
database: db,
model: model,
allowlist: .unrestricted,
delegate: delegate
)
do {
_ = try await engine.send("Delete cheap orders")
Issue.record("Expected confirmationRequired error")
} catch is SwiftDBAIError {
// Expected
}
#expect(delegate.confirmCalls.count == 1)
let ctx = delegate.confirmCalls[0]
#expect(ctx.statementKind == .delete)
#expect(ctx.classification == .destructive(.delete))
#expect(ctx.classification.requiresConfirmation == true)
#expect(ctx.sql.uppercased().contains("DELETE FROM ORDERS"))
#expect(ctx.targetTable == "orders")
#expect(!ctx.description.isEmpty)
}
@Test("Non-destructive operations do not consult delegate")
func selectDoesNotConsultDelegate() async throws {
let db = try makeTestDatabase()
let delegate = ApprovingTrackingDelegate()
let model = TestSequentialModel(responses: [
"SELECT COUNT(*) FROM users",
"There are 3 users."
])
let engine = ChatEngine(
database: db,
model: model,
allowlist: .unrestricted,
delegate: delegate
)
_ = try await engine.send("How many users?")
// Delegate should NOT have been asked to confirm (SELECT is not destructive)
#expect(delegate.confirmCalls.isEmpty)
// But willExecute/didExecute should still be called (observation hooks)
#expect(delegate.willExecuteCalls.count == 1)
#expect(delegate.didExecuteCalls.count == 1)
}
@Test("INSERT does not require confirmation even with delegate")
func insertDoesNotRequireConfirmation() async throws {
let db = try makeTestDatabase()
let delegate = RejectingTrackingDelegate()
let model = TestSequentialModel(responses: [
"INSERT INTO users (name, email) VALUES ('Dave', 'dave@example.com')",
"Inserted 1 row."
])
let engine = ChatEngine(
database: db,
model: model,
allowlist: .standard,
delegate: delegate
)
let response = try await engine.send("Add user Dave")
// No confirmation needed for INSERT
#expect(delegate.confirmCalls.isEmpty)
#expect(response.sql?.uppercased().contains("INSERT") == true)
// Verify the insert happened
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE name = 'Dave'")
}
#expect(count == 1)
}
@Test("UPDATE does not require confirmation even with delegate")
func updateDoesNotRequireConfirmation() async throws {
let db = try makeTestDatabase()
let delegate = RejectingTrackingDelegate()
let model = TestSequentialModel(responses: [
"UPDATE users SET email = 'alice-new@example.com' WHERE id = 1",
"Updated 1 row."
])
let engine = ChatEngine(
database: db,
model: model,
allowlist: .standard,
delegate: delegate
)
let response = try await engine.send("Update Alice's email")
// No confirmation needed for UPDATE
#expect(delegate.confirmCalls.isEmpty)
#expect(response.sql?.uppercased().contains("UPDATE") == true)
}
}
// MARK: - Tests: MutationPolicy Confirmation Flag
@Suite("Destructive Operations - MutationPolicy Confirmation Control")
struct MutationPolicyConfirmationTests {
@Test("DELETE skips confirmation when requiresDestructiveConfirmation is false")
func deleteSkipsConfirmationWhenDisabled() async throws {
let db = try makeTestDatabase()
let delegate = RejectingTrackingDelegate()
let model = TestSequentialModel(responses: [
"DELETE FROM users WHERE id = 1",
"Deleted 1 user."
])
let policy = MutationPolicy(
allowedOperations: [.insert, .update, .delete],
requiresDestructiveConfirmation: false // Explicitly disabled
)
let engine = ChatEngine(
database: db,
model: model,
mutationPolicy: policy,
delegate: delegate
)
// Should succeed without confirmation since the policy disables it
let response = try await engine.send("Delete user 1")
// Delegate should NOT have been consulted for confirmation
#expect(delegate.confirmCalls.isEmpty)
// But the SQL should have executed
#expect(response.sql?.uppercased().contains("DELETE") == true)
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 1")
}
#expect(count == 0, "User should have been deleted without confirmation")
}
@Test("MutationPolicy.requiresConfirmation only triggers for DELETE")
func requiresConfirmationOnlyForDelete() {
let policy = MutationPolicy(
allowedOperations: [.insert, .update, .delete],
requiresDestructiveConfirmation: true
)
#expect(policy.requiresConfirmation(for: .delete) == true)
#expect(policy.requiresConfirmation(for: .select) == false)
#expect(policy.requiresConfirmation(for: .insert) == false)
#expect(policy.requiresConfirmation(for: .update) == false)
}
@Test("MutationPolicy.readOnly never requires confirmation (no delete allowed)")
func readOnlyNeverRequiresConfirmation() {
let policy = MutationPolicy.readOnly
#expect(policy.requiresConfirmation(for: .select) == false)
#expect(policy.requiresConfirmation(for: .delete) == true) // Would require confirmation IF allowed
#expect(policy.isOperationAllowed(.delete) == false) // But it's not allowed at all
}
@Test("Table-restricted DELETE is blocked for disallowed tables")
func tableRestrictedDeleteBlocked() async throws {
let db = try makeTestDatabase()
let model = TestSequentialModel(responses: [
"DELETE FROM users WHERE id = 1"
])
let policy = MutationPolicy(
allowedOperations: [.insert, .update, .delete],
allowedTables: ["orders"], // Only orders, NOT users
requiresDestructiveConfirmation: true
)
let engine = ChatEngine(
database: db,
model: model,
mutationPolicy: policy
)
do {
_ = try await engine.send("Delete user 1")
Issue.record("Expected tableNotAllowedForMutation error")
} catch let error as SwiftDBAIError {
guard case .tableNotAllowedForMutation(let tableName, let operation) = error else {
Issue.record("Expected tableNotAllowedForMutation, got: \(error)")
return
}
#expect(tableName == "users")
#expect(operation == "delete")
}
// User was not deleted
let count = try await db.read { db in
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 1")
}
#expect(count == 1)
}
}

View File

@@ -0,0 +1,49 @@
// MockLanguageModel.swift
// SwiftDBAI Tests
//
// A mock LanguageModel for unit tests that returns canned responses.
import AnyLanguageModel
import Foundation
/// A mock language model that returns a configurable canned response.
///
/// Used in tests to avoid hitting a real LLM provider.
struct MockLanguageModel: LanguageModel {
typealias UnavailableReason = Never
/// The text the mock will return from `respond(...)`.
let responseText: String
init(responseText: String = "Mock summary response.") {
self.responseText = responseText
}
func respond<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
includeSchemaInPrompt: Bool,
options: GenerationOptions
) async throws -> LanguageModelSession.Response<Content> where Content: Generable {
let rawContent = GeneratedContent(kind: .string(responseText))
let content = try Content(rawContent)
return LanguageModelSession.Response(
content: content,
rawContent: rawContent,
transcriptEntries: [][...]
)
}
func streamResponse<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
includeSchemaInPrompt: Bool,
options: GenerationOptions
) -> sending LanguageModelSession.ResponseStream<Content> where Content: Generable {
let rawContent = GeneratedContent(kind: .string(responseText))
let content = try! Content(rawContent)
return LanguageModelSession.ResponseStream(content: content, rawContent: rawContent)
}
}

View File

@@ -0,0 +1,337 @@
// LocalProviderConfigurationTests.swift
// SwiftDBAI Tests
//
// Tests for local/self-hosted provider configurations (Ollama, llama.cpp):
// factory methods, endpoint discovery, connection handling, and model creation.
import AnyLanguageModel
import Foundation
import GRDB
@testable import SwiftDBAI
import Testing
@Suite("Local Provider Configuration")
struct LocalProviderConfigurationTests {
// MARK: - Ollama Configuration
@Test("Ollama configuration stores provider and model")
func ollamaBasicConfiguration() {
let config = ProviderConfiguration.ollama(model: "llama3.2")
#expect(config.provider == .ollama)
#expect(config.model == "llama3.2")
#expect(config.baseURL == OllamaLanguageModel.defaultBaseURL)
}
@Test("Ollama configuration produces OllamaLanguageModel")
func ollamaMakeModel() {
let config = ProviderConfiguration.ollama(model: "qwen2.5")
let model = config.makeModel()
#expect(model is OllamaLanguageModel)
}
@Test("Ollama with custom base URL for remote instance")
func ollamaCustomBaseURL() {
let remoteURL = URL(string: "http://192.168.1.100:11434")!
let config = ProviderConfiguration.ollama(
model: "mistral",
baseURL: remoteURL
)
#expect(config.baseURL == remoteURL)
#expect(config.provider == .ollama)
let model = config.makeModel()
#expect(model is OllamaLanguageModel)
}
@Test("Ollama does not require an API key")
func ollamaNoAPIKey() {
let config = ProviderConfiguration.ollama(model: "llama3.2")
// Ollama doesn't need an API key, so the key is empty
#expect(config.apiKey == "")
// hasValidAPIKey returns false because key is empty, but that's expected
// for local providers they don't need authentication
#expect(!config.hasValidAPIKey)
}
@Test("Ollama model is available without API key")
func ollamaModelAvailable() {
let config = ProviderConfiguration.ollama(model: "llama3.2")
let model = config.makeModel()
#expect(model.isAvailable)
}
// MARK: - llama.cpp Configuration
@Test("llama.cpp configuration stores provider and model")
func llamaCppBasicConfiguration() {
let config = ProviderConfiguration.llamaCpp(model: "my-model")
#expect(config.provider == .llamaCpp)
#expect(config.model == "my-model")
#expect(config.baseURL == LocalProviderDiscovery.defaultLlamaCppURL)
}
@Test("llama.cpp uses 'default' model name by default")
func llamaCppDefaultModel() {
let config = ProviderConfiguration.llamaCpp()
#expect(config.model == "default")
}
@Test("llama.cpp configuration produces OpenAILanguageModel (compatible API)")
func llamaCppMakeModel() {
let config = ProviderConfiguration.llamaCpp(model: "my-gguf")
let model = config.makeModel()
// llama.cpp uses OpenAI-compatible API
#expect(model is OpenAILanguageModel)
}
@Test("llama.cpp with custom base URL")
func llamaCppCustomBaseURL() {
let customURL = URL(string: "http://localhost:9090")!
let config = ProviderConfiguration.llamaCpp(
model: "custom-model",
baseURL: customURL
)
#expect(config.baseURL == customURL)
let model = config.makeModel()
#expect(model is OpenAILanguageModel)
}
@Test("llama.cpp with API key authentication")
func llamaCppWithAPIKey() {
let config = ProviderConfiguration.llamaCpp(
model: "secured-model",
apiKey: "my-secret-key"
)
#expect(config.apiKey == "my-secret-key")
#expect(config.hasValidAPIKey)
}
@Test("llama.cpp without API key")
func llamaCppNoAPIKey() {
let config = ProviderConfiguration.llamaCpp(model: "open-model")
#expect(config.apiKey == "")
}
// MARK: - Provider Enum
@Test("Provider enum includes ollama and llamaCpp cases")
func providerEnumHasLocalCases() {
let cases = ProviderConfiguration.Provider.allCases
#expect(cases.contains(.ollama))
#expect(cases.contains(.llamaCpp))
// Total: openAI, anthropic, gemini, openAICompatible, ollama, llamaCpp
#expect(cases.count == 6)
}
// MARK: - fromEnvironment
@Test("fromEnvironment creates Ollama configuration")
func fromEnvironmentOllama() {
let config = ProviderConfiguration.fromEnvironment(
provider: .ollama,
environmentVariable: "NONEXISTENT_OLLAMA_KEY",
model: "llama3.2"
)
#expect(config.provider == .ollama)
#expect(config.model == "llama3.2")
}
@Test("fromEnvironment creates llama.cpp configuration")
func fromEnvironmentLlamaCpp() {
let config = ProviderConfiguration.fromEnvironment(
provider: .llamaCpp,
environmentVariable: "NONEXISTENT_LLAMACPP_KEY",
model: "default"
)
#expect(config.provider == .llamaCpp)
#expect(config.model == "default")
}
// MARK: - ChatEngine Convenience Init with Local Providers
@Test("ChatEngine can be created with Ollama provider")
func chatEngineWithOllama() throws {
let dbQueue = try GRDB.DatabaseQueue()
let config = ProviderConfiguration.ollama(model: "llama3.2")
let engine = ChatEngine(database: dbQueue, provider: config)
#expect(engine.tableCount == nil) // schema not yet introspected
}
@Test("ChatEngine can be created with llama.cpp provider")
func chatEngineWithLlamaCpp() throws {
let dbQueue = try GRDB.DatabaseQueue()
let config = ProviderConfiguration.llamaCpp()
let engine = ChatEngine(database: dbQueue, provider: config)
#expect(engine.tableCount == nil)
}
// MARK: - LocalProviderType
@Test("LocalProviderType has expected raw values")
func localProviderTypeRawValues() {
#expect(LocalProviderType.ollama.rawValue == "ollama")
#expect(LocalProviderType.llamaCpp.rawValue == "llama.cpp")
}
@Test("LocalProviderType CaseIterable includes both cases")
func localProviderTypeCases() {
let cases = LocalProviderType.allCases
#expect(cases.count == 2)
#expect(cases.contains(.ollama))
#expect(cases.contains(.llamaCpp))
}
// MARK: - LocalProviderEndpoint
@Test("LocalProviderEndpoint description includes status and model count")
func endpointDescription() {
let endpoint = LocalProviderEndpoint(
baseURL: URL(string: "http://localhost:11434")!,
providerType: .ollama,
isReachable: true,
availableModels: ["llama3.2", "qwen2.5"]
)
#expect(endpoint.description.contains("ollama"))
#expect(endpoint.description.contains("reachable"))
#expect(endpoint.description.contains("2 models"))
}
@Test("LocalProviderEndpoint shows unreachable when not connected")
func endpointUnreachableDescription() {
let endpoint = LocalProviderEndpoint(
baseURL: URL(string: "http://localhost:8080")!,
providerType: .llamaCpp,
isReachable: false,
availableModels: []
)
#expect(endpoint.description.contains("unreachable"))
#expect(endpoint.description.contains("0 models"))
}
@Test("LocalProviderEndpoint equality works correctly")
func endpointEquality() {
let a = LocalProviderEndpoint(
baseURL: URL(string: "http://localhost:11434")!,
providerType: .ollama,
isReachable: true,
availableModels: ["llama3.2"]
)
let b = LocalProviderEndpoint(
baseURL: URL(string: "http://localhost:11434")!,
providerType: .ollama,
isReachable: true,
availableModels: ["llama3.2"]
)
let c = LocalProviderEndpoint(
baseURL: URL(string: "http://localhost:11434")!,
providerType: .ollama,
isReachable: false,
availableModels: []
)
#expect(a == b)
#expect(a != c)
}
// MARK: - Discovery (No Local Server Running)
@Test("Discovery returns unreachable when no server is running")
func discoveryUnreachableEndpoint() async {
// Use a port that's almost certainly not running anything
let endpoint = await LocalProviderDiscovery.discover(
providerType: .ollama,
host: "127.0.0.1",
port: 59999,
timeout: 1
)
#expect(!endpoint.isReachable)
#expect(endpoint.availableModels.isEmpty)
#expect(endpoint.providerType == .ollama)
}
@Test("isOllamaRunning returns false for unreachable endpoint")
func ollamaNotRunning() async {
let unreachableURL = URL(string: "http://127.0.0.1:59998")!
let running = await LocalProviderDiscovery.isOllamaRunning(
at: unreachableURL,
timeout: 1
)
#expect(!running)
}
@Test("isLlamaCppRunning returns false for unreachable endpoint")
func llamaCppNotRunning() async {
let unreachableURL = URL(string: "http://127.0.0.1:59997")!
let running = await LocalProviderDiscovery.isLlamaCppRunning(
at: unreachableURL,
timeout: 1
)
#expect(!running)
}
@Test("listOllamaModels returns empty for unreachable endpoint")
func ollamaModelsUnreachable() async {
let unreachableURL = URL(string: "http://127.0.0.1:59996")!
let models = await LocalProviderDiscovery.listOllamaModels(
at: unreachableURL,
timeout: 1
)
#expect(models.isEmpty)
}
@Test("listLlamaCppModels returns empty for unreachable endpoint")
func llamaCppModelsUnreachable() async {
let unreachableURL = URL(string: "http://127.0.0.1:59995")!
let models = await LocalProviderDiscovery.listLlamaCppModels(
at: unreachableURL,
timeout: 1
)
#expect(models.isEmpty)
}
@Test("discoverAll returns endpoints for both provider types")
func discoverAllReturnsAllProviders() async {
// Use very short timeout since we likely don't have servers running
let endpoints = await LocalProviderDiscovery.discoverAll(timeout: 0.5)
// Should return exactly 2 endpoints (one per well-known provider)
#expect(endpoints.count == 2)
let types = Set(endpoints.map(\.providerType))
#expect(types.contains(.ollama))
#expect(types.contains(.llamaCpp))
}
// MARK: - Default URLs
@Test("Default Ollama URL is correct")
func defaultOllamaURL() {
#expect(LocalProviderDiscovery.defaultOllamaURL.absoluteString == "http://localhost:11434")
}
@Test("Default llama.cpp URL is correct")
func defaultLlamaCppURL() {
#expect(LocalProviderDiscovery.defaultLlamaCppURL.absoluteString == "http://localhost:8080")
}
}

View File

@@ -0,0 +1,363 @@
// MultiTurnContextTests.swift
// SwiftDBAI Tests
//
// Tests verifying multi-turn conversation context follow-up queries
// correctly reference the prior query's table, columns, and results.
import AnyLanguageModel
import Foundation
import GRDB
import Testing
@testable import SwiftDBAI
@Suite("Multi-Turn Context Tests")
struct MultiTurnContextTests {
// MARK: - Test Database Setup
/// Creates an in-memory database with users (including age) and orders.
private func makeTestDatabase() throws -> DatabaseQueue {
let db = try DatabaseQueue(path: ":memory:")
try db.write { db in
try db.execute(sql: """
CREATE TABLE users (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
age INTEGER NOT NULL,
email TEXT NOT NULL,
city TEXT NOT NULL
)
""")
try db.execute(sql: """
INSERT INTO users (name, age, email, city) VALUES
('Alice', 25, 'alice@example.com', 'New York'),
('Bob', 35, 'bob@example.com', 'San Francisco'),
('Charlie', 42, 'charlie@example.com', 'New York'),
('Diana', 28, 'diana@example.com', 'Chicago'),
('Eve', 55, 'eve@example.com', 'San Francisco')
""")
try db.execute(sql: """
CREATE TABLE orders (
id INTEGER PRIMARY KEY,
user_id INTEGER NOT NULL,
amount REAL NOT NULL,
status TEXT NOT NULL,
created_at TEXT NOT NULL,
FOREIGN KEY (user_id) REFERENCES users(id)
)
""")
try db.execute(sql: """
INSERT INTO orders (user_id, amount, status, created_at) VALUES
(1, 99.99, 'completed', '2024-01-15'),
(1, 49.50, 'pending', '2024-02-20'),
(2, 150.00, 'completed', '2024-01-10'),
(3, 200.00, 'completed', '2024-03-01'),
(4, 75.00, 'cancelled', '2024-02-05')
""")
}
return db
}
// MARK: - Multi-Turn Context Tests
@Test("Follow-up 'filter those by age > 30' references prior 'show all users' context")
func followUpFilterReferencesUsersTable() async throws {
let db = try makeTestDatabase()
// Turn 1: "show all users" SELECT * FROM users (returns 5 rows, LLM summary needed)
// Turn 2: "filter those by age > 30" should reference users table from context
let mock = PromptCapturingMockModel(responses: [
"SELECT * FROM users",
"Here are all 5 users in the database.",
"SELECT * FROM users WHERE age > 30",
"Found 3 users over 30: Bob (35), Charlie (42), and Eve (55)."
])
let engine = ChatEngine(database: db, model: mock)
// First turn: show all users
let response1 = try await engine.send("show all users")
#expect(response1.sql == "SELECT * FROM users")
#expect(response1.queryResult?.rowCount == 5)
// Second turn: follow-up with implicit reference
let response2 = try await engine.send("filter those by age > 30")
#expect(response2.sql == "SELECT * FROM users WHERE age > 30")
#expect(response2.queryResult?.rowCount == 3)
// Verify the follow-up prompt includes conversation history
let prompts = mock.capturedPrompts
// Find the prompt for the second SQL generation (skip summary prompts)
let followUpSQLPrompt = prompts.first { prompt in
prompt.contains("filter those by age > 30") && prompt.contains("CONVERSATION HISTORY")
}
#expect(followUpSQLPrompt != nil, "Follow-up prompt should contain CONVERSATION HISTORY")
// The conversation history should include the prior query and its SQL
if let prompt = followUpSQLPrompt {
#expect(prompt.contains("show all users"), "History should contain prior user message")
#expect(prompt.contains("SELECT * FROM users"), "History should contain prior SQL")
#expect(prompt.contains("filter those by age > 30"), "Prompt should contain current question")
}
}
@Test("Follow-up correctly inherits table context across multiple turns")
func multipleFollowUpsInheritContext() async throws {
let db = try makeTestDatabase()
// 3-turn conversation narrowing down results
let mock = PromptCapturingMockModel(responses: [
"SELECT * FROM users",
"Here are all 5 users.",
"SELECT * FROM users WHERE city = 'New York'",
"Found 2 users in New York: Alice and Charlie.",
"SELECT * FROM users WHERE city = 'New York' AND age > 30",
"Charlie (42) is the only New York user over 30."
])
let engine = ChatEngine(database: db, model: mock)
// Turn 1
_ = try await engine.send("show all users")
// Turn 2 narrows by city
let response2 = try await engine.send("only those in New York")
#expect(response2.sql == "SELECT * FROM users WHERE city = 'New York'")
#expect(response2.queryResult?.rowCount == 2)
// Turn 3 further narrows by age
let response3 = try await engine.send("now filter by age over 30")
#expect(response3.sql == "SELECT * FROM users WHERE city = 'New York' AND age > 30")
#expect(response3.queryResult?.rowCount == 1)
// Verify third turn's prompt includes the full conversation history
let prompts = mock.capturedPrompts
let thirdTurnPrompt = prompts.last { prompt in
prompt.contains("now filter by age over 30") && prompt.contains("CONVERSATION HISTORY")
}
#expect(thirdTurnPrompt != nil)
if let prompt = thirdTurnPrompt {
// Should include both prior user messages
#expect(prompt.contains("show all users"))
#expect(prompt.contains("only those in New York"))
// Should include prior SQL
#expect(prompt.contains("SELECT * FROM users"))
#expect(prompt.contains("SELECT * FROM users WHERE city = 'New York'"))
}
}
@Test("Follow-up switching tables preserves cross-table context")
func followUpSwitchesTableWithContext() async throws {
let db = try makeTestDatabase()
// Turn 1: query users, Turn 2: ask about their orders
let mock = PromptCapturingMockModel(responses: [
"SELECT name, age FROM users WHERE age > 30",
"Found 3 users over 30.",
"SELECT o.id, u.name, o.amount, o.status FROM orders o JOIN users u ON o.user_id = u.id WHERE u.age > 30",
"Bob has a $150 completed order, Charlie has a $200 completed order."
])
let engine = ChatEngine(database: db, model: mock)
// Turn 1: users over 30
let response1 = try await engine.send("show users over 30")
#expect(response1.queryResult?.rowCount == 3)
// Turn 2: their orders references the previous result context
let response2 = try await engine.send("show their orders")
#expect(response2.sql?.contains("JOIN") == true)
// Verify the follow-up prompt contains the users context
let prompts = mock.capturedPrompts
let orderPrompt = prompts.first { prompt in
prompt.contains("show their orders") && prompt.contains("CONVERSATION HISTORY")
}
#expect(orderPrompt != nil)
if let prompt = orderPrompt {
#expect(prompt.contains("show users over 30"), "Should contain prior user message")
#expect(prompt.contains("age > 30"), "Should contain prior SQL context for table reference")
}
}
@Test("Conversation history includes SQL from prior turns for context")
func historyIncludesSQLFromPriorTurns() async throws {
let db = try makeTestDatabase()
// Both queries are aggregates no LLM summarization needed
let mock = PromptCapturingMockModel(responses: [
"SELECT COUNT(*) FROM users",
"SELECT COUNT(*) FROM users WHERE age > 30",
])
let engine = ChatEngine(database: db, model: mock)
// Turn 1
let r1 = try await engine.send("how many users are there?")
#expect(r1.sql == "SELECT COUNT(*) FROM users")
// Turn 2 references "those" implicitly
let r2 = try await engine.send("how many of those are over 30?")
#expect(r2.sql == "SELECT COUNT(*) FROM users WHERE age > 30")
// Verify engine history has all 4 messages (2 user + 2 assistant)
let messages = engine.messages
#expect(messages.count == 4)
#expect(messages[0].role == .user)
#expect(messages[0].content == "how many users are there?")
#expect(messages[1].role == .assistant)
#expect(messages[1].sql == "SELECT COUNT(*) FROM users")
#expect(messages[2].role == .user)
#expect(messages[2].content == "how many of those are over 30?")
#expect(messages[3].role == .assistant)
#expect(messages[3].sql == "SELECT COUNT(*) FROM users WHERE age > 30")
// The second prompt should reference the first query SQL
let prompts = mock.capturedPrompts
#expect(prompts.count >= 2)
let secondPrompt = prompts[1]
#expect(secondPrompt.contains("CONVERSATION HISTORY"))
#expect(secondPrompt.contains("SELECT COUNT(*) FROM users"))
#expect(secondPrompt.contains("how many users are there?"))
}
@Test("Follow-up after aggregate uses prior table context")
func followUpAfterAggregateUsesTableContext() async throws {
let db = try makeTestDatabase()
// Turn 1: aggregate (no LLM summary needed)
// Turn 2: follow-up referencing "those"
let mock = PromptCapturingMockModel(responses: [
"SELECT AVG(age) FROM users",
"SELECT name, age FROM users WHERE age > 35",
"Charlie (42) and Eve (55) are older than average."
])
let engine = ChatEngine(database: db, model: mock)
// Turn 1: average age aggregate, template summary
let r1 = try await engine.send("what is the average age of users?")
#expect(r1.sql == "SELECT AVG(age) FROM users")
// Turn 2: "who is above that?" needs the avg context
let r2 = try await engine.send("who is above average?")
#expect(r2.queryResult?.rowCount == 2)
// Verify context passed
let prompts = mock.capturedPrompts
let followUp = prompts.first { prompt in
prompt.contains("who is above average?") && prompt.contains("CONVERSATION HISTORY")
}
#expect(followUp != nil)
if let prompt = followUp {
#expect(prompt.contains("AVG(age)"), "Should include prior aggregate SQL for context")
#expect(prompt.contains("users"), "Should include table reference from prior turn")
}
}
@Test("Context window limits how much history is visible in follow-ups")
func contextWindowLimitsHistoryInFollowUps() async throws {
let db = try makeTestDatabase()
// 3 turns, but context window of 2 messages
let mock = PromptCapturingMockModel(responses: [
"SELECT COUNT(*) FROM users",
"SELECT COUNT(*) FROM orders",
"SELECT COUNT(*) FROM users WHERE age > 30",
])
let config = ChatEngineConfiguration(
queryTimeout: nil,
contextWindowSize: 2
)
let engine = ChatEngine(
database: db,
model: mock,
configuration: config
)
_ = try await engine.send("how many users?")
_ = try await engine.send("how many orders?")
_ = try await engine.send("how many users over 30?")
// The third prompt should only have the last 2 messages from turn 2
let prompts = mock.capturedPrompts
#expect(prompts.count >= 3)
let thirdPrompt = prompts[2]
#expect(thirdPrompt.contains("CONVERSATION HISTORY"))
// Turn 2 context should be present
#expect(thirdPrompt.contains("how many orders?"))
#expect(thirdPrompt.contains("SELECT COUNT(*) FROM orders"))
// Turn 1 context should be trimmed (window=2 means last 2 messages)
#expect(!thirdPrompt.contains("how many users?\n"), "First turn should be trimmed from context window")
}
@Test("clearHistory resets context so follow-ups have no prior history")
func clearHistoryResetsFollowUpContext() async throws {
let db = try makeTestDatabase()
let mock = PromptCapturingMockModel(responses: [
"SELECT * FROM users",
"Here are the 5 users.",
"SELECT COUNT(*) FROM users",
])
let engine = ChatEngine(database: db, model: mock)
// Turn 1
_ = try await engine.send("show all users")
#expect(engine.messages.count == 2)
// Clear history
engine.clearHistory()
#expect(engine.messages.isEmpty)
// Turn 2 after clear should NOT have conversation history
_ = try await engine.send("count all users")
let prompts = mock.capturedPrompts
let lastPrompt = prompts.last!
// After clearing, the prompt should NOT contain conversation history
#expect(!lastPrompt.contains("CONVERSATION HISTORY"),
"After clearHistory(), follow-up should not have prior context")
#expect(!lastPrompt.contains("show all users"),
"After clearHistory(), prior messages should be gone")
}
@Test("Multi-turn with result data in context enables informed follow-ups")
func resultDataInContextEnablesInformedFollowUps() async throws {
let db = try makeTestDatabase()
// Turn 1: list users multi-row result, LLM summarizes
// Turn 2: "sort those by age" references same table
let mock = PromptCapturingMockModel(responses: [
"SELECT name, age, city FROM users",
"Found 5 users: Alice (25, NY), Bob (35, SF), Charlie (42, NY), Diana (28, Chicago), Eve (55, SF).",
"SELECT name, age, city FROM users ORDER BY age DESC",
"Users sorted by age: Eve (55), Charlie (42), Bob (35), Diana (28), Alice (25)."
])
let engine = ChatEngine(database: db, model: mock)
let r1 = try await engine.send("list all users with their age and city")
#expect(r1.queryResult?.rowCount == 5)
#expect(r1.queryResult?.columns.contains("age") == true)
#expect(r1.queryResult?.columns.contains("city") == true)
let r2 = try await engine.send("sort those by age descending")
#expect(r2.sql == "SELECT name, age, city FROM users ORDER BY age DESC")
// Verify the assistant message in history includes the SQL
let messages = engine.messages
#expect(messages.count == 4)
// First assistant message should have the SQL recorded
#expect(messages[1].sql == "SELECT name, age, city FROM users")
// Second assistant should have the sorted SQL
#expect(messages[3].sql == "SELECT name, age, city FROM users ORDER BY age DESC")
}
}

View File

@@ -0,0 +1,508 @@
// OnDeviceProviderConfigurationTests.swift
// SwiftDBAI Tests
//
// Tests for on-device provider configurations (CoreML, MLX) including
// configuration validation, inference pipeline setup, and system readiness.
import AnyLanguageModel
import Foundation
@testable import SwiftDBAI
import Testing
@Suite("OnDeviceProviderConfiguration")
struct OnDeviceProviderConfigurationTests {
// MARK: - OnDeviceProviderType
@Test("OnDeviceProviderType has CoreML and MLX cases")
func providerTypeCases() {
let cases = OnDeviceProviderType.allCases
#expect(cases.count == 2)
#expect(cases.contains(.coreML))
#expect(cases.contains(.mlx))
}
@Test("OnDeviceProviderType raw values are descriptive")
func providerTypeRawValues() {
#expect(OnDeviceProviderType.coreML.rawValue == "coreML")
#expect(OnDeviceProviderType.mlx.rawValue == "mlx")
}
// MARK: - CoreML Configuration
@Test("CoreML configuration stores all properties")
func coreMLBasicConfiguration() {
let url = URL(fileURLWithPath: "/tmp/TestModel.mlmodelc")
let config = CoreMLProviderConfiguration(
modelURL: url,
computeUnits: .cpuAndGPU,
maxResponseTokens: 1024,
useSampling: true,
temperature: 0.3
)
#expect(config.modelURL == url)
#expect(config.computeUnits == .cpuAndGPU)
#expect(config.maxResponseTokens == 1024)
#expect(config.useSampling == true)
#expect(config.temperature == 0.3)
}
@Test("CoreML configuration uses sensible defaults")
func coreMLDefaultConfiguration() {
let url = URL(fileURLWithPath: "/tmp/TestModel.mlmodelc")
let config = CoreMLProviderConfiguration(modelURL: url)
#expect(config.computeUnits == .all)
#expect(config.maxResponseTokens == 2048)
#expect(config.useSampling == false)
#expect(config.temperature == 0.1)
}
@Test("CoreML validation fails for non-mlmodelc extension")
func coreMLValidateWrongExtension() {
let url = URL(fileURLWithPath: "/tmp/TestModel.onnx")
let config = CoreMLProviderConfiguration(modelURL: url)
#expect(throws: OnDeviceProviderError.self) {
try config.validate()
}
}
@Test("CoreML validation fails for missing model file")
func coreMLValidateMissingFile() {
let url = URL(fileURLWithPath: "/nonexistent/path/Model.mlmodelc")
let config = CoreMLProviderConfiguration(modelURL: url)
#expect(throws: OnDeviceProviderError.self) {
try config.validate()
}
}
@Test("CoreML configuration is Equatable")
func coreMLEquatable() {
let url = URL(fileURLWithPath: "/tmp/TestModel.mlmodelc")
let a = CoreMLProviderConfiguration(modelURL: url, computeUnits: .all)
let b = CoreMLProviderConfiguration(modelURL: url, computeUnits: .all)
let c = CoreMLProviderConfiguration(modelURL: url, computeUnits: .cpuOnly)
#expect(a == b)
#expect(a != c)
}
// MARK: - ComputeUnitPreference
@Test("ComputeUnitPreference has all expected cases")
func computeUnitCases() {
let cases = ComputeUnitPreference.allCases
#expect(cases.count == 4)
#expect(cases.contains(.all))
#expect(cases.contains(.cpuOnly))
#expect(cases.contains(.cpuAndGPU))
#expect(cases.contains(.cpuAndNeuralEngine))
}
// MARK: - MLX Configuration
@Test("MLX configuration stores all properties")
func mlxBasicConfiguration() {
let dir = URL(fileURLWithPath: "/tmp/models/my-model")
let config = MLXProviderConfiguration(
modelId: "mlx-community/Test-Model-4bit",
localDirectory: dir,
gpuMemory: .minimal,
maxResponseTokens: 512,
temperature: 0.2,
topP: 0.9,
repetitionPenalty: 1.2
)
#expect(config.modelId == "mlx-community/Test-Model-4bit")
#expect(config.localDirectory == dir)
#expect(config.gpuMemory == .minimal)
#expect(config.maxResponseTokens == 512)
#expect(config.temperature == 0.2)
#expect(config.topP == 0.9)
#expect(config.repetitionPenalty == 1.2)
}
@Test("MLX configuration uses sensible defaults")
func mlxDefaultConfiguration() {
let config = MLXProviderConfiguration(modelId: "test-model")
#expect(config.localDirectory == nil)
#expect(config.gpuMemory == .automatic)
#expect(config.maxResponseTokens == 2048)
#expect(config.temperature == 0.1)
#expect(config.topP == 0.95)
#expect(config.repetitionPenalty == 1.1)
}
@Test("MLX validation fails for empty model ID")
func mlxValidateEmptyModelId() {
let config = MLXProviderConfiguration(modelId: "")
#expect(throws: OnDeviceProviderError.self) {
try config.validate()
}
}
@Test("MLX validation fails for nonexistent local directory")
func mlxValidateMissingDirectory() {
let config = MLXProviderConfiguration(
modelId: "test-model",
localDirectory: URL(fileURLWithPath: "/nonexistent/directory")
)
#expect(throws: OnDeviceProviderError.self) {
try config.validate()
}
}
@Test("MLX validation fails for negative temperature")
func mlxValidateNegativeTemperature() {
let config = MLXProviderConfiguration(
modelId: "test-model",
temperature: -0.5
)
#expect(throws: OnDeviceProviderError.self) {
try config.validate()
}
}
@Test("MLX validation fails for topP out of range")
func mlxValidateInvalidTopP() {
let configZero = MLXProviderConfiguration(
modelId: "test-model",
topP: 0.0
)
#expect(throws: OnDeviceProviderError.self) {
try configZero.validate()
}
let configOver = MLXProviderConfiguration(
modelId: "test-model",
topP: 1.5
)
#expect(throws: OnDeviceProviderError.self) {
try configOver.validate()
}
}
@Test("MLX validation fails for zero repetition penalty")
func mlxValidateInvalidRepetitionPenalty() {
let config = MLXProviderConfiguration(
modelId: "test-model",
repetitionPenalty: 0.0
)
#expect(throws: OnDeviceProviderError.self) {
try config.validate()
}
}
@Test("MLX validation succeeds for valid configuration")
func mlxValidateSuccess() throws {
let config = MLXProviderConfiguration(modelId: "test-model")
// Should not throw (no local directory set, model ID is non-empty)
try config.validate()
}
@Test("MLX configuration is Equatable")
func mlxEquatable() {
let a = MLXProviderConfiguration(modelId: "model-a")
let b = MLXProviderConfiguration(modelId: "model-a")
let c = MLXProviderConfiguration(modelId: "model-b")
#expect(a == b)
#expect(a != c)
}
// MARK: - Well-Known MLX Models
@Test("Llama 3.2 3B preset has correct model ID")
func llama3_2_3BPreset() {
let config = MLXProviderConfiguration.llama3_2_3B()
#expect(config.modelId == "mlx-community/Llama-3.2-3B-Instruct-4bit")
#expect(config.temperature == 0.1)
#expect(config.maxResponseTokens == 2048)
}
@Test("Qwen 2.5 Coder 3B preset has correct model ID")
func qwen2_5_coder3BPreset() {
let config = MLXProviderConfiguration.qwen2_5_coder_3B()
#expect(config.modelId == "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit")
#expect(config.temperature == 0.05)
}
@Test("Phi 3.5 Mini preset has correct model ID")
func phi3_5_miniPreset() {
let config = MLXProviderConfiguration.phi3_5_mini()
#expect(config.modelId == "mlx-community/Phi-3.5-mini-instruct-4bit")
#expect(config.temperature == 0.1)
}
@Test("Well-known models accept custom GPU memory config")
func wellKnownModelsCustomGPU() {
let config = MLXProviderConfiguration.llama3_2_3B(
gpuMemory: .minimal
)
#expect(config.gpuMemory == .minimal)
}
// MARK: - GPU Memory Configuration
@Test("Automatic GPU memory config scales with RAM")
func automaticGPUMemory() {
let config = MLXGPUMemoryConfig.automatic
#expect(config.activeCacheLimit > 0)
#expect(config.idleCacheLimit == 50_000_000)
#expect(config.clearCacheOnEviction == true)
}
@Test("Minimal GPU memory config is conservative")
func minimalGPUMemory() {
let config = MLXGPUMemoryConfig.minimal
#expect(config.activeCacheLimit == 64_000_000)
#expect(config.idleCacheLimit == 16_000_000)
#expect(config.clearCacheOnEviction == true)
}
@Test("Unconstrained GPU memory config uses max values")
func unconstrainedGPUMemory() {
let config = MLXGPUMemoryConfig.unconstrained
#expect(config.activeCacheLimit == Int.max)
#expect(config.idleCacheLimit == Int.max)
#expect(config.clearCacheOnEviction == false)
}
@Test("GPU memory config is Equatable")
func gpuMemoryEquatable() {
#expect(MLXGPUMemoryConfig.minimal == MLXGPUMemoryConfig.minimal)
#expect(MLXGPUMemoryConfig.minimal != MLXGPUMemoryConfig.unconstrained)
}
// MARK: - On-Device Provider Errors
@Test("OnDeviceProviderError has descriptive messages")
func errorDescriptions() {
let errors: [OnDeviceProviderError] = [
.modelNotFound(URL(fileURLWithPath: "/tmp/model")),
.invalidModelFormat(expected: ".mlmodelc", actual: ".onnx"),
.emptyModelId,
.invalidParameter(name: "temperature", value: "-1", reason: "Must be non-negative"),
.providerUnavailable(.mlx, reason: "MLX build flag not enabled"),
.modelLoadFailed(reason: "Out of memory"),
.inferenceFailed(reason: "Token limit exceeded"),
]
for error in errors {
#expect(error.errorDescription != nil)
#expect(!error.errorDescription!.isEmpty)
}
}
@Test("OnDeviceProviderError is Equatable")
func errorEquatable() {
let a = OnDeviceProviderError.emptyModelId
let b = OnDeviceProviderError.emptyModelId
let c = OnDeviceProviderError.modelLoadFailed(reason: "test")
#expect(a == b)
#expect(a != c)
}
// MARK: - Inference Pipeline
@Test("MLX inference pipeline initializes with correct type")
func mlxPipelineInit() {
let config = MLXProviderConfiguration.llama3_2_3B()
let pipeline = OnDeviceInferencePipeline(mlxConfiguration: config)
#expect(pipeline.providerType == .mlx)
#expect(pipeline.mlxConfiguration != nil)
#expect(pipeline.coreMLConfiguration == nil)
#expect(pipeline.status == .notLoaded)
}
@Test("CoreML inference pipeline initializes with correct type")
func coreMLPipelineInit() {
let url = URL(fileURLWithPath: "/tmp/TestModel.mlmodelc")
let config = CoreMLProviderConfiguration(modelURL: url)
let pipeline = OnDeviceInferencePipeline(coreMLConfiguration: config)
#expect(pipeline.providerType == .coreML)
#expect(pipeline.coreMLConfiguration != nil)
#expect(pipeline.mlxConfiguration == nil)
#expect(pipeline.status == .notLoaded)
}
@Test("Pipeline validates MLX configuration")
func pipelineValidatesMLX() throws {
let validConfig = MLXProviderConfiguration(modelId: "test-model")
let pipeline = OnDeviceInferencePipeline(mlxConfiguration: validConfig)
try pipeline.validateConfiguration()
let invalidConfig = MLXProviderConfiguration(modelId: "")
let invalidPipeline = OnDeviceInferencePipeline(mlxConfiguration: invalidConfig)
#expect(throws: OnDeviceProviderError.self) {
try invalidPipeline.validateConfiguration()
}
}
@Test("Pipeline validates CoreML configuration")
func pipelineValidatesCoreML() {
let url = URL(fileURLWithPath: "/tmp/TestModel.onnx")
let config = CoreMLProviderConfiguration(modelURL: url)
let pipeline = OnDeviceInferencePipeline(coreMLConfiguration: config)
#expect(throws: OnDeviceProviderError.self) {
try pipeline.validateConfiguration()
}
}
@Test("Pipeline provides SQL generation hints for MLX")
func mlxSQLHints() {
let config = MLXProviderConfiguration(
modelId: "test-model",
maxResponseTokens: 512,
temperature: 0.2
)
let pipeline = OnDeviceInferencePipeline(mlxConfiguration: config)
let hints = pipeline.recommendedSQLGenerationHints
#expect(hints.maxTokens == 512)
#expect(hints.temperature == 0.2)
#expect(hints.useSampling == true)
#expect(hints.systemPromptSuffix.contains("MLX"))
}
@Test("Pipeline provides SQL generation hints for CoreML")
func coreMLSQLHints() {
let url = URL(fileURLWithPath: "/tmp/TestModel.mlmodelc")
let config = CoreMLProviderConfiguration(
modelURL: url,
maxResponseTokens: 1024,
useSampling: false,
temperature: 0.05
)
let pipeline = OnDeviceInferencePipeline(coreMLConfiguration: config)
let hints = pipeline.recommendedSQLGenerationHints
#expect(hints.maxTokens == 1024)
#expect(hints.temperature == 0.05)
#expect(hints.useSampling == false)
#expect(hints.systemPromptSuffix.contains("SQL"))
}
// MARK: - System Readiness
@Test("System capability check returns valid data")
func systemCapability() {
let capability = OnDeviceModelReadiness.checkSystemCapability()
#expect(capability.totalRAM > 0)
// On any modern test machine, we should have at least some RAM
#expect(capability.totalRAM > 1024 * 1024 * 1024) // > 1GB
// On Apple silicon Macs, this should be true
#if arch(arm64)
#expect(capability.hasNeuralEngine == true)
#endif
}
@Test("Suggested MLX model returns a valid configuration")
func suggestedMLXModel() {
let config = OnDeviceModelReadiness.suggestedMLXModel()
#expect(!config.modelId.isEmpty)
#expect(config.temperature >= 0)
#expect(config.maxResponseTokens > 0)
}
@Test("Recommended model size enum has correct raw values")
func recommendedModelSizeRawValues() {
#expect(OnDeviceModelReadiness.RecommendedModelSize.small.rawValue == "small")
#expect(OnDeviceModelReadiness.RecommendedModelSize.medium.rawValue == "medium")
#expect(OnDeviceModelReadiness.RecommendedModelSize.large.rawValue == "large")
}
// MARK: - ProviderConfiguration Integration
@Test("onDeviceMLX creates a ProviderConfiguration")
func onDeviceMLXProviderConfig() {
let mlxConfig = MLXProviderConfiguration.llama3_2_3B()
let providerConfig = ProviderConfiguration.onDeviceMLX(mlxConfig)
#expect(providerConfig.model == mlxConfig.modelId)
#expect(!providerConfig.hasValidAPIKey) // No API key needed for on-device
}
@Test("onDeviceCoreML creates a ProviderConfiguration")
func onDeviceCoreMLProviderConfig() {
let url = URL(fileURLWithPath: "/tmp/SQLModel.mlmodelc")
let coreMLConfig = CoreMLProviderConfiguration(modelURL: url)
let providerConfig = ProviderConfiguration.onDeviceCoreML(coreMLConfig)
#expect(providerConfig.model == "SQLModel.mlmodelc")
#expect(!providerConfig.hasValidAPIKey)
}
// MARK: - Pipeline Status
@Test("Pipeline status transitions")
func pipelineStatusTransitions() {
let config = MLXProviderConfiguration(modelId: "test-model")
let pipeline = OnDeviceInferencePipeline(mlxConfiguration: config)
#expect(pipeline.status == .notLoaded)
pipeline.setStatus(.loading)
#expect(pipeline.status == .loading)
pipeline.setStatus(.ready)
#expect(pipeline.status == .ready)
pipeline.setStatus(.failed("Out of memory"))
#expect(pipeline.status == .failed("Out of memory"))
}
@Test("Pipeline Status is Equatable")
func pipelineStatusEquatable() {
#expect(OnDeviceInferencePipeline.Status.notLoaded == .notLoaded)
#expect(OnDeviceInferencePipeline.Status.loading == .loading)
#expect(OnDeviceInferencePipeline.Status.ready == .ready)
#expect(OnDeviceInferencePipeline.Status.failed("a") == .failed("a"))
#expect(OnDeviceInferencePipeline.Status.failed("a") != .failed("b"))
#expect(OnDeviceInferencePipeline.Status.notLoaded != .ready)
}
// MARK: - SQL Generation Hints
@Test("SQL generation hints are Equatable")
func sqlHintsEquatable() {
let a = OnDeviceSQLGenerationHints(
maxTokens: 512,
temperature: 0.1,
systemPromptSuffix: "test",
useSampling: true
)
let b = OnDeviceSQLGenerationHints(
maxTokens: 512,
temperature: 0.1,
systemPromptSuffix: "test",
useSampling: true
)
let c = OnDeviceSQLGenerationHints(
maxTokens: 1024,
temperature: 0.1,
systemPromptSuffix: "test",
useSampling: true
)
#expect(a == b)
#expect(a != c)
}
}

View File

@@ -0,0 +1,247 @@
// PresentationTests.swift
// SwiftDBAITests
//
// Tests for presentation modalities: DataChatSheet, DataChatViewController,
// and view modifier helpers.
import SwiftUI
import Testing
import ViewInspector
import GRDB
@testable import SwiftDBAI
// MARK: - Helpers
private func makeSampleDatabase() throws -> DatabaseQueue {
let db = try DatabaseQueue()
try db.write { db in
try db.execute(sql: """
CREATE TABLE items (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL
);
INSERT INTO items (name) VALUES ('Alpha');
""")
}
return db
}
// MARK: - DataChatSheet Tests
@Suite("DataChatSheet Tests")
struct DataChatSheetTests {
@Test("DataChatSheet renders NavigationStack with title")
@MainActor
func sheetRendersNavigationStackWithTitle() throws {
let db = try makeSampleDatabase()
let sheet = DataChatSheet(
database: db,
model: MockLanguageModel(),
title: "Test Chat"
)
let view = try sheet.inspect()
// NavigationStack should be the root
let navStack = try view.navigationStack()
#expect(navStack != nil)
}
@Test("DataChatSheet has Done button")
@MainActor
func sheetHasDoneButton() throws {
let db = try makeSampleDatabase()
let sheet = DataChatSheet(
database: db,
model: MockLanguageModel()
)
let view = try sheet.inspect()
// Find the Done button in the toolbar
let button = try view.find(button: "Done")
#expect(button != nil)
}
@Test("DataChatSheet renders DataChatView inside")
@MainActor
func sheetContainsDataChatView() throws {
let db = try makeSampleDatabase()
let sheet = DataChatSheet(
database: db,
model: MockLanguageModel()
)
let view = try sheet.inspect()
// DataChatView should be present within the NavigationStack
let dataChatView = try view.find(DataChatView.self)
#expect(dataChatView != nil)
}
@Test("DataChatSheet path-based init works")
@MainActor
func sheetPathInit() throws {
let tempDir = FileManager.default.temporaryDirectory
let dbPath = tempDir.appendingPathComponent("sheet_test_\(UUID().uuidString).sqlite").path
let db = try DatabaseQueue(path: dbPath)
try db.write { db in
try db.execute(sql: "CREATE TABLE t (id INTEGER PRIMARY KEY)")
}
let sheet = DataChatSheet(
databasePath: dbPath,
model: MockLanguageModel(),
title: "Path Chat"
)
let view = try sheet.inspect()
let navStack = try view.navigationStack()
#expect(navStack != nil)
try? FileManager.default.removeItem(atPath: dbPath)
}
@Test("DataChatSheet uses custom title")
@MainActor
func sheetCustomTitle() throws {
let db = try makeSampleDatabase()
let sheet = DataChatSheet(
database: db,
model: MockLanguageModel(),
title: "My Custom Title"
)
// Verify the title property is set correctly
#expect(sheet.title == "My Custom Title")
}
@Test("DataChatSheet defaults to AI Chat title")
@MainActor
func sheetDefaultTitle() throws {
let db = try makeSampleDatabase()
let sheet = DataChatSheet(
database: db,
model: MockLanguageModel()
)
#expect(sheet.title == "AI Chat")
}
@Test("DataChatSheet defaults to read-only allowlist")
@MainActor
func sheetDefaultAllowlist() throws {
let db = try makeSampleDatabase()
let sheet = DataChatSheet(
database: db,
model: MockLanguageModel()
)
#expect(sheet.allowlist == .readOnly)
}
}
// MARK: - DataChatViewController Tests
#if canImport(UIKit) && !os(watchOS)
@Suite("DataChatViewController Tests")
struct DataChatViewControllerTests {
@Test("DataChatViewController can be instantiated with database path")
@MainActor
func viewControllerPathInit() throws {
let tempDir = FileManager.default.temporaryDirectory
let dbPath = tempDir.appendingPathComponent("vc_test_\(UUID().uuidString).sqlite").path
let db = try DatabaseQueue(path: dbPath)
try db.write { db in
try db.execute(sql: "CREATE TABLE t (id INTEGER PRIMARY KEY)")
}
let vc = DataChatViewController(
databasePath: dbPath,
model: MockLanguageModel()
)
#expect(vc.modalPresentationStyle == .formSheet)
try? FileManager.default.removeItem(atPath: dbPath)
}
@Test("DataChatViewController can be instantiated with database connection")
@MainActor
func viewControllerDatabaseInit() throws {
let db = try makeSampleDatabase()
let vc = DataChatViewController(
database: db,
model: MockLanguageModel(),
title: "VC Chat"
)
#expect(vc.modalPresentationStyle == .formSheet)
}
}
#endif
// MARK: - View Modifier Tests
@Suite("DataChatSheet Modifier Tests")
struct DataChatSheetModifierTests {
@Test("dataChatSheet modifier creates sheet correctly")
@MainActor
func sheetModifierCreatesSheet() throws {
let db = try makeSampleDatabase()
struct TestHost: View {
@State var showChat = false
let db: DatabaseQueue
var body: some View {
Text("Hello")
.dataChatSheet(
isPresented: $showChat,
database: db,
model: MockLanguageModel(),
title: "Modifier Chat"
)
}
}
let host = TestHost(db: db)
// Verify it compiles and can be inspected
let view = try host.inspect()
let text = try view.find(text: "Hello")
#expect(text != nil)
}
@Test("dataChatSheet path modifier creates sheet correctly")
@MainActor
func sheetPathModifierCreatesSheet() throws {
let tempDir = FileManager.default.temporaryDirectory
let dbPath = tempDir.appendingPathComponent("mod_test_\(UUID().uuidString).sqlite").path
let db = try DatabaseQueue(path: dbPath)
try db.write { db in
try db.execute(sql: "CREATE TABLE t (id INTEGER PRIMARY KEY)")
}
struct TestHost: View {
@State var showChat = false
let dbPath: String
var body: some View {
Text("World")
.dataChatSheet(
isPresented: $showChat,
databasePath: dbPath,
model: MockLanguageModel()
)
}
}
let host = TestHost(dbPath: dbPath)
let view = try host.inspect()
let text = try view.find(text: "World")
#expect(text != nil)
try? FileManager.default.removeItem(atPath: dbPath)
}
}

View File

@@ -0,0 +1,254 @@
// PromptBuilderTests.swift
// SwiftDBAI
import Testing
@testable import SwiftDBAI
@Suite("PromptBuilder")
struct PromptBuilderTests {
// MARK: - Helpers
/// Creates a sample schema for testing.
private func makeSampleSchema() -> DatabaseSchema {
let usersTable = TableSchema(
name: "users",
columns: [
ColumnSchema(cid: 0, name: "id", type: "INTEGER", isNotNull: true, defaultValue: nil, isPrimaryKey: true),
ColumnSchema(cid: 1, name: "name", type: "TEXT", isNotNull: true, defaultValue: nil, isPrimaryKey: false),
ColumnSchema(cid: 2, name: "email", type: "TEXT", isNotNull: false, defaultValue: nil, isPrimaryKey: false),
ColumnSchema(cid: 3, name: "created_at", type: "TEXT", isNotNull: false, defaultValue: "CURRENT_TIMESTAMP", isPrimaryKey: false),
],
primaryKey: ["id"],
foreignKeys: [],
indexes: [
IndexSchema(name: "idx_users_email", isUnique: true, columns: ["email"])
]
)
let ordersTable = TableSchema(
name: "orders",
columns: [
ColumnSchema(cid: 0, name: "id", type: "INTEGER", isNotNull: true, defaultValue: nil, isPrimaryKey: true),
ColumnSchema(cid: 1, name: "user_id", type: "INTEGER", isNotNull: true, defaultValue: nil, isPrimaryKey: false),
ColumnSchema(cid: 2, name: "total", type: "REAL", isNotNull: true, defaultValue: nil, isPrimaryKey: false),
ColumnSchema(cid: 3, name: "status", type: "TEXT", isNotNull: true, defaultValue: "'pending'", isPrimaryKey: false),
],
primaryKey: ["id"],
foreignKeys: [
ForeignKeySchema(fromColumn: "user_id", toTable: "users", toColumn: "id", onUpdate: "NO ACTION", onDelete: "CASCADE")
],
indexes: []
)
return DatabaseSchema(
tables: ["users": usersTable, "orders": ordersTable],
tableNames: ["users", "orders"]
)
}
private func makeEmptySchema() -> DatabaseSchema {
DatabaseSchema(tables: [:], tableNames: [])
}
// MARK: - System Instructions Tests
@Test("System instructions contain role section")
func systemInstructionsContainRole() {
let builder = PromptBuilder(schema: makeSampleSchema())
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("ROLE"))
#expect(instructions.contains("SQL assistant"))
#expect(instructions.contains("SQLite database"))
}
@Test("System instructions contain schema")
func systemInstructionsContainSchema() {
let builder = PromptBuilder(schema: makeSampleSchema())
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("DATABASE SCHEMA"))
#expect(instructions.contains("TABLE users"))
#expect(instructions.contains("TABLE orders"))
#expect(instructions.contains("name TEXT"))
#expect(instructions.contains("email TEXT"))
}
@Test("System instructions contain foreign keys from schema")
func systemInstructionsContainForeignKeys() {
let builder = PromptBuilder(schema: makeSampleSchema())
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("FOREIGN KEY"))
#expect(instructions.contains("REFERENCES users(id)"))
}
@Test("System instructions contain SQL generation rules")
func systemInstructionsContainRules() {
let builder = PromptBuilder(schema: makeSampleSchema())
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("SQL GENERATION RULES"))
#expect(instructions.contains("Use ONLY the tables and columns"))
#expect(instructions.contains("Never generate DDL"))
}
@Test("System instructions contain output format section")
func systemInstructionsContainOutputFormat() {
let builder = PromptBuilder(schema: makeSampleSchema())
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("OUTPUT FORMAT"))
}
@Test("Default allowlist is read-only")
func defaultAllowlistIsReadOnly() {
let builder = PromptBuilder(schema: makeSampleSchema())
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("ONLY generate SELECT queries"))
#expect(instructions.contains("No data modifications"))
}
@Test("Standard allowlist shows correct operations")
func standardAllowlistInstructions() {
let builder = PromptBuilder(schema: makeSampleSchema(), allowlist: .standard)
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("INSERT"))
#expect(instructions.contains("SELECT"))
#expect(instructions.contains("UPDATE"))
}
@Test("Unrestricted allowlist warns about DELETE")
func unrestrictedAllowlistWarnsAboutDelete() {
let builder = PromptBuilder(schema: makeSampleSchema(), allowlist: .unrestricted)
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("DELETE"))
#expect(instructions.contains("destructive"))
#expect(instructions.contains("confirmation"))
}
@Test("Additional context is appended")
func additionalContextAppended() {
let builder = PromptBuilder(
schema: makeSampleSchema(),
additionalContext: "All dates are stored in ISO 8601 format."
)
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("ADDITIONAL CONTEXT"))
#expect(instructions.contains("ISO 8601"))
}
@Test("No additional context section when nil")
func noAdditionalContextWhenNil() {
let builder = PromptBuilder(schema: makeSampleSchema())
let instructions = builder.buildSystemInstructions()
#expect(!instructions.contains("ADDITIONAL CONTEXT"))
}
@Test("No additional context section when empty string")
func noAdditionalContextWhenEmpty() {
let builder = PromptBuilder(schema: makeSampleSchema(), additionalContext: "")
let instructions = builder.buildSystemInstructions()
#expect(!instructions.contains("ADDITIONAL CONTEXT"))
}
@Test("Empty schema produces valid instructions")
func emptySchemaProducesValidInstructions() {
let builder = PromptBuilder(schema: makeEmptySchema())
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("ROLE"))
#expect(instructions.contains("SQL GENERATION RULES"))
// Schema section should still be present, just empty
#expect(instructions.contains("DATABASE SCHEMA"))
}
// MARK: - User Prompt Tests
@Test("User prompt passes through question directly")
func userPromptPassesThrough() {
let builder = PromptBuilder(schema: makeSampleSchema())
let prompt = builder.buildUserPrompt("How many users signed up this week?")
#expect(prompt == "How many users signed up this week?")
}
// MARK: - Follow-up Prompt Tests
@Test("Follow-up prompt includes previous context")
func followUpPromptIncludesPreviousContext() {
let builder = PromptBuilder(schema: makeSampleSchema())
let prompt = builder.buildFollowUpPrompt(
"Now sort them by name",
previousSQL: "SELECT * FROM users WHERE created_at > date('now', '-7 days')",
previousResultSummary: "Found 42 users who signed up this week"
)
#expect(prompt.contains("Previous query:"))
#expect(prompt.contains("SELECT * FROM users"))
#expect(prompt.contains("Previous result:"))
#expect(prompt.contains("42 users"))
#expect(prompt.contains("Follow-up question:"))
#expect(prompt.contains("sort them by name"))
}
// MARK: - Schema Description Quality
@Test("Schema includes column types and constraints")
func schemaIncludesColumnDetails() {
let builder = PromptBuilder(schema: makeSampleSchema())
let instructions = builder.buildSystemInstructions()
// Should include type info
#expect(instructions.contains("INTEGER"))
#expect(instructions.contains("TEXT"))
#expect(instructions.contains("REAL"))
// Should include constraints
#expect(instructions.contains("NOT NULL"))
#expect(instructions.contains("PRIMARY KEY"))
}
@Test("Schema includes index information")
func schemaIncludesIndexes() {
let builder = PromptBuilder(schema: makeSampleSchema())
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("INDEX"))
#expect(instructions.contains("idx_users_email"))
}
// MARK: - Sendable Conformance
@Test("PromptBuilder is Sendable")
func promptBuilderIsSendable() async {
let builder = PromptBuilder(schema: makeSampleSchema())
// Verify it can be sent across concurrency boundaries
let instructions = await Task.detached {
builder.buildSystemInstructions()
}.value
#expect(instructions.contains("ROLE"))
}
// MARK: - Custom Allowlist
@Test("Custom allowlist with select and delete only")
func customAllowlist() {
let allowlist = OperationAllowlist([.select, .delete])
let builder = PromptBuilder(schema: makeSampleSchema(), allowlist: allowlist)
let instructions = builder.buildSystemInstructions()
#expect(instructions.contains("DELETE"))
#expect(instructions.contains("SELECT"))
#expect(instructions.contains("destructive"))
}
}

View File

@@ -0,0 +1,325 @@
// ProviderConfigurationTests.swift
// SwiftDBAI Tests
//
// Tests for ProviderConfiguration verifying all cloud provider configurations
// produce valid LanguageModel instances with correct settings.
import AnyLanguageModel
import Foundation
@testable import SwiftDBAI
import Testing
@Suite("ProviderConfiguration")
struct ProviderConfigurationTests {
// MARK: - OpenAI Configuration
@Test("OpenAI configuration stores provider and model")
func openAIBasicConfiguration() {
let config = ProviderConfiguration.openAI(
apiKey: "sk-test-key-123",
model: "gpt-4o"
)
#expect(config.provider == .openAI)
#expect(config.model == "gpt-4o")
#expect(config.apiKey == "sk-test-key-123")
#expect(config.hasValidAPIKey)
}
@Test("OpenAI configuration produces a valid LanguageModel")
func openAIMakeModel() {
let config = ProviderConfiguration.openAI(
apiKey: "sk-test-key",
model: "gpt-4o-mini"
)
let model = config.makeModel()
#expect(model is OpenAILanguageModel)
}
@Test("OpenAI with custom base URL for compatible services")
func openAICustomBaseURL() {
let customURL = URL(string: "https://my-proxy.example.com/v1/")!
let config = ProviderConfiguration.openAI(
apiKey: "sk-proxy-key",
model: "gpt-4o",
baseURL: customURL
)
#expect(config.baseURL == customURL)
let model = config.makeModel()
#expect(model is OpenAILanguageModel)
}
@Test("OpenAI with Responses API variant")
func openAIResponsesVariant() {
let config = ProviderConfiguration.openAI(
apiKey: "sk-test",
model: "gpt-4o",
variant: .responses
)
#expect(config.openAIVariant == .responses)
let model = config.makeModel()
#expect(model is OpenAILanguageModel)
}
@Test("OpenAI with dynamic key provider captures key by reference")
func openAIDynamicKeyProvider() {
nonisolated(unsafe) var currentKey = "sk-initial"
let config = ProviderConfiguration.openAI(
apiKeyProvider: { currentKey },
model: "gpt-4o"
)
#expect(config.apiKey == "sk-initial")
currentKey = "sk-rotated"
#expect(config.apiKey == "sk-rotated")
}
// MARK: - Anthropic Configuration
@Test("Anthropic configuration stores provider and model")
func anthropicBasicConfiguration() {
let config = ProviderConfiguration.anthropic(
apiKey: "sk-ant-test-key",
model: "claude-sonnet-4-20250514"
)
#expect(config.provider == .anthropic)
#expect(config.model == "claude-sonnet-4-20250514")
#expect(config.apiKey == "sk-ant-test-key")
#expect(config.hasValidAPIKey)
}
@Test("Anthropic configuration produces a valid LanguageModel")
func anthropicMakeModel() {
let config = ProviderConfiguration.anthropic(
apiKey: "sk-ant-test",
model: "claude-sonnet-4-20250514"
)
let model = config.makeModel()
#expect(model is AnthropicLanguageModel)
}
@Test("Anthropic with API version and betas")
func anthropicWithVersionAndBetas() {
let config = ProviderConfiguration.anthropic(
apiKey: "sk-ant-test",
model: "claude-sonnet-4-20250514",
apiVersion: "2024-01-01",
betas: ["computer-use"]
)
#expect(config.apiVersion == "2024-01-01")
#expect(config.betas == ["computer-use"])
let model = config.makeModel()
#expect(model is AnthropicLanguageModel)
}
@Test("Anthropic with dynamic key provider captures key by reference")
func anthropicDynamicKeyProvider() {
nonisolated(unsafe) var currentKey = "sk-ant-initial"
let config = ProviderConfiguration.anthropic(
apiKeyProvider: { currentKey },
model: "claude-sonnet-4-20250514"
)
#expect(config.apiKey == "sk-ant-initial")
currentKey = "sk-ant-rotated"
#expect(config.apiKey == "sk-ant-rotated")
}
// MARK: - Gemini Configuration
@Test("Gemini configuration stores provider and model")
func geminiBasicConfiguration() {
let config = ProviderConfiguration.gemini(
apiKey: "AIzaSyTest123",
model: "gemini-2.0-flash"
)
#expect(config.provider == .gemini)
#expect(config.model == "gemini-2.0-flash")
#expect(config.apiKey == "AIzaSyTest123")
#expect(config.hasValidAPIKey)
}
@Test("Gemini configuration produces a valid LanguageModel")
func geminiMakeModel() {
let config = ProviderConfiguration.gemini(
apiKey: "AIzaSyTest",
model: "gemini-2.0-flash"
)
let model = config.makeModel()
#expect(model is GeminiLanguageModel)
}
@Test("Gemini with custom API version")
func geminiCustomVersion() {
let config = ProviderConfiguration.gemini(
apiKey: "AIzaSyTest",
model: "gemini-2.0-flash",
apiVersion: "v1"
)
#expect(config.apiVersion == "v1")
let model = config.makeModel()
#expect(model is GeminiLanguageModel)
}
@Test("Gemini with dynamic key provider captures key by reference")
func geminiDynamicKeyProvider() {
nonisolated(unsafe) var currentKey = "AIza-initial"
let config = ProviderConfiguration.gemini(
apiKeyProvider: { currentKey },
model: "gemini-2.0-flash"
)
#expect(config.apiKey == "AIza-initial")
currentKey = "AIza-rotated"
#expect(config.apiKey == "AIza-rotated")
}
// MARK: - OpenAI-Compatible Configuration
@Test("OpenAI-compatible configuration with custom base URL")
func openAICompatibleConfiguration() {
let baseURL = URL(string: "https://api.together.xyz/v1/")!
let config = ProviderConfiguration.openAICompatible(
apiKey: "together-key",
model: "meta-llama/Llama-3.1-70B",
baseURL: baseURL
)
#expect(config.provider == .openAICompatible)
#expect(config.model == "meta-llama/Llama-3.1-70B")
#expect(config.baseURL == baseURL)
let model = config.makeModel()
#expect(model is OpenAILanguageModel)
}
@Test("OpenAI-compatible with dynamic key provider")
func openAICompatibleDynamicKey() {
let baseURL = URL(string: "http://localhost:1234/v1/")!
nonisolated(unsafe) var currentKey = "local-key"
let config = ProviderConfiguration.openAICompatible(
apiKeyProvider: { currentKey },
model: "local-model",
baseURL: baseURL
)
#expect(config.apiKey == "local-key")
currentKey = "new-local-key"
#expect(config.apiKey == "new-local-key")
}
// MARK: - API Key Validation
@Test("Empty API key reports invalid")
func emptyAPIKeyInvalid() {
let config = ProviderConfiguration.openAI(
apiKey: "",
model: "gpt-4o"
)
#expect(!config.hasValidAPIKey)
}
@Test("Whitespace-only API key reports invalid")
func whitespaceAPIKeyInvalid() {
let config = ProviderConfiguration.openAI(
apiKey: " \n\t ",
model: "gpt-4o"
)
#expect(!config.hasValidAPIKey)
}
@Test("Non-empty API key reports valid")
func nonEmptyAPIKeyValid() {
let config = ProviderConfiguration.openAI(
apiKey: "x",
model: "gpt-4o"
)
#expect(config.hasValidAPIKey)
}
// MARK: - Environment Variable Configuration
@Test("fromEnvironment creates configuration for each provider")
func fromEnvironmentCreatesConfig() {
let openAI = ProviderConfiguration.fromEnvironment(
provider: .openAI,
environmentVariable: "SWIFTDAI_TEST_OPENAI_KEY",
model: "gpt-4o"
)
#expect(openAI.provider == .openAI)
#expect(openAI.model == "gpt-4o")
let anthropic = ProviderConfiguration.fromEnvironment(
provider: .anthropic,
environmentVariable: "SWIFTDAI_TEST_ANTHROPIC_KEY",
model: "claude-sonnet-4-20250514"
)
#expect(anthropic.provider == .anthropic)
let gemini = ProviderConfiguration.fromEnvironment(
provider: .gemini,
environmentVariable: "SWIFTDAI_TEST_GEMINI_KEY",
model: "gemini-2.0-flash"
)
#expect(gemini.provider == .gemini)
}
@Test("fromEnvironment returns empty key when variable not set")
func fromEnvironmentMissingVariable() {
let config = ProviderConfiguration.fromEnvironment(
provider: .openAI,
environmentVariable: "NONEXISTENT_KEY_VAR_SWIFTDBAI_TEST",
model: "gpt-4o"
)
#expect(!config.hasValidAPIKey)
#expect(config.apiKey == "")
}
// MARK: - Provider Enum
@Test("Provider enum has all expected cases")
func providerCases() {
let cases = ProviderConfiguration.Provider.allCases
#expect(cases.count == 6)
#expect(cases.contains(.openAI))
#expect(cases.contains(.anthropic))
#expect(cases.contains(.gemini))
#expect(cases.contains(.openAICompatible))
#expect(cases.contains(.ollama))
#expect(cases.contains(.llamaCpp))
}
// MARK: - Cross-Provider Model Creation
@Test("All providers produce available models")
func allProvidersCreateAvailableModels() {
let configs: [ProviderConfiguration] = [
.openAI(apiKey: "test", model: "gpt-4o"),
.anthropic(apiKey: "test", model: "claude-sonnet-4-20250514"),
.gemini(apiKey: "test", model: "gemini-2.0-flash"),
.openAICompatible(
apiKey: "test",
model: "local",
baseURL: URL(string: "http://localhost:8080/v1/")!
),
]
for config in configs {
let model = config.makeModel()
#expect(model.isAvailable, "Model for \(config.provider) should be available")
}
}
}

View File

@@ -0,0 +1,629 @@
// SQLQueryParserTests.swift
// SwiftDBAITests
import Testing
@testable import SwiftDBAI
@Suite("SQLQueryParser")
struct SQLQueryParserTests {
let readOnlyParser = SQLQueryParser(allowlist: .readOnly)
let standardParser = SQLQueryParser(allowlist: .standard)
let unrestrictedParser = SQLQueryParser(allowlist: .unrestricted)
// MARK: - Extraction from code blocks
@Test("Extracts SQL from markdown sql code block")
func extractFromSQLCodeBlock() throws {
let text = """
Here's the query to find the top users:
```sql
SELECT name, COUNT(*) as count FROM users GROUP BY name ORDER BY count DESC
```
This will give you the results.
"""
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT name, COUNT(*) as count FROM users GROUP BY name ORDER BY count DESC")
#expect(result.operation == .select)
#expect(result.requiresConfirmation == false)
}
@Test("Extracts SQL from generic code block")
func extractFromGenericCodeBlock() throws {
let text = """
Here you go:
```
SELECT * FROM products WHERE price > 100
```
"""
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT * FROM products WHERE price > 100")
}
@Test("Extracts SQL from labeled text")
func extractFromLabel() throws {
let text = """
I can help with that.
SQL: SELECT id, name FROM categories WHERE active = 1
That should work.
"""
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT id, name FROM categories WHERE active = 1")
}
@Test("Extracts direct SQL from plain text")
func extractDirectSQL() throws {
let text = "SELECT COUNT(*) FROM orders WHERE status = 'shipped'"
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT COUNT(*) FROM orders WHERE status = 'shipped'")
}
@Test("Handles SQL with trailing semicolons")
func trailingSemicolon() throws {
let text = "```sql\nSELECT * FROM users;\n```"
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT * FROM users")
}
@Test("Handles multiline SQL in code block")
func multilineSQL() throws {
let text = """
```sql
SELECT u.name, COUNT(o.id) as order_count
FROM users u
JOIN orders o ON u.id = o.user_id
GROUP BY u.name
ORDER BY order_count DESC
LIMIT 10
```
"""
let result = try readOnlyParser.parse(text)
#expect(result.sql.contains("SELECT u.name"))
#expect(result.sql.contains("LIMIT 10"))
}
@Test("Handles WITH (CTE) queries as SELECT")
func cteQuery() throws {
let text = """
```sql
WITH top_users AS (
SELECT user_id, COUNT(*) as cnt FROM orders GROUP BY user_id
)
SELECT * FROM top_users WHERE cnt > 5
```
"""
let result = try readOnlyParser.parse(text)
#expect(result.operation == .select)
}
// MARK: - No SQL found
@Test("Throws noSQLFound for text without SQL")
func noSQLFound() throws {
let text = "I'm sorry, I can't help with that request."
#expect(throws: SQLParsingError.noSQLFound) {
try readOnlyParser.parse(text)
}
}
@Test("Throws noSQLFound for empty input")
func emptyInput() throws {
#expect(throws: SQLParsingError.noSQLFound) {
try readOnlyParser.parse("")
}
}
// MARK: - Operation detection
@Test("Detects INSERT operation")
func detectInsert() throws {
let text = "```sql\nINSERT INTO users (name) VALUES ('Alice')\n```"
let result = try standardParser.parse(text)
#expect(result.operation == .insert)
}
@Test("Detects UPDATE operation")
func detectUpdate() throws {
let text = "```sql\nUPDATE users SET name = 'Bob' WHERE id = 1\n```"
let result = try standardParser.parse(text)
#expect(result.operation == .update)
}
@Test("Detects DELETE operation and requires confirmation")
func detectDeleteRequiresConfirmation() throws {
let text = "```sql\nDELETE FROM users WHERE id = 99\n```"
let result = try unrestrictedParser.parse(text)
#expect(result.operation == .delete)
#expect(result.requiresConfirmation == true)
}
// MARK: - Allowlist enforcement
@Test("Rejects INSERT on read-only allowlist")
func rejectInsertOnReadOnly() throws {
let text = "```sql\nINSERT INTO users (name) VALUES ('Mallory')\n```"
#expect(throws: SQLParsingError.operationNotAllowed(.insert)) {
try readOnlyParser.parse(text)
}
}
@Test("Rejects UPDATE on read-only allowlist")
func rejectUpdateOnReadOnly() {
let text = "```sql\nUPDATE users SET name = 'Eve' WHERE id = 1\n```"
#expect(throws: SQLParsingError.operationNotAllowed(.update)) {
try readOnlyParser.parse(text)
}
}
@Test("Rejects DELETE on standard allowlist")
func rejectDeleteOnStandard() {
let text = "```sql\nDELETE FROM users WHERE id = 1\n```"
#expect(throws: SQLParsingError.operationNotAllowed(.delete)) {
try standardParser.parse(text)
}
}
// MARK: - Dangerous operations
@Test("Rejects DROP TABLE")
func rejectDrop() {
let text = "```sql\nDROP TABLE users\n```"
#expect(throws: SQLParsingError.dangerousOperation("DROP")) {
try unrestrictedParser.parse(text)
}
}
@Test("Rejects ALTER TABLE")
func rejectAlter() {
let text = "```sql\nALTER TABLE users ADD COLUMN age INTEGER\n```"
#expect(throws: SQLParsingError.dangerousOperation("ALTER")) {
try unrestrictedParser.parse(text)
}
}
@Test("Rejects PRAGMA")
func rejectPragma() {
let text = "```sql\nPRAGMA table_info(users)\n```"
#expect(throws: SQLParsingError.dangerousOperation("PRAGMA")) {
try unrestrictedParser.parse(text)
}
}
@Test("Does not match dangerous keywords inside identifiers")
func noFalsePositiveOnSubstring() throws {
// "DROPDOWN" contains "DROP" as substring but is not the keyword
let text = "SELECT dropdown_value FROM settings"
let result = try readOnlyParser.parse(text)
#expect(result.sql.contains("dropdown_value"))
}
// MARK: - Multiple statements
@Test("Rejects multiple statements separated by semicolons")
func rejectMultipleStatements() {
let text = "```sql\nSELECT * FROM users; SELECT * FROM orders\n```"
#expect(throws: SQLParsingError.multipleStatements) {
try readOnlyParser.parse(text)
}
}
@Test("Allows semicolons inside string literals")
func allowSemicolonInString() throws {
let text = "SELECT * FROM users WHERE bio = 'hello; world'"
let result = try readOnlyParser.parse(text)
#expect(result.sql.contains("hello; world"))
}
// MARK: - ParsedSQL equality
@Test("ParsedSQL equality works")
func parsedSQLEquality() {
let a = ParsedSQL(sql: "SELECT 1", operation: .select)
let b = ParsedSQL(sql: "SELECT 1", operation: .select)
#expect(a == b)
}
// MARK: - Error descriptions
@Test("Error descriptions are meaningful")
func errorDescriptions() {
#expect(SQLParsingError.noSQLFound.description.contains("No SQL"))
#expect(SQLParsingError.operationNotAllowed(.insert).description.contains("INSERT"))
#expect(SQLParsingError.dangerousOperation("DROP").description.contains("DROP"))
#expect(SQLParsingError.multipleStatements.description.contains("single"))
}
// MARK: - MutationPolicy integration
@Test("MutationPolicy allows INSERT on permitted table")
func mutationPolicyAllowsInsertOnPermittedTable() throws {
let policy = MutationPolicy(
allowedOperations: [.insert, .update],
allowedTables: ["orders", "order_items"]
)
let parser = SQLQueryParser(mutationPolicy: policy)
let text = "```sql\nINSERT INTO orders (product, qty) VALUES ('Widget', 3)\n```"
let result = try parser.parse(text)
#expect(result.operation == .insert)
#expect(result.requiresConfirmation == false)
}
@Test("MutationPolicy rejects INSERT on non-permitted table")
func mutationPolicyRejectsInsertOnForbiddenTable() {
let policy = MutationPolicy(
allowedOperations: [.insert, .update],
allowedTables: ["orders"]
)
let parser = SQLQueryParser(mutationPolicy: policy)
let text = "```sql\nINSERT INTO users (name) VALUES ('Alice')\n```"
#expect(throws: SQLParsingError.tableNotAllowed(table: "users", operation: .insert)) {
try parser.parse(text)
}
}
@Test("MutationPolicy rejects UPDATE on non-permitted table")
func mutationPolicyRejectsUpdateOnForbiddenTable() {
let policy = MutationPolicy(
allowedOperations: [.insert, .update],
allowedTables: ["orders"]
)
let parser = SQLQueryParser(mutationPolicy: policy)
let text = "```sql\nUPDATE users SET name = 'Bob' WHERE id = 1\n```"
#expect(throws: SQLParsingError.tableNotAllowed(table: "users", operation: .update)) {
try parser.parse(text)
}
}
@Test("MutationPolicy rejects DELETE on non-permitted table")
func mutationPolicyRejectsDeleteOnForbiddenTable() {
let policy = MutationPolicy(
allowedOperations: [.insert, .update, .delete],
allowedTables: ["temp_data"]
)
let parser = SQLQueryParser(mutationPolicy: policy)
let text = "```sql\nDELETE FROM users WHERE id = 99\n```"
#expect(throws: SQLParsingError.tableNotAllowed(table: "users", operation: .delete)) {
try parser.parse(text)
}
}
@Test("MutationPolicy allows mutation on any table when allowedTables is nil")
func mutationPolicyAllowsAllTablesWhenNil() throws {
let policy = MutationPolicy(allowedOperations: [.insert, .update])
let parser = SQLQueryParser(mutationPolicy: policy)
let text = "```sql\nINSERT INTO any_table (col) VALUES ('val')\n```"
let result = try parser.parse(text)
#expect(result.operation == .insert)
}
@Test("MutationPolicy SELECT is never restricted by table allowlist")
func mutationPolicySelectIgnoresTableRestrictions() throws {
let policy = MutationPolicy(
allowedOperations: [.insert],
allowedTables: ["orders"]
)
let parser = SQLQueryParser(mutationPolicy: policy)
// SELECT from a table NOT in allowedTables should still work
let text = "```sql\nSELECT * FROM users\n```"
let result = try parser.parse(text)
#expect(result.operation == .select)
#expect(result.requiresConfirmation == false)
}
@Test("MutationPolicy DELETE requires confirmation by default")
func mutationPolicyDeleteRequiresConfirmation() throws {
let policy = MutationPolicy(allowedOperations: [.delete])
let parser = SQLQueryParser(mutationPolicy: policy)
let text = "```sql\nDELETE FROM users WHERE id = 1\n```"
let result = try parser.parse(text)
#expect(result.operation == .delete)
#expect(result.requiresConfirmation == true)
}
@Test("MutationPolicy DELETE skips confirmation when configured")
func mutationPolicyDeleteNoConfirmation() throws {
let policy = MutationPolicy(
allowedOperations: [.delete],
requiresDestructiveConfirmation: false
)
let parser = SQLQueryParser(mutationPolicy: policy)
let text = "```sql\nDELETE FROM users WHERE id = 1\n```"
let result = try parser.parse(text)
#expect(result.operation == .delete)
#expect(result.requiresConfirmation == false)
}
@Test("MutationPolicy readOnly preset rejects all mutations")
func mutationPolicyReadOnlyRejectsAll() {
let parser = SQLQueryParser(mutationPolicy: .readOnly)
#expect(throws: SQLParsingError.operationNotAllowed(.insert)) {
try parser.parse("INSERT INTO t (a) VALUES (1)")
}
#expect(throws: SQLParsingError.operationNotAllowed(.update)) {
try parser.parse("UPDATE t SET a = 1")
}
#expect(throws: SQLParsingError.operationNotAllowed(.delete)) {
try parser.parse("DELETE FROM t WHERE id = 1")
}
}
@Test("MutationPolicy table matching is case-insensitive")
func mutationPolicyTableCaseInsensitive() throws {
let policy = MutationPolicy(
allowedOperations: [.insert],
allowedTables: ["Orders"]
)
let parser = SQLQueryParser(mutationPolicy: policy)
let text = "INSERT INTO orders (product) VALUES ('Widget')"
let result = try parser.parse(text)
#expect(result.operation == .insert)
}
@Test("MutationPolicy handles quoted table names")
func mutationPolicyQuotedTableNames() throws {
let policy = MutationPolicy(
allowedOperations: [.insert, .update],
allowedTables: ["order_items"]
)
let parser = SQLQueryParser(mutationPolicy: policy)
// Backtick-quoted
let backtick = "INSERT INTO `order_items` (qty) VALUES (5)"
let r1 = try parser.parse(backtick)
#expect(r1.operation == .insert)
// Double-quote-quoted
let doubleQuote = "UPDATE \"order_items\" SET qty = 10 WHERE id = 1"
let r2 = try parser.parse(doubleQuote)
#expect(r2.operation == .update)
}
@Test("Error description for tableNotAllowed is meaningful")
func tableNotAllowedDescription() {
let error = SQLParsingError.tableNotAllowed(table: "secret", operation: .delete)
#expect(error.description.contains("secret"))
#expect(error.description.contains("DELETE"))
}
@Test("Error description for confirmationRequired is meaningful")
func confirmationRequiredDescription() {
let error = SQLParsingError.confirmationRequired(sql: "DELETE FROM x", operation: .delete)
#expect(error.description.contains("DELETE"))
#expect(error.description.contains("confirmation"))
}
// MARK: - Robust extraction edge cases
@Test("Extracts plain SQL without any wrapping")
func plainSQL() throws {
let text = "SELECT * FROM users"
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT * FROM users")
}
@Test("Extracts SQL from markdown sql code block")
func markdownSQLBlock() throws {
let text = "```sql\nSELECT * FROM users\n```"
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT * FROM users")
}
@Test("Extracts SQL from generic code block")
func genericCodeBlock() throws {
let text = "```\nSELECT * FROM users\n```"
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT * FROM users")
}
@Test("Strips trailing semicolons")
func trailingSemicolonEdge() throws {
let text = "SELECT * FROM users;"
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT * FROM users")
}
@Test("Extracts SQL with preamble text")
func preambleText() throws {
let text = "Here's the query:\nSELECT * FROM users"
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT * FROM users")
}
@Test("Handles trailing backticks only (no opening fence)")
func trailingBackticksOnly() throws {
let text = "SELECT * FROM users\n```"
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT * FROM users")
}
@Test("Extracts SQL from single-line code block")
func singleLineCodeBlock() throws {
let text = "```sql SELECT * FROM users ```"
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT * FROM users")
}
@Test("Handles no newline before closing fence")
func noNewlineBeforeClosingFence() throws {
let text = "```sql\nSELECT * FROM users```"
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT * FROM users")
}
@Test("Extracts SQL inline with text prefix")
func inlineWithText() throws {
let text = "The SQL query is: SELECT * FROM users"
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT * FROM users")
}
@Test("Handles extra whitespace around SQL")
func extraWhitespace() throws {
let text = "\n\nSELECT * FROM users\n\n"
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT * FROM users")
}
@Test("Extracts SQL from chatty LLM response with preamble and postamble")
func chattyLLMResponse() throws {
let text = "Sure! Here's the SQL:\n\n```sql\nSELECT * FROM users\n```\n\nThis will return all users."
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT * FROM users")
}
@Test("Preserves SQL comments")
func sqlWithComments() throws {
let text = "SELECT * FROM users -- get all users"
let result = try readOnlyParser.parse(text)
#expect(result.sql.contains("-- get all users"))
}
@Test("Preserves backtick-quoted identifiers in SQL")
func backtickQuotedIdentifiers() throws {
let text = "SELECT `column name` FROM users"
let result = try readOnlyParser.parse(text)
#expect(result.sql.contains("`column name`"))
}
@Test("Strips think tags from Qwen-style models")
func thinkTags() throws {
let text = "<think>I need to query the users table</think>\nSELECT * FROM users"
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT * FROM users")
#expect(!result.sql.contains("think"))
}
@Test("Handles 4 or 5 backtick fences")
func extraBacktickFences() throws {
let text4 = "````sql\nSELECT * FROM users\n````"
let result4 = try readOnlyParser.parse(text4)
#expect(result4.sql == "SELECT * FROM users")
let text5 = "`````\nSELECT * FROM users\n`````"
let result5 = try readOnlyParser.parse(text5)
#expect(result5.sql == "SELECT * FROM users")
}
@Test("Handles mixed case SQL keywords")
func mixedCaseSQL() throws {
let text = "select * from USERS"
let result = try readOnlyParser.parse(text)
#expect(result.sql == "select * from USERS")
}
@Test("Handles WITH clause (CTE) queries")
func withClause() throws {
let text = "WITH cte AS (SELECT id FROM orders) SELECT * FROM cte"
let result = try readOnlyParser.parse(text)
#expect(result.sql.hasPrefix("WITH"))
#expect(result.operation == .select)
}
@Test("Handles WITH clause in code block")
func withClauseInCodeBlock() throws {
let text = "```sql\nWITH top AS (\n SELECT user_id, COUNT(*) as cnt FROM orders GROUP BY user_id\n)\nSELECT * FROM top WHERE cnt > 5\n```"
let result = try readOnlyParser.parse(text)
#expect(result.sql.hasPrefix("WITH"))
#expect(result.operation == .select)
}
@Test("Multi-line SQL with JOINs and subqueries in code block")
func multiLineJoinsAndSubqueries() throws {
let text = """
```sql
SELECT u.name, o.total
FROM users u
INNER JOIN orders o ON u.id = o.user_id
WHERE o.total > (SELECT AVG(total) FROM orders)
ORDER BY o.total DESC
```
"""
let result = try readOnlyParser.parse(text)
#expect(result.sql.contains("INNER JOIN"))
#expect(result.sql.contains("SELECT AVG(total)"))
#expect(result.sql.contains("ORDER BY"))
}
@Test("Handles response with both explanation text and SQL")
func explanationAndSQL() throws {
let text = """
To find all active users, we need to query the users table
and filter by the active column. Here's the query:
SELECT * FROM users WHERE active = 1
This should give you the results you're looking for.
"""
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT * FROM users WHERE active = 1")
}
@Test("Throws noSQLFound for empty response")
func emptyResponse() throws {
#expect(throws: SQLParsingError.noSQLFound) {
try readOnlyParser.parse("")
}
#expect(throws: SQLParsingError.noSQLFound) {
try readOnlyParser.parse(" \n\n ")
}
}
@Test("Throws noSQLFound for response with no SQL at all")
func noSQLAtAll() throws {
#expect(throws: SQLParsingError.noSQLFound) {
try readOnlyParser.parse("I cannot help with that question. Please try asking about your data.")
}
}
@Test("Handles response with multiple SQL statements in code block (rejects them)")
func multipleStatementsInCodeBlock() throws {
// When multiple statements are in a code block, the parser sees both and rejects
let text = "```sql\nSELECT * FROM users; SELECT * FROM orders\n```"
#expect(throws: SQLParsingError.multipleStatements) {
try readOnlyParser.parse(text)
}
}
@Test("Extracts first SQL statement from plain text with multiple statements")
func multipleStatementsPlainText() throws {
// In plain text, the direct extraction stops at the semicolon and extracts the first statement
let text = "SELECT * FROM users; SELECT * FROM orders"
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT * FROM users")
}
@Test("Preserves backtick identifiers inside code blocks")
func backtickIdentifiersInCodeBlock() throws {
let text = "```sql\nSELECT `first name`, `last name` FROM `user data`\n```"
let result = try readOnlyParser.parse(text)
#expect(result.sql.contains("`first name`"))
#expect(result.sql.contains("`last name`"))
#expect(result.sql.contains("`user data`"))
}
@Test("Strips think tags with multiline reasoning content")
func multilineThinkTags() throws {
let text = """
<think>
The user wants to find all users.
I should use SELECT * FROM users.
Let me think about which columns to include...
</think>
SELECT * FROM users
"""
let result = try readOnlyParser.parse(text)
#expect(result.sql == "SELECT * FROM users")
}
@Test("Handles mixed backtick styles in response")
func mixedBacktickStyles() throws {
// Code fences + backtick-quoted identifiers inside
let text = "```sql\nSELECT `user name` FROM users WHERE `is active` = 1\n```"
let result = try readOnlyParser.parse(text)
#expect(result.sql.contains("`user name`"))
#expect(result.sql.contains("`is active`"))
}
}

View File

@@ -0,0 +1,234 @@
// SchemaIntrospectorTests.swift
// SwiftDBAI
import Testing
import GRDB
@testable import SwiftDBAI
@Suite("SchemaIntrospector")
struct SchemaIntrospectorTests {
// MARK: - Helper
/// Creates an in-memory database with a sample schema for testing.
private func makeTestDatabase() throws -> DatabaseQueue {
let db = try DatabaseQueue(configuration: {
var config = Configuration()
config.foreignKeysEnabled = true
return config
}())
try db.write { db in
try db.execute(sql: """
CREATE TABLE authors (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
email TEXT UNIQUE
);
""")
try db.execute(sql: """
CREATE TABLE books (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL,
author_id INTEGER NOT NULL REFERENCES authors(id) ON DELETE CASCADE,
published_date TEXT,
price REAL DEFAULT 9.99
);
""")
try db.execute(sql: """
CREATE INDEX idx_books_author ON books(author_id);
""")
try db.execute(sql: """
CREATE INDEX idx_books_title ON books(title);
""")
try db.execute(sql: """
CREATE TABLE reviews (
id INTEGER PRIMARY KEY,
book_id INTEGER NOT NULL REFERENCES books(id),
rating INTEGER NOT NULL,
comment TEXT
);
""")
}
return db
}
// MARK: - Tests
@Test("Discovers all user tables")
func discoversAllTables() async throws {
let db = try makeTestDatabase()
let schema = try await SchemaIntrospector.introspect(database: db)
#expect(schema.tableNames.count == 3)
#expect(schema.tableNames.contains("authors"))
#expect(schema.tableNames.contains("books"))
#expect(schema.tableNames.contains("reviews"))
}
@Test("Excludes sqlite_ internal tables")
func excludesInternalTables() async throws {
let db = try makeTestDatabase()
let schema = try await SchemaIntrospector.introspect(database: db)
for name in schema.tableNames {
#expect(!name.hasPrefix("sqlite_"))
}
}
@Test("Introspects column names and types")
func introspectsColumns() async throws {
let db = try makeTestDatabase()
let schema = try await SchemaIntrospector.introspect(database: db)
let books = try #require(schema.tables["books"])
#expect(books.columns.count == 5)
let titleCol = try #require(books.columns.first { $0.name == "title" })
#expect(titleCol.type == "TEXT")
#expect(titleCol.isNotNull == true)
#expect(titleCol.isPrimaryKey == false)
let priceCol = try #require(books.columns.first { $0.name == "price" })
#expect(priceCol.type == "REAL")
#expect(priceCol.defaultValue == "9.99")
}
@Test("Detects primary keys")
func detectsPrimaryKeys() async throws {
let db = try makeTestDatabase()
let schema = try await SchemaIntrospector.introspect(database: db)
let authors = try #require(schema.tables["authors"])
#expect(authors.primaryKey == ["id"])
let idCol = try #require(authors.columns.first { $0.name == "id" })
#expect(idCol.isPrimaryKey == true)
}
@Test("Detects foreign keys")
func detectsForeignKeys() async throws {
let db = try makeTestDatabase()
let schema = try await SchemaIntrospector.introspect(database: db)
let books = try #require(schema.tables["books"])
#expect(books.foreignKeys.count == 1)
let fk = books.foreignKeys[0]
#expect(fk.fromColumn == "author_id")
#expect(fk.toTable == "authors")
#expect(fk.toColumn == "id")
#expect(fk.onDelete == "CASCADE")
}
@Test("Detects indexes")
func detectsIndexes() async throws {
let db = try makeTestDatabase()
let schema = try await SchemaIntrospector.introspect(database: db)
let books = try #require(schema.tables["books"])
let indexNames = books.indexes.map(\.name)
#expect(indexNames.contains("idx_books_author"))
#expect(indexNames.contains("idx_books_title"))
}
@Test("Detects NOT NULL constraints")
func detectsNotNull() async throws {
let db = try makeTestDatabase()
let schema = try await SchemaIntrospector.introspect(database: db)
let reviews = try #require(schema.tables["reviews"])
let ratingCol = try #require(reviews.columns.first { $0.name == "rating" })
#expect(ratingCol.isNotNull == true)
let commentCol = try #require(reviews.columns.first { $0.name == "comment" })
#expect(commentCol.isNotNull == false)
}
@Test("Generates LLM-friendly schema description")
func generatesSchemaDescription() async throws {
let db = try makeTestDatabase()
let schema = try await SchemaIntrospector.introspect(database: db)
let description = schema.schemaDescription
#expect(description.contains("TABLE authors"))
#expect(description.contains("TABLE books"))
#expect(description.contains("FOREIGN KEY"))
#expect(description.contains("REFERENCES authors(id)"))
#expect(description.contains("INDEX idx_books_author"))
}
@Test("Handles empty database")
func handlesEmptyDatabase() async throws {
let db = try DatabaseQueue()
let schema = try await SchemaIntrospector.introspect(database: db)
#expect(schema.tables.isEmpty)
#expect(schema.tableNames.isEmpty)
#expect(schema.schemaDescription.isEmpty)
}
@Test("Handles composite primary keys")
func handlesCompositePrimaryKey() async throws {
let db = try DatabaseQueue()
try await db.write { db in
try db.execute(sql: """
CREATE TABLE book_tags (
book_id INTEGER NOT NULL,
tag_id INTEGER NOT NULL,
PRIMARY KEY (book_id, tag_id)
);
""")
}
let schema = try await SchemaIntrospector.introspect(database: db)
let bookTags = try #require(schema.tables["book_tags"])
#expect(bookTags.primaryKey.count == 2)
#expect(bookTags.primaryKey.contains("book_id"))
#expect(bookTags.primaryKey.contains("tag_id"))
}
@Test("Handles tables with no explicit types (SQLite dynamic typing)")
func handlesDynamicTyping() async throws {
let db = try DatabaseQueue()
try await db.write { db in
try db.execute(sql: """
CREATE TABLE flexible (
id INTEGER PRIMARY KEY,
data,
info BLOB
);
""")
}
let schema = try await SchemaIntrospector.introspect(database: db)
let flexible = try #require(schema.tables["flexible"])
let dataCol = try #require(flexible.columns.first { $0.name == "data" })
#expect(dataCol.type == "") // No declared type
let infoCol = try #require(flexible.columns.first { $0.name == "info" })
#expect(infoCol.type == "BLOB")
}
@Test("Synchronous introspection works within database access")
func synchronousIntrospection() async throws {
let db = try DatabaseQueue()
try await db.write { db in
try db.execute(sql: "CREATE TABLE test (id INTEGER PRIMARY KEY, val TEXT);")
}
let schema = try await db.read { db in
try SchemaIntrospector.introspect(db: db)
}
#expect(schema.tableNames == ["test"])
let table = try #require(schema.tables["test"])
#expect(table.columns.count == 2)
}
}

View File

@@ -0,0 +1,133 @@
// ScrollableDataTableViewTests.swift
// SwiftDBAITests
//
// Tests for the ScrollableDataTableView component.
import Foundation
import Testing
@testable import SwiftDBAI
@Suite("ScrollableDataTableView")
@MainActor
struct ScrollableDataTableViewTests {
// MARK: - Test Helpers
private func makeDataTable(
columnNames: [String] = ["id", "name", "score"],
inferredTypes: [DataTable.InferredType] = [.integer, .text, .real],
rowCount: Int = 5
) -> DataTable {
let columns = columnNames.enumerated().map { idx, name in
DataTable.Column(name: name, index: idx, inferredType: inferredTypes[idx])
}
let rows = (0..<rowCount).map { i in
DataTable.Row(
id: i,
values: [
.integer(Int64(i + 1)),
.text("Item \(i + 1)"),
.real(Double(i) * 10.5),
],
columnNames: columnNames
)
}
return DataTable(columns: columns, rows: rows, sql: "SELECT * FROM test", executionTime: 0.015)
}
private func makeEmptyDataTable() -> DataTable {
DataTable(columns: [], rows: [], sql: "", executionTime: 0)
}
// MARK: - Initialization Tests
@Test("Initializes with default parameters")
func initWithDefaults() {
let table = makeDataTable()
let view = ScrollableDataTableView(dataTable: table)
#expect(view.minimumColumnWidth == 80)
#expect(view.maximumColumnWidth == 250)
#expect(view.showAlternatingRows == true)
#expect(view.showFooter == true)
}
@Test("Initializes with custom parameters")
func initWithCustomParams() {
let table = makeDataTable()
let view = ScrollableDataTableView(
dataTable: table,
minimumColumnWidth: 100,
maximumColumnWidth: 300,
showAlternatingRows: false,
showFooter: false
)
#expect(view.minimumColumnWidth == 100)
#expect(view.maximumColumnWidth == 300)
#expect(view.showAlternatingRows == false)
#expect(view.showFooter == false)
}
@Test("Handles empty data table")
func handlesEmptyTable() {
let table = makeEmptyDataTable()
let view = ScrollableDataTableView(dataTable: table)
#expect(view.dataTable.isEmpty)
}
@Test("Handles single row table")
func handlesSingleRow() {
let table = makeDataTable(rowCount: 1)
let view = ScrollableDataTableView(dataTable: table)
#expect(view.dataTable.rowCount == 1)
#expect(view.dataTable.columnCount == 3)
}
@Test("Handles single column table")
func handlesSingleColumn() {
let columns = [DataTable.Column(name: "count", index: 0, inferredType: .integer)]
let rows = [
DataTable.Row(id: 0, values: [.integer(42)], columnNames: ["count"])
]
let table = DataTable(columns: columns, rows: rows, sql: "SELECT count(*) FROM t", executionTime: 0.001)
let view = ScrollableDataTableView(dataTable: table)
#expect(view.dataTable.columnCount == 1)
#expect(view.dataTable.rowCount == 1)
}
@Test("Handles large number of rows")
func handlesLargeRowCount() {
let table = makeDataTable(rowCount: 1000)
let view = ScrollableDataTableView(dataTable: table)
#expect(view.dataTable.rowCount == 1000)
}
@Test("Handles null values in cells")
func handlesNullValues() {
let columns = [
DataTable.Column(name: "name", index: 0, inferredType: .text),
DataTable.Column(name: "value", index: 1, inferredType: .null),
]
let rows = [
DataTable.Row(id: 0, values: [.text("test"), .null], columnNames: ["name", "value"])
]
let table = DataTable(columns: columns, rows: rows)
let view = ScrollableDataTableView(dataTable: table)
#expect(view.dataTable.rows[0][1] == .null)
}
@Test("Handles blob values in cells")
func handlesBlobValues() {
let columns = [
DataTable.Column(name: "data", index: 0, inferredType: .blob),
]
let blobData = Data([0x00, 0xFF, 0xAB])
let rows = [
DataTable.Row(id: 0, values: [.blob(blobData)], columnNames: ["data"])
]
let table = DataTable(columns: columns, rows: rows)
let view = ScrollableDataTableView(dataTable: table)
#expect(view.dataTable.rows[0][0] == QueryResult.Value.blob(blobData))
}
}

View File

@@ -0,0 +1,301 @@
// TextSummaryRendererTests.swift
// SwiftDBAI
import AnyLanguageModel
import Testing
import Foundation
@testable import SwiftDBAI
@Suite("TextSummaryRenderer")
struct TextSummaryRendererTests {
// MARK: - QueryResult.Value Tests
@Test("Value description renders correctly")
func valueDescriptions() {
#expect(QueryResult.Value.text("hello").description == "hello")
#expect(QueryResult.Value.integer(42).description == "42")
#expect(QueryResult.Value.real(3.14).description == "3.14")
#expect(QueryResult.Value.null.description == "NULL")
#expect(QueryResult.Value.blob(Data([0x01, 0x02])).description == "<2 bytes>")
}
@Test("Value doubleValue extracts numeric values")
func valueDoubleValues() {
#expect(QueryResult.Value.integer(42).doubleValue == 42.0)
#expect(QueryResult.Value.real(3.14).doubleValue == 3.14)
#expect(QueryResult.Value.text("100").doubleValue == 100.0)
#expect(QueryResult.Value.text("not a number").doubleValue == nil)
#expect(QueryResult.Value.null.doubleValue == nil)
#expect(QueryResult.Value.blob(Data()).doubleValue == nil)
}
@Test("Value isNull works correctly")
func valueIsNull() {
#expect(QueryResult.Value.null.isNull == true)
#expect(QueryResult.Value.text("").isNull == false)
#expect(QueryResult.Value.integer(0).isNull == false)
}
// MARK: - QueryResult Tests
@Test("Empty result has correct properties")
func emptyResult() {
let result = QueryResult(
columns: ["id", "name"],
rows: [],
sql: "SELECT id, name FROM users",
executionTime: 0.01
)
#expect(result.rowCount == 0)
#expect(result.isAggregate == false)
#expect(result.tabularDescription == "(empty result set)")
}
@Test("Single aggregate result is detected")
func aggregateDetection() {
let result = QueryResult(
columns: ["COUNT(*)"],
rows: [["COUNT(*)": .integer(42)]],
sql: "SELECT COUNT(*) FROM users",
executionTime: 0.01
)
#expect(result.isAggregate == true)
}
@Test("Multi-row result is not aggregate")
func nonAggregateDetection() {
let result = QueryResult(
columns: ["name"],
rows: [
["name": .text("Alice")],
["name": .text("Bob")],
],
sql: "SELECT name FROM users",
executionTime: 0.01
)
#expect(result.isAggregate == false)
}
@Test("Tabular description formats correctly")
func tabularDescription() {
let result = QueryResult(
columns: ["id", "name"],
rows: [
["id": .integer(1), "name": .text("Alice")],
["id": .integer(2), "name": .text("Bob")],
],
sql: "SELECT id, name FROM users",
executionTime: 0.01
)
let desc = result.tabularDescription
#expect(desc.contains("id | name"))
#expect(desc.contains("1 | Alice"))
#expect(desc.contains("2 | Bob"))
}
@Test("values(forColumn:) extracts column values")
func valuesForColumn() {
let result = QueryResult(
columns: ["name"],
rows: [
["name": .text("Alice")],
["name": .text("Bob")],
],
sql: "SELECT name FROM users",
executionTime: 0.01
)
let values = result.values(forColumn: "name")
#expect(values.count == 2)
#expect(values[0] == .text("Alice"))
}
// MARK: - Local Summary Tests (no LLM required)
@Test("Local summary for empty result")
func localSummaryEmpty() {
let result = makeResult(columns: ["id"], rows: [])
let renderer = makeMockRenderer()
let summary = renderer.localSummary(result: result, userQuestion: "Any users?")
#expect(summary == "No results found for your query.")
}
@Test("Local summary for single aggregate")
func localSummarySingleAggregate() {
let result = makeResult(
columns: ["COUNT(*)"],
rows: [["COUNT(*)": .integer(42)]]
)
let renderer = makeMockRenderer()
let summary = renderer.localSummary(result: result, userQuestion: "How many?")
#expect(summary.contains("42"))
}
@Test("Local summary for multiple aggregates")
func localSummaryMultipleAggregates() {
let result = makeResult(
columns: ["COUNT(*)", "AVG(price)"],
rows: [["COUNT(*)": .integer(10), "AVG(price)": .real(25.5)]]
)
let renderer = makeMockRenderer()
let summary = renderer.localSummary(result: result, userQuestion: "Stats?")
#expect(summary.contains("count"))
#expect(summary.contains("average price"))
}
@Test("Local summary for single record")
func localSummarySingleRecord() {
let result = makeResult(
columns: ["name", "email"],
rows: [["name": .text("Alice"), "email": .text("alice@example.com")]]
)
let renderer = makeMockRenderer()
let summary = renderer.localSummary(result: result, userQuestion: "Who?")
#expect(summary.contains("1 result"))
#expect(summary.contains("Alice"))
}
@Test("Local summary for multiple records with name column")
func localSummaryMultipleWithNames() {
let result = makeResult(
columns: ["name", "age"],
rows: [
["name": .text("Alice"), "age": .integer(30)],
["name": .text("Bob"), "age": .integer(25)],
["name": .text("Charlie"), "age": .integer(35)],
["name": .text("Diana"), "age": .integer(28)],
]
)
let renderer = makeMockRenderer()
let summary = renderer.localSummary(result: result, userQuestion: "List users")
#expect(summary.contains("4 results"))
#expect(summary.contains("Alice"))
#expect(summary.contains("1 more"))
}
@Test("Local summary for mutation result")
func localSummaryMutation() {
let result = QueryResult(
columns: [],
rows: [],
sql: "INSERT INTO users (name) VALUES ('Test')",
executionTime: 0.01,
rowsAffected: 1
)
let renderer = makeMockRenderer()
let summary = renderer.localSummary(result: result, userQuestion: "Add user")
#expect(summary == "Successfully inserted 1 row.")
}
@Test("Local summary for delete mutation")
func localSummaryDelete() {
let result = QueryResult(
columns: [],
rows: [],
sql: "DELETE FROM users WHERE id = 5",
executionTime: 0.01,
rowsAffected: 3
)
let renderer = makeMockRenderer()
let summary = renderer.localSummary(result: result, userQuestion: "Delete old users")
#expect(summary == "Successfully deleted 3 rows.")
}
@Test("Local summary for update mutation")
func localSummaryUpdate() {
let result = QueryResult(
columns: [],
rows: [],
sql: "UPDATE users SET active = 0 WHERE id = 1",
executionTime: 0.01,
rowsAffected: 1
)
let renderer = makeMockRenderer()
let summary = renderer.localSummary(result: result, userQuestion: "Deactivate user")
#expect(summary == "Successfully updated 1 row.")
}
// MARK: - LLM-based Summary Tests (using MockLanguageModel)
@Test("Summarize with LLM returns mock response for multi-row results")
func summarizeWithLLM() async throws {
let result = makeResult(
columns: ["name", "age"],
rows: [
["name": .text("Alice"), "age": .integer(30)],
["name": .text("Bob"), "age": .integer(25)],
]
)
let mockModel = MockLanguageModel(responseText: "There are 2 users: Alice (30) and Bob (25).")
let renderer = TextSummaryRenderer(model: mockModel)
let summary = try await renderer.summarize(result: result, userQuestion: "List all users")
#expect(summary == "There are 2 users: Alice (30) and Bob (25).")
}
@Test("Summarize returns empty result message without calling LLM")
func summarizeEmptyResult() async throws {
let result = makeResult(columns: ["id"], rows: [])
let renderer = makeMockRenderer()
let summary = try await renderer.summarize(result: result, userQuestion: "Find users")
#expect(summary == "No results found for your query.")
}
@Test("Summarize returns direct aggregate without calling LLM")
func summarizeAggregate() async throws {
let result = makeResult(
columns: ["COUNT(*)"],
rows: [["COUNT(*)": .integer(42)]]
)
let renderer = makeMockRenderer()
let summary = try await renderer.summarize(result: result, userQuestion: "How many?")
#expect(summary.contains("42"))
}
@Test("Summarize mutation returns template without calling LLM")
func summarizeMutation() async throws {
let result = QueryResult(
columns: [],
rows: [],
sql: "UPDATE users SET name = 'Test' WHERE id = 1",
executionTime: 0.01,
rowsAffected: 1
)
let renderer = makeMockRenderer()
let summary = try await renderer.summarize(result: result, userQuestion: "Update user")
#expect(summary == "Successfully updated 1 row.")
}
@Test("Summarize passes context to LLM prompt")
func summarizeWithContext() async throws {
let result = makeResult(
columns: ["total"],
rows: [
["total": .real(100.0)],
["total": .real(200.0)],
]
)
let mockModel = MockLanguageModel(responseText: "The totals are 100 and 200.")
let renderer = TextSummaryRenderer(model: mockModel)
let summary = try await renderer.summarize(
result: result,
userQuestion: "Show totals",
context: "Amounts are in USD"
)
#expect(summary == "The totals are 100 and 200.")
}
// MARK: - Helpers
private func makeResult(
columns: [String],
rows: [[String: QueryResult.Value]],
sql: String = "SELECT * FROM test"
) -> QueryResult {
QueryResult(columns: columns, rows: rows, sql: sql, executionTime: 0.01)
}
/// Creates a renderer with a mock model (for localSummary tests that don't hit the LLM).
private func makeMockRenderer() -> TextSummaryRenderer {
TextSummaryRenderer(model: MockLanguageModel())
}
}

View File

@@ -0,0 +1,246 @@
// ToolExecutionDelegateTests.swift
// SwiftDBAITests
import Foundation
import Testing
@testable import SwiftDBAI
@Suite("DestructiveClassification")
struct DestructiveClassificationTests {
// MARK: - Safe statements
@Test("SELECT is classified as safe")
func selectIsSafe() {
let result = classifySQL("SELECT * FROM users")
#expect(result == .safe)
#expect(!result.requiresConfirmation)
#expect(!result.isMutating)
}
@Test("WITH (CTE) is classified as safe")
func withIsSafe() {
let result = classifySQL("WITH cte AS (SELECT 1) SELECT * FROM cte")
#expect(result == .safe)
}
// MARK: - Mutation statements
@Test("INSERT is classified as mutation")
func insertIsMutation() {
let result = classifySQL("INSERT INTO users (name) VALUES ('Alice')")
#expect(result == .mutation(.insert))
#expect(!result.requiresConfirmation)
#expect(result.isMutating)
}
@Test("UPDATE is classified as mutation")
func updateIsMutation() {
let result = classifySQL("UPDATE users SET name = 'Bob' WHERE id = 1")
#expect(result == .mutation(.update))
#expect(!result.requiresConfirmation)
#expect(result.isMutating)
}
// MARK: - Destructive statements
@Test("DELETE is classified as destructive")
func deleteIsDestructive() {
let result = classifySQL("DELETE FROM users WHERE id = 1")
#expect(result == .destructive(.delete))
#expect(result.requiresConfirmation)
#expect(result.isMutating)
}
@Test("DROP is classified as destructive")
func dropIsDestructive() {
let result = classifySQL("DROP TABLE users")
#expect(result == .destructive(.drop))
#expect(result.requiresConfirmation)
}
@Test("ALTER is classified as destructive")
func alterIsDestructive() {
let result = classifySQL("ALTER TABLE users ADD COLUMN age INTEGER")
#expect(result == .destructive(.alter))
#expect(result.requiresConfirmation)
}
@Test("TRUNCATE is classified as destructive")
func truncateIsDestructive() {
let result = classifySQL("TRUNCATE TABLE users")
#expect(result == .destructive(.truncate))
#expect(result.requiresConfirmation)
}
// MARK: - Case insensitivity
@Test("Classification is case-insensitive")
func caseInsensitive() {
#expect(classifySQL("delete from users") == .destructive(.delete))
#expect(classifySQL("Drop Table foo") == .destructive(.drop))
#expect(classifySQL("select 1") == .safe)
#expect(classifySQL("INSERT into t values (1)") == .mutation(.insert))
}
// MARK: - Leading whitespace
@Test("Classification ignores leading whitespace")
func leadingWhitespace() {
#expect(classifySQL(" \n DELETE FROM users") == .destructive(.delete))
#expect(classifySQL("\t SELECT 1") == .safe)
}
// MARK: - SQLStatementKind
@Test("Destructive kinds are correct")
func destructiveKinds() {
#expect(SQLStatementKind.delete.isDestructive)
#expect(SQLStatementKind.drop.isDestructive)
#expect(SQLStatementKind.alter.isDestructive)
#expect(SQLStatementKind.truncate.isDestructive)
#expect(!SQLStatementKind.select.isDestructive)
#expect(!SQLStatementKind.insert.isDestructive)
#expect(!SQLStatementKind.update.isDestructive)
}
@Test("Mutation kinds are correct")
func mutationKinds() {
#expect(SQLStatementKind.insert.isMutation)
#expect(SQLStatementKind.update.isMutation)
#expect(!SQLStatementKind.select.isMutation)
#expect(!SQLStatementKind.delete.isMutation)
}
}
@Suite("ToolExecutionDelegate")
struct ToolExecutionDelegateProtocolTests {
@Test("AutoApproveDelegate approves all operations")
func autoApprove() async {
let delegate = AutoApproveDelegate()
let context = DestructiveOperationContext(
sql: "DELETE FROM users",
statementKind: .delete,
classification: .destructive(.delete),
description: "Delete all rows from users"
)
let result = await delegate.confirmDestructiveOperation(context)
#expect(result == true)
}
@Test("RejectAllDelegate rejects all operations")
func rejectAll() async {
let delegate = RejectAllDelegate()
let context = DestructiveOperationContext(
sql: "DROP TABLE users",
statementKind: .drop,
classification: .destructive(.drop),
description: "Drop the users table"
)
let result = await delegate.confirmDestructiveOperation(context)
#expect(result == false)
}
@Test("Default delegate implementation rejects destructive operations")
func defaultRejects() async {
struct EmptyDelegate: ToolExecutionDelegate {}
let delegate = EmptyDelegate()
let context = DestructiveOperationContext(
sql: "DELETE FROM users",
statementKind: .delete,
classification: .destructive(.delete),
description: "Delete rows"
)
let result = await delegate.confirmDestructiveOperation(context)
#expect(result == false)
}
}
// MARK: - Tracking Delegate for Integration Tests
/// A delegate that records all calls for verification in tests.
private final class TrackingDelegate: ToolExecutionDelegate, @unchecked Sendable {
private let lock = NSLock()
private var _confirmCalls: [DestructiveOperationContext] = []
private var _willExecuteCalls: [(sql: String, classification: DestructiveClassification)] = []
private var _didExecuteCalls: [(sql: String, success: Bool)] = []
private var _confirmResult: Bool
var confirmCalls: [DestructiveOperationContext] {
lock.withLock { _confirmCalls }
}
var willExecuteCalls: [(sql: String, classification: DestructiveClassification)] {
lock.withLock { _willExecuteCalls }
}
var didExecuteCalls: [(sql: String, success: Bool)] {
lock.withLock { _didExecuteCalls }
}
init(confirmResult: Bool) {
self._confirmResult = confirmResult
}
func confirmDestructiveOperation(_ context: DestructiveOperationContext) async -> Bool {
lock.withLock { _confirmCalls.append(context) }
return _confirmResult
}
func willExecuteSQL(_ sql: String, classification: DestructiveClassification) async {
lock.withLock { _willExecuteCalls.append((sql: sql, classification: classification)) }
}
func didExecuteSQL(_ sql: String, success: Bool) async {
lock.withLock { _didExecuteCalls.append((sql: sql, success: success)) }
}
}
@Suite("ToolExecutionDelegate - ChatEngine Integration")
struct DelegateIntegrationTests {
@Test("DestructiveOperationContext captures target table")
func contextCapturesTable() {
let context = DestructiveOperationContext(
sql: "DELETE FROM users WHERE id = 1",
statementKind: .delete,
classification: .destructive(.delete),
description: "Delete from users",
targetTable: "users"
)
#expect(context.targetTable == "users")
#expect(context.statementKind == .delete)
#expect(context.classification.requiresConfirmation)
}
@Test("classifySQL returns destructive for DELETE")
func classifySQLDestructive() {
let result = classifySQL("DELETE FROM orders WHERE id = 5")
#expect(result == .destructive(.delete))
#expect(result.requiresConfirmation)
}
@Test("classifySQL returns safe for SELECT")
func classifySQLSafe() {
let result = classifySQL("SELECT * FROM users")
#expect(result == .safe)
#expect(!result.requiresConfirmation)
}
@Test("classifySQL returns mutation for INSERT")
func classifySQLMutation() {
let result = classifySQL("INSERT INTO users (name) VALUES ('test')")
#expect(result == .mutation(.insert))
#expect(!result.requiresConfirmation)
}
@Test("DestructiveClassification.isMutating is true for mutations and destructive")
func isMutatingCovers() {
#expect(DestructiveClassification.mutation(.insert).isMutating)
#expect(DestructiveClassification.mutation(.update).isMutating)
#expect(DestructiveClassification.destructive(.delete).isMutating)
#expect(!DestructiveClassification.safe.isMutating)
}
}

View File

@@ -0,0 +1,617 @@
// UnifiedProviderTestHarness.swift
// SwiftDBAI Tests
//
// A unified test harness that validates all seven provider types
// conform to the AnyLanguageModel protocol and produce consistent
// ChatEngine-compatible output. Covers: OpenAI, Anthropic, Gemini,
// OpenAI-Compatible, Ollama, llama.cpp, and on-device (MLX/CoreML).
import AnyLanguageModel
import Foundation
import GRDB
import Testing
@testable import SwiftDBAI
// MARK: - Provider-Simulating Mock Models
/// A mock that records which LanguageModel protocol methods were called,
/// the arguments passed, and returns configurable responses.
/// Used to validate that every provider path through ChatEngine
/// exercises the same protocol surface.
final class ProviderConformanceMock: LanguageModel, @unchecked Sendable {
typealias UnavailableReason = Never
/// Track calls to verify protocol conformance exercised fully.
struct CallRecord: Sendable {
let method: String
let promptDescription: String
let timestamp: Date
}
private let lock = NSLock()
private var _calls: [CallRecord] = []
private let _responses: [String]
private var _callIndex = 0
/// Label for diagnostics.
let providerName: String
var calls: [CallRecord] {
lock.lock()
defer { lock.unlock() }
return _calls
}
init(providerName: String, responses: [String]) {
self.providerName = providerName
self._responses = responses
}
private func nextResponse() -> String {
lock.lock()
defer { lock.unlock() }
let idx = _callIndex
_callIndex += 1
return idx < _responses.count ? _responses[idx] : "fallback response"
}
private func recordCall(method: String, prompt: String) {
lock.lock()
_calls.append(CallRecord(method: method, promptDescription: prompt, timestamp: Date()))
lock.unlock()
}
func respond<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
includeSchemaInPrompt: Bool,
options: GenerationOptions
) async throws -> LanguageModelSession.Response<Content> where Content: Generable {
recordCall(method: "respond", prompt: prompt.description)
let text = nextResponse()
let rawContent = GeneratedContent(kind: .string(text))
let content = try Content(rawContent)
return LanguageModelSession.Response(
content: content,
rawContent: rawContent,
transcriptEntries: [][...]
)
}
func streamResponse<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
includeSchemaInPrompt: Bool,
options: GenerationOptions
) -> sending LanguageModelSession.ResponseStream<Content> where Content: Generable {
recordCall(method: "streamResponse", prompt: prompt.description)
let text = nextResponse()
let rawContent = GeneratedContent(kind: .string(text))
let content = try! Content(rawContent)
return LanguageModelSession.ResponseStream(content: content, rawContent: rawContent)
}
}
// MARK: - Test Database Helper
/// Creates a minimal in-memory database for provider integration tests.
private func makeProviderTestDatabase() throws -> DatabaseQueue {
let db = try DatabaseQueue(path: ":memory:")
try db.write { db in
try db.execute(sql: """
CREATE TABLE products (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
price REAL NOT NULL,
category TEXT NOT NULL
)
""")
try db.execute(sql: """
INSERT INTO products (name, price, category) VALUES
('Widget', 9.99, 'tools'),
('Gadget', 24.99, 'electronics'),
('Doohickey', 4.50, 'tools')
""")
}
return db
}
// MARK: - Unified Provider Test Harness
@Suite("Unified Provider Test Harness")
struct UnifiedProviderTestHarness {
// MARK: - Provider Configuration Enumeration
/// All seven provider types that SwiftDBAI supports.
enum TestedProvider: String, CaseIterable {
case openAI
case anthropic
case gemini
case openAICompatible
case ollama
case llamaCpp
case onDevice
}
/// Creates a ProviderConformanceMock simulating each provider type.
private func makeMock(for provider: TestedProvider, responses: [String]) -> ProviderConformanceMock {
ProviderConformanceMock(providerName: provider.rawValue, responses: responses)
}
// MARK: - 1. Protocol Conformance All Providers Are LanguageModel
@Test("All provider types produce instances conforming to LanguageModel protocol")
func allProvidersConformToLanguageModel() {
// Cloud providers via ProviderConfiguration.makeModel()
let openAI = ProviderConfiguration.openAI(apiKey: "test-key", model: "gpt-4o").makeModel()
let anthropic = ProviderConfiguration.anthropic(apiKey: "test-key", model: "claude-sonnet-4-20250514").makeModel()
let gemini = ProviderConfiguration.gemini(apiKey: "test-key", model: "gemini-2.0-flash").makeModel()
let openAICompatible = ProviderConfiguration.openAICompatible(
apiKey: "test-key",
model: "local-model",
baseURL: URL(string: "http://localhost:8080/v1/")!
).makeModel()
let ollama = ProviderConfiguration.ollama(model: "llama3.2").makeModel()
let llamaCpp = ProviderConfiguration.llamaCpp(model: "default").makeModel()
// On-device MLX (wraps as openAICompatible internally)
let onDeviceMLX = ProviderConfiguration.onDeviceMLX(
MLXProviderConfiguration(modelId: "test-model")
).makeModel()
// Verify all are LanguageModel
let models: [(String, any LanguageModel)] = [
("OpenAI", openAI),
("Anthropic", anthropic),
("Gemini", gemini),
("OpenAI-Compatible", openAICompatible),
("Ollama", ollama),
("llama.cpp", llamaCpp),
("On-Device MLX", onDeviceMLX),
]
for (name, model) in models {
// Protocol conformance is compile-time, but we verify isAvailable works
#expect(model.isAvailable, "\(name) model should report as available")
}
}
@Test("All provider configurations produce correct concrete model types")
func providerConfigurationsProduceCorrectTypes() {
let openAI = ProviderConfiguration.openAI(apiKey: "k", model: "m").makeModel()
#expect(openAI is OpenAILanguageModel, "OpenAI config should produce OpenAILanguageModel")
let anthropic = ProviderConfiguration.anthropic(apiKey: "k", model: "m").makeModel()
#expect(anthropic is AnthropicLanguageModel, "Anthropic config should produce AnthropicLanguageModel")
let gemini = ProviderConfiguration.gemini(apiKey: "k", model: "m").makeModel()
#expect(gemini is GeminiLanguageModel, "Gemini config should produce GeminiLanguageModel")
let openAICompat = ProviderConfiguration.openAICompatible(
apiKey: "k", model: "m", baseURL: URL(string: "http://localhost:1234")!
).makeModel()
#expect(openAICompat is OpenAILanguageModel, "OpenAI-Compatible config should produce OpenAILanguageModel")
let ollama = ProviderConfiguration.ollama(model: "m").makeModel()
#expect(ollama is OllamaLanguageModel, "Ollama config should produce OllamaLanguageModel")
let llamaCpp = ProviderConfiguration.llamaCpp(model: "m").makeModel()
#expect(llamaCpp is OpenAILanguageModel, "llama.cpp config should produce OpenAILanguageModel (OpenAI-compatible)")
// On-device uses OpenAILanguageModel internally as a wrapper
let onDevice = ProviderConfiguration.onDeviceMLX(
MLXProviderConfiguration(modelId: "test")
).makeModel()
#expect(onDevice is OpenAILanguageModel, "On-device MLX config should produce OpenAILanguageModel wrapper")
}
// MARK: - 2. Consistent ChatEngine-Compatible Output
@Test("Every provider mock produces valid ChatEngine responses for SELECT queries",
arguments: TestedProvider.allCases)
func providerProducesValidChatEngineResponse(provider: TestedProvider) async throws {
let db = try makeProviderTestDatabase()
let mock = makeMock(for: provider, responses: [
"SELECT COUNT(*) FROM products", // SQL generation
"There are 3 products in the database.", // Summary (fallback)
])
let engine = ChatEngine(database: db, model: mock)
let response = try await engine.send("How many products are there?")
// All providers must produce:
// 1. Non-empty summary
#expect(!response.summary.isEmpty, "\(provider.rawValue): summary must not be empty")
// 2. Valid SQL that was executed
#expect(response.sql == "SELECT COUNT(*) FROM products",
"\(provider.rawValue): SQL must match generated query")
// 3. A QueryResult with data
#expect(response.queryResult != nil, "\(provider.rawValue): queryResult must exist")
#expect(response.queryResult?.rowCount == 1, "\(provider.rawValue): should have 1 row for COUNT")
}
@Test("Every provider mock produces valid ChatEngine responses for multi-row SELECT",
arguments: TestedProvider.allCases)
func providerProducesMultiRowResponse(provider: TestedProvider) async throws {
let db = try makeProviderTestDatabase()
let mock = makeMock(for: provider, responses: [
"SELECT name, price FROM products ORDER BY price DESC",
"Here are the products sorted by price.",
])
let engine = ChatEngine(database: db, model: mock)
let response = try await engine.send("List products by price")
#expect(response.queryResult != nil, "\(provider.rawValue): queryResult must exist")
#expect(response.queryResult?.rowCount == 3, "\(provider.rawValue): should return all 3 products")
#expect(response.queryResult?.columns.contains("name") == true,
"\(provider.rawValue): columns must include 'name'")
#expect(response.queryResult?.columns.contains("price") == true,
"\(provider.rawValue): columns must include 'price'")
}
// MARK: - 3. Consistent LanguageModelSession Integration
@Test("Every provider mock works through LanguageModelSession.respond(to:)",
arguments: TestedProvider.allCases)
func providerWorksWithSession(provider: TestedProvider) async throws {
let mock = makeMock(for: provider, responses: [
"SELECT 1 AS test",
])
let session = LanguageModelSession(
model: mock,
instructions: "You are a SQL assistant."
)
let response = try await session.respond(to: "Generate a test query")
// Verify the response content is the expected string
#expect(response.content == "SELECT 1 AS test",
"\(provider.rawValue): session response should match mock output")
// Verify the mock received the call
#expect(mock.calls.count == 1, "\(provider.rawValue): should have exactly 1 call")
#expect(mock.calls.first?.method == "respond",
"\(provider.rawValue): should call respond method")
}
@Test("Every provider mock works through LanguageModelSession.streamResponse(to:)",
arguments: TestedProvider.allCases)
func providerWorksWithStreamSession(provider: TestedProvider) async throws {
let mock = makeMock(for: provider, responses: [
"SELECT 42 AS answer",
])
let session = LanguageModelSession(
model: mock,
instructions: "You are a SQL assistant."
)
let stream = session.streamResponse(to: "Give me a number")
let collected = try await stream.collect()
#expect(collected.content == "SELECT 42 AS answer",
"\(provider.rawValue): stream collected response should match mock output")
#expect(mock.calls.count == 1, "\(provider.rawValue): should have exactly 1 call")
#expect(mock.calls.first?.method == "streamResponse",
"\(provider.rawValue): should call streamResponse method")
}
// MARK: - 4. Schema Introspection Works Identically Across Providers
@Test("Schema introspection returns same schema regardless of provider",
arguments: TestedProvider.allCases)
func schemaIntrospectionIsProviderAgnostic(provider: TestedProvider) async throws {
let db = try makeProviderTestDatabase()
let mock = makeMock(for: provider, responses: ["SELECT 1"])
let engine = ChatEngine(database: db, model: mock)
let schema = try await engine.prepareSchema()
#expect(schema.tableNames.contains("products"),
"\(provider.rawValue): schema must include 'products' table")
#expect(schema.tableNames.count == 1,
"\(provider.rawValue): should have exactly 1 table")
let table = schema.tables["products"]
#expect(table != nil, "\(provider.rawValue): must find products table")
#expect(table?.columns.count == 4,
"\(provider.rawValue): products table must have 4 columns")
}
// MARK: - 5. Error Handling Consistency
@Test("All providers handle empty schema consistently",
arguments: TestedProvider.allCases)
func emptySchemaHandledConsistently(provider: TestedProvider) async throws {
let db = try DatabaseQueue(path: ":memory:")
let mock = makeMock(for: provider, responses: ["SELECT 1"])
let engine = ChatEngine(database: db, model: mock)
do {
_ = try await engine.send("Show me data")
Issue.record("\(provider.rawValue): should throw for empty schema")
} catch let error as SwiftDBAIError {
#expect(error == .emptySchema,
"\(provider.rawValue): must throw .emptySchema for database with no tables")
}
}
@Test("All providers reject disallowed SQL operations consistently",
arguments: TestedProvider.allCases)
func disallowedSQLRejectedConsistently(provider: TestedProvider) async throws {
let db = try makeProviderTestDatabase()
let mock = makeMock(for: provider, responses: [
"DELETE FROM products WHERE id = 1",
])
// Default allowlist is readOnly (SELECT only)
let engine = ChatEngine(database: db, model: mock)
do {
_ = try await engine.send("Delete the first product")
Issue.record("\(provider.rawValue): should reject DELETE when allowlist is readOnly")
} catch {
// All providers must trigger the same error path for disallowed operations
#expect(error is SwiftDBAIError,
"\(provider.rawValue): error must be SwiftDBAIError")
}
}
// MARK: - 6. Conversation History Consistency
@Test("Conversation history works identically for all providers",
arguments: TestedProvider.allCases)
func conversationHistoryConsistent(provider: TestedProvider) async throws {
let db = try makeProviderTestDatabase()
// ChatEngine calls LLM for SQL generation, then TextSummaryRenderer
// may call LLM for summarization. For aggregate queries (COUNT, AVG),
// TextSummaryRenderer uses a template and skips the LLM call.
// So the mock sequence is: SQL1, SQL2 (each followed by template summary).
let mock = makeMock(for: provider, responses: [
"SELECT COUNT(*) FROM products",
"SELECT AVG(price) FROM products",
])
let engine = ChatEngine(database: db, model: mock)
_ = try await engine.send("How many products?")
_ = try await engine.send("What is the average price?")
let messages = engine.messages
#expect(messages.count == 4,
"\(provider.rawValue): should have 4 messages (2 user + 2 assistant)")
#expect(messages[0].role == .user, "\(provider.rawValue): first message should be user")
#expect(messages[1].role == .assistant, "\(provider.rawValue): second message should be assistant")
#expect(messages[2].role == .user, "\(provider.rawValue): third message should be user")
#expect(messages[3].role == .assistant, "\(provider.rawValue): fourth message should be assistant")
// Both assistant messages must have SQL
#expect(messages[1].sql != nil, "\(provider.rawValue): first response must have SQL")
#expect(messages[3].sql != nil, "\(provider.rawValue): second response must have SQL")
}
// MARK: - 7. ProviderConfiguration Roundtrip
@Test("All cloud provider configurations roundtrip through makeModel()")
func allCloudProvidersRoundtrip() {
let configs: [(String, ProviderConfiguration)] = [
("OpenAI", .openAI(apiKey: "sk-test", model: "gpt-4o")),
("OpenAI Responses", .openAI(apiKey: "sk-test", model: "gpt-4o", variant: .responses)),
("Anthropic", .anthropic(apiKey: "sk-ant-test", model: "claude-sonnet-4-20250514")),
("Anthropic+version", .anthropic(apiKey: "sk-ant-test", model: "claude-sonnet-4-20250514", apiVersion: "2024-01-01")),
("Anthropic+betas", .anthropic(apiKey: "sk-ant-test", model: "claude-sonnet-4-20250514", betas: ["computer-use"])),
("Gemini", .gemini(apiKey: "AIza-test", model: "gemini-2.0-flash")),
("Gemini+version", .gemini(apiKey: "AIza-test", model: "gemini-2.0-flash", apiVersion: "v1")),
("OpenAI-Compatible", .openAICompatible(
apiKey: "key", model: "model", baseURL: URL(string: "http://localhost:1234")!
)),
("Ollama", .ollama(model: "llama3.2")),
("Ollama+custom URL", .ollama(model: "qwen2.5", baseURL: URL(string: "http://192.168.1.100:11434")!)),
("llama.cpp", .llamaCpp(model: "default")),
("llama.cpp+custom", .llamaCpp(model: "my-model", baseURL: URL(string: "http://localhost:9090")!)),
]
for (name, config) in configs {
let model = config.makeModel()
#expect(model.isAvailable, "\(name): model must be available after makeModel()")
}
}
@Test("On-device provider configurations produce valid models")
func onDeviceProvidersRoundtrip() {
let mlxConfigs: [MLXProviderConfiguration] = [
.llama3_2_3B(),
.qwen2_5_coder_3B(),
.phi3_5_mini(),
MLXProviderConfiguration(modelId: "custom-model", temperature: 0.2),
]
for mlxConfig in mlxConfigs {
let providerConfig = ProviderConfiguration.onDeviceMLX(mlxConfig)
let model = providerConfig.makeModel()
#expect(model.isAvailable, "MLX model '\(mlxConfig.modelId)' must be available")
}
}
// MARK: - 8. Write Operation Allowlist Consistency
@Test("Write operations require explicit opt-in for all providers",
arguments: TestedProvider.allCases)
func writeOperationsRequireOptIn(provider: TestedProvider) async throws {
let db = try makeProviderTestDatabase()
// Mock returns an INSERT statement
let mock = makeMock(for: provider, responses: [
"INSERT INTO products (name, price, category) VALUES ('New', 1.00, 'misc')",
])
// readOnly allowlist (default)
let readOnlyEngine = ChatEngine(database: db, model: mock)
do {
_ = try await readOnlyEngine.send("Add a new product")
Issue.record("\(provider.rawValue): INSERT should be rejected with readOnly allowlist")
} catch {
#expect(error is SwiftDBAIError,
"\(provider.rawValue): must throw SwiftDBAIError for disallowed INSERT")
}
}
@Test("Allowed write operations work for all providers",
arguments: TestedProvider.allCases)
func allowedWriteOperationsWork(provider: TestedProvider) async throws {
let db = try makeProviderTestDatabase()
let mock = makeMock(for: provider, responses: [
"INSERT INTO products (name, price, category) VALUES ('NewItem', 1.00, 'misc')",
"Successfully added 1 product.",
])
let engine = ChatEngine(
database: db,
model: mock,
allowlist: .standard
)
let response = try await engine.send("Add a product called NewItem")
#expect(response.sql?.uppercased().hasPrefix("INSERT") == true,
"\(provider.rawValue): SQL should be an INSERT")
}
// MARK: - 9. Response Format Consistency
@Test("ChatResponse structure is identical regardless of provider",
arguments: TestedProvider.allCases)
func responseStructureConsistent(provider: TestedProvider) async throws {
let db = try makeProviderTestDatabase()
let mock = makeMock(for: provider, responses: [
"SELECT name, price, category FROM products",
"Found 3 products across 2 categories.",
])
let engine = ChatEngine(database: db, model: mock)
let response = try await engine.send("Show all products")
// ChatResponse must always have these properties populated
#expect(response.summary.count > 0,
"\(provider.rawValue): summary must be non-empty")
#expect(response.sql != nil,
"\(provider.rawValue): sql must be present")
#expect(response.queryResult != nil,
"\(provider.rawValue): queryResult must be present")
// QueryResult structure must match the query
let qr = response.queryResult!
#expect(qr.columns == ["name", "price", "category"],
"\(provider.rawValue): columns must match SELECT clause")
#expect(qr.rowCount == 3,
"\(provider.rawValue): must return all rows")
#expect(qr.sql == "SELECT name, price, category FROM products",
"\(provider.rawValue): QueryResult.sql must match executed SQL")
#expect(qr.executionTime >= 0,
"\(provider.rawValue): execution time must be non-negative")
}
// MARK: - 10. Provider Enum Completeness
@Test("TestedProvider covers all ProviderConfiguration.Provider cases plus on-device")
func testedProviderCoversAllCases() {
// ProviderConfiguration.Provider has 6 cases
let configProviderCount = ProviderConfiguration.Provider.allCases.count
#expect(configProviderCount == 6, "ProviderConfiguration.Provider should have 6 cases")
// TestedProvider adds on-device for 7 total
#expect(TestedProvider.allCases.count == 7, "TestedProvider should cover all 7 provider types")
// Verify 1:1 mapping for the config providers
let configNames = Set(ProviderConfiguration.Provider.allCases.map(\.rawValue))
for tested in TestedProvider.allCases where tested != .onDevice {
#expect(configNames.contains(tested.rawValue),
"\(tested.rawValue) must map to a ProviderConfiguration.Provider case")
}
}
// MARK: - 11. ChatEngine Convenience Init Consistency
@Test("ChatEngine convenience init with ProviderConfiguration works for all cloud providers")
func chatEngineConvenienceInitWorks() throws {
let db = try makeProviderTestDatabase()
let configs: [ProviderConfiguration] = [
.openAI(apiKey: "test", model: "gpt-4o"),
.anthropic(apiKey: "test", model: "claude-sonnet-4-20250514"),
.gemini(apiKey: "test", model: "gemini-2.0-flash"),
.openAICompatible(apiKey: "test", model: "m", baseURL: URL(string: "http://localhost:1234")!),
.ollama(model: "llama3.2"),
.llamaCpp(model: "default"),
]
for config in configs {
// This should not throw it only creates the engine, doesn't call the LLM
let engine = ChatEngine(database: db, provider: config)
#expect(engine.tableCount == nil, "tableCount should be nil before first query")
}
}
// MARK: - 12. Availability Reporting
@Test("All real provider models report available by default")
func allModelsReportAvailable() {
let models: [(String, any LanguageModel)] = [
("OpenAI", OpenAILanguageModel(apiKey: "k", model: "m")),
("Anthropic", AnthropicLanguageModel(apiKey: "k", model: "m")),
("Gemini", GeminiLanguageModel(apiKey: "k", model: "m")),
("Ollama", OllamaLanguageModel(model: "m")),
]
for (name, model) in models {
#expect(model.isAvailable, "\(name) should be available by default")
}
}
// MARK: - 13. On-Device Pipeline Status
@Test("On-device inference pipeline starts in notLoaded state")
func onDevicePipelineInitialState() {
let mlxPipeline = OnDeviceInferencePipeline(
mlxConfiguration: .llama3_2_3B()
)
#expect(mlxPipeline.status == .notLoaded)
#expect(mlxPipeline.providerType == .mlx)
let coreMLPipeline = OnDeviceInferencePipeline(
coreMLConfiguration: CoreMLProviderConfiguration(
modelURL: URL(fileURLWithPath: "/tmp/test.mlmodelc")
)
)
#expect(coreMLPipeline.status == .notLoaded)
#expect(coreMLPipeline.providerType == .coreML)
}
@Test("On-device SQL generation hints are populated for both provider types")
func onDeviceSQLHints() {
let mlxPipeline = OnDeviceInferencePipeline(mlxConfiguration: .llama3_2_3B())
let mlxHints = mlxPipeline.recommendedSQLGenerationHints
#expect(mlxHints.maxTokens > 0)
#expect(mlxHints.temperature >= 0)
#expect(!mlxHints.systemPromptSuffix.isEmpty)
let coreMLPipeline = OnDeviceInferencePipeline(
coreMLConfiguration: CoreMLProviderConfiguration(
modelURL: URL(fileURLWithPath: "/tmp/test.mlmodelc")
)
)
let coreMLHints = coreMLPipeline.recommendedSQLGenerationHints
#expect(coreMLHints.maxTokens > 0)
#expect(coreMLHints.temperature >= 0)
#expect(!coreMLHints.systemPromptSuffix.isEmpty)
}
}

View File

@@ -0,0 +1,489 @@
// ViewInspectorTests.swift
// SwiftDBAITests
//
// ViewInspector-based tests for SwiftDBAI's SwiftUI views.
// Tests content and structure of MessageBubbleView, ErrorMessageView,
// ScrollableDataTableView, ChatViewConfiguration, and BarChartView.
import SwiftUI
import Testing
import ViewInspector
@testable import SwiftDBAI
// MARK: - Test Helpers
/// Helper to build a DataTable for tests.
private func makeDataTable(
columnNames: [String] = ["id", "name", "score"],
inferredTypes: [DataTable.InferredType] = [.integer, .text, .real],
rowCount: Int = 3
) -> DataTable {
let columns = columnNames.enumerated().map { idx, name in
DataTable.Column(name: name, index: idx, inferredType: inferredTypes[idx])
}
let rows = (0..<rowCount).map { i in
DataTable.Row(
id: i,
values: [
.integer(Int64(i + 1)),
.text("Item \(i + 1)"),
.real(Double(i) * 10.5),
],
columnNames: columnNames
)
}
return DataTable(columns: columns, rows: rows, sql: "SELECT * FROM test", executionTime: 0.015)
}
/// Helper to build a QueryResult for tests.
private func makeQueryResult(
columns: [String] = ["id", "name"],
rowCount: Int = 2
) -> QueryResult {
let rows: [[String: QueryResult.Value]] = (0..<rowCount).map { i in
["id": .integer(Int64(i + 1)), "name": .text("User \(i + 1)")]
}
return QueryResult(
columns: columns,
rows: rows,
sql: "SELECT id, name FROM users",
executionTime: 0.01
)
}
// MARK: - MessageBubbleView Tests
@Suite("MessageBubbleView - ViewInspector")
struct MessageBubbleViewInspectorTests {
@Test("User message bubble renders the user text")
@MainActor
func userMessageShowsText() throws {
let message = ChatMessage(role: .user, content: "Show me all users")
let view = MessageBubbleView(message: message)
let inspected = try view.inspect()
let found = try inspected.find(text: "Show me all users")
#expect(try found.string() == "Show me all users")
}
@Test("Assistant message renders summary text")
@MainActor
func assistantMessageShowsSummary() throws {
let message = ChatMessage(
role: .assistant,
content: "Found 42 users in the database."
)
let view = MessageBubbleView(message: message)
let inspected = try view.inspect()
let found = try inspected.find(text: "Found 42 users in the database.")
#expect(try found.string() == "Found 42 users in the database.")
}
@Test("Assistant message with SQL shows disclosure group")
@MainActor
func assistantMessageWithSQLShowsDisclosure() throws {
let message = ChatMessage(
role: .assistant,
content: "Here are the results.",
sql: "SELECT * FROM users"
)
let view = MessageBubbleView(message: message)
let inspected = try view.inspect()
// The SQL disclosure contains "SQL Query" label text
let sqlLabel = try inspected.find(text: "SQL Query")
#expect(try sqlLabel.string() == "SQL Query")
}
@Test("Error message renders error text")
@MainActor
func errorMessageShowsText() throws {
let error = SwiftDBAIError.databaseError(reason: "connection lost")
let message = ChatMessage(
role: .error,
content: error.localizedDescription,
error: error
)
let view = MessageBubbleView(message: message)
let inspected = try view.inspect()
// The error message text should be present
let found = try inspected.find(text: error.localizedDescription)
#expect(try found.string() == error.localizedDescription)
}
}
// MARK: - ErrorMessageView Tests
@Suite("ErrorMessageView - ViewInspector")
struct ErrorMessageViewInspectorTests {
@Test("Safety error shows Operation Blocked title")
@MainActor
func safetyErrorShowsTitle() throws {
let error = SwiftDBAIError.dangerousOperationBlocked(keyword: "DROP")
let view = ErrorMessageView(error: error)
let inspected = try view.inspect()
let title = try inspected.find(text: "Operation Blocked")
#expect(try title.string() == "Operation Blocked")
}
@Test("Safety error shows error message")
@MainActor
func safetyErrorShowsMessage() throws {
let error = SwiftDBAIError.operationNotAllowed(operation: "DELETE")
let view = ErrorMessageView(error: error)
let inspected = try view.inspect()
let msg = try inspected.find(text: error.localizedDescription)
#expect(try msg.string() == error.localizedDescription)
}
@Test("LLM response unparseable error shows recovery hint")
@MainActor
func parsingErrorShowsRecoveryHint() throws {
let error = SwiftDBAIError.llmResponseUnparseable(response: "gibberish")
let view = ErrorMessageView(error: error)
let inspected = try view.inspect()
let hint = try inspected.find(text: "Try rephrasing your question.")
#expect(try hint.string() == "Try rephrasing your question.")
}
@Test("Database error shows Database Error title")
@MainActor
func databaseErrorShowsTitle() throws {
let error = SwiftDBAIError.databaseError(reason: "disk full")
let view = ErrorMessageView(error: error)
let inspected = try view.inspect()
let title = try inspected.find(text: "Database Error")
#expect(try title.string() == "Database Error")
}
@Test("LLM timeout shows AI Provider Error title and recovery hint")
@MainActor
func timeoutErrorShowsTitleAndHint() throws {
let error = SwiftDBAIError.llmTimeout(seconds: 30)
let view = ErrorMessageView(error: error)
let inspected = try view.inspect()
let title = try inspected.find(text: "AI Provider Error")
#expect(try title.string() == "AI Provider Error")
let hint = try inspected.find(text: "The AI took too long. Try a simpler question.")
#expect(try hint.string() == "The AI took too long. Try a simpler question.")
}
@Test("LLM failure error shows AI Provider Error title")
@MainActor
func llmFailureShowsTitle() throws {
let error = SwiftDBAIError.llmFailure(reason: "rate limited")
let view = ErrorMessageView(error: error)
let inspected = try view.inspect()
let title = try inspected.find(text: "AI Provider Error")
#expect(try title.string() == "AI Provider Error")
}
@Test("Generic error from plain string shows message text")
@MainActor
func genericStringErrorShowsMessage() throws {
let view = ErrorMessageView(message: "Something went wrong")
let inspected = try view.inspect()
let msg = try inspected.find(text: "Something went wrong")
#expect(try msg.string() == "Something went wrong")
}
@Test("Recoverable error with retry shows retry button")
@MainActor
func recoverableErrorShowsRetryButton() throws {
let error = SwiftDBAIError.noSQLGenerated
let view = ErrorMessageView(error: error, onRetry: { })
let inspected = try view.inspect()
let button = try inspected.find(text: "Try Again")
#expect(try button.string() == "Try Again")
}
@Test("LLM error with retry shows Retry button")
@MainActor
func llmErrorShowsRetryButton() throws {
let error = SwiftDBAIError.llmFailure(reason: "timeout")
let view = ErrorMessageView(error: error, onRetry: { })
let inspected = try view.inspect()
let button = try inspected.find(text: "Retry")
#expect(try button.string() == "Retry")
}
@Test("Query timed out shows Database Error title and recovery hint")
@MainActor
func queryTimedOutShowsTitleAndHint() throws {
let error = SwiftDBAIError.queryTimedOut(seconds: 10)
let view = ErrorMessageView(error: error)
let inspected = try view.inspect()
let title = try inspected.find(text: "Database Error")
#expect(try title.string() == "Database Error")
let hint = try inspected.find(text: "Try a simpler query or add database indexes.")
#expect(try hint.string() == "Try a simpler query or add database indexes.")
}
@Test("Empty schema error shows Database Error title and recovery hint")
@MainActor
func emptySchemaShowsTitleAndHint() throws {
let error = SwiftDBAIError.emptySchema
let view = ErrorMessageView(error: error)
let inspected = try view.inspect()
let title = try inspected.find(text: "Database Error")
#expect(try title.string() == "Database Error")
let hint = try inspected.find(text: "Add some tables to your database first.")
#expect(try hint.string() == "Add some tables to your database first.")
}
@Test("Configuration error shows Configuration Error title")
@MainActor
func configurationErrorShowsTitle() throws {
let error = SwiftDBAIError.configurationError(reason: "missing API key")
let view = ErrorMessageView(error: error)
let inspected = try view.inspect()
let title = try inspected.find(text: "Configuration Error")
#expect(try title.string() == "Configuration Error")
}
}
// MARK: - ChatViewConfiguration Tests
@Suite("ChatViewConfiguration - ViewInspector")
struct ChatViewConfigurationInspectorTests {
@Test("Dark configuration has expected color values")
func darkConfigHasCorrectColors() {
let dark = ChatViewConfiguration.dark
#expect(dark.userTextColor == .white)
#expect(dark.backgroundColor == .black)
#expect(dark.accentColor == .blue)
}
@Test("Default configuration has expected placeholder and empty state text")
func defaultConfigHasExpectedText() {
let config = ChatViewConfiguration.default
#expect(config.inputPlaceholder == "Ask about your data\u{2026}")
#expect(config.emptyStateTitle == "Ask a question about your data")
#expect(config.emptyStateSubtitle == "Try something like \"How many records are in the database?\"")
}
@Test("Custom inputPlaceholder propagates through environment")
@MainActor
func customPlaceholderInEnvironment() throws {
var config = ChatViewConfiguration.default
config.inputPlaceholder = "Ask about recipes..."
config.emptyStateTitle = "Recipe Search"
// Verify the configuration values are set correctly
#expect(config.inputPlaceholder == "Ask about recipes...")
#expect(config.emptyStateTitle == "Recipe Search")
}
@Test("Compact configuration has smaller padding and hidden SQL disclosure")
func compactConfigProperties() {
let compact = ChatViewConfiguration.compact
#expect(compact.messagePadding == 8)
#expect(compact.bubbleCornerRadius == 10)
#expect(compact.showSQLDisclosure == false)
#expect(compact.showTimestamps == false)
}
@Test("Dark configuration userBubbleColor is dark gray")
func darkConfigUserBubble() {
let dark = ChatViewConfiguration.dark
// Dark config uses Color(white: 0.25) for user bubble
#expect(dark.userBubbleColor == Color(white: 0.25))
#expect(dark.assistantBubbleColor == Color(white: 0.15))
#expect(dark.inputBarBackgroundColor == Color(white: 0.1))
}
@Test("ErrorMessageView uses environment config for database error color")
@MainActor
func errorViewUsesDarkConfig() throws {
let error = SwiftDBAIError.databaseError(reason: "test error")
let view = ErrorMessageView(error: error)
.chatViewConfiguration(.dark)
let inspected = try view.inspect()
// Should still render the error message text
let msg = try inspected.find(text: error.localizedDescription)
#expect(try msg.string() == error.localizedDescription)
}
}
// MARK: - ScrollableDataTableView Tests
@Suite("ScrollableDataTableView - ViewInspector")
struct ScrollableDataTableViewInspectorTests {
@Test("Column headers appear in the view")
@MainActor
func columnHeadersAppear() throws {
let table = makeDataTable()
let view = ScrollableDataTableView(dataTable: table)
let inspected = try view.inspect()
// Each column header should be present
let idHeader = try inspected.find(text: "id")
#expect(try idHeader.string() == "id")
let nameHeader = try inspected.find(text: "name")
#expect(try nameHeader.string() == "name")
let scoreHeader = try inspected.find(text: "score")
#expect(try scoreHeader.string() == "score")
}
@Test("Row count text appears in footer")
@MainActor
func rowCountFooterAppears() throws {
let table = makeDataTable(rowCount: 5)
let view = ScrollableDataTableView(dataTable: table, showFooter: true)
let inspected = try view.inspect()
let footer = try inspected.find(text: "5 rows")
#expect(try footer.string() == "5 rows")
}
@Test("Single row shows singular 'row' text")
@MainActor
func singleRowFooter() throws {
let table = makeDataTable(rowCount: 1)
let view = ScrollableDataTableView(dataTable: table, showFooter: true)
let inspected = try view.inspect()
let footer = try inspected.find(text: "1 row")
#expect(try footer.string() == "1 row")
}
@Test("Empty table shows No results text")
@MainActor
func emptyTableShowsNoResults() throws {
let table = DataTable(columns: [], rows: [], sql: "", executionTime: 0)
let view = ScrollableDataTableView(dataTable: table)
let inspected = try view.inspect()
let empty = try inspected.find(text: "No results")
#expect(try empty.string() == "No results")
}
@Test("Execution time appears in footer when > 0")
@MainActor
func executionTimeAppearsInFooter() throws {
let columns = [DataTable.Column(name: "val", index: 0, inferredType: .integer)]
let rows = [DataTable.Row(id: 0, values: [.integer(1)], columnNames: ["val"])]
let table = DataTable(columns: columns, rows: rows, sql: "SELECT 1", executionTime: 0.023)
let view = ScrollableDataTableView(dataTable: table, showFooter: true)
let inspected = try view.inspect()
let timing = try inspected.find(text: "23.0 ms")
#expect(try timing.string() == "23.0 ms")
}
}
// MARK: - BarChartView Tests
@Suite("BarChartView - ViewInspector")
struct BarChartViewInspectorTests {
@Test("BarChartView with title renders the title text")
@MainActor
func barChartShowsTitle() throws {
let columns: [DataTable.Column] = [
.init(name: "dept", index: 0, inferredType: .text),
.init(name: "revenue", index: 1, inferredType: .real),
]
let rows: [DataTable.Row] = [
.init(id: 0, values: [.text("Sales"), .real(100.0)], columnNames: ["dept", "revenue"]),
.init(id: 1, values: [.text("Eng"), .real(200.0)], columnNames: ["dept", "revenue"]),
]
let table = DataTable(columns: columns, rows: rows)
let view = BarChartView(
dataTable: table,
categoryColumn: "dept",
valueColumn: "revenue",
title: "Revenue by Department"
)
let inspected = try view.inspect()
let title = try inspected.find(text: "Revenue by Department")
#expect(try title.string() == "Revenue by Department")
}
@Test("BarChartView with empty data shows empty state")
@MainActor
func barChartEmptyState() throws {
let table = DataTable(columns: [], rows: [])
let view = BarChartView(
dataTable: table,
categoryColumn: "x",
valueColumn: "y"
)
let inspected = try view.inspect()
let empty = try inspected.find(text: "No chartable data")
#expect(try empty.string() == "No chartable data")
}
@Test("BarChartView with truncated data shows truncation notice")
@MainActor
func barChartTruncationNotice() throws {
let columns: [DataTable.Column] = [
.init(name: "cat", index: 0, inferredType: .text),
.init(name: "val", index: 1, inferredType: .real),
]
// Create 10 rows but set maxBars to 3
let rows: [DataTable.Row] = (0..<10).map { i in
.init(id: i, values: [.text("Cat \(i)"), .real(Double(i) * 10)], columnNames: ["cat", "val"])
}
let table = DataTable(columns: columns, rows: rows)
let view = BarChartView(
dataTable: table,
categoryColumn: "cat",
valueColumn: "val",
maxBars: 3
)
let inspected = try view.inspect()
let notice = try inspected.find(text: "Showing 3 of 10 categories")
#expect(try notice.string() == "Showing 3 of 10 categories")
}
}
// MARK: - PieChartView Tests
@Suite("PieChartView - ViewInspector")
struct PieChartViewInspectorTests {
@Test("PieChartView with title renders the title text")
@MainActor
func pieChartShowsTitle() throws {
let columns: [DataTable.Column] = [
.init(name: "status", index: 0, inferredType: .text),
.init(name: "count", index: 1, inferredType: .integer),
]
let rows: [DataTable.Row] = [
.init(id: 0, values: [.text("Active"), .integer(40)], columnNames: ["status", "count"]),
.init(id: 1, values: [.text("Inactive"), .integer(10)], columnNames: ["status", "count"]),
]
let table = DataTable(columns: columns, rows: rows)
let view = PieChartView(
dataTable: table,
categoryColumn: "status",
valueColumn: "count",
title: "Users by Status"
)
let inspected = try view.inspect()
let title = try inspected.find(text: "Users by Status")
#expect(try title.string() == "Users by Status")
}
@Test("PieChartView with empty data shows empty state")
@MainActor
func pieChartEmptyState() throws {
let table = DataTable(columns: [], rows: [])
let view = PieChartView(
dataTable: table,
categoryColumn: "x",
valueColumn: "y"
)
let inspected = try view.inspect()
let empty = try inspected.find(text: "No chartable data")
#expect(try empty.string() == "No chartable data")
}
}

61
screenshots/GALLERY.md Normal file
View File

@@ -0,0 +1,61 @@
# SwiftDBAI Screenshots
## Query Results
Bar chart and data table from a natural language query against a GitHub stars database.
![Results with chart](results-chart.png)
## Customization
### Custom Theme
Purple accent, custom placeholder ("Search GitHub repos..."), custom empty state icon and text.
```swift
var config = ChatViewConfiguration.default
config.userBubbleColor = .purple
config.accentColor = .purple
config.inputPlaceholder = "Search GitHub repos..."
config.emptyStateTitle = "Explore GitHub Data"
config.emptyStateIcon = "star.circle"
DataChatView(databasePath: path, model: myLLM)
.chatViewConfiguration(config)
```
| Empty state | With results |
|---|---|
| ![Custom empty](custom-theme.png) | ![Custom results](custom-results.png) |
### Dark Theme
```swift
DataChatView(databasePath: path, model: myLLM)
.chatViewConfiguration(.dark)
```
![Dark theme](dark-theme.png)
### Compact Theme
```swift
DataChatView(databasePath: path, model: myLLM)
.chatViewConfiguration(.compact)
```
![Compact theme](compact-theme.png)
## Presentation Modes
![Presentation modes](presentation-modes.png)
### Sheet
```swift
.sheet(isPresented: $showChat) {
DataChatSheet(databasePath: path, model: myLLM, title: "GitHub Stars")
}
```
![Sheet presentation](sheet-presentation.png)

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 128 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 31 KiB

BIN
screenshots/dark-theme.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 69 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 127 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 77 KiB

BIN
screenshots/tool-api.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 144 KiB