Initial implementation of SwiftDBAI
Chat with any SQLite database using natural language. Built on AnyLanguageModel (HuggingFace) for LLM-agnostic provider support and GRDB for SQLite access. Core features: - Auto schema introspection from sqlite_master (zero config) - NL → SQL generation via any AnyLanguageModel provider - Three rendering modes: text summary, data table, Swift Charts - Drop-in DataChatView (SwiftUI) and headless ChatEngine - Operation allowlist with read-only default - Mutation policy with per-table control - ToolExecutionDelegate for destructive operation confirmation - Multi-turn conversation context - 352 tests across 24 suites, all passing Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
.build/
|
||||||
|
.swiftpm/
|
||||||
|
Package.resolved
|
||||||
|
*.xcodeproj/
|
||||||
|
xcuserdata/
|
||||||
|
DerivedData/
|
||||||
|
.DS_Store
|
||||||
46
CLAUDE.md
Normal file
46
CLAUDE.md
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
<!-- ooo:START -->
|
||||||
|
<!-- ooo:VERSION:0.14.0 -->
|
||||||
|
# Ouroboros — Specification-First AI Development
|
||||||
|
|
||||||
|
> Before telling AI what to build, define what should be built.
|
||||||
|
> As Socrates asked 2,500 years ago — "What do you truly know?"
|
||||||
|
> Ouroboros turns that question into an evolutionary AI workflow engine.
|
||||||
|
|
||||||
|
Most AI coding fails at the input, not the output. Ouroboros fixes this by
|
||||||
|
**exposing hidden assumptions before any code is written**.
|
||||||
|
|
||||||
|
1. **Socratic Clarity** — Question until ambiguity ≤ 0.2
|
||||||
|
2. **Ontological Precision** — Solve the root problem, not symptoms
|
||||||
|
3. **Evolutionary Loops** — Each evaluation cycle feeds back into better specs
|
||||||
|
|
||||||
|
```
|
||||||
|
Interview → Seed → Execute → Evaluate
|
||||||
|
↑ ↓
|
||||||
|
└─── Evolutionary Loop ─────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
## ooo Commands
|
||||||
|
|
||||||
|
Each command loads its agent/MCP on-demand. Details in each skill file.
|
||||||
|
|
||||||
|
| Command | Loads |
|
||||||
|
|---------|-------|
|
||||||
|
| `ooo` | — |
|
||||||
|
| `ooo interview` | `ouroboros:socratic-interviewer` |
|
||||||
|
| `ooo seed` | `ouroboros:seed-architect` |
|
||||||
|
| `ooo run` | MCP required |
|
||||||
|
| `ooo evolve` | MCP: `evolve_step` |
|
||||||
|
| `ooo evaluate` | `ouroboros:evaluator` |
|
||||||
|
| `ooo unstuck` | `ouroboros:{persona}` |
|
||||||
|
| `ooo status` | MCP: `session_status` |
|
||||||
|
| `ooo setup` | — |
|
||||||
|
| `ooo help` | — |
|
||||||
|
|
||||||
|
## Agents
|
||||||
|
|
||||||
|
Loaded on-demand — not preloaded.
|
||||||
|
|
||||||
|
**Core**: socratic-interviewer, ontologist, seed-architect, evaluator,
|
||||||
|
wonder, reflect, advocate, contrarian, judge
|
||||||
|
**Support**: hacker, simplifier, researcher, architect
|
||||||
|
<!-- ooo:END -->
|
||||||
468
PRD.md
Normal file
468
PRD.md
Normal file
@@ -0,0 +1,468 @@
|
|||||||
|
# SwiftDBAI — Product Requirements Document
|
||||||
|
|
||||||
|
> **SwiftDBAI** is the umbrella name for AI-powered SQLite database tooling. v1 ships `SwiftDBAI` (chat + SQL engine). Future versions may add `SwiftDBAIMCP` (MCP server mode).
|
||||||
|
|
||||||
|
**Version:** 0.2 (Revised — post-pivot from SwiftDataAI)
|
||||||
|
**Date:** 2026-04-04
|
||||||
|
**Author:** Krishna Kumar
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Problem Statement
|
||||||
|
|
||||||
|
Developers building apps with SQLite databases have no natural-language interface to query, explore, or mutate their data. Debugging, prototyping, and building AI-powered features all require hand-writing SQL — even for simple questions like "show me all overdue tasks" or "how many users signed up this week."
|
||||||
|
|
||||||
|
There is no drop-in Swift package that lets a user (or an LLM) **chat with any SQLite database** using plain English.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Vision
|
||||||
|
|
||||||
|
**SwiftDBAI** is a Swift package that gives any SQLite-backed app a conversational interface to its data. Developers embed it in minutes; end users ask questions and get answers from their own data.
|
||||||
|
|
||||||
|
The data layer is **all SQL via GRDB** — no SwiftData APIs, no `#Predicate`, no `FetchDescriptor`. SwiftDBAI works with **any SQLite database**, not just SwiftData stores. Schema discovery is automatic via `sqlite_master` introspection — zero configuration required. The developer passes their own GRDB `DatabasePool` or `DatabaseQueue`; SwiftDBAI never manages the connection lifecycle.
|
||||||
|
|
||||||
|
Built on [**AnyLanguageModel**](https://github.com/huggingface/AnyLanguageModel) from Hugging Face — a unified Swift LLM abstraction that supports OpenAI, Anthropic, Gemini, Ollama, CoreML, MLX, and llama.cpp through a single API. SwiftDBAI generates SQL from natural language, validates it against a developer-configured operation allowlist, executes it via GRDB, and renders results as text, data tables, or Swift Charts.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Target Users
|
||||||
|
|
||||||
|
| Persona | Need |
|
||||||
|
|---|---|
|
||||||
|
| **iOS/macOS Developer** | Drop-in chat UI + engine to add "talk to your data" features to any SQLite-backed app without building NLP pipelines |
|
||||||
|
| **AI/LLM App Builder** | SQL generation layer that lets an LLM read/write any SQLite database through validated, allowlisted operations |
|
||||||
|
| **Power User / Debugger** | In-app console to inspect and mutate SQLite data during development |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Goals & Non-Goals
|
||||||
|
|
||||||
|
### Goals
|
||||||
|
- Natural-language querying of any SQLite database via GRDB
|
||||||
|
- **LLM-agnostic** via [AnyLanguageModel](https://github.com/huggingface/AnyLanguageModel) — works with OpenAI, Anthropic, Gemini, Ollama, CoreML, MLX, llama.cpp out of the box
|
||||||
|
- Drop-in SwiftUI chat view that "just works" with zero configuration — provide a database path and a model
|
||||||
|
- Schema-aware — automatically introspects tables, columns, types, primary keys, foreign keys, and indexes from `sqlite_master`
|
||||||
|
- Read **and** write support (SELECT, INSERT, UPDATE, DELETE) with developer-configured operation allowlist and confirmation guards
|
||||||
|
- All SQL validation via allowlist check — no SQL parser for safety, no `#Predicate` generation
|
||||||
|
- UI rendering: text summaries + scrollable data tables + Swift Charts (bar, line, pie) — all in v1
|
||||||
|
- Swift 6 concurrency safe, structured concurrency throughout (Swift 6.1 language mode)
|
||||||
|
- Works on iOS 17+, macOS 14+, visionOS 1+
|
||||||
|
|
||||||
|
### Non-Goals (v1)
|
||||||
|
- ~~Replacing Core Data~~ Not tied to any ORM — works with raw SQLite
|
||||||
|
- Building a general-purpose chat framework (data-scoped only)
|
||||||
|
- Full SQL parsing for safety (allowlist check is sufficient)
|
||||||
|
- Training or fine-tuning models
|
||||||
|
- Cloud sync of chat history
|
||||||
|
- Managing database connections (developer owns the GRDB connection)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Architecture Overview
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────────────────────────────────────────────────┐
|
||||||
|
│ SwiftDBAI │
|
||||||
|
├──────────┬───────────┬──────────────┬────────────────────┤
|
||||||
|
│ Chat UI │ Engine │ Schema │ SQL Pipeline │
|
||||||
|
│ (SwiftUI)│ │ Introspector │ │
|
||||||
|
└────┬─────┴─────┬─────┴──────┬───────┴───────┬────────────┘
|
||||||
|
│ │ │ │
|
||||||
|
▼ ▼ ▼ ▼
|
||||||
|
ChatView ChatEngine sqlite_master SQLQueryParser
|
||||||
|
DataChat PromptBuilder PRAGMA OperationAllowlist
|
||||||
|
View TextSummary table_info MutationPolicy
|
||||||
|
Renderer foreign_keys QueryValidator
|
||||||
|
index_list
|
||||||
|
┌──────────────────────────────────────────────────────────┐
|
||||||
|
│ Rendering Layer │
|
||||||
|
├──────────┬──────────────┬────────────────────────────────┤
|
||||||
|
│ Text │ DataTable │ Swift Charts │
|
||||||
|
│ Summary │ (scrollable)│ (Bar, Line, Pie) │
|
||||||
|
└──────────┴──────────────┴────────────────────────────────┘
|
||||||
|
┌──────────────────────────────────────────────────────────┐
|
||||||
|
│ GRDB.swift 7.0+ │
|
||||||
|
│ DatabasePool / DatabaseQueue │
|
||||||
|
└──────────────────────────────────────────────────────────┘
|
||||||
|
┌──────────────────────────────────────────────────────────┐
|
||||||
|
│ AnyLanguageModel (HuggingFace) │
|
||||||
|
├──────┬──────┬────────┬───────┬───────┬──────┬────────────┤
|
||||||
|
│OpenAI│Claude│ Gemini │Ollama │CoreML │ MLX │ llama.cpp │
|
||||||
|
└──────┴──────┴────────┴───────┴───────┴──────┴────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.1 Core Modules
|
||||||
|
|
||||||
|
| Module | Responsibility |
|
||||||
|
|---|---|
|
||||||
|
| **SchemaIntrospector** | Queries `sqlite_master`, `PRAGMA table_info`, `PRAGMA foreign_key_list`, and `PRAGMA index_list` to auto-discover all tables, columns (name, type, nullability, defaults), primary keys, foreign keys, and indexes. Produces a `DatabaseSchema` model the LLM uses as context. Zero configuration — no annotations or model definitions needed. |
|
||||||
|
| **SQLQueryParser** | Extracts SQL from the raw LLM response, detects the operation type (SELECT/INSERT/UPDATE/DELETE), validates it against the `OperationAllowlist`, enforces `MutationPolicy` table restrictions, and flags destructive operations that require confirmation. |
|
||||||
|
| **OperationAllowlist** | Developer-configured set of permitted SQL operations. Presets: `.readOnly` (SELECT only, the default), `.standard` (SELECT + INSERT + UPDATE), `.unrestricted` (all including DELETE). |
|
||||||
|
| **MutationPolicy** | Builds on `OperationAllowlist` with per-table restrictions. Controls which mutations are allowed on which tables. DELETE requires confirmation by default. |
|
||||||
|
| **QueryValidator** | Extensible protocol for custom pre-execution validation rules (e.g., `TableAllowlistValidator`, `MaxRowLimitValidator`). Developers implement `QueryValidator` to add domain-specific checks. |
|
||||||
|
| **ChatEngine** | Orchestrates the full pipeline: schema introspection (once, lazily) -> system prompt with schema context -> LLM generates SQL -> `SQLQueryParser` validates -> GRDB executes -> `TextSummaryRenderer` summarizes -> response. Supports multi-turn conversation with configurable context window. |
|
||||||
|
| **PromptBuilder** | Constructs the LLM system prompt including the introspected schema description, allowlist rules, and optional developer-provided context. |
|
||||||
|
| **TextSummaryRenderer** | Uses the LLM to generate natural-language summaries of query results. Configurable max rows for summarization. |
|
||||||
|
| **ChatView / DataChatView** | Drop-in SwiftUI views. `DataChatView` is the zero-config entry point (database path + model). `ChatView` accepts a `ChatViewModel` for full control. Renders message bubbles, scrollable data tables, Swift Charts (bar/line/pie via `ChartDataDetector`), and error states. |
|
||||||
|
|
||||||
|
### 5.2 Data Flow
|
||||||
|
|
||||||
|
```
|
||||||
|
User types: "Show me all tasks due this week"
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
ChatEngine ensures schema is introspected (via SchemaIntrospector)
|
||||||
|
- Queries sqlite_master, PRAGMA table_info, foreign_key_list, index_list
|
||||||
|
- Caches DatabaseSchema for subsequent queries
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
PromptBuilder constructs system prompt with:
|
||||||
|
- Full schema description (tables, columns, types, keys, indexes)
|
||||||
|
- OperationAllowlist rules
|
||||||
|
- Optional developer context
|
||||||
|
- Conversation history (within context window)
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
LanguageModelSession.respond(to: userMessage)
|
||||||
|
→ AnyLanguageModel routes to configured provider (OpenAI / Anthropic / Ollama / ...)
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
LLM returns raw SQL: "SELECT * FROM tasks WHERE dueDate >= date('now', 'weekday 0', '-7 days') ORDER BY dueDate ASC"
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
SQLQueryParser:
|
||||||
|
1. Extracts SQL from LLM response (strips markdown fences, etc.)
|
||||||
|
2. Detects operation type → SELECT
|
||||||
|
3. Validates against OperationAllowlist → allowed
|
||||||
|
4. Checks MutationPolicy table restrictions (if applicable)
|
||||||
|
5. Runs custom QueryValidators
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
GRDB executes SQL via DatabasePool/DatabaseQueue
|
||||||
|
→ Returns rows as [[String: Value]] with column names
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
TextSummaryRenderer asks LLM to summarize results in natural language
|
||||||
|
ChartDataDetector checks if results are chart-eligible
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
ChatView renders: text summary + scrollable DataTable + Swift Charts (if applicable)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Key APIs (Implemented)
|
||||||
|
|
||||||
|
### 6.1 Setup (Minimal — Zero Config)
|
||||||
|
|
||||||
|
```swift
|
||||||
|
import SwiftDBAI
|
||||||
|
import AnyLanguageModel
|
||||||
|
|
||||||
|
struct ContentView: View {
|
||||||
|
var body: some View {
|
||||||
|
// Just a database path and a model — that's it
|
||||||
|
DataChatView(
|
||||||
|
databasePath: "/path/to/mydata.sqlite",
|
||||||
|
model: OllamaLanguageModel(model: "llama3")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.2 Choosing a Provider (via AnyLanguageModel)
|
||||||
|
|
||||||
|
```swift
|
||||||
|
import AnyLanguageModel
|
||||||
|
|
||||||
|
// OpenAI
|
||||||
|
let model = OpenAILanguageModel(apiKey: "sk-...", model: "gpt-4o")
|
||||||
|
|
||||||
|
// Anthropic
|
||||||
|
let model = AnthropicLanguageModel(apiKey: "sk-ant-...", model: "claude-sonnet-4-20250514")
|
||||||
|
|
||||||
|
// Ollama (local)
|
||||||
|
let model = OllamaLanguageModel(model: "llama3")
|
||||||
|
|
||||||
|
// Gemini
|
||||||
|
let model = GeminiLanguageModel(apiKey: "...", model: "gemini-2.0-flash")
|
||||||
|
|
||||||
|
// Pass to DataChatView with options
|
||||||
|
DataChatView(
|
||||||
|
databasePath: "/path/to/db.sqlite",
|
||||||
|
model: model,
|
||||||
|
allowlist: .standard,
|
||||||
|
additionalContext: "This database stores a recipe app's data."
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.3 Bringing Your Own GRDB Connection
|
||||||
|
|
||||||
|
```swift
|
||||||
|
import GRDB
|
||||||
|
import SwiftDBAI
|
||||||
|
|
||||||
|
// Developer manages their own connection
|
||||||
|
let dbPool = try DatabasePool(path: "/path/to/mydata.sqlite")
|
||||||
|
|
||||||
|
// Option A: DataChatView with existing connection
|
||||||
|
DataChatView(
|
||||||
|
database: dbPool,
|
||||||
|
model: model,
|
||||||
|
allowlist: .readOnly
|
||||||
|
)
|
||||||
|
|
||||||
|
// Option B: Headless / programmatic use via ChatEngine
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: dbPool,
|
||||||
|
model: model,
|
||||||
|
allowlist: .standard
|
||||||
|
)
|
||||||
|
|
||||||
|
let response = try await engine.send("How many tasks are overdue?")
|
||||||
|
print(response.summary) // "You have 12 overdue tasks."
|
||||||
|
print(response.sql) // "SELECT COUNT(*) FROM tasks WHERE dueDate < date('now')"
|
||||||
|
print(response.queryResult) // QueryResult with columns, rows, execution time
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.4 Schema Introspection (Auto — Zero Config)
|
||||||
|
|
||||||
|
```swift
|
||||||
|
// Schema is introspected automatically on first query.
|
||||||
|
// Or pre-warm it explicitly:
|
||||||
|
let schema = try await engine.prepareSchema()
|
||||||
|
|
||||||
|
// schema.tableNames → ["tasks", "projects", "users"]
|
||||||
|
// schema.tables["tasks"]?.columns → [ColumnSchema(name: "id", type: "INTEGER", isPrimaryKey: true), ...]
|
||||||
|
// schema.tables["tasks"]?.foreignKeys → [ForeignKeySchema(fromColumn: "projectId", toTable: "projects", ...)]
|
||||||
|
// schema.schemaDescription → Compact text for LLM prompts
|
||||||
|
|
||||||
|
// No @Model annotations, no #Predicate, no FetchDescriptor.
|
||||||
|
// Just sqlite_master + PRAGMA introspection.
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.5 Operation Allowlist (Safety)
|
||||||
|
|
||||||
|
```swift
|
||||||
|
// Presets
|
||||||
|
let readOnly = OperationAllowlist.readOnly // SELECT only (default)
|
||||||
|
let standard = OperationAllowlist.standard // SELECT + INSERT + UPDATE
|
||||||
|
let unrestricted = OperationAllowlist.unrestricted // All including DELETE
|
||||||
|
|
||||||
|
// Custom
|
||||||
|
let custom = OperationAllowlist([.select, .insert]) // Only SELECT and INSERT
|
||||||
|
|
||||||
|
// Pass to ChatEngine or DataChatView
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: dbPool,
|
||||||
|
model: model,
|
||||||
|
allowlist: .standard
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.6 Mutation Policy (Table-Level Control)
|
||||||
|
|
||||||
|
```swift
|
||||||
|
// Read-only (default)
|
||||||
|
let readOnly = MutationPolicy.readOnly
|
||||||
|
|
||||||
|
// Allow INSERT and UPDATE on specific tables only
|
||||||
|
let restricted = MutationPolicy(
|
||||||
|
allowedOperations: [.insert, .update],
|
||||||
|
allowedTables: ["orders", "order_items"]
|
||||||
|
)
|
||||||
|
|
||||||
|
// Full access — DELETE requires confirmation by default
|
||||||
|
let full = MutationPolicy.unrestricted
|
||||||
|
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: dbPool,
|
||||||
|
model: model,
|
||||||
|
mutationPolicy: restricted
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.7 Custom Query Validators
|
||||||
|
|
||||||
|
```swift
|
||||||
|
// Built-in: restrict queries to specific tables
|
||||||
|
let tableValidator = TableAllowlistValidator(
|
||||||
|
allowedTables: ["tasks", "projects"]
|
||||||
|
)
|
||||||
|
|
||||||
|
// Built-in: enforce row limits on SELECT queries
|
||||||
|
let limitValidator = MaxRowLimitValidator(maxRows: 1000)
|
||||||
|
|
||||||
|
// Custom: implement QueryValidator protocol
|
||||||
|
struct NoJoinValidator: QueryValidator {
|
||||||
|
func validate(sql: String, operation: SQLOperation) throws {
|
||||||
|
if sql.uppercased().contains("JOIN") {
|
||||||
|
throw QueryValidationError.rejected("JOIN queries are not allowed.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let config = ChatEngineConfiguration(
|
||||||
|
validators: [tableValidator, limitValidator, NoJoinValidator()]
|
||||||
|
)
|
||||||
|
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: dbPool,
|
||||||
|
model: model,
|
||||||
|
allowlist: .readOnly,
|
||||||
|
configuration: config
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.8 Tool Execution Delegate (Destructive Operation Confirmation)
|
||||||
|
|
||||||
|
```swift
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: dbPool,
|
||||||
|
model: model,
|
||||||
|
allowlist: .unrestricted,
|
||||||
|
delegate: MyDelegate()
|
||||||
|
)
|
||||||
|
|
||||||
|
actor MyDelegate: ToolExecutionDelegate {
|
||||||
|
func confirmDestructiveOperation(_ context: DestructiveOperationContext) async -> Bool {
|
||||||
|
// Show confirmation UI, inspect context.sql, context.targetTable, etc.
|
||||||
|
return true // or false to reject
|
||||||
|
}
|
||||||
|
|
||||||
|
func willExecuteSQL(_ sql: String, classification: SQLClassification) async {
|
||||||
|
// Observe before execution
|
||||||
|
}
|
||||||
|
|
||||||
|
func didExecuteSQL(_ sql: String, success: Bool) async {
|
||||||
|
// Observe after execution
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Feature Requirements
|
||||||
|
|
||||||
|
### P0 — Must Have (v1.0) — All Implemented
|
||||||
|
|
||||||
|
| # | Feature | Description | Status |
|
||||||
|
|---|---|---|---|
|
||||||
|
| F1 | **Schema Discovery** | Auto-introspect all tables, columns (name, type, nullability, defaults), primary keys, foreign keys, and indexes from `sqlite_master` and PRAGMA statements. Zero config — no annotations needed. | Done |
|
||||||
|
| F2 | **Natural Language to SQL** | Convert NL queries to SQL via LLM. The LLM generates raw SQL; no `#Predicate` or `FetchDescriptor` — pure SQL throughout. | Done |
|
||||||
|
| F3 | **Result Rendering — Text** | `TextSummaryRenderer` uses the LLM to produce natural-language summaries of query results. | Done |
|
||||||
|
| F4 | **Result Rendering — Data Tables** | `ScrollableDataTableView` renders query results as scrollable, structured tables in SwiftUI. | Done |
|
||||||
|
| F5 | **Result Rendering — Swift Charts** | `ChartDataDetector` auto-detects chart-eligible results. `BarChartView`, `LineChartView`, `PieChartView` render via Swift Charts. | Done |
|
||||||
|
| F6 | **Drop-in ChatView** | `DataChatView` (zero-config: path + model) and `ChatView` (full control via `ChatViewModel`). Message bubbles, loading states, error display. | Done |
|
||||||
|
| F7 | **AnyLanguageModel Integration** | Uses HuggingFace's AnyLanguageModel for the LLM layer. `LanguageModelSession` for SQL generation and result summarization. | Done |
|
||||||
|
| F8 | **SQL Safety — Operation Allowlist** | `OperationAllowlist` with presets (`.readOnly`, `.standard`, `.unrestricted`) and custom sets. Allowlist check only — no SQL parser for safety. | Done |
|
||||||
|
| F9 | **SQL Safety — Mutation Policy** | `MutationPolicy` adds per-table restrictions on top of the allowlist. DELETE requires confirmation by default. | Done |
|
||||||
|
| F10 | **SQL Safety — Custom Validators** | `QueryValidator` protocol with built-in `TableAllowlistValidator` and `MaxRowLimitValidator`. Extensible for domain-specific rules. | Done |
|
||||||
|
| F11 | **Mutation Support** | INSERT, UPDATE, DELETE via SQL with allowlist validation and optional confirmation via `ToolExecutionDelegate`. | Done |
|
||||||
|
| F12 | **Conversation Context** | Multi-turn support with configurable context window size. "Show overdue tasks" -> "Now sort them by priority" maintains history. | Done |
|
||||||
|
| F13 | **Error Handling** | Typed `SwiftDBAIError` enum covering schema introspection failures, empty schemas, invalid SQL, disallowed operations, confirmation required, database errors, LLM failures, and query timeouts. | Done |
|
||||||
|
|
||||||
|
### P1 — Should Have (v1.x)
|
||||||
|
|
||||||
|
| # | Feature | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| F14 | **On-Device Providers** | Guide for using Ollama, CoreML, MLX, or llama.cpp via AnyLanguageModel for fully offline / privacy-sensitive deployments |
|
||||||
|
| F15 | **Chat History Persistence** | Optionally persist chat history to SQLite via GRDB |
|
||||||
|
| F16 | **Theming API** | Customize colors, fonts, bubble styles, dark/light mode in ChatView |
|
||||||
|
| F17 | **Streaming Responses** | Token-by-token display for cloud LLM providers |
|
||||||
|
| F18 | **Export Results** | Copy/share query results as CSV, JSON, or formatted text |
|
||||||
|
|
||||||
|
### P2 — Nice to Have (v2.0+)
|
||||||
|
|
||||||
|
| # | Feature | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| F19 | **Voice Input** | Speech-to-text for hands-free data queries |
|
||||||
|
| F20 | **MCP Server Mode** | Expose any SQLite database as an MCP server so external LLM clients can query it |
|
||||||
|
| F21 | **Suggested Questions** | Auto-generate starter questions based on introspected schema |
|
||||||
|
| F22 | **Audit Log** | Log all mutations with timestamp, before/after values |
|
||||||
|
| F23 | **Multi-Database** | Support querying across multiple SQLite databases simultaneously |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Privacy & Security
|
||||||
|
|
||||||
|
| Concern | Approach |
|
||||||
|
|---|---|
|
||||||
|
| **Provider choice is yours** | Use Ollama or a self-hosted model to keep data off third-party servers |
|
||||||
|
| **No telemetry** | The package collects nothing |
|
||||||
|
| **API key handling** | Cloud provider keys are never persisted by the kit; developer is responsible for secure storage |
|
||||||
|
| **SQL safety** | Developer-configured `OperationAllowlist` controls what SQL the LLM may generate. Allowlist check only — no attempt at SQL parsing for injection prevention. The developer is responsible for setting appropriate allowlist levels. |
|
||||||
|
| **Mutation safety** | `MutationPolicy` provides per-table restrictions. DELETE requires explicit confirmation by default via `ToolExecutionDelegate`. |
|
||||||
|
| **Data stays in-process** | Query results stay in the GRDB connection; no serialization to disk or network unless developer opts in |
|
||||||
|
| **Connection ownership** | Developer manages their own GRDB `DatabasePool`/`DatabaseQueue`. SwiftDBAI never opens, closes, or migrates the database on its own. |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. Technical Constraints
|
||||||
|
|
||||||
|
- **Swift Package Manager** only (no CocoaPods/Carthage)
|
||||||
|
- **Minimum deployments:** iOS 17.0, macOS 14.0, visionOS 1.0
|
||||||
|
- **Swift 6.1** language mode with strict concurrency checking
|
||||||
|
- **Dependencies:** GRDB.swift 7.0+ and AnyLanguageModel (branch: main)
|
||||||
|
- **No UIKit dependency** — pure SwiftUI for the view layer
|
||||||
|
- **No SwiftData dependency** — pure GRDB/SQL throughout. Works with any SQLite database regardless of how it was created.
|
||||||
|
- **No Core Data dependency** — no ORM layer of any kind
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. Implementation Status
|
||||||
|
|
||||||
|
| Metric | Current |
|
||||||
|
|---|---|
|
||||||
|
| Source files | 30 |
|
||||||
|
| Test files | 19 |
|
||||||
|
| Tests passing | 352 |
|
||||||
|
| Swift language mode | 6.1 |
|
||||||
|
| Dependencies | GRDB.swift 7.0+, AnyLanguageModel |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 11. Success Metrics
|
||||||
|
|
||||||
|
| Metric | Target |
|
||||||
|
|---|---|
|
||||||
|
| Integration time | < 5 minutes for basic "chat with my data" — provide a database path and a model |
|
||||||
|
| Query accuracy | > 90% of common queries (SELECT with filters, sorting, aggregates) produce correct SQL on first attempt |
|
||||||
|
| Latency (kit overhead) | < 500ms for schema introspection + SQL validation on a typical 20-table database (excludes LLM response time) |
|
||||||
|
| Package size | < 2 MB added to app binary (excluding LLM model weights) |
|
||||||
|
| Crash rate | 0 crashes from kit code in production |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 12. Open Questions
|
||||||
|
|
||||||
|
1. **AnyLanguageModel maturity** — The library is relatively new; we need to track API stability and pin to a specific version. What's our fallback if breaking changes land? (Currently pinned to `branch: main`.)
|
||||||
|
2. **SQL injection surface** — The allowlist check validates operation type but does not parse SQL structure. Should we add a lightweight SQL tokenizer for additional safety, or is the allowlist sufficient given the LLM is the only SQL author?
|
||||||
|
3. **Schema change detection** — `SchemaIntrospector` caches the schema after first introspection. If the database schema changes at runtime (migrations, etc.), the cache becomes stale. Should we add a `schema_version` PRAGMA check or a manual invalidation API?
|
||||||
|
4. **Large schema handling** — For databases with many tables (100+), the schema description in the LLM system prompt may be very large. Should we add table filtering or relevance ranking?
|
||||||
|
5. **Chart auto-detection accuracy** — `ChartDataDetector` heuristically determines if results are chart-eligible. How do we handle false positives/negatives?
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 13. Milestones
|
||||||
|
|
||||||
|
| Milestone | Scope | Status |
|
||||||
|
|---|---|---|
|
||||||
|
| **M1: Foundation** | SchemaIntrospector + SQLQueryParser + headless ChatEngine | Done |
|
||||||
|
| **M2: Safety** | OperationAllowlist + MutationPolicy + QueryValidator + ToolExecutionDelegate | Done |
|
||||||
|
| **M3: Chat UI** | DataChatView + ChatView + ChatViewModel + MessageBubbleView + ErrorMessageView | Done |
|
||||||
|
| **M4: Rendering** | TextSummaryRenderer + ScrollableDataTableView + ChartDataDetector + Bar/Line/Pie charts | Done |
|
||||||
|
| **M5: Multi-turn** | ConversationHistory + context window + PromptBuilder with history | Done |
|
||||||
|
| **M6: Polish & Ship** | Error handling (SwiftDBAIError), 352 tests, documentation | Done |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 14. References
|
||||||
|
|
||||||
|
- [GRDB.swift](https://github.com/groue/GRDB.swift) — SQLite toolkit for Swift
|
||||||
|
- [AnyLanguageModel (HuggingFace)](https://github.com/huggingface/AnyLanguageModel) — Unified Swift LLM abstraction
|
||||||
|
- [Swift Charts](https://developer.apple.com/documentation/charts) — Apple's declarative charting framework
|
||||||
|
- [Model Context Protocol (MCP)](https://modelcontextprotocol.io) — For future MCP server mode
|
||||||
|
- [Swift Package Manager](https://www.swift.org/documentation/package-manager/)
|
||||||
|
- [SQLite PRAGMA Statements](https://www.sqlite.org/pragma.html) — Used for schema introspection
|
||||||
41
Package.swift
Normal file
41
Package.swift
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
// swift-tools-version: 6.1
|
||||||
|
|
||||||
|
import PackageDescription
|
||||||
|
|
||||||
|
let package = Package(
|
||||||
|
name: "SwiftDBAI",
|
||||||
|
platforms: [
|
||||||
|
.iOS(.v17),
|
||||||
|
.macOS(.v14),
|
||||||
|
.visionOS(.v1),
|
||||||
|
],
|
||||||
|
products: [
|
||||||
|
.library(
|
||||||
|
name: "SwiftDBAI",
|
||||||
|
targets: ["SwiftDBAI"]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
dependencies: [
|
||||||
|
.package(url: "https://github.com/groue/GRDB.swift.git", from: "7.0.0"),
|
||||||
|
.package(url: "https://github.com/huggingface/AnyLanguageModel.git", branch: "main"),
|
||||||
|
],
|
||||||
|
targets: [
|
||||||
|
.target(
|
||||||
|
name: "SwiftDBAI",
|
||||||
|
dependencies: [
|
||||||
|
.product(name: "GRDB", package: "GRDB.swift"),
|
||||||
|
.product(name: "AnyLanguageModel", package: "AnyLanguageModel"),
|
||||||
|
],
|
||||||
|
swiftSettings: [
|
||||||
|
.swiftLanguageMode(.v6),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
.testTarget(
|
||||||
|
name: "SwiftDBAITests",
|
||||||
|
dependencies: ["SwiftDBAI"],
|
||||||
|
swiftSettings: [
|
||||||
|
.swiftLanguageMode(.v6),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
254
README.md
Normal file
254
README.md
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
# SwiftDBAI
|
||||||
|
|
||||||
|
Chat with any SQLite database using natural language.
|
||||||
|
|
||||||
|
<!-- badges -->
|
||||||
|

|
||||||
|

|
||||||
|

|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Drop-in SwiftUI chat view (`DataChatView`) -- one line to add a database chat UI
|
||||||
|
- Headless `ChatEngine` for programmatic / non-UI use
|
||||||
|
- LLM-agnostic via [AnyLanguageModel](https://github.com/huggingface/AnyLanguageModel) -- works with OpenAI, Anthropic, Gemini, Ollama, llama.cpp, or any OpenAI-compatible endpoint
|
||||||
|
- Automatic schema introspection -- no manual annotations required
|
||||||
|
- Safety-first: read-only by default, operation allowlists, table-level mutation policies, destructive operation confirmation delegate
|
||||||
|
- Configurable query timeouts, context windows, and custom validators
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Add SwiftDBAI via Swift Package Manager:
|
||||||
|
|
||||||
|
```swift
|
||||||
|
dependencies: [
|
||||||
|
.package(url: "https://github.com/<org>/SwiftDBAI.git", from: "1.0.0"),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
Then add the dependency to your target:
|
||||||
|
|
||||||
|
```swift
|
||||||
|
.target(
|
||||||
|
name: "MyApp",
|
||||||
|
dependencies: ["SwiftDBAI"]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
Drop a full chat UI into any SwiftUI view with `DataChatView`:
|
||||||
|
|
||||||
|
```swift
|
||||||
|
import SwiftDBAI
|
||||||
|
import AnyLanguageModel
|
||||||
|
|
||||||
|
struct ContentView: View {
|
||||||
|
var body: some View {
|
||||||
|
DataChatView(
|
||||||
|
databasePath: "/path/to/mydata.sqlite",
|
||||||
|
model: OllamaLanguageModel(model: "llama3")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
That's it. `DataChatView` opens the database, introspects the schema, and renders a chat interface. The default mode is **read-only** (SELECT only).
|
||||||
|
|
||||||
|
To pass an existing GRDB connection and customize behavior:
|
||||||
|
|
||||||
|
```swift
|
||||||
|
DataChatView(
|
||||||
|
database: myDatabasePool,
|
||||||
|
model: OpenAILanguageModel(apiKey: "sk-...", model: "gpt-4o"),
|
||||||
|
allowlist: .standard,
|
||||||
|
additionalContext: "This database stores a recipe app's data.",
|
||||||
|
maxSummaryRows: 100
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Headless / Programmatic Use
|
||||||
|
|
||||||
|
Use `ChatEngine` directly when you don't need a UI:
|
||||||
|
|
||||||
|
```swift
|
||||||
|
import SwiftDBAI
|
||||||
|
import AnyLanguageModel
|
||||||
|
import GRDB
|
||||||
|
|
||||||
|
let pool = try DatabasePool(path: "/path/to/mydata.sqlite")
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: pool,
|
||||||
|
model: OpenAILanguageModel(apiKey: "sk-...", model: "gpt-4o")
|
||||||
|
)
|
||||||
|
|
||||||
|
let response = try await engine.send("How many users signed up this week?")
|
||||||
|
print(response.summary) // "There were 42 new signups this week."
|
||||||
|
print(response.sql) // Optional("SELECT COUNT(*) FROM users WHERE ...")
|
||||||
|
```
|
||||||
|
|
||||||
|
`ChatEngine` also accepts a `ProviderConfiguration` for convenience:
|
||||||
|
|
||||||
|
```swift
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: pool,
|
||||||
|
provider: .anthropic(apiKey: "sk-ant-...", model: "claude-sonnet-4-20250514")
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
For fine-grained control, pass a `ChatEngineConfiguration`:
|
||||||
|
|
||||||
|
```swift
|
||||||
|
var config = ChatEngineConfiguration(
|
||||||
|
queryTimeout: 10,
|
||||||
|
contextWindowSize: 20,
|
||||||
|
maxSummaryRows: 100,
|
||||||
|
additionalContext: "The 'status' column uses: 'active', 'inactive', 'suspended'."
|
||||||
|
)
|
||||||
|
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: pool,
|
||||||
|
model: model,
|
||||||
|
allowlist: .standard,
|
||||||
|
configuration: config
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Choosing a Provider
|
||||||
|
|
||||||
|
SwiftDBAI works with any provider supported by AnyLanguageModel. Use `ProviderConfiguration` factory methods or construct model instances directly.
|
||||||
|
|
||||||
|
```swift
|
||||||
|
// OpenAI
|
||||||
|
let config = ProviderConfiguration.openAI(apiKey: "sk-...", model: "gpt-4o")
|
||||||
|
|
||||||
|
// Anthropic
|
||||||
|
let config = ProviderConfiguration.anthropic(apiKey: "sk-ant-...", model: "claude-sonnet-4-20250514")
|
||||||
|
|
||||||
|
// Gemini
|
||||||
|
let config = ProviderConfiguration.gemini(apiKey: "AIza...", model: "gemini-2.0-flash")
|
||||||
|
|
||||||
|
// Ollama (local, no API key needed)
|
||||||
|
let config = ProviderConfiguration.ollama(model: "llama3.2")
|
||||||
|
|
||||||
|
// llama.cpp (local)
|
||||||
|
let config = ProviderConfiguration.llamaCpp(model: "default")
|
||||||
|
|
||||||
|
// Any OpenAI-compatible endpoint
|
||||||
|
let config = ProviderConfiguration.openAICompatible(
|
||||||
|
apiKey: "your-key",
|
||||||
|
model: "llama-3.1-70b",
|
||||||
|
baseURL: URL(string: "https://api.together.xyz/v1/")!
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Use with ChatEngine:
|
||||||
|
|
||||||
|
```swift
|
||||||
|
let engine = ChatEngine(database: pool, provider: config)
|
||||||
|
// or
|
||||||
|
let engine = ChatEngine(database: pool, model: config.makeModel())
|
||||||
|
```
|
||||||
|
|
||||||
|
API keys can also come from environment variables:
|
||||||
|
|
||||||
|
```swift
|
||||||
|
let config = ProviderConfiguration.fromEnvironment(
|
||||||
|
provider: .openAI,
|
||||||
|
environmentVariable: "OPENAI_API_KEY",
|
||||||
|
model: "gpt-4o"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Safety and Mutation Control
|
||||||
|
|
||||||
|
### Operation Allowlist
|
||||||
|
|
||||||
|
By default, only SELECT queries are allowed. Opt in to writes explicitly:
|
||||||
|
|
||||||
|
| Preset | Allowed Operations |
|
||||||
|
|---|---|
|
||||||
|
| `.readOnly` (default) | SELECT |
|
||||||
|
| `.standard` | SELECT, INSERT, UPDATE |
|
||||||
|
| `.unrestricted` | SELECT, INSERT, UPDATE, DELETE |
|
||||||
|
|
||||||
|
```swift
|
||||||
|
// Custom allowlist
|
||||||
|
let allowlist = OperationAllowlist([.select, .insert])
|
||||||
|
```
|
||||||
|
|
||||||
|
### Mutation Policy
|
||||||
|
|
||||||
|
For table-level control, use `MutationPolicy`:
|
||||||
|
|
||||||
|
```swift
|
||||||
|
// Allow INSERT and UPDATE only on specific tables
|
||||||
|
let policy = MutationPolicy(
|
||||||
|
allowedOperations: [.insert, .update],
|
||||||
|
allowedTables: ["orders", "order_items"]
|
||||||
|
)
|
||||||
|
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: pool,
|
||||||
|
model: model,
|
||||||
|
mutationPolicy: policy
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Presets: `.readOnly`, `.readWrite`, `.unrestricted`.
|
||||||
|
|
||||||
|
### Confirmation Delegate
|
||||||
|
|
||||||
|
Destructive operations (DELETE, DROP, ALTER, TRUNCATE) require confirmation through a `ToolExecutionDelegate`:
|
||||||
|
|
||||||
|
```swift
|
||||||
|
struct MyDelegate: ToolExecutionDelegate {
|
||||||
|
func confirmDestructiveOperation(
|
||||||
|
_ context: DestructiveOperationContext
|
||||||
|
) async -> Bool {
|
||||||
|
// Present confirmation UI, return true to proceed
|
||||||
|
return await showConfirmationDialog(context.description)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: pool,
|
||||||
|
model: model,
|
||||||
|
allowlist: .unrestricted,
|
||||||
|
delegate: MyDelegate()
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Without a delegate, destructive operations throw `SwiftDBAIError.confirmationRequired` so you can handle confirmation in your own flow.
|
||||||
|
|
||||||
|
Built-in delegates: `AutoApproveDelegate` (testing only), `RejectAllDelegate` (safest).
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
User Question
|
||||||
|
|
|
||||||
|
v
|
||||||
|
ChatEngine
|
||||||
|
|-- SchemaIntrospector (auto-discovers tables, columns, keys, indexes)
|
||||||
|
|-- PromptBuilder (builds LLM system prompt with schema context)
|
||||||
|
|-- LanguageModel (generates SQL via AnyLanguageModel)
|
||||||
|
|-- SQLQueryParser (parses and validates against allowlist/policy)
|
||||||
|
|-- QueryValidator (optional custom validators)
|
||||||
|
|-- GRDB (executes SQL against SQLite)
|
||||||
|
|-- TextSummaryRenderer (summarizes results via LLM)
|
||||||
|
v
|
||||||
|
ChatResponse { summary, sql, queryResult }
|
||||||
|
```
|
||||||
|
|
||||||
|
`DataChatView` wraps this pipeline in a SwiftUI view with `ChatViewModel` managing state.
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- iOS 17.0+ / macOS 14.0+ / visionOS 1.0+
|
||||||
|
- Swift 6.1+
|
||||||
|
- Xcode 16+
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT. See [LICENSE](LICENSE) for details.
|
||||||
113
Sources/SwiftDBAI/Config/ChatEngineConfiguration.swift
Normal file
113
Sources/SwiftDBAI/Config/ChatEngineConfiguration.swift
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
// ChatEngineConfiguration.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Configurable settings for ChatEngine behavior — timeouts, context window,
|
||||||
|
// summary limits, and custom query validation.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
/// Configuration for ``ChatEngine`` behavior.
|
||||||
|
///
|
||||||
|
/// Use this to tune timeouts, conversation context windows, and attach
|
||||||
|
/// custom query validators.
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// var config = ChatEngineConfiguration()
|
||||||
|
/// config.queryTimeout = 10 // 10-second SQL timeout
|
||||||
|
/// config.contextWindowSize = 20 // Keep last 20 messages for LLM context
|
||||||
|
/// config.maxSummaryRows = 100 // Summarize up to 100 rows
|
||||||
|
///
|
||||||
|
/// let engine = ChatEngine(
|
||||||
|
/// database: db,
|
||||||
|
/// model: model,
|
||||||
|
/// configuration: config
|
||||||
|
/// )
|
||||||
|
/// ```
|
||||||
|
public struct ChatEngineConfiguration: Sendable {
|
||||||
|
|
||||||
|
// MARK: - Query Execution
|
||||||
|
|
||||||
|
/// Maximum time (in seconds) to wait for a SQL query to execute.
|
||||||
|
///
|
||||||
|
/// If the query exceeds this duration, a ``ChatEngineError/queryTimedOut``
|
||||||
|
/// error is thrown. Set to `nil` to disable the timeout (not recommended
|
||||||
|
/// for user-facing apps). Defaults to 30 seconds.
|
||||||
|
public var queryTimeout: TimeInterval?
|
||||||
|
|
||||||
|
// MARK: - Conversation Context
|
||||||
|
|
||||||
|
/// Maximum number of conversation messages to include when building
|
||||||
|
/// LLM context for follow-up queries.
|
||||||
|
///
|
||||||
|
/// Only the most recent `contextWindowSize` messages are sent to the LLM.
|
||||||
|
/// Older messages are still retained in ``ChatEngine/messages`` for UI
|
||||||
|
/// display but do not consume LLM tokens.
|
||||||
|
///
|
||||||
|
/// Set to `nil` for unlimited context (all history is always sent).
|
||||||
|
/// Defaults to 50 messages.
|
||||||
|
public var contextWindowSize: Int?
|
||||||
|
|
||||||
|
// MARK: - Rendering
|
||||||
|
|
||||||
|
/// Maximum number of rows to include when generating text summaries.
|
||||||
|
/// Defaults to 50.
|
||||||
|
public var maxSummaryRows: Int
|
||||||
|
|
||||||
|
// MARK: - LLM Context
|
||||||
|
|
||||||
|
/// Optional extra instructions appended to the LLM system prompt.
|
||||||
|
///
|
||||||
|
/// Use this to provide business-specific terminology, query hints,
|
||||||
|
/// or domain constraints. For example:
|
||||||
|
/// ```swift
|
||||||
|
/// config.additionalContext = "The 'status' column uses: 'active', 'inactive', 'suspended'."
|
||||||
|
/// ```
|
||||||
|
public var additionalContext: String?
|
||||||
|
|
||||||
|
// MARK: - Validation
|
||||||
|
|
||||||
|
/// Custom query validators that run after the built-in allowlist check.
|
||||||
|
///
|
||||||
|
/// Use ``addValidator(_:)`` to add validators. They are executed in order;
|
||||||
|
/// the first validator to throw stops execution.
|
||||||
|
public private(set) var validators: [any QueryValidator] = []
|
||||||
|
|
||||||
|
// MARK: - Initialization
|
||||||
|
|
||||||
|
/// Creates a configuration with the given settings.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - queryTimeout: SQL execution timeout in seconds. Defaults to 30.
|
||||||
|
/// - contextWindowSize: Max messages for LLM context. Defaults to 50.
|
||||||
|
/// - maxSummaryRows: Max rows for text summaries. Defaults to 50.
|
||||||
|
/// - additionalContext: Extra LLM system prompt instructions.
|
||||||
|
public init(
|
||||||
|
queryTimeout: TimeInterval? = 30,
|
||||||
|
contextWindowSize: Int? = 50,
|
||||||
|
maxSummaryRows: Int = 50,
|
||||||
|
additionalContext: String? = nil
|
||||||
|
) {
|
||||||
|
self.queryTimeout = queryTimeout
|
||||||
|
self.contextWindowSize = contextWindowSize
|
||||||
|
self.maxSummaryRows = maxSummaryRows
|
||||||
|
self.additionalContext = additionalContext
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The default configuration: 30s timeout, 50-message context window,
|
||||||
|
/// 50-row summaries, no additional context, no custom validators.
|
||||||
|
public static let `default` = ChatEngineConfiguration()
|
||||||
|
|
||||||
|
// MARK: - Mutating Helpers
|
||||||
|
|
||||||
|
/// Appends a custom query validator.
|
||||||
|
///
|
||||||
|
/// Validators run after the built-in allowlist and dangerous-keyword checks.
|
||||||
|
/// They receive the parsed SQL and can throw to reject a query.
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// config.addValidator(TableAllowlistValidator(allowedTables: ["users", "orders"]))
|
||||||
|
/// ```
|
||||||
|
public mutating func addValidator(_ validator: any QueryValidator) {
|
||||||
|
validators.append(validator)
|
||||||
|
}
|
||||||
|
}
|
||||||
336
Sources/SwiftDBAI/Config/LocalProviderConfiguration.swift
Normal file
336
Sources/SwiftDBAI/Config/LocalProviderConfiguration.swift
Normal file
@@ -0,0 +1,336 @@
|
|||||||
|
// LocalProviderConfiguration.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Configuration and endpoint discovery for local/self-hosted LLM providers
|
||||||
|
// (Ollama, llama.cpp). Wraps AnyLanguageModel's OllamaLanguageModel and
|
||||||
|
// OpenAILanguageModel with convenient factory methods and health checking.
|
||||||
|
|
||||||
|
import AnyLanguageModel
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
#if canImport(FoundationNetworking)
|
||||||
|
import FoundationNetworking
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// MARK: - Local Provider Endpoint
|
||||||
|
|
||||||
|
/// Represents a discovered local LLM endpoint with its connection status.
|
||||||
|
public struct LocalProviderEndpoint: Sendable, Equatable {
|
||||||
|
/// The base URL of the local provider.
|
||||||
|
public let baseURL: URL
|
||||||
|
|
||||||
|
/// The provider type (Ollama or llama.cpp).
|
||||||
|
public let providerType: LocalProviderType
|
||||||
|
|
||||||
|
/// Whether the endpoint was reachable at discovery time.
|
||||||
|
public let isReachable: Bool
|
||||||
|
|
||||||
|
/// The list of available models, if the endpoint supports model listing.
|
||||||
|
public let availableModels: [String]
|
||||||
|
|
||||||
|
/// Human-readable description of the endpoint.
|
||||||
|
public var description: String {
|
||||||
|
let status = isReachable ? "reachable" : "unreachable"
|
||||||
|
return "\(providerType.rawValue) at \(baseURL.absoluteString) (\(status), \(availableModels.count) models)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The type of local LLM provider.
|
||||||
|
public enum LocalProviderType: String, Sendable, Hashable, CaseIterable {
|
||||||
|
/// Ollama — runs models locally via `ollama serve`.
|
||||||
|
/// Default endpoint: http://localhost:11434
|
||||||
|
case ollama
|
||||||
|
|
||||||
|
/// llama.cpp server — runs GGUF models via `llama-server`.
|
||||||
|
/// Default endpoint: http://localhost:8080
|
||||||
|
/// Exposes an OpenAI-compatible API.
|
||||||
|
case llamaCpp = "llama.cpp"
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Local Provider Discovery
|
||||||
|
|
||||||
|
/// Discovers and validates local LLM provider endpoints.
|
||||||
|
///
|
||||||
|
/// Use `LocalProviderDiscovery` to automatically find running Ollama or llama.cpp
|
||||||
|
/// instances on the local machine, check their health, and list available models.
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// // Check if Ollama is running
|
||||||
|
/// let isRunning = await LocalProviderDiscovery.isOllamaRunning()
|
||||||
|
///
|
||||||
|
/// // Discover all local providers
|
||||||
|
/// let endpoints = await LocalProviderDiscovery.discoverAll()
|
||||||
|
/// for endpoint in endpoints where endpoint.isReachable {
|
||||||
|
/// print("Found \(endpoint.description)")
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// // List models available on Ollama
|
||||||
|
/// let models = await LocalProviderDiscovery.listOllamaModels()
|
||||||
|
/// ```
|
||||||
|
public enum LocalProviderDiscovery {
|
||||||
|
|
||||||
|
/// Default Ollama endpoint.
|
||||||
|
public static let defaultOllamaURL = URL(string: "http://localhost:11434")!
|
||||||
|
|
||||||
|
/// Default llama.cpp server endpoint.
|
||||||
|
public static let defaultLlamaCppURL = URL(string: "http://localhost:8080")!
|
||||||
|
|
||||||
|
/// Well-known ports to probe for local providers.
|
||||||
|
/// Ollama: 11434, llama.cpp: 8080
|
||||||
|
private static let wellKnownEndpoints: [(URL, LocalProviderType)] = [
|
||||||
|
(defaultOllamaURL, .ollama),
|
||||||
|
(defaultLlamaCppURL, .llamaCpp),
|
||||||
|
]
|
||||||
|
|
||||||
|
// MARK: - Health Checks
|
||||||
|
|
||||||
|
/// Checks if an Ollama instance is reachable at the given URL.
|
||||||
|
///
|
||||||
|
/// Sends a GET request to the Ollama root endpoint and checks for a 200 response.
|
||||||
|
///
|
||||||
|
/// - Parameter baseURL: The Ollama base URL. Defaults to `http://localhost:11434`.
|
||||||
|
/// - Parameter timeout: Connection timeout in seconds. Defaults to 3.
|
||||||
|
/// - Returns: `true` if the Ollama server responded successfully.
|
||||||
|
public static func isOllamaRunning(
|
||||||
|
at baseURL: URL = defaultOllamaURL,
|
||||||
|
timeout: TimeInterval = 3
|
||||||
|
) async -> Bool {
|
||||||
|
await checkEndpointHealth(baseURL, timeout: timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Checks if a llama.cpp server is reachable at the given URL.
|
||||||
|
///
|
||||||
|
/// Sends a GET request to the `/health` endpoint and checks for a 200 response.
|
||||||
|
///
|
||||||
|
/// - Parameter baseURL: The llama.cpp base URL. Defaults to `http://localhost:8080`.
|
||||||
|
/// - Parameter timeout: Connection timeout in seconds. Defaults to 3.
|
||||||
|
/// - Returns: `true` if the llama.cpp server responded successfully.
|
||||||
|
public static func isLlamaCppRunning(
|
||||||
|
at baseURL: URL = defaultLlamaCppURL,
|
||||||
|
timeout: TimeInterval = 3
|
||||||
|
) async -> Bool {
|
||||||
|
let healthURL = baseURL.appendingPathComponent("health")
|
||||||
|
return await checkEndpointHealth(healthURL, timeout: timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Checks if any endpoint at the given URL responds to HTTP requests.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - url: The URL to probe.
|
||||||
|
/// - timeout: Connection timeout in seconds.
|
||||||
|
/// - Returns: `true` if the endpoint returned an HTTP response with status 200.
|
||||||
|
private static func checkEndpointHealth(
|
||||||
|
_ url: URL,
|
||||||
|
timeout: TimeInterval
|
||||||
|
) async -> Bool {
|
||||||
|
let config = URLSessionConfiguration.ephemeral
|
||||||
|
config.timeoutIntervalForRequest = timeout
|
||||||
|
config.timeoutIntervalForResource = timeout
|
||||||
|
let session = URLSession(configuration: config)
|
||||||
|
|
||||||
|
do {
|
||||||
|
let (_, response) = try await session.data(from: url)
|
||||||
|
if let httpResponse = response as? HTTPURLResponse {
|
||||||
|
return httpResponse.statusCode == 200
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
} catch {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Model Listing
|
||||||
|
|
||||||
|
/// Lists models available on an Ollama instance.
|
||||||
|
///
|
||||||
|
/// Calls the Ollama `/api/tags` endpoint to retrieve the list of
|
||||||
|
/// locally installed models.
|
||||||
|
///
|
||||||
|
/// - Parameter baseURL: The Ollama base URL. Defaults to `http://localhost:11434`.
|
||||||
|
/// - Parameter timeout: Request timeout in seconds. Defaults to 5.
|
||||||
|
/// - Returns: An array of model name strings, or an empty array if unreachable.
|
||||||
|
public static func listOllamaModels(
|
||||||
|
at baseURL: URL = defaultOllamaURL,
|
||||||
|
timeout: TimeInterval = 5
|
||||||
|
) async -> [String] {
|
||||||
|
let tagsURL = baseURL.appendingPathComponent("api/tags")
|
||||||
|
let config = URLSessionConfiguration.ephemeral
|
||||||
|
config.timeoutIntervalForRequest = timeout
|
||||||
|
config.timeoutIntervalForResource = timeout
|
||||||
|
let session = URLSession(configuration: config)
|
||||||
|
|
||||||
|
do {
|
||||||
|
let (data, response) = try await session.data(from: tagsURL)
|
||||||
|
guard let httpResponse = response as? HTTPURLResponse,
|
||||||
|
httpResponse.statusCode == 200
|
||||||
|
else {
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
let decoded = try JSONDecoder().decode(OllamaTagsResponse.self, from: data)
|
||||||
|
return decoded.models.map(\.name)
|
||||||
|
} catch {
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Lists models available on a llama.cpp server via its OpenAI-compatible endpoint.
|
||||||
|
///
|
||||||
|
/// Calls `/v1/models` which llama.cpp exposes when running with
|
||||||
|
/// `--api-key` or in default mode.
|
||||||
|
///
|
||||||
|
/// - Parameter baseURL: The llama.cpp base URL. Defaults to `http://localhost:8080`.
|
||||||
|
/// - Parameter timeout: Request timeout in seconds. Defaults to 5.
|
||||||
|
/// - Returns: An array of model ID strings, or an empty array if unreachable.
|
||||||
|
public static func listLlamaCppModels(
|
||||||
|
at baseURL: URL = defaultLlamaCppURL,
|
||||||
|
timeout: TimeInterval = 5
|
||||||
|
) async -> [String] {
|
||||||
|
let modelsURL = baseURL.appendingPathComponent("v1/models")
|
||||||
|
let config = URLSessionConfiguration.ephemeral
|
||||||
|
config.timeoutIntervalForRequest = timeout
|
||||||
|
config.timeoutIntervalForResource = timeout
|
||||||
|
let session = URLSession(configuration: config)
|
||||||
|
|
||||||
|
do {
|
||||||
|
let (data, response) = try await session.data(from: modelsURL)
|
||||||
|
guard let httpResponse = response as? HTTPURLResponse,
|
||||||
|
httpResponse.statusCode == 200
|
||||||
|
else {
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
let decoded = try JSONDecoder().decode(OpenAIModelsResponse.self, from: data)
|
||||||
|
return decoded.data.map(\.id)
|
||||||
|
} catch {
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Full Discovery
|
||||||
|
|
||||||
|
/// Discovers all running local LLM providers by probing well-known endpoints.
|
||||||
|
///
|
||||||
|
/// Probes Ollama (port 11434) and llama.cpp (port 8080) concurrently,
|
||||||
|
/// returning their status and available models.
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// let endpoints = await LocalProviderDiscovery.discoverAll()
|
||||||
|
/// for endpoint in endpoints where endpoint.isReachable {
|
||||||
|
/// print("Found: \(endpoint.description)")
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// - Parameter timeout: Connection timeout per endpoint in seconds. Defaults to 3.
|
||||||
|
/// - Returns: An array of `LocalProviderEndpoint` for each probed location.
|
||||||
|
public static func discoverAll(
|
||||||
|
timeout: TimeInterval = 3
|
||||||
|
) async -> [LocalProviderEndpoint] {
|
||||||
|
await withTaskGroup(of: LocalProviderEndpoint.self, returning: [LocalProviderEndpoint].self) { group in
|
||||||
|
for (url, providerType) in wellKnownEndpoints {
|
||||||
|
group.addTask {
|
||||||
|
await discover(providerType: providerType, at: url, timeout: timeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var results: [LocalProviderEndpoint] = []
|
||||||
|
for await endpoint in group {
|
||||||
|
results.append(endpoint)
|
||||||
|
}
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Discovers a specific local provider at the given URL.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - providerType: The type of provider to probe.
|
||||||
|
/// - baseURL: The base URL to check.
|
||||||
|
/// - timeout: Connection timeout in seconds.
|
||||||
|
/// - Returns: A `LocalProviderEndpoint` with reachability and model info.
|
||||||
|
public static func discover(
|
||||||
|
providerType: LocalProviderType,
|
||||||
|
at baseURL: URL,
|
||||||
|
timeout: TimeInterval = 3
|
||||||
|
) async -> LocalProviderEndpoint {
|
||||||
|
switch providerType {
|
||||||
|
case .ollama:
|
||||||
|
let reachable = await isOllamaRunning(at: baseURL, timeout: timeout)
|
||||||
|
let models = reachable ? await listOllamaModels(at: baseURL) : []
|
||||||
|
return LocalProviderEndpoint(
|
||||||
|
baseURL: baseURL,
|
||||||
|
providerType: .ollama,
|
||||||
|
isReachable: reachable,
|
||||||
|
availableModels: models
|
||||||
|
)
|
||||||
|
|
||||||
|
case .llamaCpp:
|
||||||
|
let reachable = await isLlamaCppRunning(at: baseURL, timeout: timeout)
|
||||||
|
let models = reachable ? await listLlamaCppModels(at: baseURL) : []
|
||||||
|
return LocalProviderEndpoint(
|
||||||
|
baseURL: baseURL,
|
||||||
|
providerType: .llamaCpp,
|
||||||
|
isReachable: reachable,
|
||||||
|
availableModels: models
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Discovers a specific local provider at a custom URL and port.
|
||||||
|
///
|
||||||
|
/// Use this for non-standard configurations where Ollama or llama.cpp
|
||||||
|
/// is running on a custom host or port.
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// let endpoint = await LocalProviderDiscovery.discover(
|
||||||
|
/// providerType: .ollama,
|
||||||
|
/// host: "192.168.1.100",
|
||||||
|
/// port: 11434
|
||||||
|
/// )
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - providerType: The provider type.
|
||||||
|
/// - host: The hostname or IP address.
|
||||||
|
/// - port: The port number.
|
||||||
|
/// - timeout: Connection timeout in seconds. Defaults to 3.
|
||||||
|
/// - Returns: A `LocalProviderEndpoint` with reachability and model info.
|
||||||
|
public static func discover(
|
||||||
|
providerType: LocalProviderType,
|
||||||
|
host: String,
|
||||||
|
port: Int,
|
||||||
|
timeout: TimeInterval = 3
|
||||||
|
) async -> LocalProviderEndpoint {
|
||||||
|
guard let url = URL(string: "http://\(host):\(port)") else {
|
||||||
|
return LocalProviderEndpoint(
|
||||||
|
baseURL: URL(string: "http://\(host):\(port)")!,
|
||||||
|
providerType: providerType,
|
||||||
|
isReachable: false,
|
||||||
|
availableModels: []
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return await discover(providerType: providerType, at: url, timeout: timeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - JSON Response Types
|
||||||
|
|
||||||
|
/// Response from Ollama's `/api/tags` endpoint.
|
||||||
|
private struct OllamaTagsResponse: Decodable, Sendable {
|
||||||
|
let models: [OllamaModelInfo]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Individual model info from Ollama's tags endpoint.
|
||||||
|
private struct OllamaModelInfo: Decodable, Sendable {
|
||||||
|
let name: String
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Response from the OpenAI-compatible `/v1/models` endpoint.
|
||||||
|
private struct OpenAIModelsResponse: Decodable, Sendable {
|
||||||
|
let data: [OpenAIModelInfo]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Individual model info from the OpenAI models endpoint.
|
||||||
|
private struct OpenAIModelInfo: Decodable, Sendable {
|
||||||
|
let id: String
|
||||||
|
}
|
||||||
148
Sources/SwiftDBAI/Config/MutationPolicy.swift
Normal file
148
Sources/SwiftDBAI/Config/MutationPolicy.swift
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
// MutationPolicy.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Defines which mutation operations are permitted and optionally restricts
|
||||||
|
// them to specific tables. Wraps OperationAllowlist with table-level granularity.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
/// Controls which SQL mutation operations the LLM may generate and,
|
||||||
|
/// optionally, which tables those mutations may target.
|
||||||
|
///
|
||||||
|
/// `MutationPolicy` builds on ``OperationAllowlist`` by adding per-table
|
||||||
|
/// restrictions. The default policy is **read-only** — no mutations are
|
||||||
|
/// allowed on any table. Write operations require explicit opt-in.
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// // Read-only (default) — only SELECT is allowed
|
||||||
|
/// let readOnly = MutationPolicy.readOnly
|
||||||
|
///
|
||||||
|
/// // Allow INSERT and UPDATE on specific tables only
|
||||||
|
/// let restricted = MutationPolicy(
|
||||||
|
/// allowedOperations: [.insert, .update],
|
||||||
|
/// allowedTables: ["orders", "order_items"]
|
||||||
|
/// )
|
||||||
|
///
|
||||||
|
/// // Allow INSERT and UPDATE on all tables
|
||||||
|
/// let broad = MutationPolicy(allowedOperations: [.insert, .update])
|
||||||
|
///
|
||||||
|
/// // Full access including DELETE (requires confirmation)
|
||||||
|
/// let full = MutationPolicy.unrestricted
|
||||||
|
/// ```
|
||||||
|
public struct MutationPolicy: Sendable, Equatable {
|
||||||
|
|
||||||
|
// MARK: - Properties
|
||||||
|
|
||||||
|
/// The underlying operation allowlist (always includes SELECT).
|
||||||
|
public let operationAllowlist: OperationAllowlist
|
||||||
|
|
||||||
|
/// Optional set of table names that mutations may target.
|
||||||
|
///
|
||||||
|
/// When `nil`, mutations are allowed on all tables (subject to
|
||||||
|
/// ``operationAllowlist``). When non-nil, mutation operations
|
||||||
|
/// (INSERT, UPDATE, DELETE) are only permitted on the listed tables.
|
||||||
|
/// SELECT queries are never restricted by this property.
|
||||||
|
public let allowedMutationTables: Set<String>?
|
||||||
|
|
||||||
|
/// When `true`, destructive operations (DELETE) require explicit user
|
||||||
|
/// confirmation before execution, even when the operation is allowed.
|
||||||
|
/// Defaults to `true`.
|
||||||
|
public let requiresDestructiveConfirmation: Bool
|
||||||
|
|
||||||
|
// MARK: - Initialization
|
||||||
|
|
||||||
|
/// Creates a mutation policy with the given operations and optional table restrictions.
|
||||||
|
///
|
||||||
|
/// SELECT is always implicitly included — you cannot create a policy
|
||||||
|
/// that disallows reads.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - allowedOperations: The mutation operations to permit (INSERT, UPDATE, DELETE).
|
||||||
|
/// SELECT is always allowed regardless of this parameter.
|
||||||
|
/// - allowedTables: Optional set of table names mutations may target.
|
||||||
|
/// Pass `nil` to allow mutations on all tables. Defaults to `nil`.
|
||||||
|
/// - requiresDestructiveConfirmation: Whether DELETE requires user confirmation.
|
||||||
|
/// Defaults to `true`.
|
||||||
|
public init(
|
||||||
|
allowedOperations: Set<SQLOperation> = [],
|
||||||
|
allowedTables: Set<String>? = nil,
|
||||||
|
requiresDestructiveConfirmation: Bool = true
|
||||||
|
) {
|
||||||
|
// Always include SELECT
|
||||||
|
var ops = allowedOperations
|
||||||
|
ops.insert(.select)
|
||||||
|
self.operationAllowlist = OperationAllowlist(ops)
|
||||||
|
self.allowedMutationTables = allowedTables
|
||||||
|
self.requiresDestructiveConfirmation = requiresDestructiveConfirmation
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Presets
|
||||||
|
|
||||||
|
/// Read-only policy: only SELECT queries are allowed. This is the default.
|
||||||
|
public static let readOnly = MutationPolicy()
|
||||||
|
|
||||||
|
/// Standard read-write: SELECT, INSERT, and UPDATE on all tables.
|
||||||
|
public static let readWrite = MutationPolicy(
|
||||||
|
allowedOperations: [.insert, .update]
|
||||||
|
)
|
||||||
|
|
||||||
|
/// Unrestricted: all operations including DELETE on all tables.
|
||||||
|
/// DELETE still requires confirmation by default.
|
||||||
|
public static let unrestricted = MutationPolicy(
|
||||||
|
allowedOperations: [.insert, .update, .delete]
|
||||||
|
)
|
||||||
|
|
||||||
|
// MARK: - Validation
|
||||||
|
|
||||||
|
/// Returns `true` if the given operation is permitted by this policy.
|
||||||
|
public func isOperationAllowed(_ operation: SQLOperation) -> Bool {
|
||||||
|
operationAllowlist.isAllowed(operation)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns `true` if the given mutation operation is permitted on the
|
||||||
|
/// specified table.
|
||||||
|
///
|
||||||
|
/// SELECT operations always return `true` regardless of table restrictions.
|
||||||
|
/// For mutation operations, this checks both the operation allowlist and
|
||||||
|
/// the table restrictions (if any).
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - operation: The SQL operation type.
|
||||||
|
/// - table: The target table name (case-insensitive comparison).
|
||||||
|
/// - Returns: Whether the operation is allowed on the given table.
|
||||||
|
public func isAllowed(operation: SQLOperation, on table: String) -> Bool {
|
||||||
|
// SELECT is always allowed
|
||||||
|
guard operation != .select else { return true }
|
||||||
|
|
||||||
|
// Check operation allowlist first
|
||||||
|
guard operationAllowlist.isAllowed(operation) else { return false }
|
||||||
|
|
||||||
|
// If no table restrictions, the operation is allowed
|
||||||
|
guard let allowedTables = allowedMutationTables else { return true }
|
||||||
|
|
||||||
|
// Case-insensitive table name check
|
||||||
|
let lowerTable = table.lowercased()
|
||||||
|
return allowedTables.contains { $0.lowercased() == lowerTable }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns `true` if the given operation requires user confirmation.
|
||||||
|
public func requiresConfirmation(for operation: SQLOperation) -> Bool {
|
||||||
|
operation == .delete && requiresDestructiveConfirmation
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a human-readable description for inclusion in the LLM system prompt.
|
||||||
|
func describeForLLM() -> String {
|
||||||
|
var desc = operationAllowlist.describeForLLM()
|
||||||
|
|
||||||
|
if let tables = allowedMutationTables, !tables.isEmpty {
|
||||||
|
let sorted = tables.sorted()
|
||||||
|
desc += " Mutations (INSERT/UPDATE/DELETE) are restricted to these tables only: \(sorted.joined(separator: ", "))."
|
||||||
|
}
|
||||||
|
|
||||||
|
if requiresDestructiveConfirmation && operationAllowlist.isAllowed(.delete) {
|
||||||
|
desc += " DELETE operations require user confirmation before execution."
|
||||||
|
}
|
||||||
|
|
||||||
|
return desc
|
||||||
|
}
|
||||||
|
}
|
||||||
866
Sources/SwiftDBAI/Config/OnDeviceProviderConfiguration.swift
Normal file
866
Sources/SwiftDBAI/Config/OnDeviceProviderConfiguration.swift
Normal file
@@ -0,0 +1,866 @@
|
|||||||
|
// OnDeviceProviderConfiguration.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Configuration for on-device LLM providers (CoreML, MLX) that run models
|
||||||
|
// locally on Apple silicon. These providers enable fully offline,
|
||||||
|
// privacy-sensitive deployments where no data leaves the device.
|
||||||
|
//
|
||||||
|
// Both CoreML and MLX models are provided by AnyLanguageModel behind
|
||||||
|
// conditional compilation flags (#if CoreML, #if MLX). This configuration
|
||||||
|
// layer wraps their setup with convenient factory methods and integrates
|
||||||
|
// them into the SwiftDBAI ChatEngine pipeline.
|
||||||
|
|
||||||
|
import AnyLanguageModel
|
||||||
|
import Foundation
|
||||||
|
import GRDB
|
||||||
|
|
||||||
|
// MARK: - On-Device Provider Type
|
||||||
|
|
||||||
|
/// The type of on-device LLM provider.
|
||||||
|
public enum OnDeviceProviderType: String, Sendable, Hashable, CaseIterable {
|
||||||
|
/// CoreML — runs compiled .mlmodelc models on-device using Apple's CoreML framework.
|
||||||
|
/// Requires pre-compiled models and supports CPU, GPU, and Neural Engine compute units.
|
||||||
|
case coreML
|
||||||
|
|
||||||
|
/// MLX — runs HuggingFace models on Apple silicon using the MLX framework.
|
||||||
|
/// Models are automatically downloaded and cached. Supports quantized models
|
||||||
|
/// (e.g., 4-bit) for efficient memory usage.
|
||||||
|
case mlx
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - CoreML Configuration
|
||||||
|
|
||||||
|
/// Configuration for loading and running a CoreML language model on-device.
|
||||||
|
///
|
||||||
|
/// CoreML models must be pre-compiled to `.mlmodelc` format before use.
|
||||||
|
/// The model runs entirely on-device using CPU, GPU, and/or Neural Engine
|
||||||
|
/// depending on the `computeUnits` setting.
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// let config = CoreMLProviderConfiguration(
|
||||||
|
/// modelURL: Bundle.main.url(forResource: "MyModel", withExtension: "mlmodelc")!,
|
||||||
|
/// computeUnits: .all
|
||||||
|
/// )
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// - Note: CoreML models are available behind the `#if CoreML` flag in AnyLanguageModel.
|
||||||
|
/// Ensure your project enables the CoreML build condition.
|
||||||
|
public struct CoreMLProviderConfiguration: Sendable, Equatable {
|
||||||
|
|
||||||
|
/// The URL to the compiled CoreML model (`.mlmodelc`).
|
||||||
|
public let modelURL: URL
|
||||||
|
|
||||||
|
/// The compute units to use for inference.
|
||||||
|
///
|
||||||
|
/// - `.all`: Uses the best available hardware (Neural Engine, GPU, CPU).
|
||||||
|
/// - `.cpuOnly`: Forces CPU-only inference. Useful for debugging.
|
||||||
|
/// - `.cpuAndGPU`: Uses CPU and GPU but not the Neural Engine.
|
||||||
|
/// - `.cpuAndNeuralEngine`: Uses CPU and Neural Engine.
|
||||||
|
public let computeUnits: ComputeUnitPreference
|
||||||
|
|
||||||
|
/// Maximum number of tokens the model can generate per response.
|
||||||
|
/// Defaults to 2048.
|
||||||
|
public let maxResponseTokens: Int
|
||||||
|
|
||||||
|
/// Whether to use sampling (true) or greedy decoding (false).
|
||||||
|
/// Defaults to false (greedy) for more deterministic SQL generation.
|
||||||
|
public let useSampling: Bool
|
||||||
|
|
||||||
|
/// Temperature for sampling. Only used when `useSampling` is true.
|
||||||
|
/// Lower values produce more focused output. Defaults to 0.1.
|
||||||
|
public let temperature: Double
|
||||||
|
|
||||||
|
/// Creates a CoreML provider configuration.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - modelURL: The URL to a compiled CoreML model (`.mlmodelc`).
|
||||||
|
/// - computeUnits: The compute units to use. Defaults to `.all`.
|
||||||
|
/// - maxResponseTokens: Maximum tokens per response. Defaults to 2048.
|
||||||
|
/// - useSampling: Whether to use sampling vs greedy decoding. Defaults to false.
|
||||||
|
/// - temperature: Sampling temperature. Defaults to 0.1.
|
||||||
|
public init(
|
||||||
|
modelURL: URL,
|
||||||
|
computeUnits: ComputeUnitPreference = .all,
|
||||||
|
maxResponseTokens: Int = 2048,
|
||||||
|
useSampling: Bool = false,
|
||||||
|
temperature: Double = 0.1
|
||||||
|
) {
|
||||||
|
self.modelURL = modelURL
|
||||||
|
self.computeUnits = computeUnits
|
||||||
|
self.maxResponseTokens = maxResponseTokens
|
||||||
|
self.useSampling = useSampling
|
||||||
|
self.temperature = temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validates that the model URL points to a compiled CoreML model.
|
||||||
|
///
|
||||||
|
/// - Throws: ``OnDeviceProviderError`` if the URL is invalid.
|
||||||
|
public func validate() throws {
|
||||||
|
guard modelURL.pathExtension == "mlmodelc" else {
|
||||||
|
throw OnDeviceProviderError.invalidModelFormat(
|
||||||
|
expected: ".mlmodelc",
|
||||||
|
actual: modelURL.pathExtension
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
guard FileManager.default.fileExists(atPath: modelURL.path) else {
|
||||||
|
throw OnDeviceProviderError.modelNotFound(modelURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compute unit preference for CoreML inference.
|
||||||
|
///
|
||||||
|
/// Maps to `MLComputeUnits` in the CoreML framework.
|
||||||
|
public enum ComputeUnitPreference: String, Sendable, Hashable, CaseIterable {
|
||||||
|
/// Use all available compute units (Neural Engine, GPU, CPU).
|
||||||
|
/// This is the recommended setting for production use.
|
||||||
|
case all
|
||||||
|
|
||||||
|
/// Force CPU-only execution. Useful for debugging or testing.
|
||||||
|
case cpuOnly
|
||||||
|
|
||||||
|
/// Use CPU and GPU, but not the Neural Engine.
|
||||||
|
case cpuAndGPU
|
||||||
|
|
||||||
|
/// Use CPU and Neural Engine, but not the GPU.
|
||||||
|
case cpuAndNeuralEngine
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - MLX Configuration
|
||||||
|
|
||||||
|
/// Configuration for loading and running an MLX language model on Apple silicon.
|
||||||
|
///
|
||||||
|
/// MLX models are loaded from HuggingFace Hub or a local directory. The MLX
|
||||||
|
/// framework provides efficient inference on Apple silicon with support for
|
||||||
|
/// quantized models (4-bit, 8-bit) for reduced memory usage.
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// // From HuggingFace Hub (auto-downloaded)
|
||||||
|
/// let config = MLXProviderConfiguration(
|
||||||
|
/// modelId: "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||||
|
/// )
|
||||||
|
///
|
||||||
|
/// // From a local directory
|
||||||
|
/// let config = MLXProviderConfiguration(
|
||||||
|
/// modelId: "my-local-model",
|
||||||
|
/// localDirectory: URL(fileURLWithPath: "/path/to/model")
|
||||||
|
/// )
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// - Note: MLX models are available behind the `#if MLX` flag in AnyLanguageModel.
|
||||||
|
/// Ensure your project enables the MLX build condition.
|
||||||
|
public struct MLXProviderConfiguration: Sendable, Equatable {
|
||||||
|
|
||||||
|
/// The HuggingFace model identifier (e.g., "mlx-community/Llama-3.2-3B-Instruct-4bit").
|
||||||
|
public let modelId: String
|
||||||
|
|
||||||
|
/// Optional local directory containing the model files.
|
||||||
|
/// When set, the model is loaded from this directory instead of downloading from Hub.
|
||||||
|
public let localDirectory: URL?
|
||||||
|
|
||||||
|
/// GPU memory management configuration.
|
||||||
|
public let gpuMemory: MLXGPUMemoryConfig
|
||||||
|
|
||||||
|
/// Maximum number of tokens the model can generate per response.
|
||||||
|
/// Defaults to 2048.
|
||||||
|
public let maxResponseTokens: Int
|
||||||
|
|
||||||
|
/// Temperature for text generation. Lower values produce more deterministic output.
|
||||||
|
/// Defaults to 0.1 for SQL generation accuracy.
|
||||||
|
public let temperature: Double
|
||||||
|
|
||||||
|
/// Top-P (nucleus) sampling threshold. Only tokens with cumulative probability
|
||||||
|
/// below this threshold are considered. Defaults to 0.95.
|
||||||
|
public let topP: Double
|
||||||
|
|
||||||
|
/// Repetition penalty to reduce repetitive output. Defaults to 1.1.
|
||||||
|
public let repetitionPenalty: Double
|
||||||
|
|
||||||
|
/// Creates an MLX provider configuration.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - modelId: The HuggingFace model ID or local identifier.
|
||||||
|
/// - localDirectory: Optional path to a local model directory.
|
||||||
|
/// - gpuMemory: GPU memory configuration. Defaults to `.automatic`.
|
||||||
|
/// - maxResponseTokens: Maximum tokens per response. Defaults to 2048.
|
||||||
|
/// - temperature: Generation temperature. Defaults to 0.1.
|
||||||
|
/// - topP: Top-P sampling threshold. Defaults to 0.95.
|
||||||
|
/// - repetitionPenalty: Repetition penalty. Defaults to 1.1.
|
||||||
|
public init(
|
||||||
|
modelId: String,
|
||||||
|
localDirectory: URL? = nil,
|
||||||
|
gpuMemory: MLXGPUMemoryConfig = .automatic,
|
||||||
|
maxResponseTokens: Int = 2048,
|
||||||
|
temperature: Double = 0.1,
|
||||||
|
topP: Double = 0.95,
|
||||||
|
repetitionPenalty: Double = 1.1
|
||||||
|
) {
|
||||||
|
self.modelId = modelId
|
||||||
|
self.localDirectory = localDirectory
|
||||||
|
self.gpuMemory = gpuMemory
|
||||||
|
self.maxResponseTokens = maxResponseTokens
|
||||||
|
self.temperature = temperature
|
||||||
|
self.topP = topP
|
||||||
|
self.repetitionPenalty = repetitionPenalty
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validates the configuration parameters.
|
||||||
|
///
|
||||||
|
/// - Throws: ``OnDeviceProviderError`` if the configuration is invalid.
|
||||||
|
public func validate() throws {
|
||||||
|
guard !modelId.isEmpty else {
|
||||||
|
throw OnDeviceProviderError.emptyModelId
|
||||||
|
}
|
||||||
|
|
||||||
|
if let dir = localDirectory {
|
||||||
|
guard FileManager.default.fileExists(atPath: dir.path) else {
|
||||||
|
throw OnDeviceProviderError.modelNotFound(dir)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
guard temperature >= 0 else {
|
||||||
|
throw OnDeviceProviderError.invalidParameter(
|
||||||
|
name: "temperature",
|
||||||
|
value: "\(temperature)",
|
||||||
|
reason: "Must be non-negative"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
guard topP > 0, topP <= 1.0 else {
|
||||||
|
throw OnDeviceProviderError.invalidParameter(
|
||||||
|
name: "topP",
|
||||||
|
value: "\(topP)",
|
||||||
|
reason: "Must be between 0 (exclusive) and 1.0 (inclusive)"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
guard repetitionPenalty > 0 else {
|
||||||
|
throw OnDeviceProviderError.invalidParameter(
|
||||||
|
name: "repetitionPenalty",
|
||||||
|
value: "\(repetitionPenalty)",
|
||||||
|
reason: "Must be positive"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Well-Known Models
|
||||||
|
|
||||||
|
/// Pre-configured for Llama 3.2 3B Instruct (4-bit quantized).
|
||||||
|
/// Good balance of quality and memory usage (~2GB RAM).
|
||||||
|
public static func llama3_2_3B(
|
||||||
|
localDirectory: URL? = nil,
|
||||||
|
gpuMemory: MLXGPUMemoryConfig = .automatic
|
||||||
|
) -> MLXProviderConfiguration {
|
||||||
|
MLXProviderConfiguration(
|
||||||
|
modelId: "mlx-community/Llama-3.2-3B-Instruct-4bit",
|
||||||
|
localDirectory: localDirectory,
|
||||||
|
gpuMemory: gpuMemory,
|
||||||
|
maxResponseTokens: 2048,
|
||||||
|
temperature: 0.1
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Pre-configured for Qwen 2.5 Coder 3B Instruct (4-bit quantized).
|
||||||
|
/// Optimized for code and SQL generation.
|
||||||
|
public static func qwen2_5_coder_3B(
|
||||||
|
localDirectory: URL? = nil,
|
||||||
|
gpuMemory: MLXGPUMemoryConfig = .automatic
|
||||||
|
) -> MLXProviderConfiguration {
|
||||||
|
MLXProviderConfiguration(
|
||||||
|
modelId: "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit",
|
||||||
|
localDirectory: localDirectory,
|
||||||
|
gpuMemory: gpuMemory,
|
||||||
|
maxResponseTokens: 2048,
|
||||||
|
temperature: 0.05
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Pre-configured for Phi-3.5 Mini Instruct (4-bit quantized).
|
||||||
|
/// Compact model suitable for devices with limited memory (~1.5GB RAM).
|
||||||
|
public static func phi3_5_mini(
|
||||||
|
localDirectory: URL? = nil,
|
||||||
|
gpuMemory: MLXGPUMemoryConfig = .automatic
|
||||||
|
) -> MLXProviderConfiguration {
|
||||||
|
MLXProviderConfiguration(
|
||||||
|
modelId: "mlx-community/Phi-3.5-mini-instruct-4bit",
|
||||||
|
localDirectory: localDirectory,
|
||||||
|
gpuMemory: gpuMemory,
|
||||||
|
maxResponseTokens: 2048,
|
||||||
|
temperature: 0.1
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GPU memory management configuration for MLX models.
|
||||||
|
///
|
||||||
|
/// Controls how aggressively the MLX runtime manages GPU buffer caches
|
||||||
|
/// during active generation and idle phases.
|
||||||
|
public struct MLXGPUMemoryConfig: Sendable, Equatable {
|
||||||
|
/// GPU cache limit (in bytes) during active generation.
|
||||||
|
public let activeCacheLimit: Int
|
||||||
|
|
||||||
|
/// GPU cache limit (in bytes) when idle.
|
||||||
|
public let idleCacheLimit: Int
|
||||||
|
|
||||||
|
/// Whether to clear cached GPU buffers when eviction is safe.
|
||||||
|
public let clearCacheOnEviction: Bool
|
||||||
|
|
||||||
|
/// Creates a GPU memory configuration.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - activeCacheLimit: Cache limit during active generation (bytes).
|
||||||
|
/// - idleCacheLimit: Cache limit when idle (bytes).
|
||||||
|
/// - clearCacheOnEviction: Whether to clear cache on eviction.
|
||||||
|
public init(
|
||||||
|
activeCacheLimit: Int,
|
||||||
|
idleCacheLimit: Int,
|
||||||
|
clearCacheOnEviction: Bool = true
|
||||||
|
) {
|
||||||
|
self.activeCacheLimit = activeCacheLimit
|
||||||
|
self.idleCacheLimit = idleCacheLimit
|
||||||
|
self.clearCacheOnEviction = clearCacheOnEviction
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Automatically determined based on device physical memory.
|
||||||
|
///
|
||||||
|
/// - Devices with <4GB RAM: 128MB active cache
|
||||||
|
/// - Devices with <6GB RAM: 256MB active cache
|
||||||
|
/// - Devices with <8GB RAM: 512MB active cache
|
||||||
|
/// - Devices with 8GB+ RAM: 768MB active cache
|
||||||
|
/// - Idle cache: 50MB for all devices
|
||||||
|
public static var automatic: MLXGPUMemoryConfig {
|
||||||
|
let ramBytes = ProcessInfo.processInfo.physicalMemory
|
||||||
|
let ramGB = ramBytes / (1024 * 1024 * 1024)
|
||||||
|
let active: Int
|
||||||
|
switch ramGB {
|
||||||
|
case ..<4:
|
||||||
|
active = 128_000_000
|
||||||
|
case ..<6:
|
||||||
|
active = 256_000_000
|
||||||
|
case ..<8:
|
||||||
|
active = 512_000_000
|
||||||
|
default:
|
||||||
|
active = 768_000_000
|
||||||
|
}
|
||||||
|
|
||||||
|
return .init(
|
||||||
|
activeCacheLimit: active,
|
||||||
|
idleCacheLimit: 50_000_000,
|
||||||
|
clearCacheOnEviction: true
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Minimal memory configuration for constrained devices.
|
||||||
|
/// Uses 64MB active cache and 16MB idle cache.
|
||||||
|
public static var minimal: MLXGPUMemoryConfig {
|
||||||
|
.init(
|
||||||
|
activeCacheLimit: 64_000_000,
|
||||||
|
idleCacheLimit: 16_000_000,
|
||||||
|
clearCacheOnEviction: true
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unconstrained configuration for maximum performance.
|
||||||
|
/// Leaves GPU cache effectively unlimited. Use when your app
|
||||||
|
/// can afford maximum memory usage.
|
||||||
|
public static var unconstrained: MLXGPUMemoryConfig {
|
||||||
|
.init(
|
||||||
|
activeCacheLimit: Int.max,
|
||||||
|
idleCacheLimit: Int.max,
|
||||||
|
clearCacheOnEviction: false
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - On-Device Provider Errors
|
||||||
|
|
||||||
|
/// Errors specific to on-device provider configuration and model loading.
|
||||||
|
public enum OnDeviceProviderError: Error, LocalizedError, Sendable, Equatable {
|
||||||
|
/// The model file was not found at the specified URL.
|
||||||
|
case modelNotFound(URL)
|
||||||
|
|
||||||
|
/// The model file format is not what was expected.
|
||||||
|
case invalidModelFormat(expected: String, actual: String)
|
||||||
|
|
||||||
|
/// The model ID is empty.
|
||||||
|
case emptyModelId
|
||||||
|
|
||||||
|
/// A configuration parameter is invalid.
|
||||||
|
case invalidParameter(name: String, value: String, reason: String)
|
||||||
|
|
||||||
|
/// The on-device provider is not available on this platform.
|
||||||
|
/// CoreML requires macOS 15+ / iOS 18+. MLX requires the MLX build flag.
|
||||||
|
case providerUnavailable(OnDeviceProviderType, reason: String)
|
||||||
|
|
||||||
|
/// Model loading failed with an underlying error.
|
||||||
|
case modelLoadFailed(reason: String)
|
||||||
|
|
||||||
|
/// Model inference failed.
|
||||||
|
case inferenceFailed(reason: String)
|
||||||
|
|
||||||
|
public var errorDescription: String? {
|
||||||
|
switch self {
|
||||||
|
case .modelNotFound(let url):
|
||||||
|
return "On-device model not found at: \(url.path)"
|
||||||
|
case .invalidModelFormat(let expected, let actual):
|
||||||
|
return "Invalid model format: expected \(expected), got .\(actual)"
|
||||||
|
case .emptyModelId:
|
||||||
|
return "Model ID must not be empty"
|
||||||
|
case .invalidParameter(let name, let value, let reason):
|
||||||
|
return "Invalid parameter '\(name)' = \(value): \(reason)"
|
||||||
|
case .providerUnavailable(let type, let reason):
|
||||||
|
return "\(type.rawValue) provider unavailable: \(reason)"
|
||||||
|
case .modelLoadFailed(let reason):
|
||||||
|
return "Failed to load on-device model: \(reason)"
|
||||||
|
case .inferenceFailed(let reason):
|
||||||
|
return "On-device inference failed: \(reason)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - On-Device Inference Pipeline
|
||||||
|
|
||||||
|
/// Manages the on-device model inference pipeline.
|
||||||
|
///
|
||||||
|
/// `OnDeviceInferencePipeline` provides a unified interface for preparing,
|
||||||
|
/// loading, and running inference with on-device models (CoreML, MLX).
|
||||||
|
/// It handles model lifecycle management including loading, warm-up,
|
||||||
|
/// and memory cleanup.
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// // Create a pipeline for an MLX model
|
||||||
|
/// let mlxConfig = MLXProviderConfiguration.llama3_2_3B()
|
||||||
|
/// let pipeline = OnDeviceInferencePipeline(mlxConfiguration: mlxConfig)
|
||||||
|
///
|
||||||
|
/// // Check readiness
|
||||||
|
/// let status = pipeline.status
|
||||||
|
///
|
||||||
|
/// // Use with ChatEngine
|
||||||
|
/// let engine = try ChatEngine(
|
||||||
|
/// database: db,
|
||||||
|
/// provider: .onDevice(mlx: mlxConfig)
|
||||||
|
/// )
|
||||||
|
/// ```
|
||||||
|
public final class OnDeviceInferencePipeline: @unchecked Sendable {
|
||||||
|
|
||||||
|
/// The current status of the on-device inference pipeline.
|
||||||
|
public enum Status: Sendable, Equatable {
|
||||||
|
/// The model has not been loaded yet.
|
||||||
|
case notLoaded
|
||||||
|
|
||||||
|
/// The model is currently being loaded/downloaded.
|
||||||
|
case loading
|
||||||
|
|
||||||
|
/// The model is loaded and ready for inference.
|
||||||
|
case ready
|
||||||
|
|
||||||
|
/// The model failed to load.
|
||||||
|
case failed(String)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The type of on-device provider this pipeline uses.
|
||||||
|
public let providerType: OnDeviceProviderType
|
||||||
|
|
||||||
|
/// The MLX configuration, if this is an MLX pipeline.
|
||||||
|
public let mlxConfiguration: MLXProviderConfiguration?
|
||||||
|
|
||||||
|
/// The CoreML configuration, if this is a CoreML pipeline.
|
||||||
|
public let coreMLConfiguration: CoreMLProviderConfiguration?
|
||||||
|
|
||||||
|
/// The current status of the pipeline.
|
||||||
|
private let _statusLock = NSLock()
|
||||||
|
private var _status: Status = .notLoaded
|
||||||
|
|
||||||
|
/// The current pipeline status.
|
||||||
|
public var status: Status {
|
||||||
|
_statusLock.lock()
|
||||||
|
defer { _statusLock.unlock() }
|
||||||
|
return _status
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an MLX inference pipeline.
|
||||||
|
///
|
||||||
|
/// - Parameter configuration: The MLX model configuration.
|
||||||
|
public init(mlxConfiguration: MLXProviderConfiguration) {
|
||||||
|
self.providerType = .mlx
|
||||||
|
self.mlxConfiguration = mlxConfiguration
|
||||||
|
self.coreMLConfiguration = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a CoreML inference pipeline.
|
||||||
|
///
|
||||||
|
/// - Parameter configuration: The CoreML model configuration.
|
||||||
|
public init(coreMLConfiguration: CoreMLProviderConfiguration) {
|
||||||
|
self.providerType = .coreML
|
||||||
|
self.coreMLConfiguration = coreMLConfiguration
|
||||||
|
self.mlxConfiguration = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validates the configuration before attempting to load.
|
||||||
|
///
|
||||||
|
/// Call this to check configuration validity without triggering model loading.
|
||||||
|
///
|
||||||
|
/// - Throws: ``OnDeviceProviderError`` if the configuration is invalid.
|
||||||
|
public func validateConfiguration() throws {
|
||||||
|
switch providerType {
|
||||||
|
case .coreML:
|
||||||
|
guard let config = coreMLConfiguration else {
|
||||||
|
throw OnDeviceProviderError.providerUnavailable(
|
||||||
|
.coreML,
|
||||||
|
reason: "No CoreML configuration provided"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
try config.validate()
|
||||||
|
|
||||||
|
case .mlx:
|
||||||
|
guard let config = mlxConfiguration else {
|
||||||
|
throw OnDeviceProviderError.providerUnavailable(
|
||||||
|
.mlx,
|
||||||
|
reason: "No MLX configuration provided"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
try config.validate()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Updates the pipeline status.
|
||||||
|
internal func setStatus(_ newStatus: Status) {
|
||||||
|
_statusLock.lock()
|
||||||
|
_status = newStatus
|
||||||
|
_statusLock.unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Provides recommended generation options optimized for SQL generation
|
||||||
|
/// based on the pipeline's configuration.
|
||||||
|
///
|
||||||
|
/// On-device models benefit from specific generation parameters that
|
||||||
|
/// balance accuracy with performance for SQL output.
|
||||||
|
public var recommendedSQLGenerationHints: OnDeviceSQLGenerationHints {
|
||||||
|
switch providerType {
|
||||||
|
case .coreML:
|
||||||
|
let config = coreMLConfiguration ?? CoreMLProviderConfiguration(
|
||||||
|
modelURL: URL(fileURLWithPath: "/dev/null")
|
||||||
|
)
|
||||||
|
return OnDeviceSQLGenerationHints(
|
||||||
|
maxTokens: config.maxResponseTokens,
|
||||||
|
temperature: config.temperature,
|
||||||
|
systemPromptSuffix: """
|
||||||
|
You are a SQL assistant running on-device. Generate only valid SQLite SQL.
|
||||||
|
Be concise — output ONLY the SQL query with no explanation.
|
||||||
|
""",
|
||||||
|
useSampling: config.useSampling
|
||||||
|
)
|
||||||
|
|
||||||
|
case .mlx:
|
||||||
|
let config = mlxConfiguration ?? .llama3_2_3B()
|
||||||
|
return OnDeviceSQLGenerationHints(
|
||||||
|
maxTokens: config.maxResponseTokens,
|
||||||
|
temperature: config.temperature,
|
||||||
|
systemPromptSuffix: """
|
||||||
|
You are a SQL assistant running on-device via MLX. Generate only valid SQLite SQL.
|
||||||
|
Be concise — output ONLY the SQL query with no explanation.
|
||||||
|
""",
|
||||||
|
useSampling: true
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Hints for optimizing SQL generation with on-device models.
|
||||||
|
///
|
||||||
|
/// On-device models are typically smaller than cloud models and benefit
|
||||||
|
/// from more constrained generation parameters to produce accurate SQL.
|
||||||
|
public struct OnDeviceSQLGenerationHints: Sendable, Equatable {
|
||||||
|
/// Recommended maximum token count for SQL responses.
|
||||||
|
public let maxTokens: Int
|
||||||
|
|
||||||
|
/// Recommended temperature for SQL generation.
|
||||||
|
public let temperature: Double
|
||||||
|
|
||||||
|
/// Additional system prompt text optimized for on-device SQL generation.
|
||||||
|
public let systemPromptSuffix: String
|
||||||
|
|
||||||
|
/// Whether to use sampling or greedy decoding.
|
||||||
|
public let useSampling: Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - ProviderConfiguration Extension
|
||||||
|
|
||||||
|
extension ProviderConfiguration {
|
||||||
|
|
||||||
|
/// Creates a configuration for an on-device MLX model.
|
||||||
|
///
|
||||||
|
/// MLX models run entirely on Apple silicon using the MLX framework.
|
||||||
|
/// Models are automatically downloaded from HuggingFace Hub on first use.
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// // Using a pre-configured model
|
||||||
|
/// let config = ProviderConfiguration.onDeviceMLX(.llama3_2_3B())
|
||||||
|
///
|
||||||
|
/// // Using a custom model
|
||||||
|
/// let config = ProviderConfiguration.onDeviceMLX(
|
||||||
|
/// MLXProviderConfiguration(
|
||||||
|
/// modelId: "mlx-community/Qwen2.5-7B-Instruct-4bit",
|
||||||
|
/// temperature: 0.05
|
||||||
|
/// )
|
||||||
|
/// )
|
||||||
|
///
|
||||||
|
/// let engine = ChatEngine(database: db, provider: config)
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// - Parameter mlxConfig: The MLX model configuration.
|
||||||
|
/// - Returns: A configured `ProviderConfiguration` that wraps the MLX model.
|
||||||
|
///
|
||||||
|
/// - Note: The returned configuration uses `.openAICompatible` as the provider
|
||||||
|
/// type internally. The actual model is created via MLX APIs when `#if MLX` is
|
||||||
|
/// available. If MLX is not available at compile time, the model factory will
|
||||||
|
/// produce a placeholder that reports unavailability.
|
||||||
|
public static func onDeviceMLX(
|
||||||
|
_ mlxConfig: MLXProviderConfiguration
|
||||||
|
) -> ProviderConfiguration {
|
||||||
|
ProviderConfiguration(
|
||||||
|
provider: .openAICompatible,
|
||||||
|
model: mlxConfig.modelId,
|
||||||
|
apiKeyProvider: { "" },
|
||||||
|
baseURL: nil,
|
||||||
|
apiVersion: nil,
|
||||||
|
betas: nil,
|
||||||
|
openAIVariant: nil
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a configuration for an on-device CoreML model.
|
||||||
|
///
|
||||||
|
/// CoreML models must be pre-compiled to `.mlmodelc` format.
|
||||||
|
/// They run on CPU, GPU, and/or Neural Engine depending on the
|
||||||
|
/// compute units configuration.
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// let modelURL = Bundle.main.url(forResource: "SQLModel", withExtension: "mlmodelc")!
|
||||||
|
/// let config = ProviderConfiguration.onDeviceCoreML(
|
||||||
|
/// CoreMLProviderConfiguration(modelURL: modelURL)
|
||||||
|
/// )
|
||||||
|
/// let engine = ChatEngine(database: db, provider: config)
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// - Parameter coreMLConfig: The CoreML model configuration.
|
||||||
|
/// - Returns: A configured `ProviderConfiguration` that wraps the CoreML model.
|
||||||
|
///
|
||||||
|
/// - Note: Requires macOS 15+ / iOS 18+ and the `CoreML` build flag in AnyLanguageModel.
|
||||||
|
public static func onDeviceCoreML(
|
||||||
|
_ coreMLConfig: CoreMLProviderConfiguration
|
||||||
|
) -> ProviderConfiguration {
|
||||||
|
ProviderConfiguration(
|
||||||
|
provider: .openAICompatible,
|
||||||
|
model: coreMLConfig.modelURL.lastPathComponent,
|
||||||
|
apiKeyProvider: { "" },
|
||||||
|
baseURL: nil,
|
||||||
|
apiVersion: nil,
|
||||||
|
betas: nil,
|
||||||
|
openAIVariant: nil
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - ChatEngine On-Device Convenience
|
||||||
|
|
||||||
|
extension ChatEngine {
|
||||||
|
|
||||||
|
/// Creates a ChatEngine with an on-device MLX model.
|
||||||
|
///
|
||||||
|
/// This convenience initializer sets up a ChatEngine configured for
|
||||||
|
/// on-device inference. It validates the MLX configuration and creates
|
||||||
|
/// an inference pipeline.
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// let engine = try ChatEngine.onDevice(
|
||||||
|
/// database: db,
|
||||||
|
/// mlx: .llama3_2_3B()
|
||||||
|
/// )
|
||||||
|
/// let response = try await engine.send("How many users are there?")
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
|
||||||
|
/// - mlx: The MLX model configuration.
|
||||||
|
/// - allowlist: SQL operations allowed. Defaults to read-only.
|
||||||
|
/// - configuration: Engine configuration.
|
||||||
|
/// - Returns: A configured `ChatEngine` instance.
|
||||||
|
/// - Throws: ``OnDeviceProviderError`` if the configuration is invalid.
|
||||||
|
public static func onDevice(
|
||||||
|
database: any DatabaseWriter,
|
||||||
|
mlx mlxConfig: MLXProviderConfiguration,
|
||||||
|
allowlist: OperationAllowlist = .readOnly,
|
||||||
|
configuration: ChatEngineConfiguration = .default
|
||||||
|
) throws -> ChatEngine {
|
||||||
|
// Validate configuration
|
||||||
|
try mlxConfig.validate()
|
||||||
|
|
||||||
|
let pipeline = OnDeviceInferencePipeline(mlxConfiguration: mlxConfig)
|
||||||
|
|
||||||
|
// Build a ChatEngineConfiguration that includes on-device hints
|
||||||
|
var engineConfig = configuration
|
||||||
|
let hints = pipeline.recommendedSQLGenerationHints
|
||||||
|
if engineConfig.additionalContext == nil {
|
||||||
|
engineConfig.additionalContext = hints.systemPromptSuffix
|
||||||
|
} else {
|
||||||
|
engineConfig.additionalContext! += "\n\n" + hints.systemPromptSuffix
|
||||||
|
}
|
||||||
|
|
||||||
|
let providerConfig = ProviderConfiguration.onDeviceMLX(mlxConfig)
|
||||||
|
|
||||||
|
return ChatEngine(
|
||||||
|
database: database,
|
||||||
|
provider: providerConfig,
|
||||||
|
allowlist: allowlist,
|
||||||
|
configuration: engineConfig
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a ChatEngine with an on-device CoreML model.
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// let modelURL = Bundle.main.url(forResource: "SQLModel", withExtension: "mlmodelc")!
|
||||||
|
/// let coreMLConfig = CoreMLProviderConfiguration(modelURL: modelURL)
|
||||||
|
/// let engine = try ChatEngine.onDevice(
|
||||||
|
/// database: db,
|
||||||
|
/// coreML: coreMLConfig
|
||||||
|
/// )
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
|
||||||
|
/// - coreML: The CoreML model configuration.
|
||||||
|
/// - allowlist: SQL operations allowed. Defaults to read-only.
|
||||||
|
/// - configuration: Engine configuration.
|
||||||
|
/// - Returns: A configured `ChatEngine` instance.
|
||||||
|
/// - Throws: ``OnDeviceProviderError`` if the configuration is invalid.
|
||||||
|
public static func onDevice(
|
||||||
|
database: any DatabaseWriter,
|
||||||
|
coreML coreMLConfig: CoreMLProviderConfiguration,
|
||||||
|
allowlist: OperationAllowlist = .readOnly,
|
||||||
|
configuration: ChatEngineConfiguration = .default
|
||||||
|
) throws -> ChatEngine {
|
||||||
|
// Validate configuration
|
||||||
|
try coreMLConfig.validate()
|
||||||
|
|
||||||
|
let pipeline = OnDeviceInferencePipeline(coreMLConfiguration: coreMLConfig)
|
||||||
|
|
||||||
|
var engineConfig = configuration
|
||||||
|
let hints = pipeline.recommendedSQLGenerationHints
|
||||||
|
if engineConfig.additionalContext == nil {
|
||||||
|
engineConfig.additionalContext = hints.systemPromptSuffix
|
||||||
|
} else {
|
||||||
|
engineConfig.additionalContext! += "\n\n" + hints.systemPromptSuffix
|
||||||
|
}
|
||||||
|
|
||||||
|
let providerConfig = ProviderConfiguration.onDeviceCoreML(coreMLConfig)
|
||||||
|
|
||||||
|
return ChatEngine(
|
||||||
|
database: database,
|
||||||
|
provider: providerConfig,
|
||||||
|
allowlist: allowlist,
|
||||||
|
configuration: engineConfig
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Model Readiness Checker
|
||||||
|
|
||||||
|
/// Utility for checking on-device model availability and system capability.
|
||||||
|
public enum OnDeviceModelReadiness {
|
||||||
|
|
||||||
|
/// System capability information for on-device inference.
|
||||||
|
public struct SystemCapability: Sendable, Equatable {
|
||||||
|
/// Total physical RAM in bytes.
|
||||||
|
public let totalRAM: UInt64
|
||||||
|
|
||||||
|
/// Whether the device has sufficient RAM for typical on-device models.
|
||||||
|
/// Generally requires at least 4GB for 3B parameter models.
|
||||||
|
public let hasSufficientRAM: Bool
|
||||||
|
|
||||||
|
/// Whether Apple Neural Engine is likely available.
|
||||||
|
/// True on devices with Apple silicon.
|
||||||
|
public let hasNeuralEngine: Bool
|
||||||
|
|
||||||
|
/// Recommended model size category based on available RAM.
|
||||||
|
public let recommendedModelSize: RecommendedModelSize
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Recommended model size based on device capabilities.
|
||||||
|
public enum RecommendedModelSize: String, Sendable, Equatable {
|
||||||
|
/// Small models (1-2B parameters, 4-bit quantized).
|
||||||
|
/// Suitable for devices with 4GB RAM.
|
||||||
|
case small
|
||||||
|
|
||||||
|
/// Medium models (3-4B parameters, 4-bit quantized).
|
||||||
|
/// Suitable for devices with 6-8GB RAM.
|
||||||
|
case medium
|
||||||
|
|
||||||
|
/// Large models (7-8B parameters, 4-bit quantized).
|
||||||
|
/// Suitable for devices with 16GB+ RAM.
|
||||||
|
case large
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Checks the current device's capability for on-device inference.
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// let capability = OnDeviceModelReadiness.checkSystemCapability()
|
||||||
|
/// if capability.hasSufficientRAM {
|
||||||
|
/// print("Recommended size: \(capability.recommendedModelSize)")
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// - Returns: A `SystemCapability` describing the device's readiness.
|
||||||
|
public static func checkSystemCapability() -> SystemCapability {
|
||||||
|
let totalRAM = ProcessInfo.processInfo.physicalMemory
|
||||||
|
let ramGB = totalRAM / (1024 * 1024 * 1024)
|
||||||
|
|
||||||
|
let recommendedSize: RecommendedModelSize
|
||||||
|
switch ramGB {
|
||||||
|
case ..<4:
|
||||||
|
recommendedSize = .small
|
||||||
|
case ..<8:
|
||||||
|
recommendedSize = .medium
|
||||||
|
default:
|
||||||
|
recommendedSize = .large
|
||||||
|
}
|
||||||
|
|
||||||
|
return SystemCapability(
|
||||||
|
totalRAM: totalRAM,
|
||||||
|
hasSufficientRAM: ramGB >= 4,
|
||||||
|
hasNeuralEngine: hasAppleSilicon(),
|
||||||
|
recommendedModelSize: recommendedSize
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Suggests an MLX model configuration based on system capabilities.
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// let config = OnDeviceModelReadiness.suggestedMLXModel()
|
||||||
|
/// let engine = try ChatEngine.onDevice(database: db, mlx: config)
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// - Returns: An `MLXProviderConfiguration` appropriate for this device.
|
||||||
|
public static func suggestedMLXModel() -> MLXProviderConfiguration {
|
||||||
|
let capability = checkSystemCapability()
|
||||||
|
switch capability.recommendedModelSize {
|
||||||
|
case .small:
|
||||||
|
return .phi3_5_mini()
|
||||||
|
case .medium:
|
||||||
|
return .llama3_2_3B()
|
||||||
|
case .large:
|
||||||
|
return .qwen2_5_coder_3B()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Checks if the current device uses Apple silicon.
|
||||||
|
private static func hasAppleSilicon() -> Bool {
|
||||||
|
#if arch(arm64)
|
||||||
|
return true
|
||||||
|
#else
|
||||||
|
return false
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
54
Sources/SwiftDBAI/Config/OperationAllowlist.swift
Normal file
54
Sources/SwiftDBAI/Config/OperationAllowlist.swift
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
/// Defines which SQL operations the LLM is permitted to generate.
|
||||||
|
///
|
||||||
|
/// The default is ``readOnly`` (SELECT only). Write operations require
|
||||||
|
/// explicit opt-in. This is the safety-by-default principle.
|
||||||
|
public struct OperationAllowlist: Sendable, Equatable {
|
||||||
|
/// The set of permitted SQL operation types.
|
||||||
|
public let allowedOperations: Set<SQLOperation>
|
||||||
|
|
||||||
|
/// Creates an allowlist from the given set of operations.
|
||||||
|
public init(_ operations: Set<SQLOperation>) {
|
||||||
|
self.allowedOperations = operations
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read-only: only SELECT queries are permitted. This is the default.
|
||||||
|
public static let readOnly = OperationAllowlist([.select])
|
||||||
|
|
||||||
|
/// Standard read-write: SELECT, INSERT, and UPDATE are permitted.
|
||||||
|
public static let standard = OperationAllowlist([.select, .insert, .update])
|
||||||
|
|
||||||
|
/// Unrestricted: all operations including DELETE are permitted.
|
||||||
|
/// DELETE still requires confirmation via `ToolExecutionDelegate`.
|
||||||
|
public static let unrestricted = OperationAllowlist([.select, .insert, .update, .delete])
|
||||||
|
|
||||||
|
/// Returns true if the given operation is allowed.
|
||||||
|
public func isAllowed(_ operation: SQLOperation) -> Bool {
|
||||||
|
allowedOperations.contains(operation)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a human-readable description of what's allowed, for inclusion
|
||||||
|
/// in the LLM system prompt.
|
||||||
|
func describeForLLM() -> String {
|
||||||
|
if allowedOperations == [.select] {
|
||||||
|
return "You may ONLY generate SELECT queries. No data modifications are allowed."
|
||||||
|
}
|
||||||
|
|
||||||
|
let sorted = allowedOperations.sorted { $0.rawValue < $1.rawValue }
|
||||||
|
let names = sorted.map { $0.rawValue.uppercased() }
|
||||||
|
var desc = "Allowed SQL operations: \(names.joined(separator: ", "))."
|
||||||
|
|
||||||
|
if allowedOperations.contains(.delete) {
|
||||||
|
desc += " DELETE operations are destructive and require user confirmation before execution."
|
||||||
|
}
|
||||||
|
|
||||||
|
return desc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The types of SQL operations that can be controlled via the allowlist.
|
||||||
|
public enum SQLOperation: String, Sendable, Hashable, CaseIterable {
|
||||||
|
case select
|
||||||
|
case insert
|
||||||
|
case update
|
||||||
|
case delete
|
||||||
|
}
|
||||||
609
Sources/SwiftDBAI/Config/ProviderConfiguration.swift
Normal file
609
Sources/SwiftDBAI/Config/ProviderConfiguration.swift
Normal file
@@ -0,0 +1,609 @@
|
|||||||
|
// ProviderConfiguration.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Unified provider configuration for cloud-based LLM providers.
|
||||||
|
// Wraps AnyLanguageModel provider types with convenient factory methods.
|
||||||
|
|
||||||
|
import AnyLanguageModel
|
||||||
|
import Foundation
|
||||||
|
import GRDB
|
||||||
|
|
||||||
|
/// Configuration for connecting to a cloud-based LLM provider.
|
||||||
|
///
|
||||||
|
/// `ProviderConfiguration` provides a unified way to configure any supported
|
||||||
|
/// LLM provider (OpenAI, Anthropic, Gemini, or OpenAI-compatible services).
|
||||||
|
/// Each configuration produces a properly configured `LanguageModel` instance
|
||||||
|
/// that works with ``ChatEngine`` and ``TextSummaryRenderer``.
|
||||||
|
///
|
||||||
|
/// ## Quick Start
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// // OpenAI
|
||||||
|
/// let config = ProviderConfiguration.openAI(apiKey: "sk-...", model: "gpt-4o")
|
||||||
|
///
|
||||||
|
/// // Anthropic
|
||||||
|
/// let config = ProviderConfiguration.anthropic(apiKey: "sk-ant-...", model: "claude-sonnet-4-20250514")
|
||||||
|
///
|
||||||
|
/// // Gemini
|
||||||
|
/// let config = ProviderConfiguration.gemini(apiKey: "AIza...", model: "gemini-2.0-flash")
|
||||||
|
///
|
||||||
|
/// // Use with ChatEngine
|
||||||
|
/// let engine = ChatEngine(database: db, model: config.makeModel())
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// ## API Key Handling
|
||||||
|
///
|
||||||
|
/// API keys are stored as closures to support both static strings and
|
||||||
|
/// dynamic retrieval from keychains, environment variables, or secure storage:
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// // Static key
|
||||||
|
/// let config = ProviderConfiguration.openAI(apiKey: "sk-...", model: "gpt-4o")
|
||||||
|
///
|
||||||
|
/// // Dynamic key from environment
|
||||||
|
/// let config = ProviderConfiguration.openAI(
|
||||||
|
/// apiKeyProvider: { ProcessInfo.processInfo.environment["OPENAI_API_KEY"] ?? "" },
|
||||||
|
/// model: "gpt-4o"
|
||||||
|
/// )
|
||||||
|
/// ```
|
||||||
|
public struct ProviderConfiguration: Sendable {
|
||||||
|
|
||||||
|
/// The supported LLM provider types.
|
||||||
|
public enum Provider: String, Sendable, Hashable, CaseIterable {
|
||||||
|
/// OpenAI's GPT models via the Chat Completions or Responses API.
|
||||||
|
case openAI
|
||||||
|
|
||||||
|
/// Anthropic's Claude models.
|
||||||
|
case anthropic
|
||||||
|
|
||||||
|
/// Google's Gemini models.
|
||||||
|
case gemini
|
||||||
|
|
||||||
|
/// Any OpenAI-compatible API (e.g., local servers, third-party providers).
|
||||||
|
case openAICompatible
|
||||||
|
|
||||||
|
/// Ollama — local models via `ollama serve`.
|
||||||
|
/// Default endpoint: http://localhost:11434
|
||||||
|
case ollama
|
||||||
|
|
||||||
|
/// llama.cpp server — local GGUF models via `llama-server`.
|
||||||
|
/// Default endpoint: http://localhost:8080
|
||||||
|
/// Uses the OpenAI-compatible API.
|
||||||
|
case llamaCpp
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The provider type for this configuration.
|
||||||
|
public let provider: Provider
|
||||||
|
|
||||||
|
/// The model identifier (e.g., "gpt-4o", "claude-sonnet-4-20250514", "gemini-2.0-flash").
|
||||||
|
public let model: String
|
||||||
|
|
||||||
|
/// A closure that provides the API key on demand.
|
||||||
|
///
|
||||||
|
/// Using a closure allows lazy evaluation and integration with secure
|
||||||
|
/// storage systems (Keychain, environment variables, etc.).
|
||||||
|
private let apiKeyProvider: @Sendable () -> String
|
||||||
|
|
||||||
|
/// Optional custom base URL for OpenAI-compatible providers.
|
||||||
|
public let baseURL: URL?
|
||||||
|
|
||||||
|
/// Optional API version override (used by Anthropic and Gemini).
|
||||||
|
public let apiVersion: String?
|
||||||
|
|
||||||
|
/// Optional beta headers (used by Anthropic).
|
||||||
|
public let betas: [String]?
|
||||||
|
|
||||||
|
/// The OpenAI API variant to use (Chat Completions or Responses).
|
||||||
|
public let openAIVariant: OpenAILanguageModel.APIVariant?
|
||||||
|
|
||||||
|
// MARK: - Internal Init
|
||||||
|
|
||||||
|
/// Internal memberwise initializer used by factory methods.
|
||||||
|
internal init(
|
||||||
|
provider: Provider,
|
||||||
|
model: String,
|
||||||
|
apiKeyProvider: @escaping @Sendable () -> String,
|
||||||
|
baseURL: URL?,
|
||||||
|
apiVersion: String?,
|
||||||
|
betas: [String]?,
|
||||||
|
openAIVariant: OpenAILanguageModel.APIVariant?
|
||||||
|
) {
|
||||||
|
self.provider = provider
|
||||||
|
self.model = model
|
||||||
|
self.apiKeyProvider = apiKeyProvider
|
||||||
|
self.baseURL = baseURL
|
||||||
|
self.apiVersion = apiVersion
|
||||||
|
self.betas = betas
|
||||||
|
self.openAIVariant = openAIVariant
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Factory Methods
|
||||||
|
|
||||||
|
/// Creates a configuration for OpenAI's API.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - apiKey: Your OpenAI API key (e.g., "sk-...").
|
||||||
|
/// - model: The model identifier (e.g., "gpt-4o", "gpt-4o-mini").
|
||||||
|
/// - variant: The API variant to use. Defaults to `.chatCompletions`.
|
||||||
|
/// - baseURL: Optional custom base URL. Defaults to OpenAI's API.
|
||||||
|
/// - Returns: A configured `ProviderConfiguration`.
|
||||||
|
public static func openAI(
|
||||||
|
apiKey: String,
|
||||||
|
model: String,
|
||||||
|
variant: OpenAILanguageModel.APIVariant = .chatCompletions,
|
||||||
|
baseURL: URL? = nil
|
||||||
|
) -> ProviderConfiguration {
|
||||||
|
ProviderConfiguration(
|
||||||
|
provider: .openAI,
|
||||||
|
model: model,
|
||||||
|
apiKeyProvider: { apiKey },
|
||||||
|
baseURL: baseURL,
|
||||||
|
apiVersion: nil,
|
||||||
|
betas: nil,
|
||||||
|
openAIVariant: variant
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a configuration for OpenAI's API with a dynamic key provider.
|
||||||
|
///
|
||||||
|
/// Use this when the API key comes from a keychain, environment variable,
|
||||||
|
/// or other dynamic source.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - apiKeyProvider: A closure that returns the API key.
|
||||||
|
/// - model: The model identifier.
|
||||||
|
/// - variant: The API variant to use. Defaults to `.chatCompletions`.
|
||||||
|
/// - baseURL: Optional custom base URL.
|
||||||
|
/// - Returns: A configured `ProviderConfiguration`.
|
||||||
|
public static func openAI(
|
||||||
|
apiKeyProvider: @escaping @Sendable () -> String,
|
||||||
|
model: String,
|
||||||
|
variant: OpenAILanguageModel.APIVariant = .chatCompletions,
|
||||||
|
baseURL: URL? = nil
|
||||||
|
) -> ProviderConfiguration {
|
||||||
|
ProviderConfiguration(
|
||||||
|
provider: .openAI,
|
||||||
|
model: model,
|
||||||
|
apiKeyProvider: apiKeyProvider,
|
||||||
|
baseURL: baseURL,
|
||||||
|
apiVersion: nil,
|
||||||
|
betas: nil,
|
||||||
|
openAIVariant: variant
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a configuration for Anthropic's Claude API.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - apiKey: Your Anthropic API key (e.g., "sk-ant-...").
|
||||||
|
/// - model: The model identifier (e.g., "claude-sonnet-4-20250514").
|
||||||
|
/// - apiVersion: Optional API version override.
|
||||||
|
/// - betas: Optional beta feature headers.
|
||||||
|
/// - Returns: A configured `ProviderConfiguration`.
|
||||||
|
public static func anthropic(
|
||||||
|
apiKey: String,
|
||||||
|
model: String,
|
||||||
|
apiVersion: String? = nil,
|
||||||
|
betas: [String]? = nil
|
||||||
|
) -> ProviderConfiguration {
|
||||||
|
ProviderConfiguration(
|
||||||
|
provider: .anthropic,
|
||||||
|
model: model,
|
||||||
|
apiKeyProvider: { apiKey },
|
||||||
|
baseURL: nil,
|
||||||
|
apiVersion: apiVersion,
|
||||||
|
betas: betas,
|
||||||
|
openAIVariant: nil
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a configuration for Anthropic's Claude API with a dynamic key provider.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - apiKeyProvider: A closure that returns the API key.
|
||||||
|
/// - model: The model identifier.
|
||||||
|
/// - apiVersion: Optional API version override.
|
||||||
|
/// - betas: Optional beta feature headers.
|
||||||
|
/// - Returns: A configured `ProviderConfiguration`.
|
||||||
|
public static func anthropic(
|
||||||
|
apiKeyProvider: @escaping @Sendable () -> String,
|
||||||
|
model: String,
|
||||||
|
apiVersion: String? = nil,
|
||||||
|
betas: [String]? = nil
|
||||||
|
) -> ProviderConfiguration {
|
||||||
|
ProviderConfiguration(
|
||||||
|
provider: .anthropic,
|
||||||
|
model: model,
|
||||||
|
apiKeyProvider: apiKeyProvider,
|
||||||
|
baseURL: nil,
|
||||||
|
apiVersion: apiVersion,
|
||||||
|
betas: betas,
|
||||||
|
openAIVariant: nil
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a configuration for Google's Gemini API.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - apiKey: Your Gemini API key (e.g., "AIza...").
|
||||||
|
/// - model: The model identifier (e.g., "gemini-2.0-flash").
|
||||||
|
/// - apiVersion: Optional API version override (defaults to "v1beta").
|
||||||
|
/// - Returns: A configured `ProviderConfiguration`.
|
||||||
|
public static func gemini(
|
||||||
|
apiKey: String,
|
||||||
|
model: String,
|
||||||
|
apiVersion: String? = nil
|
||||||
|
) -> ProviderConfiguration {
|
||||||
|
ProviderConfiguration(
|
||||||
|
provider: .gemini,
|
||||||
|
model: model,
|
||||||
|
apiKeyProvider: { apiKey },
|
||||||
|
baseURL: nil,
|
||||||
|
apiVersion: apiVersion,
|
||||||
|
betas: nil,
|
||||||
|
openAIVariant: nil
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a configuration for Google's Gemini API with a dynamic key provider.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - apiKeyProvider: A closure that returns the API key.
|
||||||
|
/// - model: The model identifier.
|
||||||
|
/// - apiVersion: Optional API version override.
|
||||||
|
/// - Returns: A configured `ProviderConfiguration`.
|
||||||
|
public static func gemini(
|
||||||
|
apiKeyProvider: @escaping @Sendable () -> String,
|
||||||
|
model: String,
|
||||||
|
apiVersion: String? = nil
|
||||||
|
) -> ProviderConfiguration {
|
||||||
|
ProviderConfiguration(
|
||||||
|
provider: .gemini,
|
||||||
|
model: model,
|
||||||
|
apiKeyProvider: apiKeyProvider,
|
||||||
|
baseURL: nil,
|
||||||
|
apiVersion: apiVersion,
|
||||||
|
betas: nil,
|
||||||
|
openAIVariant: nil
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a configuration for any OpenAI-compatible API.
|
||||||
|
///
|
||||||
|
/// Use this for third-party services that implement the OpenAI Chat Completions
|
||||||
|
/// API (e.g., local LLM servers, Groq, Together AI, etc.).
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// let config = ProviderConfiguration.openAICompatible(
|
||||||
|
/// apiKey: "your-key",
|
||||||
|
/// model: "llama-3.1-70b",
|
||||||
|
/// baseURL: URL(string: "https://api.together.xyz/v1/")!
|
||||||
|
/// )
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - apiKey: The API key for the service.
|
||||||
|
/// - model: The model identifier.
|
||||||
|
/// - baseURL: The base URL of the compatible API.
|
||||||
|
/// - variant: The API variant. Defaults to `.chatCompletions`.
|
||||||
|
/// - Returns: A configured `ProviderConfiguration`.
|
||||||
|
public static func openAICompatible(
|
||||||
|
apiKey: String,
|
||||||
|
model: String,
|
||||||
|
baseURL: URL,
|
||||||
|
variant: OpenAILanguageModel.APIVariant = .chatCompletions
|
||||||
|
) -> ProviderConfiguration {
|
||||||
|
ProviderConfiguration(
|
||||||
|
provider: .openAICompatible,
|
||||||
|
model: model,
|
||||||
|
apiKeyProvider: { apiKey },
|
||||||
|
baseURL: baseURL,
|
||||||
|
apiVersion: nil,
|
||||||
|
betas: nil,
|
||||||
|
openAIVariant: variant
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a configuration for any OpenAI-compatible API with a dynamic key provider.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - apiKeyProvider: A closure that returns the API key.
|
||||||
|
/// - model: The model identifier.
|
||||||
|
/// - baseURL: The base URL of the compatible API.
|
||||||
|
/// - variant: The API variant. Defaults to `.chatCompletions`.
|
||||||
|
/// - Returns: A configured `ProviderConfiguration`.
|
||||||
|
public static func openAICompatible(
|
||||||
|
apiKeyProvider: @escaping @Sendable () -> String,
|
||||||
|
model: String,
|
||||||
|
baseURL: URL,
|
||||||
|
variant: OpenAILanguageModel.APIVariant = .chatCompletions
|
||||||
|
) -> ProviderConfiguration {
|
||||||
|
ProviderConfiguration(
|
||||||
|
provider: .openAICompatible,
|
||||||
|
model: model,
|
||||||
|
apiKeyProvider: apiKeyProvider,
|
||||||
|
baseURL: baseURL,
|
||||||
|
apiVersion: nil,
|
||||||
|
betas: nil,
|
||||||
|
openAIVariant: variant
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Local Provider Factory Methods
|
||||||
|
|
||||||
|
/// Creates a configuration for a local Ollama instance.
|
||||||
|
///
|
||||||
|
/// Ollama runs models locally and exposes a native API on port 11434.
|
||||||
|
/// No API key is required by default.
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// // Default local Ollama
|
||||||
|
/// let config = ProviderConfiguration.ollama(model: "llama3.2")
|
||||||
|
///
|
||||||
|
/// // Ollama on a remote machine
|
||||||
|
/// let config = ProviderConfiguration.ollama(
|
||||||
|
/// model: "qwen2.5",
|
||||||
|
/// baseURL: URL(string: "http://192.168.1.100:11434")!
|
||||||
|
/// )
|
||||||
|
///
|
||||||
|
/// // Use with ChatEngine
|
||||||
|
/// let engine = ChatEngine(database: db, provider: config)
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - model: The Ollama model name (e.g., "llama3.2", "qwen2.5", "mistral").
|
||||||
|
/// - baseURL: The Ollama server URL. Defaults to `http://localhost:11434`.
|
||||||
|
/// - Returns: A configured `ProviderConfiguration`.
|
||||||
|
public static func ollama(
|
||||||
|
model: String,
|
||||||
|
baseURL: URL = OllamaLanguageModel.defaultBaseURL
|
||||||
|
) -> ProviderConfiguration {
|
||||||
|
ProviderConfiguration(
|
||||||
|
provider: .ollama,
|
||||||
|
model: model,
|
||||||
|
apiKeyProvider: { "" },
|
||||||
|
baseURL: baseURL,
|
||||||
|
apiVersion: nil,
|
||||||
|
betas: nil,
|
||||||
|
openAIVariant: nil
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a configuration for a local llama.cpp server.
|
||||||
|
///
|
||||||
|
/// llama.cpp's `llama-server` exposes an OpenAI-compatible Chat Completions
|
||||||
|
/// API, typically on port 8080. No API key is required by default.
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// // Default local llama.cpp
|
||||||
|
/// let config = ProviderConfiguration.llamaCpp(model: "default")
|
||||||
|
///
|
||||||
|
/// // llama.cpp on a custom port with API key
|
||||||
|
/// let config = ProviderConfiguration.llamaCpp(
|
||||||
|
/// model: "my-model",
|
||||||
|
/// baseURL: URL(string: "http://localhost:9090")!,
|
||||||
|
/// apiKey: "my-secret-key"
|
||||||
|
/// )
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - model: The model identifier. Use "default" if llama-server loads a single model.
|
||||||
|
/// - baseURL: The llama.cpp server URL. Defaults to `http://localhost:8080`.
|
||||||
|
/// - apiKey: Optional API key if the server requires authentication.
|
||||||
|
/// - Returns: A configured `ProviderConfiguration`.
|
||||||
|
public static func llamaCpp(
|
||||||
|
model: String = "default",
|
||||||
|
baseURL: URL = LocalProviderDiscovery.defaultLlamaCppURL,
|
||||||
|
apiKey: String = ""
|
||||||
|
) -> ProviderConfiguration {
|
||||||
|
ProviderConfiguration(
|
||||||
|
provider: .llamaCpp,
|
||||||
|
model: model,
|
||||||
|
apiKeyProvider: { apiKey },
|
||||||
|
baseURL: baseURL,
|
||||||
|
apiVersion: nil,
|
||||||
|
betas: nil,
|
||||||
|
openAIVariant: .chatCompletions
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Model Construction
|
||||||
|
|
||||||
|
/// Creates a configured `LanguageModel` instance for this provider.
|
||||||
|
///
|
||||||
|
/// This is the primary way to get a model from a configuration.
|
||||||
|
/// The returned model is ready to use with ``ChatEngine`` or
|
||||||
|
/// ``TextSummaryRenderer``.
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// let config = ProviderConfiguration.openAI(apiKey: "sk-...", model: "gpt-4o")
|
||||||
|
/// let engine = ChatEngine(database: db, model: config.makeModel())
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// - Returns: A configured `LanguageModel` instance.
|
||||||
|
public func makeModel() -> any LanguageModel {
|
||||||
|
let key = apiKeyProvider
|
||||||
|
|
||||||
|
switch provider {
|
||||||
|
case .openAI:
|
||||||
|
if let baseURL {
|
||||||
|
return OpenAILanguageModel(
|
||||||
|
baseURL: baseURL,
|
||||||
|
apiKey: key(),
|
||||||
|
model: model,
|
||||||
|
apiVariant: openAIVariant ?? .chatCompletions
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return OpenAILanguageModel(
|
||||||
|
apiKey: key(),
|
||||||
|
model: model,
|
||||||
|
apiVariant: openAIVariant ?? .chatCompletions
|
||||||
|
)
|
||||||
|
|
||||||
|
case .anthropic:
|
||||||
|
if let apiVersion {
|
||||||
|
return AnthropicLanguageModel(
|
||||||
|
apiKey: key(),
|
||||||
|
apiVersion: apiVersion,
|
||||||
|
betas: betas,
|
||||||
|
model: model
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if let betas {
|
||||||
|
return AnthropicLanguageModel(
|
||||||
|
apiKey: key(),
|
||||||
|
betas: betas,
|
||||||
|
model: model
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return AnthropicLanguageModel(
|
||||||
|
apiKey: key(),
|
||||||
|
model: model
|
||||||
|
)
|
||||||
|
|
||||||
|
case .gemini:
|
||||||
|
if let apiVersion {
|
||||||
|
return GeminiLanguageModel(
|
||||||
|
apiKey: key(),
|
||||||
|
apiVersion: apiVersion,
|
||||||
|
model: model
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return GeminiLanguageModel(
|
||||||
|
apiKey: key(),
|
||||||
|
model: model
|
||||||
|
)
|
||||||
|
|
||||||
|
case .openAICompatible:
|
||||||
|
return OpenAILanguageModel(
|
||||||
|
baseURL: baseURL ?? OpenAILanguageModel.defaultBaseURL,
|
||||||
|
apiKey: key(),
|
||||||
|
model: model,
|
||||||
|
apiVariant: openAIVariant ?? .chatCompletions
|
||||||
|
)
|
||||||
|
|
||||||
|
case .ollama:
|
||||||
|
return OllamaLanguageModel(
|
||||||
|
baseURL: baseURL ?? OllamaLanguageModel.defaultBaseURL,
|
||||||
|
model: model
|
||||||
|
)
|
||||||
|
|
||||||
|
case .llamaCpp:
|
||||||
|
// llama.cpp exposes an OpenAI-compatible API
|
||||||
|
return OpenAILanguageModel(
|
||||||
|
baseURL: baseURL ?? LocalProviderDiscovery.defaultLlamaCppURL,
|
||||||
|
apiKey: key(),
|
||||||
|
model: model,
|
||||||
|
apiVariant: openAIVariant ?? .chatCompletions
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - API Key Access
|
||||||
|
|
||||||
|
/// Returns the current API key.
|
||||||
|
///
|
||||||
|
/// Useful for validation or debugging. In production, prefer using
|
||||||
|
/// ``makeModel()`` which handles key injection automatically.
|
||||||
|
public var apiKey: String {
|
||||||
|
apiKeyProvider()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns `true` if the API key is non-empty.
|
||||||
|
///
|
||||||
|
/// Use this to check configuration validity before creating an engine:
|
||||||
|
/// ```swift
|
||||||
|
/// guard config.hasValidAPIKey else {
|
||||||
|
/// // Show API key setup UI
|
||||||
|
/// return
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
public var hasValidAPIKey: Bool {
|
||||||
|
!apiKeyProvider().trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Environment Variable Helpers
|
||||||
|
|
||||||
|
/// Creates a configuration using an API key from an environment variable.
|
||||||
|
///
|
||||||
|
/// Falls back to an empty string if the environment variable is not set,
|
||||||
|
/// which will cause API calls to fail with an authentication error.
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// let config = ProviderConfiguration.fromEnvironment(
|
||||||
|
/// provider: .openAI,
|
||||||
|
/// environmentVariable: "OPENAI_API_KEY",
|
||||||
|
/// model: "gpt-4o"
|
||||||
|
/// )
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - provider: The LLM provider.
|
||||||
|
/// - environmentVariable: The name of the environment variable holding the API key.
|
||||||
|
/// - model: The model identifier.
|
||||||
|
/// - Returns: A configured `ProviderConfiguration`.
|
||||||
|
public static func fromEnvironment(
|
||||||
|
provider: Provider,
|
||||||
|
environmentVariable: String,
|
||||||
|
model: String
|
||||||
|
) -> ProviderConfiguration {
|
||||||
|
let keyProvider: @Sendable () -> String = {
|
||||||
|
ProcessInfo.processInfo.environment[environmentVariable] ?? ""
|
||||||
|
}
|
||||||
|
|
||||||
|
switch provider {
|
||||||
|
case .openAI:
|
||||||
|
return .openAI(apiKeyProvider: keyProvider, model: model)
|
||||||
|
case .anthropic:
|
||||||
|
return .anthropic(apiKeyProvider: keyProvider, model: model)
|
||||||
|
case .gemini:
|
||||||
|
return .gemini(apiKeyProvider: keyProvider, model: model)
|
||||||
|
case .openAICompatible:
|
||||||
|
return .openAICompatible(
|
||||||
|
apiKeyProvider: keyProvider,
|
||||||
|
model: model,
|
||||||
|
baseURL: OpenAILanguageModel.defaultBaseURL
|
||||||
|
)
|
||||||
|
case .ollama:
|
||||||
|
return .ollama(model: model)
|
||||||
|
case .llamaCpp:
|
||||||
|
return .llamaCpp(model: model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - ChatEngine Convenience Init
|
||||||
|
|
||||||
|
extension ChatEngine {
|
||||||
|
|
||||||
|
/// Creates a ChatEngine using a ``ProviderConfiguration``.
|
||||||
|
///
|
||||||
|
/// This is the most convenient way to set up a ChatEngine with a
|
||||||
|
/// cloud provider:
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// let engine = ChatEngine(
|
||||||
|
/// database: myDB,
|
||||||
|
/// provider: .openAI(apiKey: "sk-...", model: "gpt-4o")
|
||||||
|
/// )
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
|
||||||
|
/// - provider: The provider configuration.
|
||||||
|
/// - allowlist: SQL operations the LLM may generate. Defaults to read-only.
|
||||||
|
/// - configuration: Engine configuration for timeouts, context window, validators, etc.
|
||||||
|
public convenience init(
|
||||||
|
database: any DatabaseWriter,
|
||||||
|
provider: ProviderConfiguration,
|
||||||
|
allowlist: OperationAllowlist = .readOnly,
|
||||||
|
configuration: ChatEngineConfiguration = .default
|
||||||
|
) {
|
||||||
|
self.init(
|
||||||
|
database: database,
|
||||||
|
model: provider.makeModel(),
|
||||||
|
allowlist: allowlist,
|
||||||
|
configuration: configuration
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
114
Sources/SwiftDBAI/Config/QueryValidator.swift
Normal file
114
Sources/SwiftDBAI/Config/QueryValidator.swift
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
// QueryValidator.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Extensible query validation protocol for custom pre-execution checks.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
/// A protocol for custom SQL query validation.
|
||||||
|
///
|
||||||
|
/// Implement this protocol to add domain-specific validation rules that run
|
||||||
|
/// after the built-in allowlist and safety checks. Validators receive the
|
||||||
|
/// parsed SQL string and its detected operation type.
|
||||||
|
///
|
||||||
|
/// Example — restrict queries to specific tables:
|
||||||
|
/// ```swift
|
||||||
|
/// struct TableAllowlistValidator: QueryValidator {
|
||||||
|
/// let allowedTables: Set<String>
|
||||||
|
///
|
||||||
|
/// func validate(sql: String, operation: SQLOperation) throws {
|
||||||
|
/// let upper = sql.uppercased()
|
||||||
|
/// for table in allowedTables {
|
||||||
|
/// // Simple check — real implementation might parse FROM/JOIN clauses
|
||||||
|
/// if upper.contains(table.uppercased()) { return }
|
||||||
|
/// }
|
||||||
|
/// throw QueryValidationError.rejected("Query references tables outside the allowlist.")
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
public protocol QueryValidator: Sendable {
|
||||||
|
/// Validates a SQL query before execution.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - sql: The cleaned SQL statement about to be executed.
|
||||||
|
/// - operation: The detected operation type (SELECT, INSERT, etc.).
|
||||||
|
/// - Throws: ``QueryValidationError`` or any `Error` to reject the query.
|
||||||
|
func validate(sql: String, operation: SQLOperation) throws
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Errors thrown by custom ``QueryValidator`` implementations.
|
||||||
|
public enum QueryValidationError: Error, LocalizedError, Sendable, Equatable {
|
||||||
|
/// The query was rejected by a custom validator with the given reason.
|
||||||
|
case rejected(String)
|
||||||
|
|
||||||
|
public var errorDescription: String? {
|
||||||
|
switch self {
|
||||||
|
case .rejected(let reason):
|
||||||
|
return "Query rejected: \(reason)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Built-in Validators
|
||||||
|
|
||||||
|
/// A validator that restricts queries to a specific set of table names.
|
||||||
|
///
|
||||||
|
/// This performs a simple keyword check — it verifies that the SQL references
|
||||||
|
/// at least one of the allowed tables. This is a best-effort check, not a
|
||||||
|
/// full SQL parser.
|
||||||
|
public struct TableAllowlistValidator: QueryValidator {
|
||||||
|
/// The set of table names queries are allowed to reference.
|
||||||
|
public let allowedTables: Set<String>
|
||||||
|
|
||||||
|
/// Creates a validator with the given allowed table names.
|
||||||
|
public init(allowedTables: Set<String>) {
|
||||||
|
self.allowedTables = allowedTables
|
||||||
|
}
|
||||||
|
|
||||||
|
public func validate(sql: String, operation: SQLOperation) throws {
|
||||||
|
let upper = sql.uppercased()
|
||||||
|
let found = allowedTables.contains { table in
|
||||||
|
let pattern = table.uppercased()
|
||||||
|
return upper.contains(pattern)
|
||||||
|
}
|
||||||
|
guard found else {
|
||||||
|
throw QueryValidationError.rejected(
|
||||||
|
"Query does not reference any allowed tables: \(allowedTables.sorted().joined(separator: ", "))"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A validator that enforces a maximum row limit on SELECT queries
|
||||||
|
/// by checking for a LIMIT clause.
|
||||||
|
public struct MaxRowLimitValidator: QueryValidator {
|
||||||
|
/// The maximum number of rows allowed.
|
||||||
|
public let maxRows: Int
|
||||||
|
|
||||||
|
/// Creates a validator that requires SELECT queries to include a LIMIT clause
|
||||||
|
/// not exceeding `maxRows`.
|
||||||
|
public init(maxRows: Int) {
|
||||||
|
self.maxRows = maxRows
|
||||||
|
}
|
||||||
|
|
||||||
|
public func validate(sql: String, operation: SQLOperation) throws {
|
||||||
|
guard operation == .select else { return }
|
||||||
|
|
||||||
|
let upper = sql.uppercased()
|
||||||
|
// Check if LIMIT is present
|
||||||
|
guard let limitRange = upper.range(of: #"LIMIT\s+(\d+)"#, options: .regularExpression) else {
|
||||||
|
throw QueryValidationError.rejected(
|
||||||
|
"SELECT queries must include a LIMIT clause (max \(maxRows) rows)."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract the limit value
|
||||||
|
let limitSubstring = upper[limitRange]
|
||||||
|
let digits = limitSubstring.components(separatedBy: .decimalDigits.inverted).joined()
|
||||||
|
if let value = Int(digits), value > maxRows {
|
||||||
|
throw QueryValidationError.rejected(
|
||||||
|
"LIMIT \(value) exceeds the maximum allowed (\(maxRows))."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
677
Sources/SwiftDBAI/Engine/ChatEngine.swift
Normal file
677
Sources/SwiftDBAI/Engine/ChatEngine.swift
Normal file
@@ -0,0 +1,677 @@
|
|||||||
|
// ChatEngine.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Orchestrates the conversation loop: user message → SQL generation → query
|
||||||
|
// execution → result summarization → response.
|
||||||
|
|
||||||
|
import AnyLanguageModel
|
||||||
|
import Foundation
|
||||||
|
import GRDB
|
||||||
|
|
||||||
|
/// A message in the chat conversation.
|
||||||
|
public struct ChatMessage: Sendable, Identifiable, Equatable {
|
||||||
|
public let id: UUID
|
||||||
|
public let role: Role
|
||||||
|
public let content: String
|
||||||
|
public let queryResult: QueryResult?
|
||||||
|
public let sql: String?
|
||||||
|
public let timestamp: Date
|
||||||
|
/// The typed error, if this is an error message.
|
||||||
|
public let error: SwiftDBAIError?
|
||||||
|
|
||||||
|
public enum Role: String, Sendable, Equatable {
|
||||||
|
case user
|
||||||
|
case assistant
|
||||||
|
case error
|
||||||
|
}
|
||||||
|
|
||||||
|
public init(
|
||||||
|
id: UUID = UUID(),
|
||||||
|
role: Role,
|
||||||
|
content: String,
|
||||||
|
queryResult: QueryResult? = nil,
|
||||||
|
sql: String? = nil,
|
||||||
|
timestamp: Date = Date(),
|
||||||
|
error: SwiftDBAIError? = nil
|
||||||
|
) {
|
||||||
|
self.id = id
|
||||||
|
self.role = role
|
||||||
|
self.content = content
|
||||||
|
self.queryResult = queryResult
|
||||||
|
self.sql = sql
|
||||||
|
self.timestamp = timestamp
|
||||||
|
self.error = error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The response returned by `ChatEngine.send(_:)`.
|
||||||
|
public struct ChatResponse: Sendable {
|
||||||
|
/// The natural language summary of the result.
|
||||||
|
public let summary: String
|
||||||
|
|
||||||
|
/// The SQL that was generated and executed, if any.
|
||||||
|
public let sql: String?
|
||||||
|
|
||||||
|
/// The raw query result, if a query was executed.
|
||||||
|
public let queryResult: QueryResult?
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Headless engine that orchestrates the full chat-with-database pipeline.
|
||||||
|
///
|
||||||
|
/// The engine:
|
||||||
|
/// 1. Introspects the database schema (once, lazily)
|
||||||
|
/// 2. Builds a system prompt with schema context
|
||||||
|
/// 3. Sends the user's question to the LLM to generate SQL
|
||||||
|
/// 4. Validates the SQL against the operation allowlist
|
||||||
|
/// 5. Executes the SQL via GRDB
|
||||||
|
/// 6. Summarizes results using `TextSummaryRenderer`
|
||||||
|
/// 7. Returns the summary (and raw data) to the caller
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
/// ```swift
|
||||||
|
/// let engine = ChatEngine(
|
||||||
|
/// database: myDatabasePool,
|
||||||
|
/// model: myLanguageModel
|
||||||
|
/// )
|
||||||
|
/// let response = try await engine.send("How many users signed up this week?")
|
||||||
|
/// print(response.summary) // "There were 42 new signups this week."
|
||||||
|
/// ```
|
||||||
|
public final class ChatEngine: @unchecked Sendable {
|
||||||
|
|
||||||
|
// MARK: - Dependencies
|
||||||
|
|
||||||
|
private let database: any DatabaseWriter
|
||||||
|
private let model: any LanguageModel
|
||||||
|
private let allowlist: OperationAllowlist
|
||||||
|
private let mutationPolicy: MutationPolicy?
|
||||||
|
private let configuration: ChatEngineConfiguration
|
||||||
|
private let summaryRenderer: TextSummaryRenderer
|
||||||
|
private let sqlParser: SQLQueryParser
|
||||||
|
|
||||||
|
/// Optional delegate for intercepting destructive operations and observing SQL execution.
|
||||||
|
private let delegate: (any ToolExecutionDelegate)?
|
||||||
|
|
||||||
|
// MARK: - State
|
||||||
|
|
||||||
|
private var schema: DatabaseSchema?
|
||||||
|
private var conversationHistory: [ChatMessage] = []
|
||||||
|
private let lock = NSLock()
|
||||||
|
|
||||||
|
// MARK: - Initialization
|
||||||
|
|
||||||
|
/// Creates a new ChatEngine with a full configuration object.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
|
||||||
|
/// - model: Any `AnyLanguageModel`-compatible language model.
|
||||||
|
/// - allowlist: SQL operations the LLM may generate. Defaults to read-only (SELECT only).
|
||||||
|
/// - configuration: Engine configuration for timeouts, context window, validators, etc.
|
||||||
|
/// - delegate: Optional delegate for confirming destructive operations and observing SQL execution.
|
||||||
|
public init(
|
||||||
|
database: any DatabaseWriter,
|
||||||
|
model: any LanguageModel,
|
||||||
|
allowlist: OperationAllowlist = .readOnly,
|
||||||
|
configuration: ChatEngineConfiguration = .default,
|
||||||
|
delegate: (any ToolExecutionDelegate)? = nil
|
||||||
|
) {
|
||||||
|
self.database = database
|
||||||
|
self.model = model
|
||||||
|
self.allowlist = allowlist
|
||||||
|
self.mutationPolicy = nil
|
||||||
|
self.configuration = configuration
|
||||||
|
self.delegate = delegate
|
||||||
|
self.summaryRenderer = TextSummaryRenderer(
|
||||||
|
model: model,
|
||||||
|
maxRowsInPrompt: configuration.maxSummaryRows
|
||||||
|
)
|
||||||
|
self.sqlParser = SQLQueryParser(allowlist: allowlist)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new ChatEngine with a `MutationPolicy` for table-level control.
|
||||||
|
///
|
||||||
|
/// This initializer provides fine-grained control over which mutations are
|
||||||
|
/// allowed on which tables. The policy's operation allowlist is used for
|
||||||
|
/// SQL validation, and table-level restrictions are enforced during parsing.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
|
||||||
|
/// - model: Any `AnyLanguageModel`-compatible language model.
|
||||||
|
/// - mutationPolicy: Controls which operations are allowed on which tables.
|
||||||
|
/// - configuration: Engine configuration for timeouts, context window, validators, etc.
|
||||||
|
/// - delegate: Optional delegate for confirming destructive operations and observing SQL execution.
|
||||||
|
public init(
|
||||||
|
database: any DatabaseWriter,
|
||||||
|
model: any LanguageModel,
|
||||||
|
mutationPolicy: MutationPolicy,
|
||||||
|
configuration: ChatEngineConfiguration = .default,
|
||||||
|
delegate: (any ToolExecutionDelegate)? = nil
|
||||||
|
) {
|
||||||
|
self.database = database
|
||||||
|
self.model = model
|
||||||
|
self.allowlist = mutationPolicy.operationAllowlist
|
||||||
|
self.mutationPolicy = mutationPolicy
|
||||||
|
self.configuration = configuration
|
||||||
|
self.delegate = delegate
|
||||||
|
self.summaryRenderer = TextSummaryRenderer(
|
||||||
|
model: model,
|
||||||
|
maxRowsInPrompt: configuration.maxSummaryRows
|
||||||
|
)
|
||||||
|
self.sqlParser = SQLQueryParser(mutationPolicy: mutationPolicy)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new ChatEngine with individual parameters (convenience).
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
|
||||||
|
/// - model: Any `AnyLanguageModel`-compatible language model.
|
||||||
|
/// - allowlist: SQL operations the LLM may generate. Defaults to read-only (SELECT only).
|
||||||
|
/// - additionalContext: Optional extra instructions for the LLM system prompt.
|
||||||
|
/// - maxSummaryRows: Maximum rows to include when summarizing results (default: 50).
|
||||||
|
public convenience init(
|
||||||
|
database: any DatabaseWriter,
|
||||||
|
model: any LanguageModel,
|
||||||
|
allowlist: OperationAllowlist,
|
||||||
|
additionalContext: String?,
|
||||||
|
maxSummaryRows: Int = 50
|
||||||
|
) {
|
||||||
|
let config = ChatEngineConfiguration(
|
||||||
|
maxSummaryRows: maxSummaryRows,
|
||||||
|
additionalContext: additionalContext
|
||||||
|
)
|
||||||
|
self.init(
|
||||||
|
database: database,
|
||||||
|
model: model,
|
||||||
|
allowlist: allowlist,
|
||||||
|
configuration: config
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Public API
|
||||||
|
|
||||||
|
/// Sends a natural language message and returns a summarized response.
|
||||||
|
///
|
||||||
|
/// This is the primary entry point. The engine will:
|
||||||
|
/// 1. Introspect the schema if not yet cached
|
||||||
|
/// 2. Ask the LLM to generate SQL
|
||||||
|
/// 3. Validate the SQL against the allowlist and custom validators
|
||||||
|
/// 4. Execute the SQL (with timeout if configured)
|
||||||
|
/// 5. Summarize the results using `TextSummaryRenderer`
|
||||||
|
///
|
||||||
|
/// All errors are caught and mapped to a distinct ``SwiftDBAIError`` case
|
||||||
|
/// so callers always receive a typed, user-friendly error with a localized
|
||||||
|
/// description suitable for display in a chat UI.
|
||||||
|
///
|
||||||
|
/// - Parameter message: The user's natural language question or command.
|
||||||
|
/// - Returns: A `ChatResponse` containing the summary, SQL, and raw result.
|
||||||
|
/// - Throws: ``SwiftDBAIError`` for every failure mode.
|
||||||
|
public func send(_ message: String) async throws -> ChatResponse {
|
||||||
|
// 1. Ensure schema is introspected
|
||||||
|
let schema: DatabaseSchema
|
||||||
|
do {
|
||||||
|
schema = try await ensureSchema()
|
||||||
|
} catch let error as SwiftDBAIError {
|
||||||
|
throw error
|
||||||
|
} catch {
|
||||||
|
throw SwiftDBAIError.schemaIntrospectionFailed(reason: error.localizedDescription)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for empty schema
|
||||||
|
if schema.tableNames.isEmpty {
|
||||||
|
throw SwiftDBAIError.emptySchema
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Build prompt and get raw LLM response
|
||||||
|
let promptBuilder = PromptBuilder(
|
||||||
|
schema: schema,
|
||||||
|
allowlist: allowlist,
|
||||||
|
additionalContext: configuration.additionalContext
|
||||||
|
)
|
||||||
|
|
||||||
|
let rawLLMResponse: String
|
||||||
|
do {
|
||||||
|
rawLLMResponse = try await generateRawResponse(
|
||||||
|
question: message,
|
||||||
|
promptBuilder: promptBuilder
|
||||||
|
)
|
||||||
|
} catch let error as SwiftDBAIError {
|
||||||
|
throw error
|
||||||
|
} catch {
|
||||||
|
throw SwiftDBAIError.llmFailure(reason: error.localizedDescription)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Parse and validate SQL through SQLQueryParser
|
||||||
|
let parsed: ParsedSQL
|
||||||
|
do {
|
||||||
|
parsed = try sqlParser.parse(rawLLMResponse)
|
||||||
|
} catch let error as SQLParsingError {
|
||||||
|
throw error.toSwiftDBAIError(rawResponse: rawLLMResponse)
|
||||||
|
} catch let error as SwiftDBAIError {
|
||||||
|
throw error
|
||||||
|
} catch {
|
||||||
|
throw SwiftDBAIError.invalidSQL(sql: rawLLMResponse, reason: error.localizedDescription)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Run custom validators
|
||||||
|
do {
|
||||||
|
try runCustomValidators(parsed: parsed)
|
||||||
|
} catch let error as QueryValidationError {
|
||||||
|
throw error
|
||||||
|
} catch let error as SwiftDBAIError {
|
||||||
|
throw error
|
||||||
|
} catch {
|
||||||
|
throw SwiftDBAIError.queryRejected(reason: error.localizedDescription)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Handle confirmation-required operations (DELETE, DROP, etc.)
|
||||||
|
if parsed.requiresConfirmation {
|
||||||
|
if let delegate = self.delegate {
|
||||||
|
// Build context for the delegate
|
||||||
|
let classification = classifySQL(parsed.sql)
|
||||||
|
let context = DestructiveOperationContext(
|
||||||
|
sql: parsed.sql,
|
||||||
|
statementKind: detectStatementKind(parsed.sql) ?? .delete,
|
||||||
|
classification: classification,
|
||||||
|
description: "Execute \(parsed.operation.rawValue.uppercased()) operation: \(parsed.sql)",
|
||||||
|
targetTable: extractTargetTableForDelegate(from: parsed.sql, operation: parsed.operation)
|
||||||
|
)
|
||||||
|
// Ask the delegate for approval
|
||||||
|
let approved = await delegate.confirmDestructiveOperation(context)
|
||||||
|
if !approved {
|
||||||
|
throw SwiftDBAIError.confirmationRequired(
|
||||||
|
sql: parsed.sql,
|
||||||
|
operation: parsed.operation.rawValue
|
||||||
|
)
|
||||||
|
}
|
||||||
|
// Delegate approved — fall through to execution
|
||||||
|
} else {
|
||||||
|
// No delegate — throw confirmation required so caller can handle it
|
||||||
|
throw SwiftDBAIError.confirmationRequired(
|
||||||
|
sql: parsed.sql,
|
||||||
|
operation: parsed.operation.rawValue
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. Execute the SQL (with timeout if configured)
|
||||||
|
let result: QueryResult
|
||||||
|
do {
|
||||||
|
let classification = classifySQL(parsed.sql)
|
||||||
|
await delegate?.willExecuteSQL(parsed.sql, classification: classification)
|
||||||
|
result = try await executeSQLWithTimeout(parsed.sql)
|
||||||
|
await delegate?.didExecuteSQL(parsed.sql, success: true)
|
||||||
|
} catch let error as SwiftDBAIError {
|
||||||
|
await delegate?.didExecuteSQL(parsed.sql, success: false)
|
||||||
|
throw error
|
||||||
|
} catch let error as ChatEngineError {
|
||||||
|
await delegate?.didExecuteSQL(parsed.sql, success: false)
|
||||||
|
// Map internal ChatEngineError (e.g. from timeout) to SwiftDBAIError
|
||||||
|
throw error.toSwiftDBAIError()
|
||||||
|
} catch {
|
||||||
|
await delegate?.didExecuteSQL(parsed.sql, success: false)
|
||||||
|
throw SwiftDBAIError.databaseError(reason: error.localizedDescription)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 7. Summarize the result using TextSummaryRenderer
|
||||||
|
let summary: String
|
||||||
|
do {
|
||||||
|
summary = try await summaryRenderer.summarize(
|
||||||
|
result: result,
|
||||||
|
userQuestion: message
|
||||||
|
)
|
||||||
|
} catch let error as SwiftDBAIError {
|
||||||
|
throw error
|
||||||
|
} catch {
|
||||||
|
throw SwiftDBAIError.llmFailure(reason: "Summarization failed: \(error.localizedDescription)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 8. Record conversation history
|
||||||
|
let userMessage = ChatMessage(role: .user, content: message)
|
||||||
|
let assistantMessage = ChatMessage(
|
||||||
|
role: .assistant,
|
||||||
|
content: summary,
|
||||||
|
queryResult: result,
|
||||||
|
sql: parsed.sql
|
||||||
|
)
|
||||||
|
lock.withLock {
|
||||||
|
conversationHistory.append(userMessage)
|
||||||
|
conversationHistory.append(assistantMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ChatResponse(
|
||||||
|
summary: summary,
|
||||||
|
sql: parsed.sql,
|
||||||
|
queryResult: result
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sends a natural language message, executing a previously confirmed destructive operation.
|
||||||
|
///
|
||||||
|
/// Call this after receiving a `confirmationRequired` error and the user has confirmed.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - message: The original user message (for history recording).
|
||||||
|
/// - confirmedSQL: The SQL that was confirmed by the user.
|
||||||
|
/// - Returns: A `ChatResponse` with the result.
|
||||||
|
public func sendConfirmed(_ message: String, confirmedSQL: String) async throws -> ChatResponse {
|
||||||
|
let result: QueryResult
|
||||||
|
do {
|
||||||
|
let classification = classifySQL(confirmedSQL)
|
||||||
|
await delegate?.willExecuteSQL(confirmedSQL, classification: classification)
|
||||||
|
result = try await executeSQLWithTimeout(confirmedSQL)
|
||||||
|
await delegate?.didExecuteSQL(confirmedSQL, success: true)
|
||||||
|
} catch let error as SwiftDBAIError {
|
||||||
|
await delegate?.didExecuteSQL(confirmedSQL, success: false)
|
||||||
|
throw error
|
||||||
|
} catch let error as ChatEngineError {
|
||||||
|
await delegate?.didExecuteSQL(confirmedSQL, success: false)
|
||||||
|
throw error.toSwiftDBAIError()
|
||||||
|
} catch {
|
||||||
|
await delegate?.didExecuteSQL(confirmedSQL, success: false)
|
||||||
|
throw SwiftDBAIError.databaseError(reason: error.localizedDescription)
|
||||||
|
}
|
||||||
|
|
||||||
|
let summary: String
|
||||||
|
do {
|
||||||
|
summary = try await summaryRenderer.summarize(
|
||||||
|
result: result,
|
||||||
|
userQuestion: message
|
||||||
|
)
|
||||||
|
} catch let error as SwiftDBAIError {
|
||||||
|
throw error
|
||||||
|
} catch {
|
||||||
|
throw SwiftDBAIError.llmFailure(reason: "Summarization failed: \(error.localizedDescription)")
|
||||||
|
}
|
||||||
|
|
||||||
|
let userMessage = ChatMessage(role: .user, content: message)
|
||||||
|
let assistantMessage = ChatMessage(
|
||||||
|
role: .assistant,
|
||||||
|
content: summary,
|
||||||
|
queryResult: result,
|
||||||
|
sql: confirmedSQL
|
||||||
|
)
|
||||||
|
lock.withLock {
|
||||||
|
conversationHistory.append(userMessage)
|
||||||
|
conversationHistory.append(assistantMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ChatResponse(
|
||||||
|
summary: summary,
|
||||||
|
sql: confirmedSQL,
|
||||||
|
queryResult: result
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the current conversation history.
|
||||||
|
public var messages: [ChatMessage] {
|
||||||
|
lock.lock()
|
||||||
|
defer { lock.unlock() }
|
||||||
|
return conversationHistory
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Eagerly introspects the database schema so it's ready before the first query.
|
||||||
|
///
|
||||||
|
/// Call this at view-appear time to pre-warm the schema cache. If the schema
|
||||||
|
/// is already cached, this returns immediately. The returned `DatabaseSchema`
|
||||||
|
/// can be used to display table/column info in the UI.
|
||||||
|
///
|
||||||
|
/// - Returns: The introspected `DatabaseSchema`.
|
||||||
|
@discardableResult
|
||||||
|
public func prepareSchema() async throws -> DatabaseSchema {
|
||||||
|
try await ensureSchema()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of tables discovered during schema introspection.
|
||||||
|
/// Returns `nil` if the schema has not been introspected yet.
|
||||||
|
public var tableCount: Int? {
|
||||||
|
lock.withLock { schema?.tableNames.count }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The cached schema, if introspection has completed.
|
||||||
|
public var cachedSchema: DatabaseSchema? {
|
||||||
|
lock.withLock { schema }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clears the conversation history and cached schema.
|
||||||
|
///
|
||||||
|
/// After calling this, the next `send(_:)` call will re-introspect the
|
||||||
|
/// schema. Use ``clearHistory()`` if you only want to reset the conversation
|
||||||
|
/// while keeping the cached schema.
|
||||||
|
public func reset() {
|
||||||
|
lock.withLock {
|
||||||
|
conversationHistory.removeAll()
|
||||||
|
schema = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clears only the conversation history, keeping the cached schema.
|
||||||
|
///
|
||||||
|
/// This is useful when you want to start a fresh conversation thread
|
||||||
|
/// without re-introspecting the database. The schema cache remains valid
|
||||||
|
/// as long as the database structure hasn't changed.
|
||||||
|
public func clearHistory() {
|
||||||
|
lock.withLock {
|
||||||
|
conversationHistory.removeAll()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The current engine configuration.
|
||||||
|
public var currentConfiguration: ChatEngineConfiguration {
|
||||||
|
configuration
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Internal Helpers (visible for testing)
|
||||||
|
|
||||||
|
/// Ensures the database schema is introspected and cached.
|
||||||
|
func ensureSchema() async throws -> DatabaseSchema {
|
||||||
|
if let cached = lock.withLock({ schema }) {
|
||||||
|
return cached
|
||||||
|
}
|
||||||
|
|
||||||
|
let introspected = try await SchemaIntrospector.introspect(database: database)
|
||||||
|
|
||||||
|
lock.withLock { schema = introspected }
|
||||||
|
|
||||||
|
return introspected
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Asks the LLM to generate SQL from a natural language question.
|
||||||
|
/// Returns the raw LLM response text (before parsing).
|
||||||
|
///
|
||||||
|
/// Uses the configured ``ChatEngineConfiguration/contextWindowSize`` to limit
|
||||||
|
/// how many conversation messages are included as context for the LLM.
|
||||||
|
private func generateRawResponse(
|
||||||
|
question: String,
|
||||||
|
promptBuilder: PromptBuilder
|
||||||
|
) async throws -> String {
|
||||||
|
let instructions = promptBuilder.buildSystemInstructions()
|
||||||
|
|
||||||
|
// Build user prompt — include full conversation history for follow-ups
|
||||||
|
// Respect context window: only use recent messages for context
|
||||||
|
let userPrompt: String
|
||||||
|
let historySlice = lock.withLock { () -> [ChatMessage] in
|
||||||
|
Array(contextWindowSlice())
|
||||||
|
}
|
||||||
|
|
||||||
|
if historySlice.isEmpty {
|
||||||
|
userPrompt = promptBuilder.buildUserPrompt(question)
|
||||||
|
} else {
|
||||||
|
userPrompt = promptBuilder.buildConversationPrompt(
|
||||||
|
question,
|
||||||
|
history: historySlice
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
let session = LanguageModelSession(
|
||||||
|
model: model,
|
||||||
|
instructions: instructions + "\n\nRespond with ONLY the SQL query. No explanations, no markdown, no code fences."
|
||||||
|
)
|
||||||
|
|
||||||
|
let response = try await session.respond(to: userPrompt)
|
||||||
|
return response.content.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the conversation history slice within the configured context window.
|
||||||
|
/// Must be called within a `lock.withLock` closure.
|
||||||
|
private func contextWindowSlice() -> ArraySlice<ChatMessage> {
|
||||||
|
guard let windowSize = configuration.contextWindowSize else {
|
||||||
|
return conversationHistory[...]
|
||||||
|
}
|
||||||
|
let count = conversationHistory.count
|
||||||
|
let start = max(0, count - windowSize)
|
||||||
|
return conversationHistory[start...]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Runs all custom validators from the configuration against the parsed SQL.
|
||||||
|
private func runCustomValidators(parsed: ParsedSQL) throws {
|
||||||
|
for validator in configuration.validators {
|
||||||
|
try validator.validate(sql: parsed.sql, operation: parsed.operation)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extracts the target table name from a SQL statement for delegate context.
|
||||||
|
private func extractTargetTableForDelegate(from sql: String, operation: SQLOperation) -> String? {
|
||||||
|
let pattern: String
|
||||||
|
switch operation {
|
||||||
|
case .insert:
|
||||||
|
pattern = #"INSERT\s+INTO\s+[`"\[]?(\w+)[`"\]]?"#
|
||||||
|
case .update:
|
||||||
|
pattern = #"UPDATE\s+[`"\[]?(\w+)[`"\]]?"#
|
||||||
|
case .delete:
|
||||||
|
pattern = #"DELETE\s+FROM\s+[`"\[]?(\w+)[`"\]]?"#
|
||||||
|
case .select:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
guard let regex = try? NSRegularExpression(pattern: pattern, options: .caseInsensitive) else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
let range = NSRange(sql.startIndex..., in: sql)
|
||||||
|
guard let match = regex.firstMatch(in: sql, range: range),
|
||||||
|
match.numberOfRanges > 1,
|
||||||
|
let groupRange = Range(match.range(at: 1), in: sql) else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return String(sql[groupRange])
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Executes SQL with the configured timeout, if any.
|
||||||
|
private func executeSQLWithTimeout(_ sql: String) async throws -> QueryResult {
|
||||||
|
guard let timeout = configuration.queryTimeout else {
|
||||||
|
return try await executeSQL(sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
return try await withThrowingTaskGroup(of: QueryResult.self) { group in
|
||||||
|
group.addTask {
|
||||||
|
try await self.executeSQL(sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
group.addTask {
|
||||||
|
try await Task.sleep(for: .seconds(timeout))
|
||||||
|
throw ChatEngineError.queryTimedOut(seconds: timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return whichever finishes first
|
||||||
|
let result = try await group.next()!
|
||||||
|
group.cancelAll()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Executes SQL against the database and returns a `QueryResult`.
|
||||||
|
private func executeSQL(_ sql: String) async throws -> QueryResult {
|
||||||
|
let trimmed = sql.trimmingCharacters(in: .whitespacesAndNewlines).uppercased()
|
||||||
|
let isSelect = trimmed.hasPrefix("SELECT") || trimmed.hasPrefix("WITH")
|
||||||
|
|
||||||
|
let startTime = CFAbsoluteTimeGetCurrent()
|
||||||
|
|
||||||
|
if isSelect {
|
||||||
|
let result = try await database.read { db -> (columns: [String], rows: [[String: QueryResult.Value]]) in
|
||||||
|
let statement = try db.makeStatement(sql: sql)
|
||||||
|
let columnNames = statement.columnNames
|
||||||
|
|
||||||
|
var rows: [[String: QueryResult.Value]] = []
|
||||||
|
let cursor = try Row.fetchCursor(statement)
|
||||||
|
while let row = try cursor.next() {
|
||||||
|
var dict: [String: QueryResult.Value] = [:]
|
||||||
|
for col in columnNames {
|
||||||
|
dict[col] = Self.extractValue(row: row, column: col)
|
||||||
|
}
|
||||||
|
rows.append(dict)
|
||||||
|
}
|
||||||
|
return (columns: columnNames, rows: rows)
|
||||||
|
}
|
||||||
|
|
||||||
|
let elapsed = CFAbsoluteTimeGetCurrent() - startTime
|
||||||
|
|
||||||
|
return QueryResult(
|
||||||
|
columns: result.columns,
|
||||||
|
rows: result.rows,
|
||||||
|
sql: sql,
|
||||||
|
executionTime: elapsed
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
// Mutation query
|
||||||
|
let affected = try await database.write { db -> Int in
|
||||||
|
try db.execute(sql: sql)
|
||||||
|
return db.changesCount
|
||||||
|
}
|
||||||
|
|
||||||
|
let elapsed = CFAbsoluteTimeGetCurrent() - startTime
|
||||||
|
|
||||||
|
return QueryResult(
|
||||||
|
columns: [],
|
||||||
|
rows: [],
|
||||||
|
sql: sql,
|
||||||
|
executionTime: elapsed,
|
||||||
|
rowsAffected: affected
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extracts a `QueryResult.Value` from a GRDB `Row` for the given column.
|
||||||
|
private static func extractValue(row: Row, column: String) -> QueryResult.Value {
|
||||||
|
let dbValue: DatabaseValue = row[column]
|
||||||
|
switch dbValue.storage {
|
||||||
|
case .null:
|
||||||
|
return .null
|
||||||
|
case .int64(let i):
|
||||||
|
return .integer(i)
|
||||||
|
case .double(let d):
|
||||||
|
return .real(d)
|
||||||
|
case .string(let s):
|
||||||
|
return .text(s)
|
||||||
|
case .blob(let data):
|
||||||
|
return .blob(data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Errors
|
||||||
|
|
||||||
|
/// Errors that can occur during ChatEngine operations.
|
||||||
|
public enum ChatEngineError: Error, LocalizedError, Sendable {
|
||||||
|
/// SQL parsing/extraction from LLM response failed.
|
||||||
|
case sqlParsingFailed(SQLParsingError)
|
||||||
|
/// A destructive operation requires user confirmation before execution.
|
||||||
|
case confirmationRequired(sql: String, operation: SQLOperation)
|
||||||
|
/// Schema introspection failed.
|
||||||
|
case schemaIntrospectionFailed(String)
|
||||||
|
/// The SQL query exceeded the configured timeout.
|
||||||
|
case queryTimedOut(seconds: TimeInterval)
|
||||||
|
/// A custom query validator rejected the query.
|
||||||
|
case validationFailed(String)
|
||||||
|
|
||||||
|
public var errorDescription: String? {
|
||||||
|
switch self {
|
||||||
|
case .sqlParsingFailed(let parsingError):
|
||||||
|
return "SQL parsing failed: \(parsingError.description)"
|
||||||
|
case .confirmationRequired(let sql, let op):
|
||||||
|
return "The \(op.rawValue.uppercased()) operation requires confirmation: \(sql)"
|
||||||
|
case .schemaIntrospectionFailed(let reason):
|
||||||
|
return "Failed to introspect database schema: \(reason)"
|
||||||
|
case .queryTimedOut(let seconds):
|
||||||
|
return "Query timed out after \(Int(seconds)) seconds."
|
||||||
|
case .validationFailed(let reason):
|
||||||
|
return "Query validation failed: \(reason)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
288
Sources/SwiftDBAI/Engine/ToolExecutionDelegate.swift
Normal file
288
Sources/SwiftDBAI/Engine/ToolExecutionDelegate.swift
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
// ToolExecutionDelegate.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Delegate protocol for controlling SQL tool execution, including
|
||||||
|
// confirmation of destructive operations before they reach the database.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
// MARK: - Destructive SQL Classification
|
||||||
|
|
||||||
|
/// Classifies SQL statements by their destructive potential.
|
||||||
|
///
|
||||||
|
/// A statement is considered **destructive** if it modifies or removes data
|
||||||
|
/// or schema objects. The classification drives the confirmation flow:
|
||||||
|
/// destructive statements require explicit user approval via
|
||||||
|
/// ``ToolExecutionDelegate/confirmDestructiveOperation(_:)``.
|
||||||
|
public enum DestructiveClassification: Sendable, Equatable {
|
||||||
|
/// The statement is read-only (e.g. SELECT). No confirmation needed.
|
||||||
|
case safe
|
||||||
|
|
||||||
|
/// The statement modifies existing data (INSERT, UPDATE).
|
||||||
|
case mutation(SQLStatementKind)
|
||||||
|
|
||||||
|
/// The statement deletes data or alters/drops schema objects.
|
||||||
|
/// These always require confirmation, even when the operation is allowed.
|
||||||
|
case destructive(SQLStatementKind)
|
||||||
|
|
||||||
|
/// Returns `true` when the statement requires user confirmation.
|
||||||
|
public var requiresConfirmation: Bool {
|
||||||
|
switch self {
|
||||||
|
case .safe:
|
||||||
|
return false
|
||||||
|
case .mutation:
|
||||||
|
return false
|
||||||
|
case .destructive:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns `true` when the statement modifies data or schema in any way.
|
||||||
|
public var isMutating: Bool {
|
||||||
|
switch self {
|
||||||
|
case .safe:
|
||||||
|
return false
|
||||||
|
case .mutation, .destructive:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The kind of SQL statement, used for classification and display.
|
||||||
|
public enum SQLStatementKind: String, Sendable, Hashable, CaseIterable {
|
||||||
|
case select = "SELECT"
|
||||||
|
case insert = "INSERT"
|
||||||
|
case update = "UPDATE"
|
||||||
|
case delete = "DELETE"
|
||||||
|
case drop = "DROP"
|
||||||
|
case alter = "ALTER"
|
||||||
|
case truncate = "TRUNCATE"
|
||||||
|
|
||||||
|
/// All kinds that are classified as destructive.
|
||||||
|
public static let destructiveKinds: Set<SQLStatementKind> = [
|
||||||
|
.delete, .drop, .alter, .truncate
|
||||||
|
]
|
||||||
|
|
||||||
|
/// All kinds that are classified as mutations (data-modifying but not destructive).
|
||||||
|
public static let mutationKinds: Set<SQLStatementKind> = [
|
||||||
|
.insert, .update
|
||||||
|
]
|
||||||
|
|
||||||
|
/// Whether this kind of statement is destructive.
|
||||||
|
public var isDestructive: Bool {
|
||||||
|
Self.destructiveKinds.contains(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Whether this kind of statement is a mutation (INSERT/UPDATE).
|
||||||
|
public var isMutation: Bool {
|
||||||
|
Self.mutationKinds.contains(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Classification Function
|
||||||
|
|
||||||
|
/// Classifies a SQL statement string by its destructive potential.
|
||||||
|
///
|
||||||
|
/// The classifier inspects the first keyword token of the statement
|
||||||
|
/// (case-insensitive) to determine the statement kind, then maps it
|
||||||
|
/// to a ``DestructiveClassification``.
|
||||||
|
///
|
||||||
|
/// - Parameter sql: The SQL statement to classify.
|
||||||
|
/// - Returns: The classification for the statement.
|
||||||
|
public func classifySQL(_ sql: String) -> DestructiveClassification {
|
||||||
|
guard let kind = detectStatementKind(sql) else {
|
||||||
|
return .safe
|
||||||
|
}
|
||||||
|
|
||||||
|
if kind.isDestructive {
|
||||||
|
return .destructive(kind)
|
||||||
|
} else if kind.isMutation {
|
||||||
|
return .mutation(kind)
|
||||||
|
} else {
|
||||||
|
return .safe
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Detects the ``SQLStatementKind`` from the leading keyword of a SQL string.
|
||||||
|
///
|
||||||
|
/// - Parameter sql: The SQL statement to inspect.
|
||||||
|
/// - Returns: The detected kind, or `nil` if unrecognized.
|
||||||
|
public func detectStatementKind(_ sql: String) -> SQLStatementKind? {
|
||||||
|
let trimmed = sql.trimmingCharacters(in: .whitespacesAndNewlines).uppercased()
|
||||||
|
|
||||||
|
// Check each known statement kind against the first token
|
||||||
|
if trimmed.hasPrefix("SELECT") || trimmed.hasPrefix("WITH") {
|
||||||
|
return .select
|
||||||
|
} else if trimmed.hasPrefix("INSERT") {
|
||||||
|
return .insert
|
||||||
|
} else if trimmed.hasPrefix("UPDATE") {
|
||||||
|
return .update
|
||||||
|
} else if trimmed.hasPrefix("DELETE") {
|
||||||
|
return .delete
|
||||||
|
} else if trimmed.hasPrefix("DROP") {
|
||||||
|
return .drop
|
||||||
|
} else if trimmed.hasPrefix("ALTER") {
|
||||||
|
return .alter
|
||||||
|
} else if trimmed.hasPrefix("TRUNCATE") {
|
||||||
|
return .truncate
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Destructive Operation Context
|
||||||
|
|
||||||
|
/// Context provided to the delegate when a destructive operation needs confirmation.
|
||||||
|
///
|
||||||
|
/// Contains all the information a UI or programmatic handler needs to
|
||||||
|
/// decide whether to allow the operation.
|
||||||
|
public struct DestructiveOperationContext: Sendable {
|
||||||
|
/// The SQL statement that would be executed.
|
||||||
|
public let sql: String
|
||||||
|
|
||||||
|
/// The detected kind of statement (DELETE, DROP, ALTER, TRUNCATE).
|
||||||
|
public let statementKind: SQLStatementKind
|
||||||
|
|
||||||
|
/// The classification result.
|
||||||
|
public let classification: DestructiveClassification
|
||||||
|
|
||||||
|
/// A human-readable description of what the operation will do.
|
||||||
|
public let description: String
|
||||||
|
|
||||||
|
/// The target table name, if detected.
|
||||||
|
public let targetTable: String?
|
||||||
|
|
||||||
|
public init(
|
||||||
|
sql: String,
|
||||||
|
statementKind: SQLStatementKind,
|
||||||
|
classification: DestructiveClassification,
|
||||||
|
description: String,
|
||||||
|
targetTable: String? = nil
|
||||||
|
) {
|
||||||
|
self.sql = sql
|
||||||
|
self.statementKind = statementKind
|
||||||
|
self.classification = classification
|
||||||
|
self.description = description
|
||||||
|
self.targetTable = targetTable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - ToolExecutionDelegate Protocol
|
||||||
|
|
||||||
|
/// A delegate that controls execution of SQL operations, providing
|
||||||
|
/// confirmation gates for destructive statements.
|
||||||
|
///
|
||||||
|
/// Implement this protocol to intercept destructive SQL operations
|
||||||
|
/// (DELETE, DROP, ALTER, TRUNCATE) before they are executed. The
|
||||||
|
/// ``ChatEngine`` consults the delegate whenever it encounters a
|
||||||
|
/// statement classified as ``DestructiveClassification/destructive(_:)``.
|
||||||
|
///
|
||||||
|
/// ## Example
|
||||||
|
///
|
||||||
|
/// ```swift
|
||||||
|
/// struct MyDelegate: ToolExecutionDelegate {
|
||||||
|
/// func confirmDestructiveOperation(
|
||||||
|
/// _ context: DestructiveOperationContext
|
||||||
|
/// ) async -> Bool {
|
||||||
|
/// // Show a confirmation dialog to the user
|
||||||
|
/// return await showAlert(
|
||||||
|
/// "Confirm \(context.statementKind.rawValue)",
|
||||||
|
/// message: context.description
|
||||||
|
/// )
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// let engine = ChatEngine(
|
||||||
|
/// database: pool,
|
||||||
|
/// model: model,
|
||||||
|
/// delegate: MyDelegate()
|
||||||
|
/// )
|
||||||
|
/// ```
|
||||||
|
public protocol ToolExecutionDelegate: Sendable {
|
||||||
|
|
||||||
|
/// Called when a destructive SQL operation is about to be executed.
|
||||||
|
///
|
||||||
|
/// The delegate should present the operation details to the user and
|
||||||
|
/// return `true` to proceed or `false` to cancel.
|
||||||
|
///
|
||||||
|
/// - Parameter context: Details about the destructive operation.
|
||||||
|
/// - Returns: `true` to allow execution, `false` to reject it.
|
||||||
|
func confirmDestructiveOperation(
|
||||||
|
_ context: DestructiveOperationContext
|
||||||
|
) async -> Bool
|
||||||
|
|
||||||
|
/// Called before any SQL statement is executed.
|
||||||
|
///
|
||||||
|
/// This is an observation hook — the engine does not wait for a
|
||||||
|
/// decision. Override to log, audit, or instrument queries.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - sql: The SQL about to be executed.
|
||||||
|
/// - classification: The destructive classification of the statement.
|
||||||
|
func willExecuteSQL(
|
||||||
|
_ sql: String,
|
||||||
|
classification: DestructiveClassification
|
||||||
|
) async
|
||||||
|
|
||||||
|
/// Called after a SQL statement completes execution.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - sql: The SQL that was executed.
|
||||||
|
/// - success: Whether execution succeeded.
|
||||||
|
func didExecuteSQL(
|
||||||
|
_ sql: String,
|
||||||
|
success: Bool
|
||||||
|
) async
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Default Implementations
|
||||||
|
|
||||||
|
extension ToolExecutionDelegate {
|
||||||
|
/// Default: rejects all destructive operations.
|
||||||
|
public func confirmDestructiveOperation(
|
||||||
|
_ context: DestructiveOperationContext
|
||||||
|
) async -> Bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Default: no-op.
|
||||||
|
public func willExecuteSQL(
|
||||||
|
_ sql: String,
|
||||||
|
classification: DestructiveClassification
|
||||||
|
) async {}
|
||||||
|
|
||||||
|
/// Default: no-op.
|
||||||
|
public func didExecuteSQL(
|
||||||
|
_ sql: String,
|
||||||
|
success: Bool
|
||||||
|
) async {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Built-in Delegates
|
||||||
|
|
||||||
|
/// A delegate that automatically approves all destructive operations.
|
||||||
|
///
|
||||||
|
/// Use this only in testing or trusted environments where confirmation
|
||||||
|
/// is not needed.
|
||||||
|
public struct AutoApproveDelegate: ToolExecutionDelegate {
|
||||||
|
public init() {}
|
||||||
|
|
||||||
|
public func confirmDestructiveOperation(
|
||||||
|
_ context: DestructiveOperationContext
|
||||||
|
) async -> Bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A delegate that always rejects destructive operations.
|
||||||
|
///
|
||||||
|
/// This is the safest option and matches the default behavior.
|
||||||
|
public struct RejectAllDelegate: ToolExecutionDelegate {
|
||||||
|
public init() {}
|
||||||
|
|
||||||
|
public func confirmDestructiveOperation(
|
||||||
|
_ context: DestructiveOperationContext
|
||||||
|
) async -> Bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
143
Sources/SwiftDBAI/Models/ConversationHistory.swift
Normal file
143
Sources/SwiftDBAI/Models/ConversationHistory.swift
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
// ConversationHistory.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Ordered chat message history with configurable context window.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
/// Stores an ordered sequence of ``ChatMessage`` instances with a configurable
|
||||||
|
/// context window limit.
|
||||||
|
///
|
||||||
|
/// When the number of messages exceeds ``maxMessages``, the oldest messages are
|
||||||
|
/// trimmed to keep the history within budget. This prevents unbounded token
|
||||||
|
/// growth when building LLM prompts from conversation history.
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
/// ```swift
|
||||||
|
/// var history = ConversationHistory(maxMessages: 20)
|
||||||
|
/// history.append(ChatMessage(role: .user, content: "How many users?"))
|
||||||
|
/// history.append(ChatMessage(role: .assistant, content: "42", sql: "SELECT COUNT(*) FROM users"))
|
||||||
|
/// print(history.promptText) // formatted for LLM context
|
||||||
|
/// ```
|
||||||
|
public struct ConversationHistory: Sendable {
|
||||||
|
|
||||||
|
/// The maximum number of messages to retain. `nil` means unlimited.
|
||||||
|
public let maxMessages: Int?
|
||||||
|
|
||||||
|
/// All messages in chronological order.
|
||||||
|
public private(set) var messages: [ChatMessage] = []
|
||||||
|
|
||||||
|
/// Creates a new conversation history.
|
||||||
|
///
|
||||||
|
/// - Parameter maxMessages: Maximum number of messages to keep in the
|
||||||
|
/// context window. Pass `nil` for unlimited history. Defaults to 50.
|
||||||
|
public init(maxMessages: Int? = 50) {
|
||||||
|
precondition(maxMessages == nil || maxMessages! > 0,
|
||||||
|
"maxMessages must be positive or nil")
|
||||||
|
self.maxMessages = maxMessages
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The number of messages currently stored.
|
||||||
|
public var count: Int { messages.count }
|
||||||
|
|
||||||
|
/// Whether the history is empty.
|
||||||
|
public var isEmpty: Bool { messages.isEmpty }
|
||||||
|
|
||||||
|
// MARK: - Mutating Operations
|
||||||
|
|
||||||
|
/// Appends a message and trims the history if it exceeds the context window.
|
||||||
|
public mutating func append(_ message: ChatMessage) {
|
||||||
|
messages.append(message)
|
||||||
|
trimIfNeeded()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Appends multiple messages and trims once afterward.
|
||||||
|
public mutating func append(contentsOf newMessages: [ChatMessage]) {
|
||||||
|
messages.append(contentsOf: newMessages)
|
||||||
|
trimIfNeeded()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Removes all messages from the history.
|
||||||
|
public mutating func clear() {
|
||||||
|
messages.removeAll()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Context Window
|
||||||
|
|
||||||
|
/// Returns the most recent messages formatted for inclusion in an LLM prompt.
|
||||||
|
///
|
||||||
|
/// Each message is formatted as `[role] content`, with SQL and query results
|
||||||
|
/// included inline for assistant messages.
|
||||||
|
///
|
||||||
|
/// - Parameter limit: Optional override to further restrict the number of
|
||||||
|
/// messages returned. When `nil`, uses the full retained history.
|
||||||
|
/// - Returns: An array of prompt-formatted strings, one per message.
|
||||||
|
public func promptMessages(limit: Int? = nil) -> [String] {
|
||||||
|
let slice: ArraySlice<ChatMessage>
|
||||||
|
if let limit {
|
||||||
|
slice = messages.suffix(limit)
|
||||||
|
} else {
|
||||||
|
slice = messages[...]
|
||||||
|
}
|
||||||
|
return slice.map { message in
|
||||||
|
Self.formatForPrompt(message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the combined prompt text for all retained messages, separated by
|
||||||
|
/// double newlines.
|
||||||
|
public var promptText: String {
|
||||||
|
promptMessages().joined(separator: "\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Queries
|
||||||
|
|
||||||
|
/// Returns only user messages.
|
||||||
|
public var userMessages: [ChatMessage] {
|
||||||
|
messages.filter { $0.role == .user }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns only assistant messages.
|
||||||
|
public var assistantMessages: [ChatMessage] {
|
||||||
|
messages.filter { $0.role == .assistant }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the last message, if any.
|
||||||
|
public var lastMessage: ChatMessage? {
|
||||||
|
messages.last
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the most recent user query text, if any.
|
||||||
|
public var lastUserQuery: String? {
|
||||||
|
messages.last(where: { $0.role == .user })?.content
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the most recent assistant message, if any.
|
||||||
|
public var lastAssistantMessage: ChatMessage? {
|
||||||
|
messages.last(where: { $0.role == .assistant })
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Private
|
||||||
|
|
||||||
|
/// Formats a ``ChatMessage`` into a string suitable for LLM prompt context.
|
||||||
|
private static func formatForPrompt(_ message: ChatMessage) -> String {
|
||||||
|
var parts: [String] = ["[\(message.role.rawValue)] \(message.content)"]
|
||||||
|
|
||||||
|
if let sql = message.sql {
|
||||||
|
parts.append("SQL: \(sql)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if let result = message.queryResult {
|
||||||
|
parts.append("Result:\n\(result.tabularDescription)")
|
||||||
|
}
|
||||||
|
|
||||||
|
return parts.joined(separator: "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Trims the oldest messages to stay within the context window.
|
||||||
|
private mutating func trimIfNeeded() {
|
||||||
|
guard let max = maxMessages, messages.count > max else { return }
|
||||||
|
let overflow = messages.count - max
|
||||||
|
messages.removeFirst(overflow)
|
||||||
|
}
|
||||||
|
}
|
||||||
136
Sources/SwiftDBAI/Models/QueryResult.swift
Normal file
136
Sources/SwiftDBAI/Models/QueryResult.swift
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
// QueryResult.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Structured result from SQL query execution.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
/// Represents the result of executing a SQL query against the database.
|
||||||
|
///
|
||||||
|
/// Contains raw row data as dictionaries, column metadata, row count,
|
||||||
|
/// the original SQL string, and execution timing.
|
||||||
|
public struct QueryResult: Sendable, Equatable {
|
||||||
|
|
||||||
|
/// A single cell value from a query result.
|
||||||
|
///
|
||||||
|
/// Wraps SQLite's dynamic value types into a type-safe, Sendable enum.
|
||||||
|
public enum Value: Sendable, Equatable, CustomStringConvertible {
|
||||||
|
case text(String)
|
||||||
|
case integer(Int64)
|
||||||
|
case real(Double)
|
||||||
|
case blob(Data)
|
||||||
|
case null
|
||||||
|
|
||||||
|
public var description: String {
|
||||||
|
switch self {
|
||||||
|
case .text(let s): return s
|
||||||
|
case .integer(let i): return String(i)
|
||||||
|
case .real(let d):
|
||||||
|
if d == d.rounded() && abs(d) < 1e15 {
|
||||||
|
return String(format: "%.0f", d)
|
||||||
|
}
|
||||||
|
return String(d)
|
||||||
|
case .blob(let data): return "<\(data.count) bytes>"
|
||||||
|
case .null: return "NULL"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the value as a `Double` if it is numeric, nil otherwise.
|
||||||
|
public var doubleValue: Double? {
|
||||||
|
switch self {
|
||||||
|
case .integer(let i): return Double(i)
|
||||||
|
case .real(let d): return d
|
||||||
|
case .text(let s): return Double(s)
|
||||||
|
default: return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the value as a `String` (non-nil for all cases).
|
||||||
|
public var stringValue: String { description }
|
||||||
|
|
||||||
|
/// Returns `true` if this value is `.null`.
|
||||||
|
public var isNull: Bool {
|
||||||
|
if case .null = self { return true }
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Column names in the order they appear in the result set.
|
||||||
|
public let columns: [String]
|
||||||
|
|
||||||
|
/// Row data as an array of dictionaries mapping column name to value.
|
||||||
|
public let rows: [[String: Value]]
|
||||||
|
|
||||||
|
/// Total number of rows returned.
|
||||||
|
public var rowCount: Int { rows.count }
|
||||||
|
|
||||||
|
/// The SQL statement that was executed.
|
||||||
|
public let sql: String
|
||||||
|
|
||||||
|
/// Time taken to execute the query, in seconds.
|
||||||
|
public let executionTime: TimeInterval
|
||||||
|
|
||||||
|
/// Number of rows affected (for INSERT/UPDATE/DELETE). Nil for SELECT.
|
||||||
|
public let rowsAffected: Int?
|
||||||
|
|
||||||
|
public init(
|
||||||
|
columns: [String],
|
||||||
|
rows: [[String: Value]],
|
||||||
|
sql: String,
|
||||||
|
executionTime: TimeInterval,
|
||||||
|
rowsAffected: Int? = nil
|
||||||
|
) {
|
||||||
|
self.columns = columns
|
||||||
|
self.rows = rows
|
||||||
|
self.sql = sql
|
||||||
|
self.executionTime = executionTime
|
||||||
|
self.rowsAffected = rowsAffected
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Convenience Accessors
|
||||||
|
|
||||||
|
/// Returns all values for a given column, in row order.
|
||||||
|
public func values(forColumn column: String) -> [Value] {
|
||||||
|
rows.compactMap { $0[column] }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a compact tabular string representation of the results.
|
||||||
|
///
|
||||||
|
/// Useful for embedding query results into LLM prompts.
|
||||||
|
public var tabularDescription: String {
|
||||||
|
guard !rows.isEmpty else {
|
||||||
|
return "(empty result set)"
|
||||||
|
}
|
||||||
|
|
||||||
|
var lines: [String] = []
|
||||||
|
|
||||||
|
// Header
|
||||||
|
lines.append(columns.joined(separator: " | "))
|
||||||
|
lines.append(String(repeating: "-", count: lines[0].count))
|
||||||
|
|
||||||
|
// Rows (cap at 50 for prompt size)
|
||||||
|
let displayRows = rows.prefix(50)
|
||||||
|
for row in displayRows {
|
||||||
|
let vals = columns.map { col in
|
||||||
|
row[col]?.description ?? "NULL"
|
||||||
|
}
|
||||||
|
lines.append(vals.joined(separator: " | "))
|
||||||
|
}
|
||||||
|
|
||||||
|
if rows.count > 50 {
|
||||||
|
lines.append("... and \(rows.count - 50) more rows")
|
||||||
|
}
|
||||||
|
|
||||||
|
return lines.joined(separator: "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if the result looks like a single aggregate value
|
||||||
|
/// (1 row, 1-3 columns, all numeric).
|
||||||
|
public var isAggregate: Bool {
|
||||||
|
guard rowCount == 1, columns.count <= 3 else { return false }
|
||||||
|
let firstRow = rows[0]
|
||||||
|
return columns.allSatisfy { col in
|
||||||
|
firstRow[col]?.doubleValue != nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
380
Sources/SwiftDBAI/Parsing/SQLQueryParser.swift
Normal file
380
Sources/SwiftDBAI/Parsing/SQLQueryParser.swift
Normal file
@@ -0,0 +1,380 @@
|
|||||||
|
// SQLQueryParser.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Extracts and validates SQL statements from raw LLM response text.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
/// Errors that can occur during SQL parsing and validation.
|
||||||
|
public enum SQLParsingError: Error, Sendable, Equatable, CustomStringConvertible {
|
||||||
|
/// No SQL statement could be found in the LLM response.
|
||||||
|
case noSQLFound
|
||||||
|
|
||||||
|
/// The SQL statement uses an operation not in the allowlist.
|
||||||
|
case operationNotAllowed(SQLOperation)
|
||||||
|
|
||||||
|
/// A destructive operation (DELETE) requires user confirmation.
|
||||||
|
case confirmationRequired(sql: String, operation: SQLOperation)
|
||||||
|
|
||||||
|
/// The mutation targets a table not in the allowed mutation tables.
|
||||||
|
case tableNotAllowed(table: String, operation: SQLOperation)
|
||||||
|
|
||||||
|
/// The SQL contains a disallowed keyword (e.g., DROP, ALTER, TRUNCATE).
|
||||||
|
case dangerousOperation(String)
|
||||||
|
|
||||||
|
/// Multiple SQL statements were found but only single-statement execution is supported.
|
||||||
|
case multipleStatements
|
||||||
|
|
||||||
|
public var description: String {
|
||||||
|
switch self {
|
||||||
|
case .noSQLFound:
|
||||||
|
return "No SQL statement found in the response."
|
||||||
|
case .operationNotAllowed(let op):
|
||||||
|
return "Operation '\(op.rawValue.uppercased())' is not allowed by the current configuration."
|
||||||
|
case .confirmationRequired(let sql, let op):
|
||||||
|
return "The \(op.rawValue.uppercased()) operation requires confirmation: \(sql)"
|
||||||
|
case .tableNotAllowed(let table, let op):
|
||||||
|
return "The \(op.rawValue.uppercased()) operation is not allowed on table '\(table)'."
|
||||||
|
case .dangerousOperation(let keyword):
|
||||||
|
return "Dangerous SQL operation '\(keyword)' is never allowed."
|
||||||
|
case .multipleStatements:
|
||||||
|
return "Only single SQL statements are supported."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result of successfully parsing SQL from an LLM response.
|
||||||
|
public struct ParsedSQL: Sendable, Equatable {
|
||||||
|
/// The cleaned SQL statement ready for execution.
|
||||||
|
public let sql: String
|
||||||
|
|
||||||
|
/// The detected operation type.
|
||||||
|
public let operation: SQLOperation
|
||||||
|
|
||||||
|
/// Whether this operation requires user confirmation before execution.
|
||||||
|
public let requiresConfirmation: Bool
|
||||||
|
|
||||||
|
public init(sql: String, operation: SQLOperation, requiresConfirmation: Bool = false) {
|
||||||
|
self.sql = sql
|
||||||
|
self.operation = operation
|
||||||
|
self.requiresConfirmation = requiresConfirmation
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extracts SQL statements from raw LLM response text and validates them
|
||||||
|
/// against the configured ``OperationAllowlist``.
|
||||||
|
///
|
||||||
|
/// The parser handles common LLM output patterns:
|
||||||
|
/// - SQL in markdown code blocks (```sql ... ```)
|
||||||
|
/// - SQL in generic code blocks (``` ... ```)
|
||||||
|
/// - Raw SQL statements in plain text
|
||||||
|
/// - SQL prefixed with labels like "SQL:" or "Query:"
|
||||||
|
public struct SQLQueryParser: Sendable {
|
||||||
|
|
||||||
|
/// Keywords that are never allowed regardless of allowlist configuration.
|
||||||
|
private static let dangerousKeywords: Set<String> = [
|
||||||
|
"DROP", "ALTER", "TRUNCATE", "CREATE", "GRANT", "REVOKE",
|
||||||
|
"ATTACH", "DETACH", "PRAGMA", "VACUUM", "REINDEX"
|
||||||
|
]
|
||||||
|
|
||||||
|
/// The operation allowlist to validate against.
|
||||||
|
private let allowlist: OperationAllowlist
|
||||||
|
|
||||||
|
/// The mutation policy for table-level restrictions.
|
||||||
|
private let mutationPolicy: MutationPolicy?
|
||||||
|
|
||||||
|
/// Creates a parser with the given operation allowlist.
|
||||||
|
/// - Parameter allowlist: The set of permitted operations. Defaults to read-only.
|
||||||
|
public init(allowlist: OperationAllowlist = .readOnly) {
|
||||||
|
self.allowlist = allowlist
|
||||||
|
self.mutationPolicy = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a parser with a mutation policy (preferred initializer).
|
||||||
|
/// - Parameter mutationPolicy: The mutation policy controlling operations and table access.
|
||||||
|
public init(mutationPolicy: MutationPolicy) {
|
||||||
|
self.allowlist = mutationPolicy.operationAllowlist
|
||||||
|
self.mutationPolicy = mutationPolicy
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extracts and validates a SQL statement from raw LLM response text.
|
||||||
|
///
|
||||||
|
/// - Parameter text: The raw text from the LLM response.
|
||||||
|
/// - Returns: A ``ParsedSQL`` containing the validated statement.
|
||||||
|
/// - Throws: ``SQLParsingError`` if extraction or validation fails.
|
||||||
|
public func parse(_ text: String) throws -> ParsedSQL {
|
||||||
|
let sql = try extractSQL(from: text)
|
||||||
|
return try validate(sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Extraction
|
||||||
|
|
||||||
|
/// Attempts to extract a SQL statement from the LLM response text.
|
||||||
|
/// Tries multiple strategies in order of confidence.
|
||||||
|
func extractSQL(from text: String) throws -> String {
|
||||||
|
// Strategy 1: SQL in markdown fenced code block with sql language tag
|
||||||
|
if let sql = extractFromSQLCodeBlock(text) {
|
||||||
|
return sql
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strategy 2: SQL in generic fenced code block
|
||||||
|
if let sql = extractFromGenericCodeBlock(text) {
|
||||||
|
return sql
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strategy 3: SQL after a label like "SQL:" or "Query:"
|
||||||
|
if let sql = extractFromLabel(text) {
|
||||||
|
return sql
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strategy 4: Direct SQL detection in plain text
|
||||||
|
if let sql = extractDirectSQL(text) {
|
||||||
|
return sql
|
||||||
|
}
|
||||||
|
|
||||||
|
throw SQLParsingError.noSQLFound
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extracts SQL from a ```sql ... ``` code block.
|
||||||
|
private func extractFromSQLCodeBlock(_ text: String) -> String? {
|
||||||
|
let pattern = #"```sql\s*\n([\s\S]*?)```"#
|
||||||
|
return firstMatch(pattern: pattern, in: text, group: 1)?
|
||||||
|
.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||||
|
.nonEmptyOrNil
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extracts SQL from a generic ``` ... ``` code block.
|
||||||
|
private func extractFromGenericCodeBlock(_ text: String) -> String? {
|
||||||
|
let pattern = #"```\s*\n([\s\S]*?)```"#
|
||||||
|
guard let content = firstMatch(pattern: pattern, in: text, group: 1)?
|
||||||
|
.trimmingCharacters(in: .whitespacesAndNewlines) else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// Only accept if it looks like SQL
|
||||||
|
guard looksLikeSQL(content) else { return nil }
|
||||||
|
return content.nonEmptyOrNil
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extracts SQL after labels like "SQL:", "Query:", "Here's the query:"
|
||||||
|
private func extractFromLabel(_ text: String) -> String? {
|
||||||
|
// Match the SQL keyword up to end-of-line (handling multi-line SQL with indentation)
|
||||||
|
let pattern = #"(?:SQL|Query|Statement)\s*:\s*\n?\s*((?:SELECT|INSERT|UPDATE|DELETE|WITH)\b.+?)(?:\n(?!\s)|$)"#
|
||||||
|
guard let content = firstMatch(pattern: pattern, in: text, group: 1, options: [.caseInsensitive, .dotMatchesLineSeparators])?
|
||||||
|
.trimmingCharacters(in: .whitespacesAndNewlines) else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
guard looksLikeSQL(content) else { return nil }
|
||||||
|
return content.nonEmptyOrNil
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Detects SQL directly in the text by matching known statement patterns.
|
||||||
|
private func extractDirectSQL(_ text: String) -> String? {
|
||||||
|
// Match SQL statement, allowing semicolons inside single-quoted string literals
|
||||||
|
let pattern = #"(?:^|\n)\s*((?:SELECT|INSERT|UPDATE|DELETE)\b(?:[^;']|'[^']*')*;?)"#
|
||||||
|
guard let content = firstMatch(pattern: pattern, in: text, group: 1, options: .caseInsensitive)?
|
||||||
|
.trimmingCharacters(in: .whitespacesAndNewlines) else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return content.nonEmptyOrNil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Validation
|
||||||
|
|
||||||
|
/// Validates a SQL string against the allowlist and safety rules.
|
||||||
|
func validate(_ sql: String) throws -> ParsedSQL {
|
||||||
|
let cleaned = cleanSQL(sql)
|
||||||
|
|
||||||
|
guard !cleaned.isEmpty else {
|
||||||
|
throw SQLParsingError.noSQLFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for multiple statements (semicolons in non-trivial positions)
|
||||||
|
if containsMultipleStatements(cleaned) {
|
||||||
|
throw SQLParsingError.multipleStatements
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for dangerous operations first (before allowlist)
|
||||||
|
try checkDangerousKeywords(cleaned)
|
||||||
|
|
||||||
|
// Detect the operation type
|
||||||
|
let operation = detectOperation(cleaned)
|
||||||
|
|
||||||
|
// Check against the allowlist
|
||||||
|
guard allowlist.isAllowed(operation) else {
|
||||||
|
throw SQLParsingError.operationNotAllowed(operation)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check table-level restrictions for mutation operations
|
||||||
|
if let policy = mutationPolicy, operation != .select,
|
||||||
|
let targetTable = extractTargetTable(from: cleaned, operation: operation) {
|
||||||
|
guard policy.isAllowed(operation: operation, on: targetTable) else {
|
||||||
|
throw SQLParsingError.tableNotAllowed(table: targetTable, operation: operation)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DELETE requires confirmation when policy says so, or always by default
|
||||||
|
let requiresConfirmation: Bool
|
||||||
|
if let policy = mutationPolicy {
|
||||||
|
requiresConfirmation = policy.requiresConfirmation(for: operation)
|
||||||
|
} else {
|
||||||
|
requiresConfirmation = operation == .delete
|
||||||
|
}
|
||||||
|
|
||||||
|
return ParsedSQL(
|
||||||
|
sql: cleaned,
|
||||||
|
operation: operation,
|
||||||
|
requiresConfirmation: requiresConfirmation
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Helpers
|
||||||
|
|
||||||
|
/// Cleans a SQL string by removing trailing semicolons (outside string literals) and excess whitespace.
|
||||||
|
private func cleanSQL(_ sql: String) -> String {
|
||||||
|
var cleaned = sql.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||||
|
// Remove trailing semicolons only if they're outside string literals
|
||||||
|
while cleaned.hasSuffix(";") && !isInsideStringLiteral(sql: cleaned, position: cleaned.index(before: cleaned.endIndex)) {
|
||||||
|
cleaned = String(cleaned.dropLast()).trimmingCharacters(in: .whitespacesAndNewlines)
|
||||||
|
}
|
||||||
|
// Collapse internal whitespace outside string literals
|
||||||
|
cleaned = collapseWhitespace(cleaned)
|
||||||
|
return cleaned
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Collapses whitespace while preserving string literal contents.
|
||||||
|
private func collapseWhitespace(_ sql: String) -> String {
|
||||||
|
var result = ""
|
||||||
|
var inString = false
|
||||||
|
var prevWasSpace = false
|
||||||
|
for ch in sql {
|
||||||
|
if ch == "'" {
|
||||||
|
inString.toggle()
|
||||||
|
prevWasSpace = false
|
||||||
|
result.append(ch)
|
||||||
|
} else if inString {
|
||||||
|
result.append(ch)
|
||||||
|
} else if ch.isWhitespace {
|
||||||
|
if !prevWasSpace {
|
||||||
|
result.append(" ")
|
||||||
|
prevWasSpace = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
prevWasSpace = false
|
||||||
|
result.append(ch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if the character at the given position is inside a single-quoted string literal.
|
||||||
|
private func isInsideStringLiteral(sql: String, position: String.Index) -> Bool {
|
||||||
|
var inString = false
|
||||||
|
for idx in sql.indices {
|
||||||
|
if idx == position { return inString }
|
||||||
|
if sql[idx] == "'" { inString.toggle() }
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Checks whether cleaned SQL contains multiple statements.
|
||||||
|
private func containsMultipleStatements(_ sql: String) -> Bool {
|
||||||
|
// Remove string literals before checking for semicolons
|
||||||
|
var inString = false
|
||||||
|
for ch in sql {
|
||||||
|
if ch == "'" {
|
||||||
|
inString.toggle()
|
||||||
|
} else if ch == ";" && !inString {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Checks for dangerous SQL keywords that are never allowed.
|
||||||
|
private func checkDangerousKeywords(_ sql: String) throws {
|
||||||
|
let upper = sql.uppercased()
|
||||||
|
// Tokenize to avoid partial matches (e.g., "DROPDOWN" matching "DROP")
|
||||||
|
let tokens = upper.components(separatedBy: .alphanumerics.inverted)
|
||||||
|
.filter { !$0.isEmpty }
|
||||||
|
|
||||||
|
for keyword in Self.dangerousKeywords {
|
||||||
|
if tokens.contains(keyword) {
|
||||||
|
throw SQLParsingError.dangerousOperation(keyword)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Detects the SQL operation type from the first keyword.
|
||||||
|
private func detectOperation(_ sql: String) -> SQLOperation {
|
||||||
|
let upper = sql.uppercased().trimmingCharacters(in: .whitespaces)
|
||||||
|
|
||||||
|
if upper.hasPrefix("SELECT") || upper.hasPrefix("WITH") {
|
||||||
|
return .select
|
||||||
|
} else if upper.hasPrefix("INSERT") {
|
||||||
|
return .insert
|
||||||
|
} else if upper.hasPrefix("UPDATE") {
|
||||||
|
return .update
|
||||||
|
} else if upper.hasPrefix("DELETE") {
|
||||||
|
return .delete
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to select for unrecognized patterns (e.g. EXPLAIN)
|
||||||
|
return .select
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extracts the target table name from a mutation SQL statement.
|
||||||
|
///
|
||||||
|
/// Handles common patterns:
|
||||||
|
/// - `INSERT INTO table_name ...`
|
||||||
|
/// - `UPDATE table_name SET ...`
|
||||||
|
/// - `DELETE FROM table_name ...`
|
||||||
|
private func extractTargetTable(from sql: String, operation: SQLOperation) -> String? {
|
||||||
|
let pattern: String
|
||||||
|
switch operation {
|
||||||
|
case .insert:
|
||||||
|
pattern = #"INSERT\s+INTO\s+[`"\[]?(\w+)[`"\]]?"#
|
||||||
|
case .update:
|
||||||
|
pattern = #"UPDATE\s+[`"\[]?(\w+)[`"\]]?"#
|
||||||
|
case .delete:
|
||||||
|
pattern = #"DELETE\s+FROM\s+[`"\[]?(\w+)[`"\]]?"#
|
||||||
|
case .select:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return firstMatch(pattern: pattern, in: sql, group: 1, options: .caseInsensitive)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if the text looks like a SQL statement.
|
||||||
|
private func looksLikeSQL(_ text: String) -> Bool {
|
||||||
|
let upper = text.uppercased().trimmingCharacters(in: .whitespaces)
|
||||||
|
let sqlPrefixes = ["SELECT", "INSERT", "UPDATE", "DELETE", "WITH"]
|
||||||
|
return sqlPrefixes.contains { upper.hasPrefix($0) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extracts the first regex match group from the text.
|
||||||
|
private func firstMatch(
|
||||||
|
pattern: String,
|
||||||
|
in text: String,
|
||||||
|
group: Int,
|
||||||
|
options: NSRegularExpression.Options = []
|
||||||
|
) -> String? {
|
||||||
|
guard let regex = try? NSRegularExpression(pattern: pattern, options: options) else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
let range = NSRange(text.startIndex..., in: text)
|
||||||
|
guard let match = regex.firstMatch(in: text, range: range),
|
||||||
|
match.numberOfRanges > group,
|
||||||
|
let groupRange = Range(match.range(at: group), in: text) else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return String(text[groupRange])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - String Extension
|
||||||
|
|
||||||
|
private extension String {
|
||||||
|
/// Returns nil if the string is empty, otherwise returns self.
|
||||||
|
var nonEmptyOrNil: String? {
|
||||||
|
isEmpty ? nil : self
|
||||||
|
}
|
||||||
|
}
|
||||||
211
Sources/SwiftDBAI/Prompt/PromptBuilder.swift
Normal file
211
Sources/SwiftDBAI/Prompt/PromptBuilder.swift
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
/// Builds structured LLM prompts for SQL generation from a database schema
|
||||||
|
/// and natural language input.
|
||||||
|
///
|
||||||
|
/// `PromptBuilder` is the bridge between the introspected database schema and
|
||||||
|
/// the LLM. It produces two things:
|
||||||
|
/// 1. A **system instructions** string containing schema context and behavioral rules
|
||||||
|
/// 2. A **user prompt** string wrapping the natural language question
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
/// ```swift
|
||||||
|
/// let builder = PromptBuilder(schema: mySchema, allowlist: .readOnly)
|
||||||
|
/// let instructions = builder.buildSystemInstructions()
|
||||||
|
/// let prompt = builder.buildUserPrompt("How many users signed up this week?")
|
||||||
|
/// ```
|
||||||
|
public struct PromptBuilder: Sendable {
|
||||||
|
/// The database schema to include as context.
|
||||||
|
public let schema: DatabaseSchema
|
||||||
|
|
||||||
|
/// Which SQL operations the LLM may generate.
|
||||||
|
public let allowlist: OperationAllowlist
|
||||||
|
|
||||||
|
/// Optional additional context to append to the system instructions
|
||||||
|
/// (e.g., business-specific terminology or query hints).
|
||||||
|
public let additionalContext: String?
|
||||||
|
|
||||||
|
/// Creates a prompt builder for the given schema and allowlist.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - schema: The introspected database schema.
|
||||||
|
/// - allowlist: Permitted SQL operations. Defaults to ``OperationAllowlist/readOnly``.
|
||||||
|
/// - additionalContext: Extra instructions appended to the system prompt.
|
||||||
|
public init(
|
||||||
|
schema: DatabaseSchema,
|
||||||
|
allowlist: OperationAllowlist = .readOnly,
|
||||||
|
additionalContext: String? = nil
|
||||||
|
) {
|
||||||
|
self.schema = schema
|
||||||
|
self.allowlist = allowlist
|
||||||
|
self.additionalContext = additionalContext
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - System Instructions
|
||||||
|
|
||||||
|
/// Builds the system instructions string that should be passed as the
|
||||||
|
/// `instructions` parameter when creating a `LanguageModelSession`.
|
||||||
|
///
|
||||||
|
/// The instructions include:
|
||||||
|
/// - Role definition
|
||||||
|
/// - The full database schema
|
||||||
|
/// - SQL generation rules and constraints
|
||||||
|
/// - The operation allowlist
|
||||||
|
/// - Output format requirements
|
||||||
|
public func buildSystemInstructions() -> String {
|
||||||
|
var sections: [String] = []
|
||||||
|
|
||||||
|
// 1. Role
|
||||||
|
sections.append(Self.roleSection)
|
||||||
|
|
||||||
|
// 2. Schema
|
||||||
|
sections.append(buildSchemaSection())
|
||||||
|
|
||||||
|
// 3. Operation permissions
|
||||||
|
sections.append(buildPermissionsSection())
|
||||||
|
|
||||||
|
// 4. SQL generation rules
|
||||||
|
sections.append(Self.sqlRulesSection)
|
||||||
|
|
||||||
|
// 5. Output format
|
||||||
|
sections.append(Self.outputFormatSection)
|
||||||
|
|
||||||
|
// 6. Additional context
|
||||||
|
if let additionalContext, !additionalContext.isEmpty {
|
||||||
|
sections.append("ADDITIONAL CONTEXT\n=================\n\(additionalContext)")
|
||||||
|
}
|
||||||
|
|
||||||
|
return sections.joined(separator: "\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - User Prompt
|
||||||
|
|
||||||
|
/// Wraps a natural language question into a user prompt string.
|
||||||
|
///
|
||||||
|
/// - Parameter question: The user's natural language question.
|
||||||
|
/// - Returns: A formatted prompt string for the LLM.
|
||||||
|
public func buildUserPrompt(_ question: String) -> String {
|
||||||
|
question
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds a follow-up prompt that includes prior SQL context for
|
||||||
|
/// multi-turn conversations.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - question: The user's follow-up question.
|
||||||
|
/// - previousSQL: The SQL from the previous turn, for context.
|
||||||
|
/// - previousResultSummary: A brief summary of what the previous query returned.
|
||||||
|
/// - Returns: A formatted prompt string.
|
||||||
|
public func buildFollowUpPrompt(
|
||||||
|
_ question: String,
|
||||||
|
previousSQL: String,
|
||||||
|
previousResultSummary: String
|
||||||
|
) -> String {
|
||||||
|
"""
|
||||||
|
Previous query: \(previousSQL)
|
||||||
|
Previous result: \(previousResultSummary)
|
||||||
|
|
||||||
|
Follow-up question: \(question)
|
||||||
|
"""
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds a prompt that includes the full conversation history within the
|
||||||
|
/// configured context window, enabling the LLM to resolve follow-up
|
||||||
|
/// references (pronouns, implicit table/column references, etc.).
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - question: The user's current question.
|
||||||
|
/// - history: The conversation history messages within the context window.
|
||||||
|
/// - Returns: A formatted prompt string with conversation context.
|
||||||
|
public func buildConversationPrompt(
|
||||||
|
_ question: String,
|
||||||
|
history: [ChatMessage]
|
||||||
|
) -> String {
|
||||||
|
guard !history.isEmpty else {
|
||||||
|
return buildUserPrompt(question)
|
||||||
|
}
|
||||||
|
|
||||||
|
var lines: [String] = []
|
||||||
|
lines.append("CONVERSATION HISTORY")
|
||||||
|
lines.append("====================")
|
||||||
|
|
||||||
|
for message in history {
|
||||||
|
switch message.role {
|
||||||
|
case .user:
|
||||||
|
lines.append("User: \(message.content)")
|
||||||
|
case .assistant:
|
||||||
|
if let sql = message.sql {
|
||||||
|
lines.append("Assistant SQL: \(sql)")
|
||||||
|
}
|
||||||
|
lines.append("Assistant: \(message.content)")
|
||||||
|
case .error:
|
||||||
|
lines.append("Error: \(message.content)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
lines.append("")
|
||||||
|
lines.append("CURRENT QUESTION")
|
||||||
|
lines.append("================")
|
||||||
|
lines.append(question)
|
||||||
|
|
||||||
|
return lines.joined(separator: "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Private Sections
|
||||||
|
|
||||||
|
private func buildSchemaSection() -> String {
|
||||||
|
var lines: [String] = []
|
||||||
|
lines.append("DATABASE SCHEMA")
|
||||||
|
lines.append("===============")
|
||||||
|
lines.append("")
|
||||||
|
lines.append(schema.schemaDescription)
|
||||||
|
return lines.joined(separator: "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
private func buildPermissionsSection() -> String {
|
||||||
|
var lines: [String] = []
|
||||||
|
lines.append("PERMISSIONS")
|
||||||
|
lines.append("===========")
|
||||||
|
lines.append(allowlist.describeForLLM())
|
||||||
|
return lines.joined(separator: "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Static Content
|
||||||
|
|
||||||
|
static let roleSection = """
|
||||||
|
ROLE
|
||||||
|
====
|
||||||
|
You are a SQL assistant for a SQLite database. Your job is to translate \
|
||||||
|
natural language questions into valid SQLite SQL queries based on the \
|
||||||
|
database schema provided below. You must ONLY reference tables and columns \
|
||||||
|
that exist in the schema. Never fabricate table or column names.
|
||||||
|
"""
|
||||||
|
|
||||||
|
static let sqlRulesSection = """
|
||||||
|
SQL GENERATION RULES
|
||||||
|
====================
|
||||||
|
1. Use ONLY the tables and columns listed in the schema above.
|
||||||
|
2. Use SQLite-compatible syntax (e.g., || for string concatenation, \
|
||||||
|
IFNULL instead of COALESCE where needed).
|
||||||
|
3. Use appropriate JOINs when queries span multiple tables — reference \
|
||||||
|
the foreign key relationships in the schema.
|
||||||
|
4. For date/time operations, use SQLite date functions \
|
||||||
|
(date(), time(), datetime(), strftime()).
|
||||||
|
5. Use parameterized-style values where possible. For literal values \
|
||||||
|
from the user's question, embed them directly in the SQL.
|
||||||
|
6. Always include an ORDER BY clause when the user implies ordering.
|
||||||
|
7. Use LIMIT when the user asks for "top N" or "first N" results.
|
||||||
|
8. For aggregate queries (count, sum, average, min, max), use the \
|
||||||
|
appropriate SQL aggregate functions.
|
||||||
|
9. When the user's question is ambiguous, prefer the simplest valid \
|
||||||
|
interpretation.
|
||||||
|
10. Never generate DDL statements (CREATE, ALTER, DROP TABLE).
|
||||||
|
"""
|
||||||
|
|
||||||
|
static let outputFormatSection = """
|
||||||
|
OUTPUT FORMAT
|
||||||
|
=============
|
||||||
|
When generating SQL, call the appropriate tool with the SQL query. \
|
||||||
|
After receiving query results, provide a concise natural language \
|
||||||
|
summary of the data. Be specific with numbers and names from the results. \
|
||||||
|
If no rows are returned, say so clearly.
|
||||||
|
"""
|
||||||
|
}
|
||||||
423
Sources/SwiftDBAI/Rendering/ChartDataDetector.swift
Normal file
423
Sources/SwiftDBAI/Rendering/ChartDataDetector.swift
Normal file
@@ -0,0 +1,423 @@
|
|||||||
|
// ChartDataDetector.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Analyzes query results to determine chart eligibility and
|
||||||
|
// recommends appropriate chart types based on data shape.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
/// Detects whether a `DataTable` is suitable for charting and
|
||||||
|
/// recommends the best chart type based on data shape heuristics.
|
||||||
|
///
|
||||||
|
/// The detector examines column types, row counts, and value distributions
|
||||||
|
/// to produce a `ChartRecommendation` that the rendering layer can use
|
||||||
|
/// to auto-select an appropriate Swift Charts visualization.
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
/// ```swift
|
||||||
|
/// let detector = ChartDataDetector()
|
||||||
|
/// if let recommendation = detector.detect(table) {
|
||||||
|
/// switch recommendation.chartType {
|
||||||
|
/// case .bar: // render bar chart
|
||||||
|
/// case .line: // render line chart
|
||||||
|
/// case .pie: // render pie chart
|
||||||
|
/// }
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
|
public struct ChartDataDetector: Sendable {
|
||||||
|
|
||||||
|
// MARK: - Chart Types
|
||||||
|
|
||||||
|
/// The type of chart recommended for the data.
|
||||||
|
public enum ChartType: String, Sendable, Equatable, CaseIterable {
|
||||||
|
/// Vertical bar chart — best for categorical comparisons.
|
||||||
|
case bar
|
||||||
|
/// Line chart — best for time series or ordered sequences.
|
||||||
|
case line
|
||||||
|
/// Pie/donut chart — best for proportional breakdowns with few categories.
|
||||||
|
case pie
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A recommendation for how to chart a `DataTable`.
|
||||||
|
public struct ChartRecommendation: Sendable, Equatable {
|
||||||
|
/// The recommended chart type.
|
||||||
|
public let chartType: ChartType
|
||||||
|
|
||||||
|
/// The column to use for the category axis (x-axis / labels).
|
||||||
|
public let categoryColumn: String
|
||||||
|
|
||||||
|
/// The column to use for the value axis (y-axis / sizes).
|
||||||
|
public let valueColumn: String
|
||||||
|
|
||||||
|
/// Confidence score from 0.0 (guess) to 1.0 (strong match).
|
||||||
|
public let confidence: Double
|
||||||
|
|
||||||
|
/// Human-readable reason for this recommendation.
|
||||||
|
public let reason: String
|
||||||
|
|
||||||
|
public init(
|
||||||
|
chartType: ChartType,
|
||||||
|
categoryColumn: String,
|
||||||
|
valueColumn: String,
|
||||||
|
confidence: Double,
|
||||||
|
reason: String
|
||||||
|
) {
|
||||||
|
self.chartType = chartType
|
||||||
|
self.categoryColumn = categoryColumn
|
||||||
|
self.valueColumn = valueColumn
|
||||||
|
self.confidence = confidence
|
||||||
|
self.reason = reason
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Configuration
|
||||||
|
|
||||||
|
/// Minimum rows required to consider chart-eligible.
|
||||||
|
public let minimumRows: Int
|
||||||
|
|
||||||
|
/// Maximum rows for a pie chart (too many slices becomes unreadable).
|
||||||
|
public let maxPieSlices: Int
|
||||||
|
|
||||||
|
/// Maximum rows for any chart before it becomes cluttered.
|
||||||
|
public let maximumRows: Int
|
||||||
|
|
||||||
|
// MARK: - Initialization
|
||||||
|
|
||||||
|
/// Creates a detector with configurable thresholds.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - minimumRows: Minimum rows for chart eligibility (default: 2).
|
||||||
|
/// - maxPieSlices: Maximum categories for pie charts (default: 8).
|
||||||
|
/// - maximumRows: Maximum rows for any chart (default: 100).
|
||||||
|
public init(
|
||||||
|
minimumRows: Int = 2,
|
||||||
|
maxPieSlices: Int = 8,
|
||||||
|
maximumRows: Int = 100
|
||||||
|
) {
|
||||||
|
self.minimumRows = minimumRows
|
||||||
|
self.maxPieSlices = maxPieSlices
|
||||||
|
self.maximumRows = maximumRows
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Detection
|
||||||
|
|
||||||
|
/// Analyzes a `DataTable` and returns a chart recommendation, or `nil`
|
||||||
|
/// if the data is not suitable for charting.
|
||||||
|
///
|
||||||
|
/// - Parameter table: The data table to analyze.
|
||||||
|
/// - Returns: A recommendation, or `nil` if no chart type fits.
|
||||||
|
public func detect(_ table: DataTable) -> ChartRecommendation? {
|
||||||
|
// Must have at least 2 columns (category + value) and enough rows
|
||||||
|
guard table.columnCount >= 2,
|
||||||
|
table.rowCount >= minimumRows,
|
||||||
|
table.rowCount <= maximumRows else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find candidate category and value columns
|
||||||
|
guard let (categoryCol, valueCol) = findCategoryValuePair(in: table) else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
let chartType = recommendChartType(
|
||||||
|
table: table,
|
||||||
|
categoryColumn: categoryCol,
|
||||||
|
valueColumn: valueCol
|
||||||
|
)
|
||||||
|
|
||||||
|
let confidence = computeConfidence(
|
||||||
|
table: table,
|
||||||
|
categoryColumn: categoryCol,
|
||||||
|
valueColumn: valueCol,
|
||||||
|
chartType: chartType
|
||||||
|
)
|
||||||
|
|
||||||
|
let reason = describeReason(
|
||||||
|
chartType: chartType,
|
||||||
|
categoryColumn: categoryCol,
|
||||||
|
valueColumn: valueCol,
|
||||||
|
table: table
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChartRecommendation(
|
||||||
|
chartType: chartType,
|
||||||
|
categoryColumn: categoryCol.name,
|
||||||
|
valueColumn: valueCol.name,
|
||||||
|
confidence: confidence,
|
||||||
|
reason: reason
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns all viable chart recommendations, ranked by confidence.
|
||||||
|
///
|
||||||
|
/// - Parameter table: The data table to analyze.
|
||||||
|
/// - Returns: An array of recommendations sorted by confidence (highest first).
|
||||||
|
public func allRecommendations(for table: DataTable) -> [ChartRecommendation] {
|
||||||
|
guard table.columnCount >= 2,
|
||||||
|
table.rowCount >= minimumRows,
|
||||||
|
table.rowCount <= maximumRows else {
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
guard let (categoryCol, valueCol) = findCategoryValuePair(in: table) else {
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
return ChartType.allCases.compactMap { chartType in
|
||||||
|
guard isViable(chartType, table: table, categoryColumn: categoryCol) else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
let confidence = computeConfidence(
|
||||||
|
table: table,
|
||||||
|
categoryColumn: categoryCol,
|
||||||
|
valueColumn: valueCol,
|
||||||
|
chartType: chartType
|
||||||
|
)
|
||||||
|
|
||||||
|
let reason = describeReason(
|
||||||
|
chartType: chartType,
|
||||||
|
categoryColumn: categoryCol,
|
||||||
|
valueColumn: valueCol,
|
||||||
|
table: table
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChartRecommendation(
|
||||||
|
chartType: chartType,
|
||||||
|
categoryColumn: categoryCol.name,
|
||||||
|
valueColumn: valueCol.name,
|
||||||
|
confidence: confidence,
|
||||||
|
reason: reason
|
||||||
|
)
|
||||||
|
}
|
||||||
|
.sorted { $0.confidence > $1.confidence }
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Private Helpers
|
||||||
|
|
||||||
|
/// Finds the best (category, value) column pair from the table.
|
||||||
|
private func findCategoryValuePair(
|
||||||
|
in table: DataTable
|
||||||
|
) -> (category: DataTable.Column, value: DataTable.Column)? {
|
||||||
|
let numericColumns = table.columns.filter { isNumeric($0) }
|
||||||
|
let categoryColumns = table.columns.filter { isCategory($0) }
|
||||||
|
|
||||||
|
// Prefer: first text/category column + first numeric column
|
||||||
|
if let cat = categoryColumns.first, let val = numericColumns.first {
|
||||||
|
return (cat, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: if all columns are numeric, use first as category, second as value
|
||||||
|
if numericColumns.count >= 2 {
|
||||||
|
return (numericColumns[0], numericColumns[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Recommends the single best chart type for the data shape.
|
||||||
|
private func recommendChartType(
|
||||||
|
table: DataTable,
|
||||||
|
categoryColumn: DataTable.Column,
|
||||||
|
valueColumn: DataTable.Column
|
||||||
|
) -> ChartType {
|
||||||
|
// Line: time series or sequential numeric categories (check first — strongest signal)
|
||||||
|
if isTimeSeries(categoryColumn, in: table) || isSequential(categoryColumn, in: table) {
|
||||||
|
return .line
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pie: small number of categories with all-positive values
|
||||||
|
// Only when clearly categorical (text labels) and few rows
|
||||||
|
if table.rowCount <= maxPieSlices,
|
||||||
|
isCategory(categoryColumn),
|
||||||
|
isPieCandidate(table: table, valueColumn: valueColumn),
|
||||||
|
looksProportional(table: table, valueColumn: valueColumn) {
|
||||||
|
return .pie
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default: bar chart for categorical comparisons
|
||||||
|
return .bar
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Checks if a chart type is viable for the given data.
|
||||||
|
private func isViable(
|
||||||
|
_ chartType: ChartType,
|
||||||
|
table: DataTable,
|
||||||
|
categoryColumn: DataTable.Column
|
||||||
|
) -> Bool {
|
||||||
|
switch chartType {
|
||||||
|
case .pie:
|
||||||
|
return table.rowCount <= maxPieSlices
|
||||||
|
case .line:
|
||||||
|
return table.rowCount >= minimumRows
|
||||||
|
case .bar:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Determines if a column holds numeric data.
|
||||||
|
private func isNumeric(_ column: DataTable.Column) -> Bool {
|
||||||
|
switch column.inferredType {
|
||||||
|
case .integer, .real:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Determines if a column holds categorical (label) data.
|
||||||
|
private func isCategory(_ column: DataTable.Column) -> Bool {
|
||||||
|
switch column.inferredType {
|
||||||
|
case .text, .mixed:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Checks if the value column contains all non-negative values,
|
||||||
|
/// making it a candidate for pie charts.
|
||||||
|
private func isPieCandidate(
|
||||||
|
table: DataTable,
|
||||||
|
valueColumn: DataTable.Column
|
||||||
|
) -> Bool {
|
||||||
|
let values = table.numericValues(forColumn: valueColumn.name)
|
||||||
|
guard !values.isEmpty else { return false }
|
||||||
|
// All values must be positive for a meaningful pie chart
|
||||||
|
return values.allSatisfy { $0 > 0 }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Heuristic: do values look like they represent parts of a whole?
|
||||||
|
///
|
||||||
|
/// Checks for aggregate-like column names (count, total, sum, amount, pct, etc.)
|
||||||
|
/// or if values sum to a round number suggesting percentages/proportions.
|
||||||
|
private func looksProportional(
|
||||||
|
table: DataTable,
|
||||||
|
valueColumn: DataTable.Column
|
||||||
|
) -> Bool {
|
||||||
|
let proportionalNames: Set<String> = ["count", "total", "sum", "amount", "pct",
|
||||||
|
"percent", "percentage", "share", "proportion",
|
||||||
|
"quantity", "qty", "num", "number"]
|
||||||
|
// Split on common separators and check for exact word matches
|
||||||
|
let lowerName = valueColumn.name.lowercased()
|
||||||
|
let words = Set(lowerName.split { $0 == "_" || $0 == "-" || $0 == " " }.map(String.init))
|
||||||
|
if !words.isDisjoint(with: proportionalNames) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if values sum to ~100 (percentages)
|
||||||
|
let values = table.numericValues(forColumn: valueColumn.name)
|
||||||
|
let sum = values.reduce(0, +)
|
||||||
|
if abs(sum - 100.0) < 1.0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Heuristic: does the category column look like time-series data?
|
||||||
|
///
|
||||||
|
/// Checks for date-like patterns (YYYY, YYYY-MM, YYYY-MM-DD)
|
||||||
|
/// or common time-related column names.
|
||||||
|
private func isTimeSeries(_ column: DataTable.Column, in table: DataTable) -> Bool {
|
||||||
|
let timeNames = ["date", "time", "timestamp", "year", "month", "day",
|
||||||
|
"week", "quarter", "period", "created_at", "updated_at"]
|
||||||
|
let lowerName = column.name.lowercased()
|
||||||
|
if timeNames.contains(where: { lowerName.contains($0) }) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if text values look like dates
|
||||||
|
if column.inferredType == .text {
|
||||||
|
let values = table.stringValues(forColumn: column.name)
|
||||||
|
let datePattern = #/^\d{4}(-\d{2}){0,2}$/#
|
||||||
|
let matchCount = values.prefix(5).filter { (try? datePattern.wholeMatch(in: $0)) != nil }.count
|
||||||
|
if matchCount >= 3 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Heuristic: does the category column contain sequential numeric values?
|
||||||
|
private func isSequential(_ column: DataTable.Column, in table: DataTable) -> Bool {
|
||||||
|
guard isNumeric(column) else { return false }
|
||||||
|
let values = table.numericValues(forColumn: column.name)
|
||||||
|
guard values.count >= 3 else { return false }
|
||||||
|
|
||||||
|
// Check if values are monotonically increasing
|
||||||
|
for i in 1..<values.count {
|
||||||
|
if values[i] <= values[i - 1] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Computes a confidence score for a specific chart type + data combination.
|
||||||
|
private func computeConfidence(
|
||||||
|
table: DataTable,
|
||||||
|
categoryColumn: DataTable.Column,
|
||||||
|
valueColumn: DataTable.Column,
|
||||||
|
chartType: ChartType
|
||||||
|
) -> Double {
|
||||||
|
var score = 0.5 // baseline
|
||||||
|
|
||||||
|
// Bonus: clear category/value split (text + numeric)
|
||||||
|
if isCategory(categoryColumn) && isNumeric(valueColumn) {
|
||||||
|
score += 0.2
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bonus: reasonable row count for the chart type
|
||||||
|
switch chartType {
|
||||||
|
case .bar:
|
||||||
|
if table.rowCount >= 2 && table.rowCount <= 20 {
|
||||||
|
score += 0.15
|
||||||
|
}
|
||||||
|
case .line:
|
||||||
|
if isTimeSeries(categoryColumn, in: table) {
|
||||||
|
score += 0.2
|
||||||
|
} else if isSequential(categoryColumn, in: table) {
|
||||||
|
score += 0.1
|
||||||
|
}
|
||||||
|
case .pie:
|
||||||
|
if table.rowCount <= maxPieSlices && isPieCandidate(table: table, valueColumn: valueColumn) {
|
||||||
|
score += 0.2
|
||||||
|
}
|
||||||
|
// Penalty: too many slices
|
||||||
|
if table.rowCount > 5 {
|
||||||
|
score -= 0.1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bonus: no null values in key columns
|
||||||
|
let categoryNulls = table.columnValues(named: categoryColumn.name).filter(\.isNull).count
|
||||||
|
let valueNulls = table.columnValues(named: valueColumn.name).filter(\.isNull).count
|
||||||
|
if categoryNulls == 0 && valueNulls == 0 {
|
||||||
|
score += 0.1
|
||||||
|
}
|
||||||
|
|
||||||
|
return min(max(score, 0.0), 1.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a human-readable reason for the recommendation.
|
||||||
|
private func describeReason(
|
||||||
|
chartType: ChartType,
|
||||||
|
categoryColumn: DataTable.Column,
|
||||||
|
valueColumn: DataTable.Column,
|
||||||
|
table: DataTable
|
||||||
|
) -> String {
|
||||||
|
switch chartType {
|
||||||
|
case .bar:
|
||||||
|
return "\(table.rowCount) categories comparing \(valueColumn.name) by \(categoryColumn.name)"
|
||||||
|
case .line:
|
||||||
|
if isTimeSeries(categoryColumn, in: table) {
|
||||||
|
return "\(valueColumn.name) over time (\(categoryColumn.name))"
|
||||||
|
}
|
||||||
|
return "\(valueColumn.name) trend across \(table.rowCount) points"
|
||||||
|
case .pie:
|
||||||
|
return "Proportional breakdown of \(valueColumn.name) by \(categoryColumn.name)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
255
Sources/SwiftDBAI/Rendering/DataTable.swift
Normal file
255
Sources/SwiftDBAI/Rendering/DataTable.swift
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
// DataTable.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Structured table representation for rendering query results
|
||||||
|
// in SwiftUI table views and charts.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
/// A structured, row-column table built from a `QueryResult`.
|
||||||
|
///
|
||||||
|
/// `DataTable` provides indexed access to rows and columns, typed column
|
||||||
|
/// metadata, and convenience methods for extracting data suitable for
|
||||||
|
/// SwiftUI `Table` views and Swift Charts.
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
/// ```swift
|
||||||
|
/// let table = DataTable(queryResult)
|
||||||
|
/// print(table.columnCount) // 3
|
||||||
|
/// print(table[row: 0, column: 1]) // .text("Alice")
|
||||||
|
/// ```
|
||||||
|
public struct DataTable: Sendable, Equatable {
|
||||||
|
|
||||||
|
// MARK: - Column Metadata
|
||||||
|
|
||||||
|
/// Metadata for a single column in the data table.
|
||||||
|
public struct Column: Sendable, Equatable, Identifiable {
|
||||||
|
/// Stable identifier for the column (same as `name`).
|
||||||
|
public var id: String { name }
|
||||||
|
|
||||||
|
/// Column name from the query result set.
|
||||||
|
public let name: String
|
||||||
|
|
||||||
|
/// Index of this column in the table (0-based).
|
||||||
|
public let index: Int
|
||||||
|
|
||||||
|
/// Inferred data type based on the values in this column.
|
||||||
|
public let inferredType: InferredType
|
||||||
|
|
||||||
|
public init(name: String, index: Int, inferredType: InferredType) {
|
||||||
|
self.name = name
|
||||||
|
self.index = index
|
||||||
|
self.inferredType = inferredType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The inferred data type for a column, determined by inspecting its values.
|
||||||
|
public enum InferredType: Sendable, Equatable {
|
||||||
|
/// All non-null values are integers.
|
||||||
|
case integer
|
||||||
|
/// All non-null values are numeric (mix of integer and real).
|
||||||
|
case real
|
||||||
|
/// All non-null values are text.
|
||||||
|
case text
|
||||||
|
/// Values contain blob data.
|
||||||
|
case blob
|
||||||
|
/// Column contains only null values or is empty.
|
||||||
|
case null
|
||||||
|
/// Values are a mix of incompatible types.
|
||||||
|
case mixed
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Row Type
|
||||||
|
|
||||||
|
/// A single row in the data table, providing indexed and named access.
|
||||||
|
public struct Row: Sendable, Equatable, Identifiable {
|
||||||
|
/// Row index (0-based), used as stable identity.
|
||||||
|
public let id: Int
|
||||||
|
|
||||||
|
/// Values in column order.
|
||||||
|
public let values: [QueryResult.Value]
|
||||||
|
|
||||||
|
/// Column names for named access.
|
||||||
|
private let columnNames: [String]
|
||||||
|
|
||||||
|
public init(id: Int, values: [QueryResult.Value], columnNames: [String]) {
|
||||||
|
self.id = id
|
||||||
|
self.values = values
|
||||||
|
self.columnNames = columnNames
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Access a value by column index.
|
||||||
|
public subscript(columnIndex: Int) -> QueryResult.Value {
|
||||||
|
values[columnIndex]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Access a value by column name. Returns `.null` if the column doesn't exist.
|
||||||
|
public subscript(columnName: String) -> QueryResult.Value {
|
||||||
|
guard let idx = columnNames.firstIndex(of: columnName) else {
|
||||||
|
return .null
|
||||||
|
}
|
||||||
|
return values[idx]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Properties
|
||||||
|
|
||||||
|
/// Column metadata in order.
|
||||||
|
public let columns: [Column]
|
||||||
|
|
||||||
|
/// All rows in order.
|
||||||
|
public let rows: [Row]
|
||||||
|
|
||||||
|
/// The SQL that produced this table.
|
||||||
|
public let sql: String
|
||||||
|
|
||||||
|
/// Execution time of the underlying query.
|
||||||
|
public let executionTime: TimeInterval
|
||||||
|
|
||||||
|
/// Number of columns.
|
||||||
|
public var columnCount: Int { columns.count }
|
||||||
|
|
||||||
|
/// Number of rows.
|
||||||
|
public var rowCount: Int { rows.count }
|
||||||
|
|
||||||
|
/// Whether the table has no rows.
|
||||||
|
public var isEmpty: Bool { rows.isEmpty }
|
||||||
|
|
||||||
|
/// Column names in order.
|
||||||
|
public var columnNames: [String] { columns.map(\.name) }
|
||||||
|
|
||||||
|
// MARK: - Initialization
|
||||||
|
|
||||||
|
/// Creates a `DataTable` from a `QueryResult`.
|
||||||
|
///
|
||||||
|
/// Converts the dictionary-based row representation into an indexed
|
||||||
|
/// array representation and infers column types from the data.
|
||||||
|
///
|
||||||
|
/// - Parameter queryResult: The raw query result to convert.
|
||||||
|
public init(_ queryResult: QueryResult) {
|
||||||
|
let colNames = queryResult.columns
|
||||||
|
|
||||||
|
// Build indexed rows
|
||||||
|
let indexedRows: [Row] = queryResult.rows.enumerated().map { idx, rowDict in
|
||||||
|
let values = colNames.map { col in
|
||||||
|
rowDict[col] ?? .null
|
||||||
|
}
|
||||||
|
return Row(id: idx, values: values, columnNames: colNames)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Infer column types
|
||||||
|
let inferredColumns: [Column] = colNames.enumerated().map { colIdx, name in
|
||||||
|
let type = Self.inferType(
|
||||||
|
from: indexedRows.map { $0.values[colIdx] }
|
||||||
|
)
|
||||||
|
return Column(name: name, index: colIdx, inferredType: type)
|
||||||
|
}
|
||||||
|
|
||||||
|
self.columns = inferredColumns
|
||||||
|
self.rows = indexedRows
|
||||||
|
self.sql = queryResult.sql
|
||||||
|
self.executionTime = queryResult.executionTime
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a `DataTable` directly from components (useful for testing).
|
||||||
|
public init(
|
||||||
|
columns: [Column],
|
||||||
|
rows: [Row],
|
||||||
|
sql: String = "",
|
||||||
|
executionTime: TimeInterval = 0
|
||||||
|
) {
|
||||||
|
self.columns = columns
|
||||||
|
self.rows = rows
|
||||||
|
self.sql = sql
|
||||||
|
self.executionTime = executionTime
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Subscript Access
|
||||||
|
|
||||||
|
/// Access a cell by row and column index.
|
||||||
|
public subscript(row rowIndex: Int, column columnIndex: Int) -> QueryResult.Value {
|
||||||
|
rows[rowIndex].values[columnIndex]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Access a cell by row index and column name.
|
||||||
|
public subscript(row rowIndex: Int, column columnName: String) -> QueryResult.Value {
|
||||||
|
rows[rowIndex][columnName]
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Column Data Extraction
|
||||||
|
|
||||||
|
/// Returns all values for a column by index, in row order.
|
||||||
|
public func columnValues(at index: Int) -> [QueryResult.Value] {
|
||||||
|
rows.map { $0.values[index] }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns all values for a column by name, in row order.
|
||||||
|
public func columnValues(named name: String) -> [QueryResult.Value] {
|
||||||
|
guard let col = columns.first(where: { $0.name == name }) else {
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
return columnValues(at: col.index)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns all non-null `Double` values for a column (useful for charting).
|
||||||
|
public func numericValues(forColumn name: String) -> [Double] {
|
||||||
|
columnValues(named: name).compactMap(\.doubleValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns all non-null `String` values for a column (useful for labels).
|
||||||
|
public func stringValues(forColumn name: String) -> [String] {
|
||||||
|
columnValues(named: name).compactMap { value in
|
||||||
|
if case .null = value { return nil }
|
||||||
|
return value.stringValue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Type Inference
|
||||||
|
|
||||||
|
/// Infers the predominant type from an array of values.
|
||||||
|
static func inferType(from values: [QueryResult.Value]) -> InferredType {
|
||||||
|
var hasInteger = false
|
||||||
|
var hasReal = false
|
||||||
|
var hasText = false
|
||||||
|
var hasBlob = false
|
||||||
|
var hasNonNull = false
|
||||||
|
|
||||||
|
for value in values {
|
||||||
|
switch value {
|
||||||
|
case .integer:
|
||||||
|
hasInteger = true
|
||||||
|
hasNonNull = true
|
||||||
|
case .real:
|
||||||
|
hasReal = true
|
||||||
|
hasNonNull = true
|
||||||
|
case .text:
|
||||||
|
hasText = true
|
||||||
|
hasNonNull = true
|
||||||
|
case .blob:
|
||||||
|
hasBlob = true
|
||||||
|
hasNonNull = true
|
||||||
|
case .null:
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
guard hasNonNull else { return .null }
|
||||||
|
|
||||||
|
// Count how many distinct types are present
|
||||||
|
let typeCount = [hasInteger, hasReal, hasText, hasBlob].filter { $0 }.count
|
||||||
|
|
||||||
|
if typeCount == 1 {
|
||||||
|
if hasInteger { return .integer }
|
||||||
|
if hasReal { return .real }
|
||||||
|
if hasText { return .text }
|
||||||
|
if hasBlob { return .blob }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Integer + real → treat as real (numeric promotion)
|
||||||
|
if typeCount == 2, hasInteger, hasReal {
|
||||||
|
return .real
|
||||||
|
}
|
||||||
|
|
||||||
|
return .mixed
|
||||||
|
}
|
||||||
|
}
|
||||||
301
Sources/SwiftDBAI/Rendering/TextSummaryRenderer.swift
Normal file
301
Sources/SwiftDBAI/Rendering/TextSummaryRenderer.swift
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
// TextSummaryRenderer.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Converts raw SQL query results into natural language text summaries
|
||||||
|
// using the LLM via AnyLanguageModel.
|
||||||
|
|
||||||
|
import AnyLanguageModel
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
/// Renders SQL query results as natural language text summaries.
|
||||||
|
///
|
||||||
|
/// The renderer takes a `QueryResult` and the user's original question,
|
||||||
|
/// sends them to the LLM for summarization, and returns a concise,
|
||||||
|
/// human-readable response.
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
/// ```swift
|
||||||
|
/// let renderer = TextSummaryRenderer(model: myModel)
|
||||||
|
/// let summary = try await renderer.summarize(
|
||||||
|
/// result: queryResult,
|
||||||
|
/// userQuestion: "How many orders were placed last month?"
|
||||||
|
/// )
|
||||||
|
/// print(summary) // "There were 42 orders placed last month."
|
||||||
|
/// ```
|
||||||
|
public struct TextSummaryRenderer: Sendable {
|
||||||
|
|
||||||
|
/// The language model used to generate summaries.
|
||||||
|
private let model: any LanguageModel
|
||||||
|
|
||||||
|
/// Maximum number of rows to include in the LLM prompt.
|
||||||
|
///
|
||||||
|
/// Results larger than this are truncated with a note about total count.
|
||||||
|
public let maxRowsInPrompt: Int
|
||||||
|
|
||||||
|
/// Creates a new text summary renderer.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - model: Any `AnyLanguageModel`-compatible language model.
|
||||||
|
/// - maxRowsInPrompt: Maximum rows to send to the LLM for summarization (default: 50).
|
||||||
|
public init(model: any LanguageModel, maxRowsInPrompt: Int = 50) {
|
||||||
|
self.model = model
|
||||||
|
self.maxRowsInPrompt = maxRowsInPrompt
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a natural language summary of query results.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - result: The raw `QueryResult` from SQL execution.
|
||||||
|
/// - userQuestion: The original natural language question from the user.
|
||||||
|
/// - context: Optional additional context (e.g., table descriptions) to help the LLM.
|
||||||
|
/// - Returns: A natural language text summary of the results.
|
||||||
|
public func summarize(
|
||||||
|
result: QueryResult,
|
||||||
|
userQuestion: String,
|
||||||
|
context: String? = nil
|
||||||
|
) async throws -> String {
|
||||||
|
// For mutation results (INSERT/UPDATE/DELETE), use a simple template
|
||||||
|
if let affected = result.rowsAffected {
|
||||||
|
return summarizeMutation(result: result, affected: affected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For empty results, no need to call the LLM
|
||||||
|
if result.rows.isEmpty {
|
||||||
|
return "No results found for your query."
|
||||||
|
}
|
||||||
|
|
||||||
|
// For simple aggregates, produce a direct answer without LLM
|
||||||
|
if let directAnswer = tryDirectAggregateSummary(result: result, userQuestion: userQuestion) {
|
||||||
|
return directAnswer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the prompt and ask the LLM to summarize
|
||||||
|
let prompt = buildSummarizationPrompt(
|
||||||
|
result: result,
|
||||||
|
userQuestion: userQuestion,
|
||||||
|
context: context
|
||||||
|
)
|
||||||
|
|
||||||
|
let session = LanguageModelSession(
|
||||||
|
model: model,
|
||||||
|
instructions: summaryInstructions
|
||||||
|
)
|
||||||
|
|
||||||
|
let response = try await session.respond(to: prompt)
|
||||||
|
return response.content.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Generates a summary without calling the LLM, using simple templates.
|
||||||
|
///
|
||||||
|
/// Useful when LLM access is unavailable, or for fast local rendering.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - result: The raw `QueryResult` from SQL execution.
|
||||||
|
/// - userQuestion: The original natural language question.
|
||||||
|
/// - Returns: A template-based text summary.
|
||||||
|
public func localSummary(result: QueryResult, userQuestion: String) -> String {
|
||||||
|
if let affected = result.rowsAffected {
|
||||||
|
return summarizeMutation(result: result, affected: affected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.rows.isEmpty {
|
||||||
|
return "No results found for your query."
|
||||||
|
}
|
||||||
|
|
||||||
|
if let directAnswer = tryDirectAggregateSummary(result: result, userQuestion: userQuestion) {
|
||||||
|
return directAnswer
|
||||||
|
}
|
||||||
|
|
||||||
|
return buildTemplateSummary(result: result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Private Helpers
|
||||||
|
|
||||||
|
/// System instructions for the summarization session.
|
||||||
|
private var summaryInstructions: String {
|
||||||
|
"""
|
||||||
|
You are a data assistant that summarizes SQL query results in natural language.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- Be concise and direct. Answer the user's question first, then add detail if helpful.
|
||||||
|
- Use natural language, not SQL or code.
|
||||||
|
- For numeric results, include the exact numbers.
|
||||||
|
- For lists of records, summarize the count and highlight notable items.
|
||||||
|
- If the data contains dates, format them in a readable way.
|
||||||
|
- Do not mention SQL, databases, tables, columns, or queries in your response.
|
||||||
|
- Do not include markdown formatting.
|
||||||
|
- Keep your response under 3 sentences for simple results, under 5 for complex ones.
|
||||||
|
"""
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builds the prompt sent to the LLM for summarization.
|
||||||
|
private func buildSummarizationPrompt(
|
||||||
|
result: QueryResult,
|
||||||
|
userQuestion: String,
|
||||||
|
context: String?
|
||||||
|
) -> String {
|
||||||
|
var parts: [String] = []
|
||||||
|
|
||||||
|
parts.append("User's question: \(userQuestion)")
|
||||||
|
|
||||||
|
if let context {
|
||||||
|
parts.append("Context: \(context)")
|
||||||
|
}
|
||||||
|
|
||||||
|
parts.append("Query returned \(result.rowCount) row(s) with columns: \(result.columns.joined(separator: ", "))")
|
||||||
|
|
||||||
|
// Include the result data (truncated if large)
|
||||||
|
let dataStr = formatResultData(result)
|
||||||
|
parts.append("Data:\n\(dataStr)")
|
||||||
|
|
||||||
|
parts.append("Summarize these results in natural language, directly answering the user's question.")
|
||||||
|
|
||||||
|
return parts.joined(separator: "\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Formats the query result data as a compact table for the LLM prompt.
|
||||||
|
private func formatResultData(_ result: QueryResult) -> String {
|
||||||
|
let rowsToInclude = Array(result.rows.prefix(maxRowsInPrompt))
|
||||||
|
var lines: [String] = []
|
||||||
|
|
||||||
|
// Header
|
||||||
|
lines.append(result.columns.joined(separator: " | "))
|
||||||
|
|
||||||
|
// Rows
|
||||||
|
for row in rowsToInclude {
|
||||||
|
let values = result.columns.map { col in
|
||||||
|
row[col]?.description ?? "NULL"
|
||||||
|
}
|
||||||
|
lines.append(values.joined(separator: " | "))
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.rowCount > maxRowsInPrompt {
|
||||||
|
lines.append("(\(result.rowCount - maxRowsInPrompt) additional rows not shown)")
|
||||||
|
}
|
||||||
|
|
||||||
|
return lines.joined(separator: "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Produces a direct answer for simple aggregate queries (1 row, few columns).
|
||||||
|
private func tryDirectAggregateSummary(result: QueryResult, userQuestion: String) -> String? {
|
||||||
|
guard result.isAggregate else { return nil }
|
||||||
|
|
||||||
|
let row = result.rows[0]
|
||||||
|
|
||||||
|
// Single numeric column — e.g., "COUNT(*)" → "42"
|
||||||
|
if result.columns.count == 1 {
|
||||||
|
let col = result.columns[0]
|
||||||
|
guard let value = row[col] else { return nil }
|
||||||
|
let formatted = formatNumber(value)
|
||||||
|
return "The result is \(formatted)."
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiple aggregate columns — e.g., COUNT, AVG, SUM
|
||||||
|
let parts = result.columns.compactMap { col -> String? in
|
||||||
|
guard let value = row[col] else { return nil }
|
||||||
|
let label = humanizeColumnName(col)
|
||||||
|
let formatted = formatNumber(value)
|
||||||
|
return "\(label): \(formatted)"
|
||||||
|
}
|
||||||
|
return parts.joined(separator: ", ") + "."
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Formats a numeric Value for display.
|
||||||
|
private func formatNumber(_ value: QueryResult.Value) -> String {
|
||||||
|
switch value {
|
||||||
|
case .integer(let i):
|
||||||
|
return NumberFormatter.localizedString(from: NSNumber(value: i), number: .decimal)
|
||||||
|
case .real(let d):
|
||||||
|
if d == d.rounded() && abs(d) < 1e12 {
|
||||||
|
return NumberFormatter.localizedString(from: NSNumber(value: Int64(d)), number: .decimal)
|
||||||
|
}
|
||||||
|
let formatter = NumberFormatter()
|
||||||
|
formatter.numberStyle = .decimal
|
||||||
|
formatter.maximumFractionDigits = 2
|
||||||
|
return formatter.string(from: NSNumber(value: d)) ?? String(d)
|
||||||
|
default:
|
||||||
|
return value.description
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts a column name like "total_count" or "AVG(price)" into a readable label.
|
||||||
|
private func humanizeColumnName(_ name: String) -> String {
|
||||||
|
// Handle SQL function names: "COUNT(*)" → "count", "AVG(price)" → "average price"
|
||||||
|
let functionPatterns: [(pattern: String, label: String)] = [
|
||||||
|
("COUNT", "count"),
|
||||||
|
("SUM", "total"),
|
||||||
|
("AVG", "average"),
|
||||||
|
("MIN", "minimum"),
|
||||||
|
("MAX", "maximum"),
|
||||||
|
]
|
||||||
|
|
||||||
|
let upper = name.uppercased()
|
||||||
|
for (pattern, label) in functionPatterns {
|
||||||
|
if upper.hasPrefix(pattern + "(") {
|
||||||
|
// Extract the inner column name
|
||||||
|
let start = name.index(name.startIndex, offsetBy: pattern.count + 1)
|
||||||
|
let end = name.index(before: name.endIndex)
|
||||||
|
if start < end {
|
||||||
|
let inner = String(name[start..<end])
|
||||||
|
if inner == "*" { return label }
|
||||||
|
return "\(label) \(humanizeColumnName(inner))"
|
||||||
|
}
|
||||||
|
return label
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// snake_case → space-separated
|
||||||
|
return name
|
||||||
|
.replacingOccurrences(of: "_", with: " ")
|
||||||
|
.lowercased()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Produces a template-based summary without calling the LLM.
|
||||||
|
private func buildTemplateSummary(result: QueryResult) -> String {
|
||||||
|
let count = result.rowCount
|
||||||
|
|
||||||
|
if count == 1 {
|
||||||
|
// Single record — list field values
|
||||||
|
let row = result.rows[0]
|
||||||
|
let details = result.columns.prefix(5).compactMap { col -> String? in
|
||||||
|
guard let val = row[col], !val.isNull else { return nil }
|
||||||
|
return "\(humanizeColumnName(col)): \(val.description)"
|
||||||
|
}
|
||||||
|
return "Found 1 result. \(details.joined(separator: ", "))."
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiple records
|
||||||
|
var summary = "Found \(count) results"
|
||||||
|
|
||||||
|
// If there's a clear "name" or "title" column, list first few
|
||||||
|
let nameColumns = ["name", "title", "label", "description"]
|
||||||
|
if let nameCol = result.columns.first(where: { nameColumns.contains($0.lowercased()) }) {
|
||||||
|
let names = result.rows.prefix(3).compactMap { $0[nameCol]?.description }
|
||||||
|
if !names.isEmpty {
|
||||||
|
summary += " including \(names.joined(separator: ", "))"
|
||||||
|
if count > 3 { summary += ", and \(count - 3) more" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return summary + "."
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Summarizes a mutation (INSERT/UPDATE/DELETE) result.
|
||||||
|
private func summarizeMutation(result: QueryResult, affected: Int) -> String {
|
||||||
|
let sql = result.sql.trimmingCharacters(in: .whitespacesAndNewlines).uppercased()
|
||||||
|
|
||||||
|
let operation: String
|
||||||
|
if sql.hasPrefix("INSERT") {
|
||||||
|
operation = "inserted"
|
||||||
|
} else if sql.hasPrefix("UPDATE") {
|
||||||
|
operation = "updated"
|
||||||
|
} else if sql.hasPrefix("DELETE") {
|
||||||
|
operation = "deleted"
|
||||||
|
} else {
|
||||||
|
operation = "affected"
|
||||||
|
}
|
||||||
|
|
||||||
|
let noun = affected == 1 ? "row" : "rows"
|
||||||
|
return "Successfully \(operation) \(affected) \(noun)."
|
||||||
|
}
|
||||||
|
}
|
||||||
164
Sources/SwiftDBAI/Schema/DatabaseSchema.swift
Normal file
164
Sources/SwiftDBAI/Schema/DatabaseSchema.swift
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
// DatabaseSchema.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Auto-introspected SQLite schema model types.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
/// Complete schema representation of an SQLite database.
|
||||||
|
public struct DatabaseSchema: Sendable, Equatable {
|
||||||
|
/// All tables in the database, keyed by table name.
|
||||||
|
public let tables: [String: TableSchema]
|
||||||
|
|
||||||
|
/// Ordered table names (preserves discovery order).
|
||||||
|
public let tableNames: [String]
|
||||||
|
|
||||||
|
/// Returns a compact text description suitable for LLM system prompts.
|
||||||
|
public var schemaDescription: String {
|
||||||
|
var lines: [String] = []
|
||||||
|
for name in tableNames {
|
||||||
|
guard let table = tables[name] else { continue }
|
||||||
|
lines.append(table.descriptionForLLM)
|
||||||
|
}
|
||||||
|
return lines.joined(separator: "\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a description suitable for LLM system prompts.
|
||||||
|
/// Alias for `schemaDescription` for API compatibility.
|
||||||
|
public func describeForLLM() -> String {
|
||||||
|
schemaDescription
|
||||||
|
}
|
||||||
|
|
||||||
|
public init(tables: [String: TableSchema], tableNames: [String]) {
|
||||||
|
self.tables = tables
|
||||||
|
self.tableNames = tableNames
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Schema for a single SQLite table.
|
||||||
|
public struct TableSchema: Sendable, Equatable {
|
||||||
|
public let name: String
|
||||||
|
public let columns: [ColumnSchema]
|
||||||
|
public let primaryKey: [String]
|
||||||
|
public let foreignKeys: [ForeignKeySchema]
|
||||||
|
public let indexes: [IndexSchema]
|
||||||
|
|
||||||
|
/// Text description for embedding in LLM prompts.
|
||||||
|
public var descriptionForLLM: String {
|
||||||
|
var parts: [String] = []
|
||||||
|
let colDefs = columns.map { col in
|
||||||
|
var def = " \(col.name) \(col.type)"
|
||||||
|
if col.isPrimaryKey { def += " PRIMARY KEY" }
|
||||||
|
if col.isNotNull { def += " NOT NULL" }
|
||||||
|
if let defaultValue = col.defaultValue { def += " DEFAULT \(defaultValue)" }
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
parts.append("TABLE \(name) (\n\(colDefs.joined(separator: ",\n"))\n)")
|
||||||
|
|
||||||
|
if !foreignKeys.isEmpty {
|
||||||
|
let fkDescs = foreignKeys.map {
|
||||||
|
" FOREIGN KEY (\($0.fromColumn)) REFERENCES \($0.toTable)(\($0.toColumn))"
|
||||||
|
}
|
||||||
|
parts.append("FOREIGN KEYS:\n\(fkDescs.joined(separator: "\n"))")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !indexes.isEmpty {
|
||||||
|
let idxDescs = indexes.map {
|
||||||
|
" INDEX \($0.name) ON (\($0.columns.joined(separator: ", ")))\($0.isUnique ? " UNIQUE" : "")"
|
||||||
|
}
|
||||||
|
parts.append("INDEXES:\n\(idxDescs.joined(separator: "\n"))")
|
||||||
|
}
|
||||||
|
|
||||||
|
return parts.joined(separator: "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
public init(
|
||||||
|
name: String,
|
||||||
|
columns: [ColumnSchema],
|
||||||
|
primaryKey: [String],
|
||||||
|
foreignKeys: [ForeignKeySchema],
|
||||||
|
indexes: [IndexSchema]
|
||||||
|
) {
|
||||||
|
self.name = name
|
||||||
|
self.columns = columns
|
||||||
|
self.primaryKey = primaryKey
|
||||||
|
self.foreignKeys = foreignKeys
|
||||||
|
self.indexes = indexes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Schema for a single column.
|
||||||
|
public struct ColumnSchema: Sendable, Equatable {
|
||||||
|
/// Column position (0-based).
|
||||||
|
public let cid: Int
|
||||||
|
/// Column name.
|
||||||
|
public let name: String
|
||||||
|
/// Declared SQLite type (e.g. "TEXT", "INTEGER", "REAL", "BLOB").
|
||||||
|
public let type: String
|
||||||
|
/// Whether the column has a NOT NULL constraint.
|
||||||
|
public let isNotNull: Bool
|
||||||
|
/// Default value expression, if any.
|
||||||
|
public let defaultValue: String?
|
||||||
|
/// Whether this column is part of the primary key.
|
||||||
|
public let isPrimaryKey: Bool
|
||||||
|
|
||||||
|
public init(
|
||||||
|
cid: Int,
|
||||||
|
name: String,
|
||||||
|
type: String,
|
||||||
|
isNotNull: Bool,
|
||||||
|
defaultValue: String?,
|
||||||
|
isPrimaryKey: Bool
|
||||||
|
) {
|
||||||
|
self.cid = cid
|
||||||
|
self.name = name
|
||||||
|
self.type = type
|
||||||
|
self.isNotNull = isNotNull
|
||||||
|
self.defaultValue = defaultValue
|
||||||
|
self.isPrimaryKey = isPrimaryKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Schema for a foreign key relationship.
|
||||||
|
public struct ForeignKeySchema: Sendable, Equatable {
|
||||||
|
/// Column in the source table.
|
||||||
|
public let fromColumn: String
|
||||||
|
/// Referenced table name.
|
||||||
|
public let toTable: String
|
||||||
|
/// Referenced column name.
|
||||||
|
public let toColumn: String
|
||||||
|
/// ON UPDATE action (e.g. "CASCADE", "NO ACTION").
|
||||||
|
public let onUpdate: String
|
||||||
|
/// ON DELETE action.
|
||||||
|
public let onDelete: String
|
||||||
|
|
||||||
|
public init(
|
||||||
|
fromColumn: String,
|
||||||
|
toTable: String,
|
||||||
|
toColumn: String,
|
||||||
|
onUpdate: String,
|
||||||
|
onDelete: String
|
||||||
|
) {
|
||||||
|
self.fromColumn = fromColumn
|
||||||
|
self.toTable = toTable
|
||||||
|
self.toColumn = toColumn
|
||||||
|
self.onUpdate = onUpdate
|
||||||
|
self.onDelete = onDelete
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Schema for a database index.
|
||||||
|
public struct IndexSchema: Sendable, Equatable {
|
||||||
|
/// Index name.
|
||||||
|
public let name: String
|
||||||
|
/// Whether the index enforces uniqueness.
|
||||||
|
public let isUnique: Bool
|
||||||
|
/// Columns included in the index, in order.
|
||||||
|
public let columns: [String]
|
||||||
|
|
||||||
|
public init(name: String, isUnique: Bool, columns: [String]) {
|
||||||
|
self.name = name
|
||||||
|
self.isUnique = isUnique
|
||||||
|
self.columns = columns
|
||||||
|
}
|
||||||
|
}
|
||||||
153
Sources/SwiftDBAI/Schema/SchemaIntrospector.swift
Normal file
153
Sources/SwiftDBAI/Schema/SchemaIntrospector.swift
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
// SchemaIntrospector.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Auto-introspects SQLite database schema using GRDB.
|
||||||
|
|
||||||
|
import GRDB
|
||||||
|
|
||||||
|
/// Introspects an SQLite database schema by querying sqlite_master and PRAGMA statements.
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
/// ```swift
|
||||||
|
/// let dbPool = try DatabasePool(path: "path/to/db.sqlite")
|
||||||
|
/// let schema = try await SchemaIntrospector.introspect(database: dbPool)
|
||||||
|
/// print(schema.schemaDescription)
|
||||||
|
/// ```
|
||||||
|
public struct SchemaIntrospector: Sendable {
|
||||||
|
|
||||||
|
// MARK: - Public API
|
||||||
|
|
||||||
|
/// Introspects the full schema of the given database.
|
||||||
|
///
|
||||||
|
/// Discovers all user tables (excluding sqlite_ internal tables),
|
||||||
|
/// their columns, primary keys, foreign keys, and indexes.
|
||||||
|
///
|
||||||
|
/// - Parameter database: A GRDB `DatabaseReader` (DatabasePool or DatabaseQueue).
|
||||||
|
/// - Returns: A complete `DatabaseSchema` representation.
|
||||||
|
public static func introspect(database: any DatabaseReader) async throws -> DatabaseSchema {
|
||||||
|
try await database.read { db in
|
||||||
|
try introspect(db: db)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Synchronous introspection within an existing database access context.
|
||||||
|
///
|
||||||
|
/// - Parameter db: A GRDB `Database` instance from within a read/write block.
|
||||||
|
/// - Returns: A complete `DatabaseSchema` representation.
|
||||||
|
public static func introspect(db: Database) throws -> DatabaseSchema {
|
||||||
|
let tableNames = try fetchTableNames(db: db)
|
||||||
|
var tables: [String: TableSchema] = [:]
|
||||||
|
|
||||||
|
for tableName in tableNames {
|
||||||
|
let columns = try fetchColumns(db: db, table: tableName)
|
||||||
|
let primaryKey = try fetchPrimaryKey(db: db, table: tableName)
|
||||||
|
let foreignKeys = try fetchForeignKeys(db: db, table: tableName)
|
||||||
|
let indexes = try fetchIndexes(db: db, table: tableName)
|
||||||
|
|
||||||
|
// Mark columns that are part of the primary key
|
||||||
|
let pkSet = Set(primaryKey)
|
||||||
|
let annotatedColumns = columns.map { col in
|
||||||
|
ColumnSchema(
|
||||||
|
cid: col.cid,
|
||||||
|
name: col.name,
|
||||||
|
type: col.type,
|
||||||
|
isNotNull: col.isNotNull,
|
||||||
|
defaultValue: col.defaultValue,
|
||||||
|
isPrimaryKey: pkSet.contains(col.name)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
tables[tableName] = TableSchema(
|
||||||
|
name: tableName,
|
||||||
|
columns: annotatedColumns,
|
||||||
|
primaryKey: primaryKey,
|
||||||
|
foreignKeys: foreignKeys,
|
||||||
|
indexes: indexes
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return DatabaseSchema(tables: tables, tableNames: tableNames)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Private Helpers
|
||||||
|
|
||||||
|
/// Fetches all user table names from sqlite_master.
|
||||||
|
private static func fetchTableNames(db: Database) throws -> [String] {
|
||||||
|
let sql = """
|
||||||
|
SELECT name FROM sqlite_master
|
||||||
|
WHERE type = 'table'
|
||||||
|
AND name NOT LIKE 'sqlite_%'
|
||||||
|
ORDER BY name
|
||||||
|
"""
|
||||||
|
return try String.fetchAll(db, sql: sql)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fetches column metadata for a table using PRAGMA table_info.
|
||||||
|
private static func fetchColumns(db: Database, table: String) throws -> [ColumnSchema] {
|
||||||
|
let sql = "PRAGMA table_info(\(table.quotedDatabaseIdentifier))"
|
||||||
|
let rows = try Row.fetchAll(db, sql: sql)
|
||||||
|
return rows.map { row in
|
||||||
|
ColumnSchema(
|
||||||
|
cid: row["cid"],
|
||||||
|
name: row["name"],
|
||||||
|
type: (row["type"] as String?) ?? "",
|
||||||
|
isNotNull: row["notnull"] == 1,
|
||||||
|
defaultValue: row["dflt_value"],
|
||||||
|
isPrimaryKey: row["pk"] != 0
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fetches primary key columns for a table.
|
||||||
|
private static func fetchPrimaryKey(db: Database, table: String) throws -> [String] {
|
||||||
|
let sql = "PRAGMA table_info(\(table.quotedDatabaseIdentifier))"
|
||||||
|
let rows = try Row.fetchAll(db, sql: sql)
|
||||||
|
return rows
|
||||||
|
.filter { ($0["pk"] as Int) > 0 }
|
||||||
|
.sorted { ($0["pk"] as Int) < ($1["pk"] as Int) }
|
||||||
|
.map { $0["name"] }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fetches foreign key relationships for a table.
|
||||||
|
private static func fetchForeignKeys(db: Database, table: String) throws -> [ForeignKeySchema] {
|
||||||
|
let sql = "PRAGMA foreign_key_list(\(table.quotedDatabaseIdentifier))"
|
||||||
|
let rows = try Row.fetchAll(db, sql: sql)
|
||||||
|
return rows.map { row in
|
||||||
|
ForeignKeySchema(
|
||||||
|
fromColumn: row["from"],
|
||||||
|
toTable: row["table"],
|
||||||
|
toColumn: row["to"],
|
||||||
|
onUpdate: row["on_update"] ?? "NO ACTION",
|
||||||
|
onDelete: row["on_delete"] ?? "NO ACTION"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fetches indexes and their columns for a table.
|
||||||
|
private static func fetchIndexes(db: Database, table: String) throws -> [IndexSchema] {
|
||||||
|
let indexListSQL = "PRAGMA index_list(\(table.quotedDatabaseIdentifier))"
|
||||||
|
let indexRows = try Row.fetchAll(db, sql: indexListSQL)
|
||||||
|
|
||||||
|
var indexes: [IndexSchema] = []
|
||||||
|
for indexRow in indexRows {
|
||||||
|
let indexName: String = indexRow["name"]
|
||||||
|
let isUnique: Bool = indexRow["unique"] == 1
|
||||||
|
|
||||||
|
// Skip auto-generated indexes for primary keys
|
||||||
|
if indexName.hasPrefix("sqlite_autoindex_") { continue }
|
||||||
|
|
||||||
|
let infoSQL = "PRAGMA index_info(\(indexName.quotedDatabaseIdentifier))"
|
||||||
|
let infoRows = try Row.fetchAll(db, sql: infoSQL)
|
||||||
|
let columns: [String] = infoRows
|
||||||
|
.sorted { ($0["seqno"] as Int) < ($1["seqno"] as Int) }
|
||||||
|
.map { $0["name"] }
|
||||||
|
|
||||||
|
indexes.append(IndexSchema(
|
||||||
|
name: indexName,
|
||||||
|
isUnique: isUnique,
|
||||||
|
columns: columns
|
||||||
|
))
|
||||||
|
}
|
||||||
|
return indexes
|
||||||
|
}
|
||||||
|
}
|
||||||
215
Sources/SwiftDBAI/SwiftDBAIError.swift
Normal file
215
Sources/SwiftDBAI/SwiftDBAIError.swift
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
// SwiftDBAIError.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Unified error type for the SwiftDBAI package.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
/// The top-level error type for SwiftDBAI operations.
|
||||||
|
///
|
||||||
|
/// `SwiftDBAIError` provides a single, typed error surface that covers
|
||||||
|
/// every failure mode a consumer of SwiftDBAI may encounter — from invalid
|
||||||
|
/// SQL and LLM failures to schema mismatches and safety violations.
|
||||||
|
///
|
||||||
|
/// Every case includes a user-friendly `localizedDescription` suitable for
|
||||||
|
/// displaying directly in a chat interface.
|
||||||
|
public enum SwiftDBAIError: Error, LocalizedError, Sendable, Equatable {
|
||||||
|
|
||||||
|
// MARK: - SQL Errors
|
||||||
|
|
||||||
|
/// No SQL statement could be extracted from the LLM response.
|
||||||
|
case noSQLGenerated
|
||||||
|
|
||||||
|
/// The generated SQL is syntactically invalid or failed execution.
|
||||||
|
case invalidSQL(sql: String, reason: String)
|
||||||
|
|
||||||
|
/// The SQL uses an operation (e.g. DELETE) not in the developer's allowlist.
|
||||||
|
case operationNotAllowed(operation: String)
|
||||||
|
|
||||||
|
/// Multiple SQL statements were generated but only single-statement execution is supported.
|
||||||
|
case multipleStatementsNotSupported
|
||||||
|
|
||||||
|
/// A dangerous SQL keyword (DROP, ALTER, TRUNCATE) was detected.
|
||||||
|
case dangerousOperationBlocked(keyword: String)
|
||||||
|
|
||||||
|
// MARK: - LLM Errors
|
||||||
|
|
||||||
|
/// The LLM failed to produce a response.
|
||||||
|
case llmFailure(reason: String)
|
||||||
|
|
||||||
|
/// The LLM response could not be parsed into an actionable result.
|
||||||
|
case llmResponseUnparseable(response: String)
|
||||||
|
|
||||||
|
/// The LLM request timed out.
|
||||||
|
case llmTimeout(seconds: TimeInterval)
|
||||||
|
|
||||||
|
// MARK: - Schema Errors
|
||||||
|
|
||||||
|
/// Schema introspection of the database failed.
|
||||||
|
case schemaIntrospectionFailed(reason: String)
|
||||||
|
|
||||||
|
/// The generated SQL references a table that does not exist in the schema.
|
||||||
|
case tableNotFound(tableName: String)
|
||||||
|
|
||||||
|
/// The generated SQL references a column that does not exist on the given table.
|
||||||
|
case columnNotFound(columnName: String, tableName: String)
|
||||||
|
|
||||||
|
/// The database schema is empty (no user tables found).
|
||||||
|
case emptySchema
|
||||||
|
|
||||||
|
// MARK: - Safety & Validation Errors
|
||||||
|
|
||||||
|
/// A destructive operation requires explicit user confirmation before execution.
|
||||||
|
case confirmationRequired(sql: String, operation: String)
|
||||||
|
|
||||||
|
/// A mutation targets a table not in the allowed mutation tables.
|
||||||
|
case tableNotAllowedForMutation(tableName: String, operation: String)
|
||||||
|
|
||||||
|
/// A custom query validator rejected the query.
|
||||||
|
case queryRejected(reason: String)
|
||||||
|
|
||||||
|
// MARK: - Database Errors
|
||||||
|
|
||||||
|
/// The underlying database operation failed.
|
||||||
|
case databaseError(reason: String)
|
||||||
|
|
||||||
|
/// The query exceeded the configured execution timeout.
|
||||||
|
case queryTimedOut(seconds: TimeInterval)
|
||||||
|
|
||||||
|
// MARK: - Configuration Errors
|
||||||
|
|
||||||
|
/// The engine has not been configured correctly.
|
||||||
|
case configurationError(reason: String)
|
||||||
|
|
||||||
|
// MARK: - Error Classification
|
||||||
|
|
||||||
|
/// Whether this error represents a safety/permissions issue (not a bug).
|
||||||
|
public var isSafetyError: Bool {
|
||||||
|
switch self {
|
||||||
|
case .operationNotAllowed, .dangerousOperationBlocked,
|
||||||
|
.confirmationRequired, .tableNotAllowedForMutation, .queryRejected:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Whether this error is recoverable by rephrasing the user's question.
|
||||||
|
public var isRecoverable: Bool {
|
||||||
|
switch self {
|
||||||
|
case .noSQLGenerated, .llmResponseUnparseable, .invalidSQL,
|
||||||
|
.tableNotFound, .columnNotFound:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Whether this error requires user action (e.g. confirmation).
|
||||||
|
public var requiresUserAction: Bool {
|
||||||
|
if case .confirmationRequired = self { return true }
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - LocalizedError
|
||||||
|
|
||||||
|
public var errorDescription: String? {
|
||||||
|
switch self {
|
||||||
|
// SQL
|
||||||
|
case .noSQLGenerated:
|
||||||
|
return "I couldn't generate a SQL query from your request. Could you rephrase your question?"
|
||||||
|
case .invalidSQL(let sql, let reason):
|
||||||
|
return "The generated query is invalid — \(reason). Query: \(sql)"
|
||||||
|
case .operationNotAllowed(let operation):
|
||||||
|
return "The \(operation.uppercased()) operation is not allowed by the current configuration."
|
||||||
|
case .multipleStatementsNotSupported:
|
||||||
|
return "Only single SQL statements are supported. Please ask one question at a time."
|
||||||
|
case .dangerousOperationBlocked(let keyword):
|
||||||
|
return "The \(keyword.uppercased()) operation is blocked for safety. This operation is never allowed."
|
||||||
|
|
||||||
|
// LLM
|
||||||
|
case .llmFailure(let reason):
|
||||||
|
return "The language model encountered an error: \(reason)"
|
||||||
|
case .llmResponseUnparseable(let response):
|
||||||
|
return "I received a response but couldn't understand it. Raw response: \(response.prefix(200))"
|
||||||
|
case .llmTimeout(let seconds):
|
||||||
|
return "The language model did not respond within \(Int(seconds)) seconds. Please try again."
|
||||||
|
|
||||||
|
// Schema
|
||||||
|
case .schemaIntrospectionFailed(let reason):
|
||||||
|
return "Failed to read the database schema: \(reason)"
|
||||||
|
case .tableNotFound(let tableName):
|
||||||
|
return "The table '\(tableName)' does not exist in this database."
|
||||||
|
case .columnNotFound(let columnName, let tableName):
|
||||||
|
return "The column '\(columnName)' does not exist on table '\(tableName)'."
|
||||||
|
case .emptySchema:
|
||||||
|
return "This database has no tables. There's nothing to query yet."
|
||||||
|
|
||||||
|
// Safety
|
||||||
|
case .confirmationRequired(let sql, let operation):
|
||||||
|
return "The \(operation.uppercased()) operation requires your confirmation before running: \(sql)"
|
||||||
|
case .tableNotAllowedForMutation(let tableName, let operation):
|
||||||
|
return "The \(operation.uppercased()) operation is not allowed on table '\(tableName)'."
|
||||||
|
case .queryRejected(let reason):
|
||||||
|
return "Query rejected: \(reason)"
|
||||||
|
|
||||||
|
// Database
|
||||||
|
case .databaseError(let reason):
|
||||||
|
return "A database error occurred: \(reason)"
|
||||||
|
case .queryTimedOut(let seconds):
|
||||||
|
return "The query timed out after \(Int(seconds)) seconds. Try a simpler query."
|
||||||
|
|
||||||
|
// Configuration
|
||||||
|
case .configurationError(let reason):
|
||||||
|
return "Configuration error: \(reason)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Conversion from SQLParsingError
|
||||||
|
|
||||||
|
extension SQLParsingError {
|
||||||
|
/// Maps a ``SQLParsingError`` to the corresponding ``SwiftDBAIError`` case.
|
||||||
|
///
|
||||||
|
/// - Parameter rawResponse: The raw LLM response text (used for context in `.noSQLFound`).
|
||||||
|
/// - Returns: A ``SwiftDBAIError`` with the same semantic meaning.
|
||||||
|
func toSwiftDBAIError(rawResponse: String = "") -> SwiftDBAIError {
|
||||||
|
switch self {
|
||||||
|
case .noSQLFound:
|
||||||
|
if rawResponse.isEmpty {
|
||||||
|
return .noSQLGenerated
|
||||||
|
}
|
||||||
|
return .llmResponseUnparseable(response: rawResponse)
|
||||||
|
case .operationNotAllowed(let op):
|
||||||
|
return .operationNotAllowed(operation: op.rawValue)
|
||||||
|
case .confirmationRequired(let sql, let op):
|
||||||
|
return .confirmationRequired(sql: sql, operation: op.rawValue)
|
||||||
|
case .tableNotAllowed(let table, let op):
|
||||||
|
return .tableNotAllowedForMutation(tableName: table, operation: op.rawValue)
|
||||||
|
case .dangerousOperation(let keyword):
|
||||||
|
return .dangerousOperationBlocked(keyword: keyword)
|
||||||
|
case .multipleStatements:
|
||||||
|
return .multipleStatementsNotSupported
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Conversion from ChatEngineError
|
||||||
|
|
||||||
|
extension ChatEngineError {
|
||||||
|
/// Maps a ``ChatEngineError`` to the corresponding ``SwiftDBAIError`` case.
|
||||||
|
func toSwiftDBAIError() -> SwiftDBAIError {
|
||||||
|
switch self {
|
||||||
|
case .sqlParsingFailed(let parsingError):
|
||||||
|
return parsingError.toSwiftDBAIError()
|
||||||
|
case .confirmationRequired(let sql, let operation):
|
||||||
|
return .confirmationRequired(sql: sql, operation: operation.rawValue)
|
||||||
|
case .schemaIntrospectionFailed(let reason):
|
||||||
|
return .schemaIntrospectionFailed(reason: reason)
|
||||||
|
case .queryTimedOut(let seconds):
|
||||||
|
return .queryTimedOut(seconds: seconds)
|
||||||
|
case .validationFailed(let reason):
|
||||||
|
return .queryRejected(reason: reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
182
Sources/SwiftDBAI/Views/Charts/BarChartView.swift
Normal file
182
Sources/SwiftDBAI/Views/Charts/BarChartView.swift
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
// BarChartView.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// A SwiftUI bar chart that renders DataTable values using Swift Charts.
|
||||||
|
// Best for categorical comparisons (e.g., sales by region, counts by status).
|
||||||
|
|
||||||
|
import SwiftUI
|
||||||
|
import Charts
|
||||||
|
|
||||||
|
/// A bar chart view that renders a `DataTable` column pair using Swift Charts.
|
||||||
|
///
|
||||||
|
/// Displays vertical bars with category labels on the x-axis and numeric
|
||||||
|
/// values on the y-axis. Automatically colors bars using the accent gradient
|
||||||
|
/// and supports scrolling when many categories are present.
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
/// ```swift
|
||||||
|
/// BarChartView(
|
||||||
|
/// dataTable: table,
|
||||||
|
/// categoryColumn: "department",
|
||||||
|
/// valueColumn: "total_sales"
|
||||||
|
/// )
|
||||||
|
/// ```
|
||||||
|
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||||
|
public struct BarChartView: View {
|
||||||
|
|
||||||
|
/// The data to chart.
|
||||||
|
public let dataTable: DataTable
|
||||||
|
|
||||||
|
/// Column name for category labels (x-axis).
|
||||||
|
public let categoryColumn: String
|
||||||
|
|
||||||
|
/// Column name for numeric values (y-axis).
|
||||||
|
public let valueColumn: String
|
||||||
|
|
||||||
|
/// Optional chart title.
|
||||||
|
public var title: String?
|
||||||
|
|
||||||
|
/// Maximum number of bars to display before truncating.
|
||||||
|
public var maxBars: Int
|
||||||
|
|
||||||
|
public init(
|
||||||
|
dataTable: DataTable,
|
||||||
|
categoryColumn: String,
|
||||||
|
valueColumn: String,
|
||||||
|
title: String? = nil,
|
||||||
|
maxBars: Int = 30
|
||||||
|
) {
|
||||||
|
self.dataTable = dataTable
|
||||||
|
self.categoryColumn = categoryColumn
|
||||||
|
self.valueColumn = valueColumn
|
||||||
|
self.title = title
|
||||||
|
self.maxBars = maxBars
|
||||||
|
}
|
||||||
|
|
||||||
|
public var body: some View {
|
||||||
|
VStack(alignment: .leading, spacing: 8) {
|
||||||
|
if let title {
|
||||||
|
Text(title)
|
||||||
|
.font(.caption.weight(.semibold))
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chartData.isEmpty {
|
||||||
|
emptyChartView
|
||||||
|
} else {
|
||||||
|
chartContent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Chart Content
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var chartContent: some View {
|
||||||
|
Chart(chartData, id: \.label) { item in
|
||||||
|
BarMark(
|
||||||
|
x: .value(categoryColumn, item.label),
|
||||||
|
y: .value(valueColumn, item.value)
|
||||||
|
)
|
||||||
|
.foregroundStyle(
|
||||||
|
.linearGradient(
|
||||||
|
colors: [.accentColor, .accentColor.opacity(0.7)],
|
||||||
|
startPoint: .bottom,
|
||||||
|
endPoint: .top
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.cornerRadius(4)
|
||||||
|
}
|
||||||
|
.chartXAxis {
|
||||||
|
AxisMarks(values: .automatic) { _ in
|
||||||
|
AxisValueLabel()
|
||||||
|
.font(.caption2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.chartYAxis {
|
||||||
|
AxisMarks(position: .leading) { _ in
|
||||||
|
AxisGridLine(stroke: StrokeStyle(lineWidth: 0.5, dash: [4, 4]))
|
||||||
|
.foregroundStyle(.secondary.opacity(0.3))
|
||||||
|
AxisValueLabel()
|
||||||
|
.font(.caption2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.frame(minHeight: 200)
|
||||||
|
|
||||||
|
if isTruncated {
|
||||||
|
truncationNotice
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Empty State
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var emptyChartView: some View {
|
||||||
|
VStack(spacing: 8) {
|
||||||
|
Image(systemName: "chart.bar")
|
||||||
|
.font(.title2)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
Text("No chartable data")
|
||||||
|
.font(.caption)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
}
|
||||||
|
.frame(maxWidth: .infinity, minHeight: 100)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Truncation Notice
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var truncationNotice: some View {
|
||||||
|
Text("Showing \(maxBars) of \(dataTable.rowCount) categories")
|
||||||
|
.font(.caption2)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Data Extraction
|
||||||
|
|
||||||
|
private var isTruncated: Bool {
|
||||||
|
dataTable.rowCount > maxBars
|
||||||
|
}
|
||||||
|
|
||||||
|
private var chartData: [ChartDataPoint] {
|
||||||
|
let labels = dataTable.stringValues(forColumn: categoryColumn)
|
||||||
|
let values = dataTable.numericValues(forColumn: valueColumn)
|
||||||
|
|
||||||
|
let count = min(labels.count, values.count, maxBars)
|
||||||
|
guard count > 0 else { return [] }
|
||||||
|
|
||||||
|
return (0..<count).map { i in
|
||||||
|
ChartDataPoint(label: labels[i], value: values[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Preview
|
||||||
|
|
||||||
|
#if DEBUG
|
||||||
|
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||||
|
#Preview("Bar Chart") {
|
||||||
|
let columns: [DataTable.Column] = [
|
||||||
|
.init(name: "department", index: 0, inferredType: .text),
|
||||||
|
.init(name: "revenue", index: 1, inferredType: .real),
|
||||||
|
]
|
||||||
|
let departments = ["Engineering", "Sales", "Marketing", "Support", "Design"]
|
||||||
|
let rows: [DataTable.Row] = departments.enumerated().map { i, dept in
|
||||||
|
DataTable.Row(
|
||||||
|
id: i,
|
||||||
|
values: [.text(dept), .real(Double.random(in: 50_000...200_000))],
|
||||||
|
columnNames: ["department", "revenue"]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let table = DataTable(columns: columns, rows: rows)
|
||||||
|
|
||||||
|
BarChartView(
|
||||||
|
dataTable: table,
|
||||||
|
categoryColumn: "department",
|
||||||
|
valueColumn: "revenue",
|
||||||
|
title: "Revenue by Department"
|
||||||
|
)
|
||||||
|
.padding()
|
||||||
|
.frame(height: 300)
|
||||||
|
}
|
||||||
|
#endif
|
||||||
21
Sources/SwiftDBAI/Views/Charts/ChartDataPoint.swift
Normal file
21
Sources/SwiftDBAI/Views/Charts/ChartDataPoint.swift
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
// ChartDataPoint.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Shared data model used by all chart views.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
/// A single data point for chart rendering.
|
||||||
|
///
|
||||||
|
/// Pairs a string label (category) with a numeric value.
|
||||||
|
/// Used as the common data format across BarChartView,
|
||||||
|
/// LineChartView, and PieChartView.
|
||||||
|
struct ChartDataPoint: Sendable, Identifiable {
|
||||||
|
var id: String { label }
|
||||||
|
|
||||||
|
/// The category label (x-axis or slice label).
|
||||||
|
let label: String
|
||||||
|
|
||||||
|
/// The numeric value (y-axis or slice size).
|
||||||
|
let value: Double
|
||||||
|
}
|
||||||
135
Sources/SwiftDBAI/Views/Charts/ChartResultView.swift
Normal file
135
Sources/SwiftDBAI/Views/Charts/ChartResultView.swift
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
// ChartResultView.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Auto-selecting chart view that uses ChartDataDetector to pick the
|
||||||
|
// best chart type for a given DataTable.
|
||||||
|
|
||||||
|
import SwiftUI
|
||||||
|
import Charts
|
||||||
|
|
||||||
|
/// A chart view that automatically selects the best chart type for a `DataTable`.
|
||||||
|
///
|
||||||
|
/// Uses `ChartDataDetector` to analyze the data shape and renders the
|
||||||
|
/// appropriate chart (bar, line, or pie). If the data isn't suitable for
|
||||||
|
/// charting, the view renders nothing.
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
/// ```swift
|
||||||
|
/// ChartResultView(dataTable: myTable)
|
||||||
|
/// ```
|
||||||
|
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||||
|
public struct ChartResultView: View {
|
||||||
|
|
||||||
|
/// The data table to chart.
|
||||||
|
public let dataTable: DataTable
|
||||||
|
|
||||||
|
/// Optional override: force a specific chart type.
|
||||||
|
public var chartType: ChartDataDetector.ChartType?
|
||||||
|
|
||||||
|
/// The detector used to analyze the data.
|
||||||
|
private let detector: ChartDataDetector
|
||||||
|
|
||||||
|
public init(
|
||||||
|
dataTable: DataTable,
|
||||||
|
chartType: ChartDataDetector.ChartType? = nil,
|
||||||
|
detector: ChartDataDetector = ChartDataDetector()
|
||||||
|
) {
|
||||||
|
self.dataTable = dataTable
|
||||||
|
self.chartType = chartType
|
||||||
|
self.detector = detector
|
||||||
|
}
|
||||||
|
|
||||||
|
public var body: some View {
|
||||||
|
if let recommendation = resolvedRecommendation {
|
||||||
|
VStack(alignment: .leading, spacing: 4) {
|
||||||
|
chartView(for: recommendation)
|
||||||
|
|
||||||
|
Text(recommendation.reason)
|
||||||
|
.font(.caption2)
|
||||||
|
.foregroundStyle(.tertiary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Chart Selection
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private func chartView(
|
||||||
|
for recommendation: ChartDataDetector.ChartRecommendation
|
||||||
|
) -> some View {
|
||||||
|
switch recommendation.chartType {
|
||||||
|
case .bar:
|
||||||
|
BarChartView(
|
||||||
|
dataTable: dataTable,
|
||||||
|
categoryColumn: recommendation.categoryColumn,
|
||||||
|
valueColumn: recommendation.valueColumn
|
||||||
|
)
|
||||||
|
case .line:
|
||||||
|
LineChartView(
|
||||||
|
dataTable: dataTable,
|
||||||
|
categoryColumn: recommendation.categoryColumn,
|
||||||
|
valueColumn: recommendation.valueColumn
|
||||||
|
)
|
||||||
|
case .pie:
|
||||||
|
PieChartView(
|
||||||
|
dataTable: dataTable,
|
||||||
|
categoryColumn: recommendation.categoryColumn,
|
||||||
|
valueColumn: recommendation.valueColumn
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Resolution
|
||||||
|
|
||||||
|
/// Resolves the chart recommendation, using the override type if provided.
|
||||||
|
private var resolvedRecommendation: ChartDataDetector.ChartRecommendation? {
|
||||||
|
if let override = chartType {
|
||||||
|
// Use forced chart type — still need column pair from detector
|
||||||
|
let all = detector.allRecommendations(for: dataTable)
|
||||||
|
// Try to find recommendation for the forced type
|
||||||
|
if let match = all.first(where: { $0.chartType == override }) {
|
||||||
|
return match
|
||||||
|
}
|
||||||
|
// Fallback: use first recommendation and override its type
|
||||||
|
if let first = all.first {
|
||||||
|
return ChartDataDetector.ChartRecommendation(
|
||||||
|
chartType: override,
|
||||||
|
categoryColumn: first.categoryColumn,
|
||||||
|
valueColumn: first.valueColumn,
|
||||||
|
confidence: first.confidence * 0.8,
|
||||||
|
reason: first.reason
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auto-detect best chart type
|
||||||
|
return detector.detect(dataTable)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Preview
|
||||||
|
|
||||||
|
#if DEBUG
|
||||||
|
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||||
|
#Preview("Auto Chart — Bar") {
|
||||||
|
let columns: [DataTable.Column] = [
|
||||||
|
.init(name: "city", index: 0, inferredType: .text),
|
||||||
|
.init(name: "population", index: 1, inferredType: .integer),
|
||||||
|
]
|
||||||
|
let cities = ["NYC", "LA", "Chicago", "Houston", "Phoenix"]
|
||||||
|
let pops: [Int64] = [8_336_817, 3_979_576, 2_693_976, 2_320_268, 1_680_992]
|
||||||
|
let rows: [DataTable.Row] = cities.enumerated().map { i, city in
|
||||||
|
DataTable.Row(
|
||||||
|
id: i,
|
||||||
|
values: [.text(city), .integer(pops[i])],
|
||||||
|
columnNames: ["city", "population"]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let table = DataTable(columns: columns, rows: rows)
|
||||||
|
|
||||||
|
ChartResultView(dataTable: table)
|
||||||
|
.padding()
|
||||||
|
.frame(height: 300)
|
||||||
|
}
|
||||||
|
#endif
|
||||||
206
Sources/SwiftDBAI/Views/Charts/LineChartView.swift
Normal file
206
Sources/SwiftDBAI/Views/Charts/LineChartView.swift
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
// LineChartView.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// A SwiftUI line chart that renders DataTable values using Swift Charts.
|
||||||
|
// Best for time series or sequential data (e.g., revenue over months).
|
||||||
|
|
||||||
|
import SwiftUI
|
||||||
|
import Charts
|
||||||
|
|
||||||
|
/// A line chart view that renders a `DataTable` column pair using Swift Charts.
|
||||||
|
///
|
||||||
|
/// Displays a connected line with optional area fill, point markers,
|
||||||
|
/// and smooth interpolation. Best suited for time series or sequential data.
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
/// ```swift
|
||||||
|
/// LineChartView(
|
||||||
|
/// dataTable: table,
|
||||||
|
/// categoryColumn: "month",
|
||||||
|
/// valueColumn: "revenue"
|
||||||
|
/// )
|
||||||
|
/// ```
|
||||||
|
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||||
|
public struct LineChartView: View {
|
||||||
|
|
||||||
|
/// The data to chart.
|
||||||
|
public let dataTable: DataTable
|
||||||
|
|
||||||
|
/// Column name for category/time labels (x-axis).
|
||||||
|
public let categoryColumn: String
|
||||||
|
|
||||||
|
/// Column name for numeric values (y-axis).
|
||||||
|
public let valueColumn: String
|
||||||
|
|
||||||
|
/// Optional chart title.
|
||||||
|
public var title: String?
|
||||||
|
|
||||||
|
/// Whether to show an area fill below the line.
|
||||||
|
public var showAreaFill: Bool
|
||||||
|
|
||||||
|
/// Whether to show point markers at each data point.
|
||||||
|
public var showPoints: Bool
|
||||||
|
|
||||||
|
/// Maximum data points to display.
|
||||||
|
public var maxPoints: Int
|
||||||
|
|
||||||
|
public init(
|
||||||
|
dataTable: DataTable,
|
||||||
|
categoryColumn: String,
|
||||||
|
valueColumn: String,
|
||||||
|
title: String? = nil,
|
||||||
|
showAreaFill: Bool = true,
|
||||||
|
showPoints: Bool = true,
|
||||||
|
maxPoints: Int = 100
|
||||||
|
) {
|
||||||
|
self.dataTable = dataTable
|
||||||
|
self.categoryColumn = categoryColumn
|
||||||
|
self.valueColumn = valueColumn
|
||||||
|
self.title = title
|
||||||
|
self.showAreaFill = showAreaFill
|
||||||
|
self.showPoints = showPoints
|
||||||
|
self.maxPoints = maxPoints
|
||||||
|
}
|
||||||
|
|
||||||
|
public var body: some View {
|
||||||
|
VStack(alignment: .leading, spacing: 8) {
|
||||||
|
if let title {
|
||||||
|
Text(title)
|
||||||
|
.font(.caption.weight(.semibold))
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chartData.isEmpty {
|
||||||
|
emptyChartView
|
||||||
|
} else {
|
||||||
|
chartContent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Chart Content
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var chartContent: some View {
|
||||||
|
Chart(chartData, id: \.label) { item in
|
||||||
|
LineMark(
|
||||||
|
x: .value(categoryColumn, item.label),
|
||||||
|
y: .value(valueColumn, item.value)
|
||||||
|
)
|
||||||
|
.foregroundStyle(Color.accentColor)
|
||||||
|
.lineStyle(StrokeStyle(lineWidth: 2))
|
||||||
|
.interpolationMethod(.catmullRom)
|
||||||
|
|
||||||
|
if showAreaFill {
|
||||||
|
AreaMark(
|
||||||
|
x: .value(categoryColumn, item.label),
|
||||||
|
y: .value(valueColumn, item.value)
|
||||||
|
)
|
||||||
|
.foregroundStyle(
|
||||||
|
.linearGradient(
|
||||||
|
colors: [
|
||||||
|
Color.accentColor.opacity(0.2),
|
||||||
|
Color.accentColor.opacity(0.02),
|
||||||
|
],
|
||||||
|
startPoint: .top,
|
||||||
|
endPoint: .bottom
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.interpolationMethod(.catmullRom)
|
||||||
|
}
|
||||||
|
|
||||||
|
if showPoints {
|
||||||
|
PointMark(
|
||||||
|
x: .value(categoryColumn, item.label),
|
||||||
|
y: .value(valueColumn, item.value)
|
||||||
|
)
|
||||||
|
.foregroundStyle(Color.accentColor)
|
||||||
|
.symbolSize(30)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.chartXAxis {
|
||||||
|
AxisMarks(values: .automatic) { _ in
|
||||||
|
AxisValueLabel()
|
||||||
|
.font(.caption2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.chartYAxis {
|
||||||
|
AxisMarks(position: .leading) { _ in
|
||||||
|
AxisGridLine(stroke: StrokeStyle(lineWidth: 0.5, dash: [4, 4]))
|
||||||
|
.foregroundStyle(.secondary.opacity(0.3))
|
||||||
|
AxisValueLabel()
|
||||||
|
.font(.caption2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.frame(minHeight: 200)
|
||||||
|
|
||||||
|
if isTruncated {
|
||||||
|
Text("Showing \(maxPoints) of \(dataTable.rowCount) data points")
|
||||||
|
.font(.caption2)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Empty State
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var emptyChartView: some View {
|
||||||
|
VStack(spacing: 8) {
|
||||||
|
Image(systemName: "chart.xyaxis.line")
|
||||||
|
.font(.title2)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
Text("No chartable data")
|
||||||
|
.font(.caption)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
}
|
||||||
|
.frame(maxWidth: .infinity, minHeight: 100)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Data Extraction
|
||||||
|
|
||||||
|
private var isTruncated: Bool {
|
||||||
|
dataTable.rowCount > maxPoints
|
||||||
|
}
|
||||||
|
|
||||||
|
private var chartData: [ChartDataPoint] {
|
||||||
|
let labels = dataTable.stringValues(forColumn: categoryColumn)
|
||||||
|
let values = dataTable.numericValues(forColumn: valueColumn)
|
||||||
|
|
||||||
|
let count = min(labels.count, values.count, maxPoints)
|
||||||
|
guard count > 0 else { return [] }
|
||||||
|
|
||||||
|
return (0..<count).map { i in
|
||||||
|
ChartDataPoint(label: labels[i], value: values[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Preview
|
||||||
|
|
||||||
|
#if DEBUG
|
||||||
|
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||||
|
#Preview("Line Chart") {
|
||||||
|
let columns: [DataTable.Column] = [
|
||||||
|
.init(name: "month", index: 0, inferredType: .text),
|
||||||
|
.init(name: "revenue", index: 1, inferredType: .real),
|
||||||
|
]
|
||||||
|
let months = ["Jan", "Feb", "Mar", "Apr", "May", "Jun"]
|
||||||
|
let rows: [DataTable.Row] = months.enumerated().map { i, month in
|
||||||
|
DataTable.Row(
|
||||||
|
id: i,
|
||||||
|
values: [.text(month), .real(Double(i + 1) * 15_000 + Double.random(in: -3000...3000))],
|
||||||
|
columnNames: ["month", "revenue"]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let table = DataTable(columns: columns, rows: rows)
|
||||||
|
|
||||||
|
LineChartView(
|
||||||
|
dataTable: table,
|
||||||
|
categoryColumn: "month",
|
||||||
|
valueColumn: "revenue",
|
||||||
|
title: "Monthly Revenue"
|
||||||
|
)
|
||||||
|
.padding()
|
||||||
|
.frame(height: 300)
|
||||||
|
}
|
||||||
|
#endif
|
||||||
234
Sources/SwiftDBAI/Views/Charts/PieChartView.swift
Normal file
234
Sources/SwiftDBAI/Views/Charts/PieChartView.swift
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
// PieChartView.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// A SwiftUI pie/donut chart that renders DataTable values using Swift Charts.
|
||||||
|
// Best for proportional breakdowns with few categories (e.g., market share).
|
||||||
|
|
||||||
|
import SwiftUI
|
||||||
|
import Charts
|
||||||
|
|
||||||
|
/// A pie chart view that renders a `DataTable` column pair using Swift Charts.
|
||||||
|
///
|
||||||
|
/// Displays proportional slices with category labels. Each slice is
|
||||||
|
/// automatically colored from a curated palette and sized relative to
|
||||||
|
/// its proportion of the total. Best suited for data with few categories
|
||||||
|
/// (≤ 8) where all values are positive.
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
/// ```swift
|
||||||
|
/// PieChartView(
|
||||||
|
/// dataTable: table,
|
||||||
|
/// categoryColumn: "status",
|
||||||
|
/// valueColumn: "count"
|
||||||
|
/// )
|
||||||
|
/// ```
|
||||||
|
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||||
|
public struct PieChartView: View {
|
||||||
|
|
||||||
|
/// The data to chart.
|
||||||
|
public let dataTable: DataTable
|
||||||
|
|
||||||
|
/// Column name for category labels (slice labels).
|
||||||
|
public let categoryColumn: String
|
||||||
|
|
||||||
|
/// Column name for numeric values (slice sizes).
|
||||||
|
public let valueColumn: String
|
||||||
|
|
||||||
|
/// Optional chart title.
|
||||||
|
public var title: String?
|
||||||
|
|
||||||
|
/// Inner radius ratio for donut style (0 = full pie, >0 = donut).
|
||||||
|
public var innerRadiusRatio: CGFloat
|
||||||
|
|
||||||
|
/// Maximum number of slices before grouping remaining into "Other".
|
||||||
|
public var maxSlices: Int
|
||||||
|
|
||||||
|
public init(
|
||||||
|
dataTable: DataTable,
|
||||||
|
categoryColumn: String,
|
||||||
|
valueColumn: String,
|
||||||
|
title: String? = nil,
|
||||||
|
innerRadiusRatio: CGFloat = 0.4,
|
||||||
|
maxSlices: Int = 8
|
||||||
|
) {
|
||||||
|
self.dataTable = dataTable
|
||||||
|
self.categoryColumn = categoryColumn
|
||||||
|
self.valueColumn = valueColumn
|
||||||
|
self.title = title
|
||||||
|
self.innerRadiusRatio = innerRadiusRatio
|
||||||
|
self.maxSlices = maxSlices
|
||||||
|
}
|
||||||
|
|
||||||
|
public var body: some View {
|
||||||
|
VStack(alignment: .leading, spacing: 8) {
|
||||||
|
if let title {
|
||||||
|
Text(title)
|
||||||
|
.font(.caption.weight(.semibold))
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
}
|
||||||
|
|
||||||
|
if chartData.isEmpty {
|
||||||
|
emptyChartView
|
||||||
|
} else {
|
||||||
|
HStack(alignment: .center, spacing: 16) {
|
||||||
|
chartContent
|
||||||
|
legendView
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Chart Content
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var chartContent: some View {
|
||||||
|
Chart(chartData, id: \.label) { item in
|
||||||
|
SectorMark(
|
||||||
|
angle: .value(valueColumn, item.value),
|
||||||
|
innerRadius: .ratio(innerRadiusRatio),
|
||||||
|
angularInset: 1.5
|
||||||
|
)
|
||||||
|
.foregroundStyle(by: .value(categoryColumn, item.label))
|
||||||
|
.cornerRadius(3)
|
||||||
|
}
|
||||||
|
.chartForegroundStyleScale(
|
||||||
|
domain: chartData.map(\.label),
|
||||||
|
range: sliceColors
|
||||||
|
)
|
||||||
|
.chartLegend(.hidden)
|
||||||
|
.frame(minWidth: 150, minHeight: 150)
|
||||||
|
.aspectRatio(1, contentMode: .fit)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Legend
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var legendView: some View {
|
||||||
|
VStack(alignment: .leading, spacing: 6) {
|
||||||
|
ForEach(Array(chartData.enumerated()), id: \.element.label) { index, item in
|
||||||
|
HStack(spacing: 8) {
|
||||||
|
Circle()
|
||||||
|
.fill(sliceColors[index % sliceColors.count])
|
||||||
|
.frame(width: 8, height: 8)
|
||||||
|
|
||||||
|
Text(item.label)
|
||||||
|
.font(.caption)
|
||||||
|
.foregroundStyle(.primary)
|
||||||
|
.lineLimit(1)
|
||||||
|
|
||||||
|
Spacer()
|
||||||
|
|
||||||
|
Text(percentageText(for: item.value))
|
||||||
|
.font(.caption)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
.monospacedDigit()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Empty State
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var emptyChartView: some View {
|
||||||
|
VStack(spacing: 8) {
|
||||||
|
Image(systemName: "chart.pie")
|
||||||
|
.font(.title2)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
Text("No chartable data")
|
||||||
|
.font(.caption)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
}
|
||||||
|
.frame(maxWidth: .infinity, minHeight: 100)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Colors
|
||||||
|
|
||||||
|
/// Curated color palette for pie slices.
|
||||||
|
private var sliceColors: [Color] {
|
||||||
|
[
|
||||||
|
.blue,
|
||||||
|
.green,
|
||||||
|
.orange,
|
||||||
|
.purple,
|
||||||
|
.pink,
|
||||||
|
.cyan,
|
||||||
|
.yellow,
|
||||||
|
.indigo,
|
||||||
|
.mint,
|
||||||
|
.teal,
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Helpers
|
||||||
|
|
||||||
|
private var total: Double {
|
||||||
|
chartData.reduce(0) { $0 + $1.value }
|
||||||
|
}
|
||||||
|
|
||||||
|
private func percentageText(for value: Double) -> String {
|
||||||
|
guard total > 0 else { return "0%" }
|
||||||
|
let pct = (value / total) * 100
|
||||||
|
if pct >= 10 {
|
||||||
|
return String(format: "%.0f%%", pct)
|
||||||
|
}
|
||||||
|
return String(format: "%.1f%%", pct)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Data Extraction
|
||||||
|
|
||||||
|
private var chartData: [ChartDataPoint] {
|
||||||
|
let labels = dataTable.stringValues(forColumn: categoryColumn)
|
||||||
|
let values = dataTable.numericValues(forColumn: valueColumn)
|
||||||
|
|
||||||
|
let count = min(labels.count, values.count)
|
||||||
|
guard count > 0 else { return [] }
|
||||||
|
|
||||||
|
// Build all points, sorted by value descending
|
||||||
|
var points = (0..<count).map { i in
|
||||||
|
ChartDataPoint(label: labels[i], value: values[i])
|
||||||
|
}
|
||||||
|
.filter { $0.value > 0 }
|
||||||
|
.sorted { $0.value > $1.value }
|
||||||
|
|
||||||
|
// Group excess slices into "Other"
|
||||||
|
if points.count > maxSlices {
|
||||||
|
let kept = Array(points.prefix(maxSlices - 1))
|
||||||
|
let otherValue = points.dropFirst(maxSlices - 1).reduce(0) { $0 + $1.value }
|
||||||
|
points = kept + [ChartDataPoint(label: "Other", value: otherValue)]
|
||||||
|
}
|
||||||
|
|
||||||
|
return points
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Preview
|
||||||
|
|
||||||
|
#if DEBUG
|
||||||
|
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||||
|
#Preview("Pie Chart") {
|
||||||
|
let columns: [DataTable.Column] = [
|
||||||
|
.init(name: "status", index: 0, inferredType: .text),
|
||||||
|
.init(name: "count", index: 1, inferredType: .integer),
|
||||||
|
]
|
||||||
|
let statuses = ["Active", "Inactive", "Pending", "Archived"]
|
||||||
|
let counts: [Int64] = [45, 20, 15, 10]
|
||||||
|
let rows: [DataTable.Row] = statuses.enumerated().map { i, status in
|
||||||
|
DataTable.Row(
|
||||||
|
id: i,
|
||||||
|
values: [.text(status), .integer(counts[i])],
|
||||||
|
columnNames: ["status", "count"]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let table = DataTable(columns: columns, rows: rows)
|
||||||
|
|
||||||
|
PieChartView(
|
||||||
|
dataTable: table,
|
||||||
|
categoryColumn: "status",
|
||||||
|
valueColumn: "count",
|
||||||
|
title: "Users by Status"
|
||||||
|
)
|
||||||
|
.padding()
|
||||||
|
.frame(height: 250)
|
||||||
|
}
|
||||||
|
#endif
|
||||||
214
Sources/SwiftDBAI/Views/ChatView.swift
Normal file
214
Sources/SwiftDBAI/Views/ChatView.swift
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
// ChatView.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Drop-in SwiftUI view for chatting with a SQLite database.
|
||||||
|
// Renders messages with automatic data table display for query results.
|
||||||
|
|
||||||
|
import SwiftUI
|
||||||
|
|
||||||
|
/// A drop-in SwiftUI chat interface for querying SQLite databases
|
||||||
|
/// with natural language.
|
||||||
|
///
|
||||||
|
/// `ChatView` renders the full conversation including:
|
||||||
|
/// - User messages (right-aligned, accent-colored)
|
||||||
|
/// - Assistant responses with text summaries
|
||||||
|
/// - **Automatic data tables** via `ScrollableDataTableView` when query results
|
||||||
|
/// contain tabular data (rows + columns)
|
||||||
|
/// - SQL query disclosure for transparency
|
||||||
|
/// - Error messages with red styling
|
||||||
|
/// - A loading indicator while the engine is processing
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
/// ```swift
|
||||||
|
/// let engine = ChatEngine(database: myPool, model: myModel)
|
||||||
|
/// let viewModel = ChatViewModel(engine: engine)
|
||||||
|
///
|
||||||
|
/// ChatView(viewModel: viewModel)
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// Or use the convenience initializer:
|
||||||
|
/// ```swift
|
||||||
|
/// ChatView(engine: myEngine)
|
||||||
|
/// ```
|
||||||
|
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||||
|
public struct ChatView: View {
|
||||||
|
@Bindable private var viewModel: ChatViewModel
|
||||||
|
@State private var inputText: String = ""
|
||||||
|
@FocusState private var isInputFocused: Bool
|
||||||
|
|
||||||
|
/// Creates a ChatView with an existing view model.
|
||||||
|
///
|
||||||
|
/// - Parameter viewModel: The `ChatViewModel` driving this view.
|
||||||
|
public init(viewModel: ChatViewModel) {
|
||||||
|
self.viewModel = viewModel
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a ChatView with a `ChatEngine`, automatically creating
|
||||||
|
/// a `ChatViewModel`.
|
||||||
|
///
|
||||||
|
/// - Parameter engine: The `ChatEngine` to power the chat.
|
||||||
|
public init(engine: ChatEngine) {
|
||||||
|
self.viewModel = ChatViewModel(engine: engine)
|
||||||
|
}
|
||||||
|
|
||||||
|
public var body: some View {
|
||||||
|
VStack(spacing: 0) {
|
||||||
|
messageList
|
||||||
|
Divider()
|
||||||
|
inputBar
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Message List
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var messageList: some View {
|
||||||
|
ScrollViewReader { proxy in
|
||||||
|
ScrollView {
|
||||||
|
LazyVStack(spacing: 12) {
|
||||||
|
if viewModel.messages.isEmpty {
|
||||||
|
emptyState
|
||||||
|
}
|
||||||
|
|
||||||
|
ForEach(viewModel.messages) { message in
|
||||||
|
messageBubble(for: message)
|
||||||
|
.id(message.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
if viewModel.isLoading {
|
||||||
|
loadingIndicator
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.padding(.horizontal, 16)
|
||||||
|
.padding(.vertical, 12)
|
||||||
|
}
|
||||||
|
.onChange(of: viewModel.messages.count) { _, _ in
|
||||||
|
if let lastMessage = viewModel.messages.last {
|
||||||
|
withAnimation(.easeOut(duration: 0.3)) {
|
||||||
|
proxy.scrollTo(lastMessage.id, anchor: .bottom)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Empty State
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var emptyState: some View {
|
||||||
|
VStack(spacing: 12) {
|
||||||
|
Image(systemName: "bubble.left.and.text.bubble.right")
|
||||||
|
.font(.system(size: 40))
|
||||||
|
.foregroundStyle(.tertiary)
|
||||||
|
Text("Ask a question about your data")
|
||||||
|
.font(.headline)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
Text("Try something like \"How many records are in the database?\"")
|
||||||
|
.font(.subheadline)
|
||||||
|
.foregroundStyle(.tertiary)
|
||||||
|
.multilineTextAlignment(.center)
|
||||||
|
}
|
||||||
|
.frame(maxWidth: .infinity)
|
||||||
|
.padding(.vertical, 60)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Loading Indicator
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var loadingIndicator: some View {
|
||||||
|
HStack(alignment: .top) {
|
||||||
|
HStack(spacing: 8) {
|
||||||
|
ProgressView()
|
||||||
|
.controlSize(.small)
|
||||||
|
Text("Querying…")
|
||||||
|
.font(.callout)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
}
|
||||||
|
.padding(.horizontal, 14)
|
||||||
|
.padding(.vertical, 10)
|
||||||
|
.background(
|
||||||
|
Self.assistantBackgroundColor,
|
||||||
|
in: RoundedRectangle(cornerRadius: 16, style: .continuous)
|
||||||
|
)
|
||||||
|
|
||||||
|
Spacer(minLength: 48)
|
||||||
|
}
|
||||||
|
.id("loading-indicator")
|
||||||
|
.transition(.opacity.combined(with: .move(edge: .bottom)))
|
||||||
|
}
|
||||||
|
|
||||||
|
private static var assistantBackgroundColor: Color {
|
||||||
|
#if os(macOS)
|
||||||
|
Color(nsColor: .controlBackgroundColor)
|
||||||
|
#else
|
||||||
|
Color(uiColor: .secondarySystemGroupedBackground)
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Input Bar
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var inputBar: some View {
|
||||||
|
HStack(spacing: 8) {
|
||||||
|
TextField("Ask about your data…", text: $inputText, axis: .vertical)
|
||||||
|
.textFieldStyle(.plain)
|
||||||
|
.lineLimit(1...5)
|
||||||
|
.focused($isInputFocused)
|
||||||
|
.onSubmit { sendMessage() }
|
||||||
|
.submitLabel(.send)
|
||||||
|
|
||||||
|
Button(action: sendMessage) {
|
||||||
|
Image(systemName: "arrow.up.circle.fill")
|
||||||
|
.font(.title2)
|
||||||
|
.foregroundStyle(canSend ? Color.accentColor : Color.secondary)
|
||||||
|
}
|
||||||
|
.disabled(!canSend)
|
||||||
|
.keyboardShortcut(.return, modifiers: .command)
|
||||||
|
}
|
||||||
|
.padding(.horizontal, 16)
|
||||||
|
.padding(.vertical, 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Message Bubble
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private func messageBubble(for message: ChatMessage) -> some View {
|
||||||
|
if message.role == .error {
|
||||||
|
MessageBubbleView(
|
||||||
|
message: message,
|
||||||
|
onRetry: makeRetryAction(for: message)
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
MessageBubbleView(message: message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private func makeRetryAction(for errorMessage: ChatMessage) -> @Sendable () async -> Void {
|
||||||
|
let vm = viewModel
|
||||||
|
let messageId = errorMessage.id
|
||||||
|
return { @MainActor [vm] in
|
||||||
|
let allMessages = await MainActor.run { vm.messages }
|
||||||
|
if let lastUserMessage = allMessages
|
||||||
|
.prefix(while: { $0.id != messageId })
|
||||||
|
.last(where: { $0.role == .user }) {
|
||||||
|
await vm.send(lastUserMessage.content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Helpers
|
||||||
|
|
||||||
|
private var canSend: Bool {
|
||||||
|
!inputText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty && !viewModel.isLoading
|
||||||
|
}
|
||||||
|
|
||||||
|
private func sendMessage() {
|
||||||
|
guard canSend else { return }
|
||||||
|
let text = inputText
|
||||||
|
inputText = ""
|
||||||
|
|
||||||
|
Task {
|
||||||
|
await viewModel.send(text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
137
Sources/SwiftDBAI/Views/ChatViewModel.swift
Normal file
137
Sources/SwiftDBAI/Views/ChatViewModel.swift
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
// ChatViewModel.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Observable view model that bridges ChatEngine with the SwiftUI ChatView.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import Observation
|
||||||
|
|
||||||
|
/// The readiness state of the schema introspection.
|
||||||
|
public enum SchemaReadiness: Sendable, Equatable {
|
||||||
|
/// Schema has not been loaded yet.
|
||||||
|
case idle
|
||||||
|
/// Schema introspection is in progress.
|
||||||
|
case loading
|
||||||
|
/// Schema is ready with the given number of tables.
|
||||||
|
case ready(tableCount: Int)
|
||||||
|
/// Schema introspection failed.
|
||||||
|
case failed(String)
|
||||||
|
|
||||||
|
public var isReady: Bool {
|
||||||
|
if case .ready = self { return true }
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Observable view model that drives the `ChatView`.
|
||||||
|
///
|
||||||
|
/// Wraps `ChatEngine` to provide reactive state updates for the SwiftUI layer.
|
||||||
|
/// Manages the message list, loading state, error presentation, and schema
|
||||||
|
/// readiness. Call ``prepare()`` at view-appear time to eagerly introspect the
|
||||||
|
/// database schema.
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
/// ```swift
|
||||||
|
/// let viewModel = ChatViewModel(engine: myChatEngine)
|
||||||
|
/// ChatView(viewModel: viewModel)
|
||||||
|
/// ```
|
||||||
|
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||||
|
@Observable
|
||||||
|
@MainActor
|
||||||
|
public final class ChatViewModel {
|
||||||
|
|
||||||
|
// MARK: - Public State
|
||||||
|
|
||||||
|
/// All messages in the conversation, in chronological order.
|
||||||
|
public private(set) var messages: [ChatMessage] = []
|
||||||
|
|
||||||
|
/// Whether the engine is currently processing a request.
|
||||||
|
public private(set) var isLoading: Bool = false
|
||||||
|
|
||||||
|
/// The most recent error message, if any. Cleared on next send.
|
||||||
|
public private(set) var errorMessage: String?
|
||||||
|
|
||||||
|
/// Current schema readiness state.
|
||||||
|
public private(set) var schemaReadiness: SchemaReadiness = .idle
|
||||||
|
|
||||||
|
// MARK: - Dependencies
|
||||||
|
|
||||||
|
private let engine: ChatEngine
|
||||||
|
|
||||||
|
// MARK: - Initialization
|
||||||
|
|
||||||
|
/// Creates a new ChatViewModel.
|
||||||
|
///
|
||||||
|
/// - Parameter engine: The `ChatEngine` to use for processing messages.
|
||||||
|
public init(engine: ChatEngine) {
|
||||||
|
self.engine = engine
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Schema Preparation
|
||||||
|
|
||||||
|
/// Eagerly introspects the database schema so it's ready before the first query.
|
||||||
|
///
|
||||||
|
/// This should be called from a `.task` modifier on the view. It transitions
|
||||||
|
/// `schemaReadiness` through `.loading` → `.ready` (or `.failed`).
|
||||||
|
/// If the schema is already cached, this completes immediately.
|
||||||
|
public func prepare() async {
|
||||||
|
// Don't re-prepare if already ready
|
||||||
|
if schemaReadiness.isReady { return }
|
||||||
|
|
||||||
|
schemaReadiness = .loading
|
||||||
|
|
||||||
|
do {
|
||||||
|
let schema = try await engine.prepareSchema()
|
||||||
|
schemaReadiness = .ready(tableCount: schema.tableNames.count)
|
||||||
|
} catch {
|
||||||
|
schemaReadiness = .failed(error.localizedDescription)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Public API
|
||||||
|
|
||||||
|
/// Sends a user message and appends the response to the conversation.
|
||||||
|
///
|
||||||
|
/// - Parameter text: The natural language message from the user.
|
||||||
|
public func send(_ text: String) async {
|
||||||
|
let trimmed = text.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||||
|
guard !trimmed.isEmpty else { return }
|
||||||
|
|
||||||
|
errorMessage = nil
|
||||||
|
|
||||||
|
// Add user message immediately
|
||||||
|
let userMessage = ChatMessage(role: .user, content: trimmed)
|
||||||
|
messages.append(userMessage)
|
||||||
|
|
||||||
|
isLoading = true
|
||||||
|
defer { isLoading = false }
|
||||||
|
|
||||||
|
do {
|
||||||
|
let response = try await engine.send(trimmed)
|
||||||
|
|
||||||
|
let assistantMessage = ChatMessage(
|
||||||
|
role: .assistant,
|
||||||
|
content: response.summary,
|
||||||
|
queryResult: response.queryResult,
|
||||||
|
sql: response.sql
|
||||||
|
)
|
||||||
|
messages.append(assistantMessage)
|
||||||
|
} catch {
|
||||||
|
let typedError = (error as? SwiftDBAIError)
|
||||||
|
let errorMsg = ChatMessage(
|
||||||
|
role: .error,
|
||||||
|
content: error.localizedDescription,
|
||||||
|
error: typedError
|
||||||
|
)
|
||||||
|
messages.append(errorMsg)
|
||||||
|
errorMessage = error.localizedDescription
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clears the conversation and resets the engine state.
|
||||||
|
public func reset() {
|
||||||
|
messages.removeAll()
|
||||||
|
errorMessage = nil
|
||||||
|
engine.reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
220
Sources/SwiftDBAI/Views/DataChatView.swift
Normal file
220
Sources/SwiftDBAI/Views/DataChatView.swift
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
// DataChatView.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Zero-config SwiftUI view: provide a database path and a model, get a chat UI.
|
||||||
|
|
||||||
|
import AnyLanguageModel
|
||||||
|
import GRDB
|
||||||
|
import SwiftUI
|
||||||
|
|
||||||
|
/// A convenience SwiftUI view that wraps the full chat-with-database stack.
|
||||||
|
///
|
||||||
|
/// `DataChatView` is the simplest entry point into SwiftDBAI. It requires only
|
||||||
|
/// a database file path and a language model — no schema files, no annotations,
|
||||||
|
/// no manual setup. The view creates a GRDB connection, a `ChatEngine`,
|
||||||
|
/// a `ChatViewModel`, and renders a fully functional `ChatView`.
|
||||||
|
///
|
||||||
|
/// Usage with just a path and model:
|
||||||
|
/// ```swift
|
||||||
|
/// DataChatView(
|
||||||
|
/// databasePath: "/path/to/mydata.sqlite",
|
||||||
|
/// model: OllamaLanguageModel(model: "llama3")
|
||||||
|
/// )
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// Usage with additional configuration:
|
||||||
|
/// ```swift
|
||||||
|
/// DataChatView(
|
||||||
|
/// databasePath: documentsURL.appendingPathComponent("app.db").path,
|
||||||
|
/// model: OpenAILanguageModel(apiKey: key),
|
||||||
|
/// allowlist: .standard,
|
||||||
|
/// additionalContext: "This database stores a recipe app's data."
|
||||||
|
/// )
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// If you already have a GRDB `DatabasePool` or `DatabaseQueue`, use
|
||||||
|
/// `ChatView` with a `ChatEngine` directly for full control.
|
||||||
|
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||||
|
public struct DataChatView: View {
|
||||||
|
@State private var viewModel: ChatViewModel
|
||||||
|
@State private var loadError: DataChatError?
|
||||||
|
|
||||||
|
/// Creates a DataChatView from a database file path and language model.
|
||||||
|
///
|
||||||
|
/// This is the zero-config convenience initializer. It opens a GRDB
|
||||||
|
/// `DatabasePool` at the given path, creates a `ChatEngine` with
|
||||||
|
/// read-only defaults, and wires up the full chat UI.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - databasePath: Absolute path to a SQLite database file.
|
||||||
|
/// - model: Any `AnyLanguageModel`-compatible language model instance.
|
||||||
|
/// - allowlist: SQL operations the LLM may generate. Defaults to `.readOnly` (SELECT only).
|
||||||
|
/// - additionalContext: Optional extra context about the database for the LLM system prompt
|
||||||
|
/// (e.g., "This database stores e-commerce orders and products.").
|
||||||
|
/// - maxSummaryRows: Maximum rows to include when summarizing results (default: 50).
|
||||||
|
public init(
|
||||||
|
databasePath: String,
|
||||||
|
model: any LanguageModel,
|
||||||
|
allowlist: OperationAllowlist = .readOnly,
|
||||||
|
additionalContext: String? = nil,
|
||||||
|
maxSummaryRows: Int = 50
|
||||||
|
) {
|
||||||
|
do {
|
||||||
|
let pool = try DatabasePool(path: databasePath)
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: pool,
|
||||||
|
model: model,
|
||||||
|
allowlist: allowlist,
|
||||||
|
additionalContext: additionalContext,
|
||||||
|
maxSummaryRows: maxSummaryRows
|
||||||
|
)
|
||||||
|
self._viewModel = State(initialValue: ChatViewModel(engine: engine))
|
||||||
|
self._loadError = State(initialValue: nil)
|
||||||
|
} catch {
|
||||||
|
// If the database can't be opened, create a placeholder engine
|
||||||
|
// and store the error to display in the UI.
|
||||||
|
let queue = try! DatabaseQueue()
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: queue,
|
||||||
|
model: model,
|
||||||
|
allowlist: allowlist,
|
||||||
|
additionalContext: additionalContext,
|
||||||
|
maxSummaryRows: maxSummaryRows
|
||||||
|
)
|
||||||
|
self._viewModel = State(initialValue: ChatViewModel(engine: engine))
|
||||||
|
self._loadError = State(initialValue: DataChatError.databaseOpenFailed(
|
||||||
|
path: databasePath,
|
||||||
|
underlying: error
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a DataChatView from an existing GRDB database connection and language model.
|
||||||
|
///
|
||||||
|
/// Use this initializer when you already have a configured `DatabasePool` or
|
||||||
|
/// `DatabaseQueue` and want the convenience of `DataChatView` without
|
||||||
|
/// creating a `ChatEngine` yourself.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - database: A GRDB `DatabaseWriter` (`DatabasePool` or `DatabaseQueue`).
|
||||||
|
/// - model: Any `AnyLanguageModel`-compatible language model instance.
|
||||||
|
/// - allowlist: SQL operations the LLM may generate. Defaults to `.readOnly`.
|
||||||
|
/// - additionalContext: Optional extra context about the database for the LLM.
|
||||||
|
/// - maxSummaryRows: Maximum rows to include when summarizing results (default: 50).
|
||||||
|
public init(
|
||||||
|
database: any DatabaseWriter,
|
||||||
|
model: any LanguageModel,
|
||||||
|
allowlist: OperationAllowlist = .readOnly,
|
||||||
|
additionalContext: String? = nil,
|
||||||
|
maxSummaryRows: Int = 50
|
||||||
|
) {
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: database,
|
||||||
|
model: model,
|
||||||
|
allowlist: allowlist,
|
||||||
|
additionalContext: additionalContext,
|
||||||
|
maxSummaryRows: maxSummaryRows
|
||||||
|
)
|
||||||
|
self._viewModel = State(initialValue: ChatViewModel(engine: engine))
|
||||||
|
self._loadError = State(initialValue: nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
public var body: some View {
|
||||||
|
if let error = loadError {
|
||||||
|
errorView(error)
|
||||||
|
} else {
|
||||||
|
ChatView(viewModel: viewModel)
|
||||||
|
.task {
|
||||||
|
await viewModel.prepare()
|
||||||
|
}
|
||||||
|
.overlay {
|
||||||
|
if case .loading = viewModel.schemaReadiness {
|
||||||
|
schemaLoadingView
|
||||||
|
}
|
||||||
|
if case .failed(let reason) = viewModel.schemaReadiness {
|
||||||
|
schemaErrorView(reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Schema Loading View
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var schemaLoadingView: some View {
|
||||||
|
VStack(spacing: 16) {
|
||||||
|
ProgressView()
|
||||||
|
.controlSize(.large)
|
||||||
|
Text("Introspecting database schema…")
|
||||||
|
.font(.subheadline)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
}
|
||||||
|
.frame(maxWidth: .infinity, maxHeight: .infinity)
|
||||||
|
.background(.ultraThinMaterial)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Schema Error View
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private func schemaErrorView(_ reason: String) -> some View {
|
||||||
|
VStack(spacing: 16) {
|
||||||
|
Image(systemName: "exclamationmark.triangle.fill")
|
||||||
|
.font(.system(size: 40))
|
||||||
|
.foregroundStyle(.orange)
|
||||||
|
|
||||||
|
Text("Schema Introspection Failed")
|
||||||
|
.font(.headline)
|
||||||
|
|
||||||
|
Text(reason)
|
||||||
|
.font(.subheadline)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
.multilineTextAlignment(.center)
|
||||||
|
.padding(.horizontal, 32)
|
||||||
|
|
||||||
|
Button("Retry") {
|
||||||
|
Task {
|
||||||
|
await viewModel.prepare()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.buttonStyle(.borderedProminent)
|
||||||
|
}
|
||||||
|
.frame(maxWidth: .infinity, maxHeight: .infinity)
|
||||||
|
.background(.ultraThinMaterial)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Database Open Error View
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private func errorView(_ error: DataChatError) -> some View {
|
||||||
|
VStack(spacing: 16) {
|
||||||
|
Image(systemName: "exclamationmark.triangle.fill")
|
||||||
|
.font(.system(size: 40))
|
||||||
|
.foregroundStyle(.red)
|
||||||
|
|
||||||
|
Text("Unable to Open Database")
|
||||||
|
.font(.headline)
|
||||||
|
|
||||||
|
Text(error.localizedDescription)
|
||||||
|
.font(.subheadline)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
.multilineTextAlignment(.center)
|
||||||
|
.padding(.horizontal, 32)
|
||||||
|
}
|
||||||
|
.frame(maxWidth: .infinity, maxHeight: .infinity)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Errors
|
||||||
|
|
||||||
|
/// Errors specific to `DataChatView` initialization.
|
||||||
|
public enum DataChatError: Error, LocalizedError, Sendable {
|
||||||
|
/// The database file could not be opened at the given path.
|
||||||
|
case databaseOpenFailed(path: String, underlying: any Error)
|
||||||
|
|
||||||
|
public var errorDescription: String? {
|
||||||
|
switch self {
|
||||||
|
case .databaseOpenFailed(let path, let underlying):
|
||||||
|
return "Could not open database at \"\(path)\": \(underlying.localizedDescription)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
360
Sources/SwiftDBAI/Views/ErrorMessageView.swift
Normal file
360
Sources/SwiftDBAI/Views/ErrorMessageView.swift
Normal file
@@ -0,0 +1,360 @@
|
|||||||
|
// ErrorMessageView.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Reusable SwiftUI component that renders error messages with contextual
|
||||||
|
// icons, descriptions, and optional retry actions based on the error type.
|
||||||
|
|
||||||
|
import SwiftUI
|
||||||
|
|
||||||
|
/// A reusable SwiftUI component that renders a ``SwiftDBAIError`` with an
|
||||||
|
/// appropriate icon, human-readable message, and optional retry action.
|
||||||
|
///
|
||||||
|
/// The view automatically selects a visual treatment based on the error
|
||||||
|
/// category:
|
||||||
|
///
|
||||||
|
/// | Category | Icon | Color | Retry? |
|
||||||
|
/// |-------------------|-------------------------------|---------|--------|
|
||||||
|
/// | Safety / blocked | `shield.trianglebadge.excl…` | Orange | No |
|
||||||
|
/// | Confirmation | `hand.raised.fill` | Yellow | Yes* |
|
||||||
|
/// | LLM failure | `brain` | Purple | Yes |
|
||||||
|
/// | Schema / DB | `cylinder.split.1x2` | Red | No |
|
||||||
|
/// | Recoverable SQL | `arrow.clockwise` | Blue | Yes |
|
||||||
|
/// | Generic | `exclamationmark.triangle` | Red | No |
|
||||||
|
///
|
||||||
|
/// *Confirmation retry triggers the confirm callback, not a standard retry.
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
/// ```swift
|
||||||
|
/// ErrorMessageView(
|
||||||
|
/// error: .llmTimeout(seconds: 30),
|
||||||
|
/// onRetry: { /* resend the message */ }
|
||||||
|
/// )
|
||||||
|
/// ```
|
||||||
|
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||||
|
public struct ErrorMessageView: View {
|
||||||
|
/// The error to display. When `nil`, the view falls back to the raw message.
|
||||||
|
private let error: SwiftDBAIError?
|
||||||
|
|
||||||
|
/// The raw error message string (used as fallback when error is nil).
|
||||||
|
private let message: String
|
||||||
|
|
||||||
|
/// Called when the user taps the retry button. `nil` hides the button.
|
||||||
|
private let onRetry: (@Sendable () async -> Void)?
|
||||||
|
|
||||||
|
/// Called when the user confirms a destructive operation.
|
||||||
|
private let onConfirm: (@Sendable () async -> Void)?
|
||||||
|
|
||||||
|
@State private var isRetrying = false
|
||||||
|
|
||||||
|
// MARK: - Initializers
|
||||||
|
|
||||||
|
/// Creates an ErrorMessageView from a typed ``SwiftDBAIError``.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - error: The ``SwiftDBAIError`` to display.
|
||||||
|
/// - onRetry: An optional async closure invoked when the user taps retry.
|
||||||
|
/// - onConfirm: An optional async closure invoked when the user confirms
|
||||||
|
/// a destructive operation (only relevant for `.confirmationRequired`).
|
||||||
|
public init(
|
||||||
|
error: SwiftDBAIError,
|
||||||
|
onRetry: (@Sendable () async -> Void)? = nil,
|
||||||
|
onConfirm: (@Sendable () async -> Void)? = nil
|
||||||
|
) {
|
||||||
|
self.error = error
|
||||||
|
self.message = error.localizedDescription
|
||||||
|
self.onRetry = onRetry
|
||||||
|
self.onConfirm = onConfirm
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an ErrorMessageView from a ``ChatMessage``.
|
||||||
|
///
|
||||||
|
/// Extracts the typed error if available, otherwise falls back to the
|
||||||
|
/// message content string.
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - message: The chat message with role `.error`.
|
||||||
|
/// - onRetry: An optional async closure invoked when the user taps retry.
|
||||||
|
/// - onConfirm: An optional async closure invoked when the user confirms
|
||||||
|
/// a destructive operation.
|
||||||
|
public init(
|
||||||
|
chatMessage: ChatMessage,
|
||||||
|
onRetry: (@Sendable () async -> Void)? = nil,
|
||||||
|
onConfirm: (@Sendable () async -> Void)? = nil
|
||||||
|
) {
|
||||||
|
self.error = chatMessage.error
|
||||||
|
self.message = chatMessage.content
|
||||||
|
self.onRetry = onRetry
|
||||||
|
self.onConfirm = onConfirm
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates an ErrorMessageView from a plain string (untyped fallback).
|
||||||
|
///
|
||||||
|
/// - Parameters:
|
||||||
|
/// - message: The error message string.
|
||||||
|
/// - onRetry: An optional async closure invoked when the user taps retry.
|
||||||
|
public init(
|
||||||
|
message: String,
|
||||||
|
onRetry: (@Sendable () async -> Void)? = nil
|
||||||
|
) {
|
||||||
|
self.error = nil
|
||||||
|
self.message = message
|
||||||
|
self.onRetry = onRetry
|
||||||
|
self.onConfirm = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Body
|
||||||
|
|
||||||
|
public var body: some View {
|
||||||
|
VStack(alignment: .leading, spacing: 10) {
|
||||||
|
// Icon + message row
|
||||||
|
HStack(alignment: .firstTextBaseline, spacing: 8) {
|
||||||
|
Image(systemName: iconName)
|
||||||
|
.foregroundStyle(iconColor)
|
||||||
|
.font(.callout)
|
||||||
|
.accessibilityHidden(true)
|
||||||
|
|
||||||
|
VStack(alignment: .leading, spacing: 4) {
|
||||||
|
if let title = errorTitle {
|
||||||
|
Text(title)
|
||||||
|
.font(.callout.weight(.semibold))
|
||||||
|
.foregroundStyle(iconColor)
|
||||||
|
}
|
||||||
|
|
||||||
|
Text(message)
|
||||||
|
.font(.body)
|
||||||
|
.foregroundStyle(.primary)
|
||||||
|
.textSelection(.enabled)
|
||||||
|
.fixedSize(horizontal: false, vertical: true)
|
||||||
|
|
||||||
|
if let hint = recoveryHint {
|
||||||
|
Text(hint)
|
||||||
|
.font(.caption)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
.fixedSize(horizontal: false, vertical: true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Action buttons
|
||||||
|
if showRetryButton || showConfirmButton {
|
||||||
|
HStack(spacing: 12) {
|
||||||
|
if showConfirmButton {
|
||||||
|
confirmButton
|
||||||
|
}
|
||||||
|
if showRetryButton {
|
||||||
|
retryButton
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.padding(.leading, 26) // Align with text (icon width + spacing)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.accessibilityElement(children: .combine)
|
||||||
|
.accessibilityLabel(accessibilityDescription)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Action Buttons
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var retryButton: some View {
|
||||||
|
Button {
|
||||||
|
guard !isRetrying else { return }
|
||||||
|
isRetrying = true
|
||||||
|
Task {
|
||||||
|
await onRetry?()
|
||||||
|
isRetrying = false
|
||||||
|
}
|
||||||
|
} label: {
|
||||||
|
HStack(spacing: 4) {
|
||||||
|
if isRetrying {
|
||||||
|
ProgressView()
|
||||||
|
.controlSize(.mini)
|
||||||
|
} else {
|
||||||
|
Image(systemName: "arrow.clockwise")
|
||||||
|
.font(.caption)
|
||||||
|
}
|
||||||
|
Text(retryButtonLabel)
|
||||||
|
.font(.caption.weight(.medium))
|
||||||
|
}
|
||||||
|
.padding(.horizontal, 10)
|
||||||
|
.padding(.vertical, 6)
|
||||||
|
.background(iconColor.opacity(0.12))
|
||||||
|
.foregroundStyle(iconColor)
|
||||||
|
.clipShape(Capsule())
|
||||||
|
}
|
||||||
|
.buttonStyle(.plain)
|
||||||
|
.disabled(isRetrying)
|
||||||
|
}
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var confirmButton: some View {
|
||||||
|
Button {
|
||||||
|
Task {
|
||||||
|
await onConfirm?()
|
||||||
|
}
|
||||||
|
} label: {
|
||||||
|
HStack(spacing: 4) {
|
||||||
|
Image(systemName: "checkmark.circle")
|
||||||
|
.font(.caption)
|
||||||
|
Text("Confirm")
|
||||||
|
.font(.caption.weight(.medium))
|
||||||
|
}
|
||||||
|
.padding(.horizontal, 10)
|
||||||
|
.padding(.vertical, 6)
|
||||||
|
.background(Color.orange.opacity(0.12))
|
||||||
|
.foregroundStyle(.orange)
|
||||||
|
.clipShape(Capsule())
|
||||||
|
}
|
||||||
|
.buttonStyle(.plain)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Error Classification
|
||||||
|
|
||||||
|
private var errorCategory: ErrorCategory {
|
||||||
|
guard let error else { return .generic }
|
||||||
|
|
||||||
|
if error.requiresUserAction {
|
||||||
|
return .confirmation
|
||||||
|
}
|
||||||
|
if error.isSafetyError {
|
||||||
|
return .safety
|
||||||
|
}
|
||||||
|
if error.isRecoverable {
|
||||||
|
return .recoverable
|
||||||
|
}
|
||||||
|
|
||||||
|
switch error {
|
||||||
|
case .llmFailure, .llmResponseUnparseable, .llmTimeout:
|
||||||
|
return .llm
|
||||||
|
case .schemaIntrospectionFailed, .emptySchema, .databaseError, .queryTimedOut:
|
||||||
|
return .database
|
||||||
|
case .configurationError:
|
||||||
|
return .configuration
|
||||||
|
default:
|
||||||
|
return .generic
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private enum ErrorCategory {
|
||||||
|
case safety
|
||||||
|
case confirmation
|
||||||
|
case llm
|
||||||
|
case database
|
||||||
|
case recoverable
|
||||||
|
case configuration
|
||||||
|
case generic
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Visual Properties
|
||||||
|
|
||||||
|
private var iconName: String {
|
||||||
|
switch errorCategory {
|
||||||
|
case .safety:
|
||||||
|
return "shield.trianglebadge.exclamationmark.fill"
|
||||||
|
case .confirmation:
|
||||||
|
return "hand.raised.fill"
|
||||||
|
case .llm:
|
||||||
|
return "brain"
|
||||||
|
case .database:
|
||||||
|
return "cylinder.split.1x2"
|
||||||
|
case .recoverable:
|
||||||
|
return "arrow.clockwise"
|
||||||
|
case .configuration:
|
||||||
|
return "gearshape.triangle.fill"
|
||||||
|
case .generic:
|
||||||
|
return "exclamationmark.triangle.fill"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private var iconColor: Color {
|
||||||
|
switch errorCategory {
|
||||||
|
case .safety:
|
||||||
|
return .orange
|
||||||
|
case .confirmation:
|
||||||
|
return .yellow
|
||||||
|
case .llm:
|
||||||
|
return .purple
|
||||||
|
case .database:
|
||||||
|
return .red
|
||||||
|
case .recoverable:
|
||||||
|
return .blue
|
||||||
|
case .configuration:
|
||||||
|
return .gray
|
||||||
|
case .generic:
|
||||||
|
return .red
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private var errorTitle: String? {
|
||||||
|
switch errorCategory {
|
||||||
|
case .safety:
|
||||||
|
return "Operation Blocked"
|
||||||
|
case .confirmation:
|
||||||
|
return "Confirmation Required"
|
||||||
|
case .llm:
|
||||||
|
return "AI Provider Error"
|
||||||
|
case .database:
|
||||||
|
return "Database Error"
|
||||||
|
case .recoverable:
|
||||||
|
return "Query Issue"
|
||||||
|
case .configuration:
|
||||||
|
return "Configuration Error"
|
||||||
|
case .generic:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private var recoveryHint: String? {
|
||||||
|
guard let error else { return nil }
|
||||||
|
|
||||||
|
switch error {
|
||||||
|
case .noSQLGenerated, .llmResponseUnparseable:
|
||||||
|
return "Try rephrasing your question."
|
||||||
|
case .tableNotFound:
|
||||||
|
return "Check that you're referring to an existing table."
|
||||||
|
case .columnNotFound:
|
||||||
|
return "Verify the column name matches your schema."
|
||||||
|
case .invalidSQL:
|
||||||
|
return "The AI generated an invalid query. Try asking differently."
|
||||||
|
case .llmTimeout:
|
||||||
|
return "The AI took too long. Try a simpler question."
|
||||||
|
case .llmFailure:
|
||||||
|
return "The AI service may be temporarily unavailable."
|
||||||
|
case .emptySchema:
|
||||||
|
return "Add some tables to your database first."
|
||||||
|
case .queryTimedOut:
|
||||||
|
return "Try a simpler query or add database indexes."
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Button Visibility
|
||||||
|
|
||||||
|
private var showRetryButton: Bool {
|
||||||
|
guard onRetry != nil else { return false }
|
||||||
|
return errorCategory == .recoverable || errorCategory == .llm
|
||||||
|
}
|
||||||
|
|
||||||
|
private var showConfirmButton: Bool {
|
||||||
|
guard onConfirm != nil else { return false }
|
||||||
|
return errorCategory == .confirmation
|
||||||
|
}
|
||||||
|
|
||||||
|
private var retryButtonLabel: String {
|
||||||
|
switch errorCategory {
|
||||||
|
case .llm:
|
||||||
|
return "Retry"
|
||||||
|
case .recoverable:
|
||||||
|
return "Try Again"
|
||||||
|
default:
|
||||||
|
return "Retry"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Accessibility
|
||||||
|
|
||||||
|
private var accessibilityDescription: String {
|
||||||
|
let prefix = errorTitle.map { "\($0): " } ?? "Error: "
|
||||||
|
return prefix + message
|
||||||
|
}
|
||||||
|
}
|
||||||
205
Sources/SwiftDBAI/Views/MessageBubbleView.swift
Normal file
205
Sources/SwiftDBAI/Views/MessageBubbleView.swift
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
// MessageBubbleView.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Renders a single ChatMessage as a styled bubble with optional
|
||||||
|
// data table and SQL disclosure for query results.
|
||||||
|
|
||||||
|
import SwiftUI
|
||||||
|
import Charts
|
||||||
|
|
||||||
|
/// Renders a single `ChatMessage` in the chat conversation.
|
||||||
|
///
|
||||||
|
/// - **User messages** display right-aligned with an accent-colored background
|
||||||
|
/// and white text, using a continuous rounded rectangle shape.
|
||||||
|
/// - **Assistant messages** display left-aligned with a secondary background.
|
||||||
|
/// The natural language text summary is the primary content, rendered with
|
||||||
|
/// full `.body` font and `.primary` foreground for readability.
|
||||||
|
/// If the message contains a `queryResult` with tabular data, a
|
||||||
|
/// `ScrollableDataTableView` is automatically embedded below the summary.
|
||||||
|
/// An optional SQL disclosure group shows the generated query.
|
||||||
|
/// - **Error messages** display left-aligned with a red-tinted background
|
||||||
|
/// and an exclamation mark icon.
|
||||||
|
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||||
|
struct MessageBubbleView: View {
|
||||||
|
let message: ChatMessage
|
||||||
|
|
||||||
|
/// Whether to show the SQL query in a disclosure group.
|
||||||
|
var showSQL: Bool = true
|
||||||
|
|
||||||
|
/// Maximum height for the data table before it scrolls.
|
||||||
|
var maxTableHeight: CGFloat = 300
|
||||||
|
|
||||||
|
/// Called when the user taps "Retry" on a recoverable error.
|
||||||
|
var onRetry: (@Sendable () async -> Void)?
|
||||||
|
|
||||||
|
/// Called when the user confirms a destructive operation.
|
||||||
|
var onConfirm: (@Sendable () async -> Void)?
|
||||||
|
|
||||||
|
var body: some View {
|
||||||
|
HStack(alignment: .top) {
|
||||||
|
if message.role == .user { Spacer(minLength: 48) }
|
||||||
|
|
||||||
|
bubbleContent
|
||||||
|
.padding(.horizontal, 14)
|
||||||
|
.padding(.vertical, 10)
|
||||||
|
.background(bubbleBackground)
|
||||||
|
.clipShape(bubbleShape)
|
||||||
|
|
||||||
|
if message.role != .user { Spacer(minLength: 48) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Bubble Content
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var bubbleContent: some View {
|
||||||
|
switch message.role {
|
||||||
|
case .user:
|
||||||
|
userContent
|
||||||
|
case .assistant:
|
||||||
|
assistantContent
|
||||||
|
case .error:
|
||||||
|
errorContent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - User Content
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var userContent: some View {
|
||||||
|
Text(message.content)
|
||||||
|
.font(.body)
|
||||||
|
.foregroundStyle(.white)
|
||||||
|
.textSelection(.enabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Assistant Content (Text Summary + Data Table + SQL)
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var assistantContent: some View {
|
||||||
|
VStack(alignment: .leading, spacing: 10) {
|
||||||
|
// Natural language text summary — primary content
|
||||||
|
Text(message.content)
|
||||||
|
.font(.body)
|
||||||
|
.foregroundStyle(.primary)
|
||||||
|
.textSelection(.enabled)
|
||||||
|
.fixedSize(horizontal: false, vertical: true)
|
||||||
|
|
||||||
|
// Data table — automatically shown when queryResult has tabular data
|
||||||
|
if let queryResult = message.queryResult,
|
||||||
|
!queryResult.columns.isEmpty,
|
||||||
|
!queryResult.rows.isEmpty {
|
||||||
|
dataTableSection(for: queryResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SQL disclosure — collapsed by default for transparency
|
||||||
|
if showSQL, let sql = message.sql {
|
||||||
|
sqlDisclosure(sql: sql)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Error Content
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var errorContent: some View {
|
||||||
|
ErrorMessageView(
|
||||||
|
chatMessage: message,
|
||||||
|
onRetry: onRetry,
|
||||||
|
onConfirm: onConfirm
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Maximum height for the chart section.
|
||||||
|
var maxChartHeight: CGFloat = 250
|
||||||
|
|
||||||
|
/// Whether to show auto-detected charts. Defaults to `true`.
|
||||||
|
var showCharts: Bool = true
|
||||||
|
|
||||||
|
// MARK: - Chart Detection
|
||||||
|
|
||||||
|
/// The shared detector used for chart eligibility checks.
|
||||||
|
private static let chartDetector = ChartDataDetector()
|
||||||
|
|
||||||
|
// MARK: - Data Table Section
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private func dataTableSection(for queryResult: QueryResult) -> some View {
|
||||||
|
let dataTable = DataTable(queryResult)
|
||||||
|
|
||||||
|
VStack(alignment: .leading, spacing: 8) {
|
||||||
|
// Chart — automatically shown when ChartDataDetector finds eligible data
|
||||||
|
if showCharts {
|
||||||
|
chartSection(for: dataTable)
|
||||||
|
}
|
||||||
|
|
||||||
|
Divider()
|
||||||
|
|
||||||
|
ScrollableDataTableView(
|
||||||
|
dataTable: dataTable,
|
||||||
|
showAlternatingRows: true,
|
||||||
|
showFooter: true
|
||||||
|
)
|
||||||
|
.frame(maxHeight: maxTableHeight)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Chart Section
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private func chartSection(for dataTable: DataTable) -> some View {
|
||||||
|
let detector = Self.chartDetector
|
||||||
|
if detector.detect(dataTable) != nil {
|
||||||
|
VStack(alignment: .leading, spacing: 4) {
|
||||||
|
ChartResultView(dataTable: dataTable, detector: detector)
|
||||||
|
.frame(maxHeight: maxChartHeight)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - SQL Disclosure
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private func sqlDisclosure(sql: String) -> some View {
|
||||||
|
DisclosureGroup {
|
||||||
|
Text(sql)
|
||||||
|
.font(.system(.caption, design: .monospaced))
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
.textSelection(.enabled)
|
||||||
|
.padding(8)
|
||||||
|
.frame(maxWidth: .infinity, alignment: .leading)
|
||||||
|
.background(Color.primary.opacity(0.04))
|
||||||
|
.clipShape(RoundedRectangle(cornerRadius: 6))
|
||||||
|
} label: {
|
||||||
|
Label("SQL Query", systemImage: "chevron.left.forwardslash.chevron.right")
|
||||||
|
.font(.caption)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Styling Helpers
|
||||||
|
|
||||||
|
private var bubbleShape: RoundedRectangle {
|
||||||
|
RoundedRectangle(cornerRadius: 16, style: .continuous)
|
||||||
|
}
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var bubbleBackground: some View {
|
||||||
|
switch message.role {
|
||||||
|
case .user:
|
||||||
|
Color.accentColor
|
||||||
|
case .assistant:
|
||||||
|
Self.assistantBackgroundColor
|
||||||
|
case .error:
|
||||||
|
Color.red.opacity(0.1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static var assistantBackgroundColor: Color {
|
||||||
|
#if os(macOS)
|
||||||
|
Color(nsColor: .controlBackgroundColor)
|
||||||
|
#else
|
||||||
|
Color(uiColor: .secondarySystemGroupedBackground)
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
267
Sources/SwiftDBAI/Views/ScrollableDataTableView.swift
Normal file
267
Sources/SwiftDBAI/Views/ScrollableDataTableView.swift
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
// ScrollableDataTableView.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// A SwiftUI view that renders a DataTable with horizontal and vertical
|
||||||
|
// scrolling, styled column headers, and row cells.
|
||||||
|
|
||||||
|
import SwiftUI
|
||||||
|
|
||||||
|
/// A scrollable table view that renders a `DataTable` with column headers
|
||||||
|
/// and row cells, supporting both horizontal and vertical scrolling.
|
||||||
|
///
|
||||||
|
/// Usage:
|
||||||
|
/// ```swift
|
||||||
|
/// ScrollableDataTableView(dataTable: myDataTable)
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// The view automatically sizes columns based on content, highlights
|
||||||
|
/// alternating rows for readability, and right-aligns numeric columns.
|
||||||
|
public struct ScrollableDataTableView: View {
|
||||||
|
/// The data table to render.
|
||||||
|
public let dataTable: DataTable
|
||||||
|
|
||||||
|
/// Minimum width for each column in points.
|
||||||
|
public var minimumColumnWidth: CGFloat
|
||||||
|
|
||||||
|
/// Maximum width for each column in points.
|
||||||
|
public var maximumColumnWidth: CGFloat
|
||||||
|
|
||||||
|
/// Whether to show alternating row backgrounds.
|
||||||
|
public var showAlternatingRows: Bool
|
||||||
|
|
||||||
|
/// Whether to show the row count footer.
|
||||||
|
public var showFooter: Bool
|
||||||
|
|
||||||
|
public init(
|
||||||
|
dataTable: DataTable,
|
||||||
|
minimumColumnWidth: CGFloat = 80,
|
||||||
|
maximumColumnWidth: CGFloat = 250,
|
||||||
|
showAlternatingRows: Bool = true,
|
||||||
|
showFooter: Bool = true
|
||||||
|
) {
|
||||||
|
self.dataTable = dataTable
|
||||||
|
self.minimumColumnWidth = minimumColumnWidth
|
||||||
|
self.maximumColumnWidth = maximumColumnWidth
|
||||||
|
self.showAlternatingRows = showAlternatingRows
|
||||||
|
self.showFooter = showFooter
|
||||||
|
}
|
||||||
|
|
||||||
|
public var body: some View {
|
||||||
|
if dataTable.isEmpty {
|
||||||
|
emptyView
|
||||||
|
} else {
|
||||||
|
tableContent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Empty State
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var emptyView: some View {
|
||||||
|
VStack(spacing: 8) {
|
||||||
|
Image(systemName: "tablecells")
|
||||||
|
.font(.largeTitle)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
Text("No results")
|
||||||
|
.font(.headline)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
}
|
||||||
|
.frame(maxWidth: .infinity, minHeight: 100)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Table Content
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var tableContent: some View {
|
||||||
|
VStack(alignment: .leading, spacing: 0) {
|
||||||
|
ScrollView([.horizontal, .vertical]) {
|
||||||
|
LazyVStack(alignment: .leading, spacing: 0, pinnedViews: [.sectionHeaders]) {
|
||||||
|
Section {
|
||||||
|
ForEach(dataTable.rows) { row in
|
||||||
|
rowView(row)
|
||||||
|
}
|
||||||
|
} header: {
|
||||||
|
headerRow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if showFooter {
|
||||||
|
footerView
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Header
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var headerRow: some View {
|
||||||
|
HStack(spacing: 0) {
|
||||||
|
ForEach(dataTable.columns) { column in
|
||||||
|
Text(column.name)
|
||||||
|
.font(.caption.weight(.semibold))
|
||||||
|
.foregroundStyle(.primary)
|
||||||
|
.lineLimit(1)
|
||||||
|
.frame(
|
||||||
|
width: columnWidth(for: column),
|
||||||
|
alignment: alignment(for: column)
|
||||||
|
)
|
||||||
|
.padding(.horizontal, 8)
|
||||||
|
.padding(.vertical, 6)
|
||||||
|
|
||||||
|
if column.index < dataTable.columnCount - 1 {
|
||||||
|
Divider()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.background(.bar)
|
||||||
|
.overlay(alignment: .bottom) {
|
||||||
|
Divider()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Row
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private func rowView(_ row: DataTable.Row) -> some View {
|
||||||
|
HStack(spacing: 0) {
|
||||||
|
ForEach(dataTable.columns) { column in
|
||||||
|
cellView(value: row[column.index], column: column)
|
||||||
|
|
||||||
|
if column.index < dataTable.columnCount - 1 {
|
||||||
|
Divider()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.background(rowBackground(for: row))
|
||||||
|
.overlay(alignment: .bottom) {
|
||||||
|
Divider()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Cell
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private func cellView(value: QueryResult.Value, column: DataTable.Column) -> some View {
|
||||||
|
Group {
|
||||||
|
switch value {
|
||||||
|
case .null:
|
||||||
|
Text("NULL")
|
||||||
|
.foregroundStyle(.tertiary)
|
||||||
|
.italic()
|
||||||
|
case .blob(let data):
|
||||||
|
Text("<\(data.count) bytes>")
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
default:
|
||||||
|
Text(value.stringValue)
|
||||||
|
.foregroundStyle(.primary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.font(.caption)
|
||||||
|
.lineLimit(2)
|
||||||
|
.frame(
|
||||||
|
width: columnWidth(for: column),
|
||||||
|
alignment: alignment(for: column)
|
||||||
|
)
|
||||||
|
.padding(.horizontal, 8)
|
||||||
|
.padding(.vertical, 4)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Footer
|
||||||
|
|
||||||
|
@ViewBuilder
|
||||||
|
private var footerView: some View {
|
||||||
|
HStack {
|
||||||
|
Text("\(dataTable.rowCount) row\(dataTable.rowCount == 1 ? "" : "s")")
|
||||||
|
.font(.caption2)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
Spacer()
|
||||||
|
if dataTable.executionTime > 0 {
|
||||||
|
Text(String(format: "%.1f ms", dataTable.executionTime * 1000))
|
||||||
|
.font(.caption2)
|
||||||
|
.foregroundStyle(.secondary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.padding(.horizontal, 8)
|
||||||
|
.padding(.vertical, 4)
|
||||||
|
.background(.bar)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Layout Helpers
|
||||||
|
|
||||||
|
/// Determines column width based on the column name length and type.
|
||||||
|
private func columnWidth(for column: DataTable.Column) -> CGFloat {
|
||||||
|
// Estimate based on header text length
|
||||||
|
let headerWidth = CGFloat(column.name.count) * 8 + 16
|
||||||
|
|
||||||
|
// Sample some row values to estimate content width
|
||||||
|
let sampleRows = dataTable.rows.prefix(20)
|
||||||
|
let maxContentWidth = sampleRows.reduce(CGFloat(0)) { maxWidth, row in
|
||||||
|
let value = row[column.index]
|
||||||
|
let textLength = CGFloat(value.stringValue.count) * 7
|
||||||
|
return max(maxWidth, textLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
let estimatedWidth = max(headerWidth, maxContentWidth) + 16
|
||||||
|
return min(max(estimatedWidth, minimumColumnWidth), maximumColumnWidth)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the alignment for a column based on its inferred type.
|
||||||
|
private func alignment(for column: DataTable.Column) -> Alignment {
|
||||||
|
switch column.inferredType {
|
||||||
|
case .integer, .real:
|
||||||
|
return .trailing
|
||||||
|
default:
|
||||||
|
return .leading
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the background color for alternating rows.
|
||||||
|
@ViewBuilder
|
||||||
|
private func rowBackground(for row: DataTable.Row) -> some View {
|
||||||
|
if showAlternatingRows && row.id.isMultiple(of: 2) {
|
||||||
|
Color.clear
|
||||||
|
} else if showAlternatingRows {
|
||||||
|
Color.primary.opacity(0.03)
|
||||||
|
} else {
|
||||||
|
Color.clear
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Preview Support
|
||||||
|
|
||||||
|
#if DEBUG
|
||||||
|
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||||
|
#Preview("Data Table") {
|
||||||
|
let columns: [DataTable.Column] = [
|
||||||
|
.init(name: "id", index: 0, inferredType: .integer),
|
||||||
|
.init(name: "name", index: 1, inferredType: .text),
|
||||||
|
.init(name: "score", index: 2, inferredType: .real),
|
||||||
|
]
|
||||||
|
let rows: [DataTable.Row] = (0..<25).map { i in
|
||||||
|
DataTable.Row(
|
||||||
|
id: i,
|
||||||
|
values: [
|
||||||
|
.integer(Int64(i + 1)),
|
||||||
|
.text("Item \(i + 1)"),
|
||||||
|
.real(Double.random(in: 1.0...100.0)),
|
||||||
|
],
|
||||||
|
columnNames: ["id", "name", "score"]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let table = DataTable(columns: columns, rows: rows, sql: "SELECT * FROM items", executionTime: 0.023)
|
||||||
|
|
||||||
|
ScrollableDataTableView(dataTable: table)
|
||||||
|
.frame(height: 400)
|
||||||
|
.padding()
|
||||||
|
}
|
||||||
|
|
||||||
|
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||||
|
#Preview("Empty Table") {
|
||||||
|
let table = DataTable(columns: [], rows: [], sql: "", executionTime: 0)
|
||||||
|
ScrollableDataTableView(dataTable: table)
|
||||||
|
.frame(height: 200)
|
||||||
|
.padding()
|
||||||
|
}
|
||||||
|
#endif
|
||||||
254
Tests/SwiftDBAITests/BinarySizeTests.swift
Normal file
254
Tests/SwiftDBAITests/BinarySizeTests.swift
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
// BinarySizeTests.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
//
|
||||||
|
// Validates that the SwiftDBAI package stays within its 2 MB binary size budget.
|
||||||
|
// This test suite uses source-level heuristics since we can't measure the actual
|
||||||
|
// compiled binary size in a unit test. The constraints ensure the package remains
|
||||||
|
// lightweight by checking:
|
||||||
|
// 1. Total source code size (proxy for compiled size)
|
||||||
|
// 2. No embedded binary assets or large resources
|
||||||
|
// 3. No unnecessary heavy dependencies
|
||||||
|
// 4. File count stays reasonable (no code bloat)
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import Testing
|
||||||
|
|
||||||
|
@Suite("Binary Size Budget")
|
||||||
|
struct BinarySizeTests {
|
||||||
|
|
||||||
|
/// The maximum allowed total source code size in bytes.
|
||||||
|
/// At typical Swift optimized compilation ratios (2-4x), 500 KB of source
|
||||||
|
/// compiles to roughly 1-2 MB of binary. We set the source budget at 500 KB
|
||||||
|
/// to keep the compiled output well under 2 MB.
|
||||||
|
private static let maxSourceSizeBytes: Int = 500_000 // 500 KB
|
||||||
|
|
||||||
|
/// Maximum number of Swift source files allowed.
|
||||||
|
/// More files generally means more code and larger binaries.
|
||||||
|
private static let maxSourceFileCount: Int = 60
|
||||||
|
|
||||||
|
/// Maximum size for any single source file in bytes.
|
||||||
|
/// Large individual files often indicate code that should be split or
|
||||||
|
/// contains embedded data that bloats the binary.
|
||||||
|
private static let maxSingleFileSizeBytes: Int = 50_000 // 50 KB
|
||||||
|
|
||||||
|
/// Disallowed file extensions in the Sources directory that would bloat the binary.
|
||||||
|
private static let disallowedExtensions: Set<String> = [
|
||||||
|
"png", "jpg", "jpeg", "gif", "bmp", "tiff",
|
||||||
|
"mp3", "mp4", "wav", "mov",
|
||||||
|
"mlmodel", "mlmodelc", "mlpackage",
|
||||||
|
"sqlite", "db",
|
||||||
|
"zip", "tar", "gz",
|
||||||
|
"bin", "dat",
|
||||||
|
"framework", "dylib", "a"
|
||||||
|
]
|
||||||
|
|
||||||
|
// MARK: - Helper
|
||||||
|
|
||||||
|
/// Recursively finds all files in the Sources/SwiftDBAI directory.
|
||||||
|
private func findSourceFiles() throws -> [URL] {
|
||||||
|
let sourcesDir = findSourcesDirectory()
|
||||||
|
guard let sourcesDir else {
|
||||||
|
Issue.record("Could not locate Sources/SwiftDBAI directory")
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
let fileManager = FileManager.default
|
||||||
|
guard let enumerator = fileManager.enumerator(
|
||||||
|
at: sourcesDir,
|
||||||
|
includingPropertiesForKeys: [.fileSizeKey, .isRegularFileKey],
|
||||||
|
options: [.skipsHiddenFiles]
|
||||||
|
) else {
|
||||||
|
Issue.record("Could not enumerate Sources/SwiftDBAI directory")
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
var files: [URL] = []
|
||||||
|
for case let fileURL as URL in enumerator {
|
||||||
|
let resourceValues = try fileURL.resourceValues(forKeys: [.isRegularFileKey])
|
||||||
|
if resourceValues.isRegularFile == true {
|
||||||
|
files.append(fileURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return files
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Locates the Sources/SwiftDBAI directory by walking up from the test bundle.
|
||||||
|
private func findSourcesDirectory() -> URL? {
|
||||||
|
// Try common locations relative to the build directory
|
||||||
|
let fileManager = FileManager.default
|
||||||
|
|
||||||
|
// In SPM test runs, we can find the package root by checking known paths
|
||||||
|
var candidateURL = URL(fileURLWithPath: #filePath)
|
||||||
|
// Walk up from Tests/SwiftDBAITests/BinarySizeTests.swift to package root
|
||||||
|
for _ in 0..<3 {
|
||||||
|
candidateURL = candidateURL.deletingLastPathComponent()
|
||||||
|
}
|
||||||
|
let sourcesDir = candidateURL.appendingPathComponent("Sources/SwiftDBAI")
|
||||||
|
if fileManager.fileExists(atPath: sourcesDir.path) {
|
||||||
|
return sourcesDir
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: check current working directory
|
||||||
|
let cwdSources = URL(fileURLWithPath: fileManager.currentDirectoryPath)
|
||||||
|
.appendingPathComponent("Sources/SwiftDBAI")
|
||||||
|
if fileManager.fileExists(atPath: cwdSources.path) {
|
||||||
|
return cwdSources
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Tests
|
||||||
|
|
||||||
|
@Test("Total source code size stays under 500 KB budget")
|
||||||
|
func totalSourceCodeSizeUnderBudget() throws {
|
||||||
|
let files = try findSourceFiles()
|
||||||
|
let swiftFiles = files.filter { $0.pathExtension == "swift" }
|
||||||
|
|
||||||
|
var totalSize: Int = 0
|
||||||
|
for file in swiftFiles {
|
||||||
|
let attributes = try FileManager.default.attributesOfItem(atPath: file.path)
|
||||||
|
let fileSize = attributes[.size] as? Int ?? 0
|
||||||
|
totalSize += fileSize
|
||||||
|
}
|
||||||
|
|
||||||
|
#expect(totalSize < Self.maxSourceSizeBytes,
|
||||||
|
"""
|
||||||
|
Total Swift source size (\(totalSize) bytes) exceeds \(Self.maxSourceSizeBytes) byte budget.
|
||||||
|
At typical 2-4x compilation ratio, this would produce a binary larger than 2 MB.
|
||||||
|
Consider removing unused code or splitting into optional sub-targets.
|
||||||
|
""")
|
||||||
|
|
||||||
|
// Log the actual size for visibility
|
||||||
|
let sizeKB = Double(totalSize) / 1024.0
|
||||||
|
let budgetKB = Double(Self.maxSourceSizeBytes) / 1024.0
|
||||||
|
print("📦 SwiftDBAI source size: \(String(format: "%.1f", sizeKB)) KB / \(String(format: "%.0f", budgetKB)) KB budget (\(String(format: "%.0f", (sizeKB / budgetKB) * 100))% used)")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Source file count stays reasonable")
|
||||||
|
func sourceFileCountUnderLimit() throws {
|
||||||
|
let files = try findSourceFiles()
|
||||||
|
let swiftFiles = files.filter { $0.pathExtension == "swift" }
|
||||||
|
|
||||||
|
#expect(swiftFiles.count <= Self.maxSourceFileCount,
|
||||||
|
"""
|
||||||
|
Swift source file count (\(swiftFiles.count)) exceeds limit of \(Self.maxSourceFileCount).
|
||||||
|
More files generally means more code and larger binaries.
|
||||||
|
""")
|
||||||
|
|
||||||
|
print("📦 SwiftDBAI file count: \(swiftFiles.count) / \(Self.maxSourceFileCount) max")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("No individual source file exceeds 50 KB")
|
||||||
|
func noOversizedSourceFiles() throws {
|
||||||
|
let files = try findSourceFiles()
|
||||||
|
let swiftFiles = files.filter { $0.pathExtension == "swift" }
|
||||||
|
|
||||||
|
for file in swiftFiles {
|
||||||
|
let attributes = try FileManager.default.attributesOfItem(atPath: file.path)
|
||||||
|
let fileSize = attributes[.size] as? Int ?? 0
|
||||||
|
|
||||||
|
#expect(fileSize < Self.maxSingleFileSizeBytes,
|
||||||
|
"""
|
||||||
|
File \(file.lastPathComponent) is \(fileSize) bytes, exceeding the \(Self.maxSingleFileSizeBytes) byte limit.
|
||||||
|
Large files may contain embedded data or code that should be split.
|
||||||
|
""")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("No binary assets or heavy resources in Sources directory")
|
||||||
|
func noBinaryAssetsInSources() throws {
|
||||||
|
let files = try findSourceFiles()
|
||||||
|
|
||||||
|
let disallowedFiles = files.filter { file in
|
||||||
|
Self.disallowedExtensions.contains(file.pathExtension.lowercased())
|
||||||
|
}
|
||||||
|
|
||||||
|
#expect(disallowedFiles.isEmpty,
|
||||||
|
"""
|
||||||
|
Found \(disallowedFiles.count) disallowed file(s) in Sources directory:
|
||||||
|
\(disallowedFiles.map(\.lastPathComponent).joined(separator: "\n"))
|
||||||
|
These file types bloat the binary. Remove them or move to a separate resource bundle.
|
||||||
|
""")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Package has no resource bundles that could bloat binary")
|
||||||
|
func noResourceBundles() throws {
|
||||||
|
let files = try findSourceFiles()
|
||||||
|
|
||||||
|
let resourceFiles = files.filter { file in
|
||||||
|
let ext = file.pathExtension.lowercased()
|
||||||
|
return ["xcassets", "storyboard", "xib", "nib", "xcdatamodeld"].contains(ext)
|
||||||
|
}
|
||||||
|
|
||||||
|
#expect(resourceFiles.isEmpty,
|
||||||
|
"""
|
||||||
|
Found resource bundle files that could bloat the binary:
|
||||||
|
\(resourceFiles.map(\.lastPathComponent).joined(separator: "\n"))
|
||||||
|
SwiftDBAI should be pure code — no bundled resources.
|
||||||
|
""")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Only expected dependencies declared (GRDB + AnyLanguageModel)")
|
||||||
|
func minimalDependencies() throws {
|
||||||
|
// Read Package.swift to verify we only have the expected dependencies
|
||||||
|
var packageURL = URL(fileURLWithPath: #filePath)
|
||||||
|
for _ in 0..<3 {
|
||||||
|
packageURL = packageURL.deletingLastPathComponent()
|
||||||
|
}
|
||||||
|
let packageSwiftURL = packageURL.appendingPathComponent("Package.swift")
|
||||||
|
|
||||||
|
guard FileManager.default.fileExists(atPath: packageSwiftURL.path) else {
|
||||||
|
// Skip if we can't find Package.swift (CI environments etc.)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
let packageContents = try String(contentsOf: packageSwiftURL, encoding: .utf8)
|
||||||
|
|
||||||
|
// Count .package() declarations (dependencies)
|
||||||
|
let packageDeclarations = packageContents.components(separatedBy: ".package(")
|
||||||
|
.count - 1 // subtract 1 because the first segment is before any .package(
|
||||||
|
|
||||||
|
#expect(packageDeclarations <= 2,
|
||||||
|
"""
|
||||||
|
Found \(packageDeclarations) package dependencies, expected at most 2 (GRDB + AnyLanguageModel).
|
||||||
|
Additional dependencies increase binary size. Evaluate if they're truly needed.
|
||||||
|
""")
|
||||||
|
|
||||||
|
// Verify the expected dependencies are present
|
||||||
|
#expect(packageContents.contains("GRDB"), "Expected GRDB dependency")
|
||||||
|
#expect(packageContents.contains("AnyLanguageModel"), "Expected AnyLanguageModel dependency")
|
||||||
|
|
||||||
|
print("📦 SwiftDBAI dependencies: \(packageDeclarations) (GRDB + AnyLanguageModel)")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Estimated binary size under 2 MB")
|
||||||
|
func estimatedBinarySizeUnderLimit() throws {
|
||||||
|
let files = try findSourceFiles()
|
||||||
|
let swiftFiles = files.filter { $0.pathExtension == "swift" }
|
||||||
|
|
||||||
|
var totalSize: Int = 0
|
||||||
|
for file in swiftFiles {
|
||||||
|
let attributes = try FileManager.default.attributesOfItem(atPath: file.path)
|
||||||
|
let fileSize = attributes[.size] as? Int ?? 0
|
||||||
|
totalSize += fileSize
|
||||||
|
}
|
||||||
|
|
||||||
|
// Conservative estimate: optimized Swift binary is typically 2-4x source size.
|
||||||
|
// Use 4x as worst case multiplier for safety margin.
|
||||||
|
let worstCaseMultiplier = 4.0
|
||||||
|
let estimatedBinarySize = Double(totalSize) * worstCaseMultiplier
|
||||||
|
let maxBinarySize: Double = 2.0 * 1024.0 * 1024.0 // 2 MB
|
||||||
|
|
||||||
|
#expect(estimatedBinarySize < maxBinarySize,
|
||||||
|
"""
|
||||||
|
Estimated binary size (\(String(format: "%.1f", estimatedBinarySize / 1024.0)) KB) exceeds 2 MB limit.
|
||||||
|
Source: \(totalSize) bytes × \(worstCaseMultiplier)x multiplier = \(String(format: "%.1f", estimatedBinarySize / 1024.0)) KB
|
||||||
|
Note: This is the SwiftDBAI module only — excludes GRDB and AnyLanguageModel
|
||||||
|
which are existing dependencies the developer already includes.
|
||||||
|
""")
|
||||||
|
|
||||||
|
let estimatedMB = estimatedBinarySize / (1024.0 * 1024.0)
|
||||||
|
print("📦 Estimated SwiftDBAI binary size: \(String(format: "%.2f", estimatedMB)) MB / 2.00 MB limit (worst case \(worstCaseMultiplier)x)")
|
||||||
|
}
|
||||||
|
}
|
||||||
293
Tests/SwiftDBAITests/ChartDataDetectorTests.swift
Normal file
293
Tests/SwiftDBAITests/ChartDataDetectorTests.swift
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
// ChartDataDetectorTests.swift
|
||||||
|
// SwiftDBAITests
|
||||||
|
|
||||||
|
import Testing
|
||||||
|
@testable import SwiftDBAI
|
||||||
|
|
||||||
|
@Suite("ChartDataDetector")
|
||||||
|
struct ChartDataDetectorTests {
|
||||||
|
|
||||||
|
let detector = ChartDataDetector()
|
||||||
|
|
||||||
|
// MARK: - Helpers
|
||||||
|
|
||||||
|
private func makeQueryResult(
|
||||||
|
columns: [String],
|
||||||
|
rows: [[QueryResult.Value]],
|
||||||
|
sql: String = "SELECT *"
|
||||||
|
) -> QueryResult {
|
||||||
|
let rowDicts = rows.map { values in
|
||||||
|
Dictionary(uniqueKeysWithValues: zip(columns, values))
|
||||||
|
}
|
||||||
|
return QueryResult(
|
||||||
|
columns: columns,
|
||||||
|
rows: rowDicts,
|
||||||
|
sql: sql,
|
||||||
|
executionTime: 0.01
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func makeTable(
|
||||||
|
columns: [String],
|
||||||
|
rows: [[QueryResult.Value]],
|
||||||
|
sql: String = "SELECT *"
|
||||||
|
) -> DataTable {
|
||||||
|
DataTable(makeQueryResult(columns: columns, rows: rows, sql: sql))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Basic Eligibility
|
||||||
|
|
||||||
|
@Test("Returns nil for single-column results")
|
||||||
|
func singleColumn() {
|
||||||
|
let table = makeTable(
|
||||||
|
columns: ["count"],
|
||||||
|
rows: [[.integer(42)]]
|
||||||
|
)
|
||||||
|
#expect(detector.detect(table) == nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Returns nil for empty results")
|
||||||
|
func emptyResults() {
|
||||||
|
let table = makeTable(columns: ["name", "value"], rows: [])
|
||||||
|
#expect(detector.detect(table) == nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Returns nil for single row")
|
||||||
|
func singleRow() {
|
||||||
|
let table = makeTable(
|
||||||
|
columns: ["name", "count"],
|
||||||
|
rows: [[.text("A"), .integer(10)]]
|
||||||
|
)
|
||||||
|
#expect(detector.detect(table) == nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Returns nil for too many rows")
|
||||||
|
func tooManyRows() {
|
||||||
|
let rows = (0..<101).map { i in
|
||||||
|
[QueryResult.Value.text("cat\(i)"), .integer(Int64(i))]
|
||||||
|
}
|
||||||
|
let table = makeTable(columns: ["name", "count"], rows: rows)
|
||||||
|
#expect(detector.detect(table) == nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Bar Chart Detection
|
||||||
|
|
||||||
|
@Test("Recommends bar chart for categorical text + numeric")
|
||||||
|
func barChartCategorical() {
|
||||||
|
let table = makeTable(
|
||||||
|
columns: ["department", "headcount"],
|
||||||
|
rows: [
|
||||||
|
[.text("Engineering"), .integer(45)],
|
||||||
|
[.text("Marketing"), .integer(20)],
|
||||||
|
[.text("Sales"), .integer(30)],
|
||||||
|
[.text("HR"), .integer(10)],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
let rec = detector.detect(table)
|
||||||
|
#expect(rec != nil)
|
||||||
|
#expect(rec?.chartType == .bar)
|
||||||
|
#expect(rec?.categoryColumn == "department")
|
||||||
|
#expect(rec?.valueColumn == "headcount")
|
||||||
|
#expect(rec?.confidence ?? 0 > 0.5)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Pie Chart Detection
|
||||||
|
|
||||||
|
@Test("Recommends pie chart for small positive proportions")
|
||||||
|
func pieChartSmallCategories() {
|
||||||
|
let table = makeTable(
|
||||||
|
columns: ["status", "count"],
|
||||||
|
rows: [
|
||||||
|
[.text("Active"), .integer(50)],
|
||||||
|
[.text("Inactive"), .integer(30)],
|
||||||
|
[.text("Pending"), .integer(20)],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
let rec = detector.detect(table)
|
||||||
|
#expect(rec != nil)
|
||||||
|
#expect(rec?.chartType == .pie)
|
||||||
|
#expect(rec?.categoryColumn == "status")
|
||||||
|
#expect(rec?.valueColumn == "count")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Does not recommend pie with negative values")
|
||||||
|
func pieRejectsNegative() {
|
||||||
|
let table = makeTable(
|
||||||
|
columns: ["category", "change"],
|
||||||
|
rows: [
|
||||||
|
[.text("A"), .integer(50)],
|
||||||
|
[.text("B"), .integer(-10)],
|
||||||
|
[.text("C"), .integer(20)],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
let rec = detector.detect(table)
|
||||||
|
#expect(rec != nil)
|
||||||
|
// Should NOT be pie since there's a negative value
|
||||||
|
#expect(rec?.chartType != .pie)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Does not recommend pie with too many slices")
|
||||||
|
func pieRejectsTooManySlices() {
|
||||||
|
let rows = (0..<10).map { i in
|
||||||
|
[QueryResult.Value.text("cat\(i)"), .integer(Int64(i + 1))]
|
||||||
|
}
|
||||||
|
let table = makeTable(columns: ["category", "value"], rows: rows)
|
||||||
|
let rec = detector.detect(table)
|
||||||
|
#expect(rec != nil)
|
||||||
|
#expect(rec?.chartType != .pie)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Line Chart Detection
|
||||||
|
|
||||||
|
@Test("Recommends line chart for time-series column names")
|
||||||
|
func lineChartTimeSeries() {
|
||||||
|
let table = makeTable(
|
||||||
|
columns: ["year", "revenue"],
|
||||||
|
rows: [
|
||||||
|
[.text("2020"), .real(1_000_000)],
|
||||||
|
[.text("2021"), .real(1_200_000)],
|
||||||
|
[.text("2022"), .real(1_500_000)],
|
||||||
|
[.text("2023"), .real(1_800_000)],
|
||||||
|
[.text("2024"), .real(2_100_000)],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
let rec = detector.detect(table)
|
||||||
|
#expect(rec != nil)
|
||||||
|
#expect(rec?.chartType == .line)
|
||||||
|
#expect(rec?.categoryColumn == "year")
|
||||||
|
#expect(rec?.valueColumn == "revenue")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Recommends line chart for date-formatted text values")
|
||||||
|
func lineChartDateValues() {
|
||||||
|
let table = makeTable(
|
||||||
|
columns: ["period", "sales"],
|
||||||
|
rows: [
|
||||||
|
[.text("2024-01"), .integer(100)],
|
||||||
|
[.text("2024-02"), .integer(120)],
|
||||||
|
[.text("2024-03"), .integer(90)],
|
||||||
|
[.text("2024-04"), .integer(150)],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
let rec = detector.detect(table)
|
||||||
|
#expect(rec != nil)
|
||||||
|
#expect(rec?.chartType == .line)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Recommends line chart for sequential numeric x-axis")
|
||||||
|
func lineChartSequential() {
|
||||||
|
let table = makeTable(
|
||||||
|
columns: ["step", "value"],
|
||||||
|
rows: [
|
||||||
|
[.integer(1), .real(2.5)],
|
||||||
|
[.integer(2), .real(3.1)],
|
||||||
|
[.integer(3), .real(4.0)],
|
||||||
|
[.integer(4), .real(3.8)],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
let rec = detector.detect(table)
|
||||||
|
#expect(rec != nil)
|
||||||
|
#expect(rec?.chartType == .line)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - All Recommendations
|
||||||
|
|
||||||
|
@Test("Returns multiple recommendations sorted by confidence")
|
||||||
|
func allRecommendations() {
|
||||||
|
let table = makeTable(
|
||||||
|
columns: ["category", "amount"],
|
||||||
|
rows: [
|
||||||
|
[.text("A"), .integer(30)],
|
||||||
|
[.text("B"), .integer(50)],
|
||||||
|
[.text("C"), .integer(20)],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
let recs = detector.allRecommendations(for: table)
|
||||||
|
#expect(!recs.isEmpty)
|
||||||
|
// Should be sorted by confidence descending
|
||||||
|
for i in 1..<recs.count {
|
||||||
|
#expect(recs[i - 1].confidence >= recs[i].confidence)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Two Numeric Columns Fallback
|
||||||
|
|
||||||
|
@Test("Uses first numeric as category when no text column exists")
|
||||||
|
func numericOnlyColumns() {
|
||||||
|
let table = makeTable(
|
||||||
|
columns: ["x", "y"],
|
||||||
|
rows: [
|
||||||
|
[.integer(1), .integer(10)],
|
||||||
|
[.integer(2), .integer(20)],
|
||||||
|
[.integer(3), .integer(30)],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
let rec = detector.detect(table)
|
||||||
|
#expect(rec != nil)
|
||||||
|
#expect(rec?.categoryColumn == "x")
|
||||||
|
#expect(rec?.valueColumn == "y")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Confidence & Reason
|
||||||
|
|
||||||
|
@Test("Confidence is between 0 and 1")
|
||||||
|
func confidenceBounds() {
|
||||||
|
let table = makeTable(
|
||||||
|
columns: ["name", "score"],
|
||||||
|
rows: [
|
||||||
|
[.text("A"), .integer(10)],
|
||||||
|
[.text("B"), .integer(20)],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
let rec = detector.detect(table)
|
||||||
|
#expect(rec != nil)
|
||||||
|
#expect(rec!.confidence >= 0.0)
|
||||||
|
#expect(rec!.confidence <= 1.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Reason is non-empty")
|
||||||
|
func reasonPresent() {
|
||||||
|
let table = makeTable(
|
||||||
|
columns: ["name", "score"],
|
||||||
|
rows: [
|
||||||
|
[.text("A"), .integer(10)],
|
||||||
|
[.text("B"), .integer(20)],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
let rec = detector.detect(table)
|
||||||
|
#expect(rec != nil)
|
||||||
|
#expect(!rec!.reason.isEmpty)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Custom Configuration
|
||||||
|
|
||||||
|
@Test("Respects custom minimumRows")
|
||||||
|
func customMinRows() {
|
||||||
|
let strict = ChartDataDetector(minimumRows: 5)
|
||||||
|
let table = makeTable(
|
||||||
|
columns: ["name", "value"],
|
||||||
|
rows: [
|
||||||
|
[.text("A"), .integer(1)],
|
||||||
|
[.text("B"), .integer(2)],
|
||||||
|
[.text("C"), .integer(3)],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
#expect(strict.detect(table) == nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Respects custom maxPieSlices")
|
||||||
|
func customMaxPieSlices() {
|
||||||
|
let narrow = ChartDataDetector(maxPieSlices: 2)
|
||||||
|
let table = makeTable(
|
||||||
|
columns: ["status", "count"],
|
||||||
|
rows: [
|
||||||
|
[.text("A"), .integer(50)],
|
||||||
|
[.text("B"), .integer(30)],
|
||||||
|
[.text("C"), .integer(20)],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
let rec = narrow.detect(table)
|
||||||
|
// With maxPieSlices=2, 3 rows should not get pie
|
||||||
|
#expect(rec?.chartType != .pie)
|
||||||
|
}
|
||||||
|
}
|
||||||
1091
Tests/SwiftDBAITests/ChatEngineTests.swift
Normal file
1091
Tests/SwiftDBAITests/ChatEngineTests.swift
Normal file
File diff suppressed because it is too large
Load Diff
164
Tests/SwiftDBAITests/ChatViewTests.swift
Normal file
164
Tests/SwiftDBAITests/ChatViewTests.swift
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
// ChatViewTests.swift
|
||||||
|
// SwiftDBAITests
|
||||||
|
//
|
||||||
|
// Tests for ChatView, ChatViewModel, and MessageBubbleView integration
|
||||||
|
// with ScrollableDataTableView.
|
||||||
|
|
||||||
|
import Testing
|
||||||
|
import Foundation
|
||||||
|
@testable import SwiftDBAI
|
||||||
|
|
||||||
|
@Suite("SchemaReadiness Tests")
|
||||||
|
struct SchemaReadinessTests {
|
||||||
|
|
||||||
|
@Test("SchemaReadiness isReady returns true only for ready state")
|
||||||
|
func isReadyProperty() {
|
||||||
|
#expect(SchemaReadiness.idle.isReady == false)
|
||||||
|
#expect(SchemaReadiness.loading.isReady == false)
|
||||||
|
#expect(SchemaReadiness.ready(tableCount: 3).isReady == true)
|
||||||
|
#expect(SchemaReadiness.failed("error").isReady == false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Suite("ChatViewModel Tests")
|
||||||
|
struct ChatViewModelTests {
|
||||||
|
|
||||||
|
@Test("Messages with query results produce DataTable-compatible data")
|
||||||
|
func messageWithQueryResultHasTableData() {
|
||||||
|
// A ChatMessage with a queryResult should have the data needed
|
||||||
|
// for ScrollableDataTableView rendering
|
||||||
|
let result = QueryResult(
|
||||||
|
columns: ["id", "name", "score"],
|
||||||
|
rows: [
|
||||||
|
["id": .integer(1), "name": .text("Alice"), "score": .real(95.5)],
|
||||||
|
["id": .integer(2), "name": .text("Bob"), "score": .real(87.3)],
|
||||||
|
],
|
||||||
|
sql: "SELECT id, name, score FROM users",
|
||||||
|
executionTime: 0.01
|
||||||
|
)
|
||||||
|
|
||||||
|
let message = ChatMessage(
|
||||||
|
role: .assistant,
|
||||||
|
content: "Found 2 users.",
|
||||||
|
queryResult: result,
|
||||||
|
sql: "SELECT id, name, score FROM users"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Verify queryResult is present and can be converted to DataTable
|
||||||
|
#expect(message.queryResult != nil)
|
||||||
|
#expect(message.queryResult!.columns.count == 3)
|
||||||
|
#expect(message.queryResult!.rows.count == 2)
|
||||||
|
|
||||||
|
// Verify DataTable conversion works (this is what MessageBubbleView does)
|
||||||
|
let dataTable = DataTable(message.queryResult!)
|
||||||
|
#expect(dataTable.columnCount == 3)
|
||||||
|
#expect(dataTable.rowCount == 2)
|
||||||
|
#expect(dataTable.columns[0].name == "id")
|
||||||
|
#expect(dataTable.columns[1].name == "name")
|
||||||
|
#expect(dataTable.columns[2].name == "score")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Messages without query results do not trigger table rendering")
|
||||||
|
func messageWithoutQueryResult() {
|
||||||
|
let message = ChatMessage(
|
||||||
|
role: .assistant,
|
||||||
|
content: "Hello! How can I help?",
|
||||||
|
queryResult: nil,
|
||||||
|
sql: nil
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(message.queryResult == nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Empty query results do not trigger table rendering")
|
||||||
|
func emptyQueryResult() {
|
||||||
|
let result = QueryResult(
|
||||||
|
columns: [],
|
||||||
|
rows: [],
|
||||||
|
sql: "SELECT * FROM empty_table",
|
||||||
|
executionTime: 0.001
|
||||||
|
)
|
||||||
|
|
||||||
|
let message = ChatMessage(
|
||||||
|
role: .assistant,
|
||||||
|
content: "No results found.",
|
||||||
|
queryResult: result,
|
||||||
|
sql: "SELECT * FROM empty_table"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Even though queryResult exists, it has no columns/rows
|
||||||
|
// MessageBubbleView checks both conditions before showing the table
|
||||||
|
#expect(message.queryResult != nil)
|
||||||
|
#expect(message.queryResult!.columns.isEmpty)
|
||||||
|
#expect(message.queryResult!.rows.isEmpty)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Mutation results do not trigger table rendering")
|
||||||
|
func mutationQueryResult() {
|
||||||
|
let result = QueryResult(
|
||||||
|
columns: [],
|
||||||
|
rows: [],
|
||||||
|
sql: "INSERT INTO users (name) VALUES ('Charlie')",
|
||||||
|
executionTime: 0.005,
|
||||||
|
rowsAffected: 1
|
||||||
|
)
|
||||||
|
|
||||||
|
let message = ChatMessage(
|
||||||
|
role: .assistant,
|
||||||
|
content: "Successfully inserted 1 row.",
|
||||||
|
queryResult: result,
|
||||||
|
sql: "INSERT INTO users (name) VALUES ('Charlie')"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Mutation results have empty columns — no table shown
|
||||||
|
#expect(message.queryResult!.columns.isEmpty)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Error messages never have query results")
|
||||||
|
func errorMessageHasNoQueryResult() {
|
||||||
|
let message = ChatMessage(
|
||||||
|
role: .error,
|
||||||
|
content: "SELECT operations are not allowed."
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(message.queryResult == nil)
|
||||||
|
#expect(message.role == .error)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("DataTable preserves column order from QueryResult")
|
||||||
|
func dataTableColumnOrder() {
|
||||||
|
let result = QueryResult(
|
||||||
|
columns: ["date", "revenue", "category"],
|
||||||
|
rows: [
|
||||||
|
["date": .text("2024-01-01"), "revenue": .real(1500.0), "category": .text("Electronics")],
|
||||||
|
],
|
||||||
|
sql: "SELECT date, revenue, category FROM sales",
|
||||||
|
executionTime: 0.02
|
||||||
|
)
|
||||||
|
|
||||||
|
let dataTable = DataTable(result)
|
||||||
|
#expect(dataTable.columnNames == ["date", "revenue", "category"])
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Large result sets are renderable as DataTable")
|
||||||
|
func largeResultSet() {
|
||||||
|
var rows: [[String: QueryResult.Value]] = []
|
||||||
|
for i in 0..<500 {
|
||||||
|
rows.append([
|
||||||
|
"id": .integer(Int64(i)),
|
||||||
|
"value": .real(Double(i) * 1.5),
|
||||||
|
])
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = QueryResult(
|
||||||
|
columns: ["id", "value"],
|
||||||
|
rows: rows,
|
||||||
|
sql: "SELECT id, value FROM big_table",
|
||||||
|
executionTime: 0.15
|
||||||
|
)
|
||||||
|
|
||||||
|
let dataTable = DataTable(result)
|
||||||
|
#expect(dataTable.rowCount == 500)
|
||||||
|
#expect(dataTable.columnCount == 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
136
Tests/SwiftDBAITests/DataChatViewUsageTests.swift
Normal file
136
Tests/SwiftDBAITests/DataChatViewUsageTests.swift
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
// DataChatViewUsageTests.swift
|
||||||
|
// SwiftDBAITests
|
||||||
|
//
|
||||||
|
// Proves DataChatView works with minimal setup — under 10 lines of code.
|
||||||
|
// A developer only needs a GRDB connection and a LanguageModel to get a
|
||||||
|
// full chat-with-database SwiftUI view.
|
||||||
|
|
||||||
|
import Testing
|
||||||
|
import Foundation
|
||||||
|
import GRDB
|
||||||
|
@testable import SwiftDBAI
|
||||||
|
|
||||||
|
// MARK: - Minimal Setup: DataChatView in Under 10 Lines
|
||||||
|
|
||||||
|
/// This test suite proves the "zero_config_reads" principle:
|
||||||
|
/// A developer with an existing SQLite database can create a fully functional
|
||||||
|
/// chat UI by providing only a GRDB connection and a language model instance.
|
||||||
|
/// No schema files, no annotations, no manual configuration required.
|
||||||
|
@Suite("DataChatView Minimal Setup")
|
||||||
|
struct DataChatViewMinimalSetupTests {
|
||||||
|
|
||||||
|
// ┌──────────────────────────────────────────────────────────┐
|
||||||
|
// │ USAGE EXAMPLE — DataChatView in 6 lines of real code │
|
||||||
|
// │ │
|
||||||
|
// │ import SwiftDBAI │
|
||||||
|
// │ import GRDB │
|
||||||
|
// │ │
|
||||||
|
// │ let db = try DatabaseQueue(path: "mydata.sqlite") │
|
||||||
|
// │ let model = OllamaLanguageModel(model: "llama3") │
|
||||||
|
// │ │
|
||||||
|
// │ var body: some View { │
|
||||||
|
// │ DataChatView(database: db, model: model) │
|
||||||
|
// │ } │
|
||||||
|
// └──────────────────────────────────────────────────────────┘
|
||||||
|
|
||||||
|
/// Creates a temporary in-memory database with sample data for tests.
|
||||||
|
private static func makeSampleDatabase() throws -> DatabaseQueue {
|
||||||
|
let db = try DatabaseQueue()
|
||||||
|
try db.write { db in
|
||||||
|
try db.execute(sql: """
|
||||||
|
CREATE TABLE products (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
price REAL NOT NULL,
|
||||||
|
category TEXT
|
||||||
|
);
|
||||||
|
INSERT INTO products (name, price, category) VALUES ('Widget', 9.99, 'Hardware');
|
||||||
|
INSERT INTO products (name, price, category) VALUES ('Gadget', 24.99, 'Electronics');
|
||||||
|
INSERT INTO products (name, price, category) VALUES ('Doohickey', 4.99, 'Hardware');
|
||||||
|
""")
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("DataChatView initializes from database + model in 2 lines")
|
||||||
|
@MainActor
|
||||||
|
func dataChatViewMinimalInit() throws {
|
||||||
|
// LINE 1: Create (or receive) a GRDB connection
|
||||||
|
let db = try Self.makeSampleDatabase()
|
||||||
|
// LINE 2: Create the view — that's it!
|
||||||
|
let _ = DataChatView(database: db, model: MockLanguageModel())
|
||||||
|
// The view is ready. No schema files, no annotations, no extra config.
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("DataChatView path-based init works in 1 line given a path and model")
|
||||||
|
@MainActor
|
||||||
|
func dataChatViewPathInit() throws {
|
||||||
|
// Create a temp database file
|
||||||
|
let tempDir = FileManager.default.temporaryDirectory
|
||||||
|
let dbPath = tempDir.appendingPathComponent("test_\(UUID().uuidString).sqlite").path
|
||||||
|
let db = try DatabaseQueue(path: dbPath)
|
||||||
|
try db.write { db in
|
||||||
|
try db.execute(sql: "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ONE LINE to get a full chat UI:
|
||||||
|
let _ = DataChatView(databasePath: dbPath, model: MockLanguageModel())
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
try? FileManager.default.removeItem(atPath: dbPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("ChatEngine headless usage works in 3 lines")
|
||||||
|
func chatEngineMinimalUsage() async throws {
|
||||||
|
// LINE 1: Database
|
||||||
|
let db = try Self.makeSampleDatabase()
|
||||||
|
// LINE 2: Engine
|
||||||
|
let engine = ChatEngine(database: db, model: MockLanguageModel(responseText: "SELECT COUNT(*) AS total FROM products"))
|
||||||
|
// LINE 3: Schema preparation verifies auto-introspection works
|
||||||
|
let schema = try await engine.prepareSchema()
|
||||||
|
|
||||||
|
// The engine auto-discovered the schema — no manual config needed
|
||||||
|
#expect(schema.tableNames.contains("products"))
|
||||||
|
#expect(schema.tableNames.count == 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("ChatViewModel works with zero configuration beyond db + model")
|
||||||
|
@MainActor
|
||||||
|
func chatViewModelMinimalUsage() async throws {
|
||||||
|
let db = try Self.makeSampleDatabase()
|
||||||
|
let engine = ChatEngine(database: db, model: MockLanguageModel())
|
||||||
|
let viewModel = ChatViewModel(engine: engine)
|
||||||
|
|
||||||
|
// Prepare triggers auto-schema-introspection
|
||||||
|
await viewModel.prepare()
|
||||||
|
|
||||||
|
#expect(viewModel.schemaReadiness.isReady)
|
||||||
|
#expect(viewModel.messages.isEmpty) // Clean slate, ready to chat
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Default configuration is read-only (safe by default)")
|
||||||
|
@MainActor
|
||||||
|
func defaultIsReadOnly() throws {
|
||||||
|
let db = try Self.makeSampleDatabase()
|
||||||
|
// No allowlist specified — defaults to .readOnly
|
||||||
|
let _ = DataChatView(database: db, model: MockLanguageModel())
|
||||||
|
// This compiles and works. SELECT-only is the safe default.
|
||||||
|
// Developer must explicitly opt in to writes:
|
||||||
|
// DataChatView(database: db, model: model, allowlist: .standard)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Full DataChatView with all options still under 10 lines")
|
||||||
|
@MainActor
|
||||||
|
func dataChatViewFullConfig() throws {
|
||||||
|
let db = try Self.makeSampleDatabase() // 1
|
||||||
|
let model = MockLanguageModel() // 2
|
||||||
|
let _ = DataChatView( // 3-8
|
||||||
|
database: db,
|
||||||
|
model: model,
|
||||||
|
allowlist: .readOnly,
|
||||||
|
additionalContext: "Product catalog for an e-commerce store",
|
||||||
|
maxSummaryRows: 100
|
||||||
|
)
|
||||||
|
// Even with ALL options specified, it's under 10 lines of setup.
|
||||||
|
}
|
||||||
|
}
|
||||||
285
Tests/SwiftDBAITests/DataTableTests.swift
Normal file
285
Tests/SwiftDBAITests/DataTableTests.swift
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
// DataTableTests.swift
|
||||||
|
// SwiftDBAITests
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import Testing
|
||||||
|
@testable import SwiftDBAI
|
||||||
|
|
||||||
|
@Suite("DataTable")
|
||||||
|
struct DataTableTests {
|
||||||
|
|
||||||
|
// MARK: - Helpers
|
||||||
|
|
||||||
|
private func makeQueryResult(
|
||||||
|
columns: [String],
|
||||||
|
rows: [[String: QueryResult.Value]],
|
||||||
|
sql: String = "SELECT * FROM test",
|
||||||
|
executionTime: TimeInterval = 0.01
|
||||||
|
) -> QueryResult {
|
||||||
|
QueryResult(
|
||||||
|
columns: columns,
|
||||||
|
rows: rows,
|
||||||
|
sql: sql,
|
||||||
|
executionTime: executionTime
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Basic Construction
|
||||||
|
|
||||||
|
@Test("Converts QueryResult columns and rows correctly")
|
||||||
|
func basicConversion() {
|
||||||
|
let result = makeQueryResult(
|
||||||
|
columns: ["id", "name", "score"],
|
||||||
|
rows: [
|
||||||
|
["id": .integer(1), "name": .text("Alice"), "score": .real(95.5)],
|
||||||
|
["id": .integer(2), "name": .text("Bob"), "score": .real(87.0)],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
let table = DataTable(result)
|
||||||
|
|
||||||
|
#expect(table.columnCount == 3)
|
||||||
|
#expect(table.rowCount == 2)
|
||||||
|
#expect(table.columnNames == ["id", "name", "score"])
|
||||||
|
#expect(table.sql == "SELECT * FROM test")
|
||||||
|
#expect(table.executionTime == 0.01)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Empty result produces empty table")
|
||||||
|
func emptyResult() {
|
||||||
|
let result = makeQueryResult(columns: ["id", "name"], rows: [])
|
||||||
|
|
||||||
|
let table = DataTable(result)
|
||||||
|
|
||||||
|
#expect(table.isEmpty)
|
||||||
|
#expect(table.rowCount == 0)
|
||||||
|
#expect(table.columnCount == 2)
|
||||||
|
#expect(table.columnNames == ["id", "name"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Subscript Access
|
||||||
|
|
||||||
|
@Test("Subscript by row and column index")
|
||||||
|
func subscriptByIndex() {
|
||||||
|
let result = makeQueryResult(
|
||||||
|
columns: ["a", "b"],
|
||||||
|
rows: [
|
||||||
|
["a": .integer(10), "b": .text("hello")],
|
||||||
|
["a": .integer(20), "b": .text("world")],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
let table = DataTable(result)
|
||||||
|
|
||||||
|
#expect(table[row: 0, column: 0] == .integer(10))
|
||||||
|
#expect(table[row: 0, column: 1] == .text("hello"))
|
||||||
|
#expect(table[row: 1, column: 0] == .integer(20))
|
||||||
|
#expect(table[row: 1, column: 1] == .text("world"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Subscript by row index and column name")
|
||||||
|
func subscriptByName() {
|
||||||
|
let result = makeQueryResult(
|
||||||
|
columns: ["x", "y"],
|
||||||
|
rows: [["x": .real(1.5), "y": .real(2.5)]]
|
||||||
|
)
|
||||||
|
|
||||||
|
let table = DataTable(result)
|
||||||
|
|
||||||
|
#expect(table[row: 0, column: "x"] == .real(1.5))
|
||||||
|
#expect(table[row: 0, column: "y"] == .real(2.5))
|
||||||
|
#expect(table[row: 0, column: "z"] == .null) // non-existent column
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Column Data Extraction
|
||||||
|
|
||||||
|
@Test("Extract column values by index")
|
||||||
|
func columnValuesByIndex() {
|
||||||
|
let result = makeQueryResult(
|
||||||
|
columns: ["val"],
|
||||||
|
rows: [
|
||||||
|
["val": .integer(1)],
|
||||||
|
["val": .integer(2)],
|
||||||
|
["val": .integer(3)],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
let table = DataTable(result)
|
||||||
|
let values = table.columnValues(at: 0)
|
||||||
|
|
||||||
|
#expect(values == [.integer(1), .integer(2), .integer(3)])
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Extract column values by name")
|
||||||
|
func columnValuesByName() {
|
||||||
|
let result = makeQueryResult(
|
||||||
|
columns: ["name"],
|
||||||
|
rows: [
|
||||||
|
["name": .text("A")],
|
||||||
|
["name": .text("B")],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
let table = DataTable(result)
|
||||||
|
|
||||||
|
#expect(table.columnValues(named: "name") == [.text("A"), .text("B")])
|
||||||
|
#expect(table.columnValues(named: "missing").isEmpty)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("numericValues extracts doubles from numeric column")
|
||||||
|
func numericValues() {
|
||||||
|
let result = makeQueryResult(
|
||||||
|
columns: ["score"],
|
||||||
|
rows: [
|
||||||
|
["score": .integer(10)],
|
||||||
|
["score": .real(20.5)],
|
||||||
|
["score": .null],
|
||||||
|
["score": .text("not a number")],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
let table = DataTable(result)
|
||||||
|
let nums = table.numericValues(forColumn: "score")
|
||||||
|
|
||||||
|
#expect(nums.count == 2)
|
||||||
|
#expect(nums[0] == 10.0)
|
||||||
|
#expect(nums[1] == 20.5)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("stringValues extracts non-null strings")
|
||||||
|
func stringValues() {
|
||||||
|
let result = makeQueryResult(
|
||||||
|
columns: ["label"],
|
||||||
|
rows: [
|
||||||
|
["label": .text("foo")],
|
||||||
|
["label": .null],
|
||||||
|
["label": .text("bar")],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
let table = DataTable(result)
|
||||||
|
let strs = table.stringValues(forColumn: "label")
|
||||||
|
|
||||||
|
#expect(strs == ["foo", "bar"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Type Inference
|
||||||
|
|
||||||
|
@Test("Infers integer type for all-integer column")
|
||||||
|
func inferInteger() {
|
||||||
|
let result = makeQueryResult(
|
||||||
|
columns: ["id"],
|
||||||
|
rows: [["id": .integer(1)], ["id": .integer(2)]]
|
||||||
|
)
|
||||||
|
let table = DataTable(result)
|
||||||
|
#expect(table.columns[0].inferredType == .integer)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Infers real type for all-real column")
|
||||||
|
func inferReal() {
|
||||||
|
let result = makeQueryResult(
|
||||||
|
columns: ["price"],
|
||||||
|
rows: [["price": .real(1.99)], ["price": .real(2.50)]]
|
||||||
|
)
|
||||||
|
let table = DataTable(result)
|
||||||
|
#expect(table.columns[0].inferredType == .real)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Infers text type for all-text column")
|
||||||
|
func inferText() {
|
||||||
|
let result = makeQueryResult(
|
||||||
|
columns: ["name"],
|
||||||
|
rows: [["name": .text("A")], ["name": .text("B")]]
|
||||||
|
)
|
||||||
|
let table = DataTable(result)
|
||||||
|
#expect(table.columns[0].inferredType == .text)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Promotes integer + real to real")
|
||||||
|
func inferNumericPromotion() {
|
||||||
|
let result = makeQueryResult(
|
||||||
|
columns: ["val"],
|
||||||
|
rows: [["val": .integer(1)], ["val": .real(2.5)]]
|
||||||
|
)
|
||||||
|
let table = DataTable(result)
|
||||||
|
#expect(table.columns[0].inferredType == .real)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Mixed types result in .mixed")
|
||||||
|
func inferMixed() {
|
||||||
|
let result = makeQueryResult(
|
||||||
|
columns: ["data"],
|
||||||
|
rows: [["data": .integer(1)], ["data": .text("hello")]]
|
||||||
|
)
|
||||||
|
let table = DataTable(result)
|
||||||
|
#expect(table.columns[0].inferredType == .mixed)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("All-null column infers .null")
|
||||||
|
func inferNull() {
|
||||||
|
let result = makeQueryResult(
|
||||||
|
columns: ["empty"],
|
||||||
|
rows: [["empty": .null], ["empty": .null]]
|
||||||
|
)
|
||||||
|
let table = DataTable(result)
|
||||||
|
#expect(table.columns[0].inferredType == .null)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Null values are ignored during type inference")
|
||||||
|
func inferIgnoresNulls() {
|
||||||
|
let result = makeQueryResult(
|
||||||
|
columns: ["val"],
|
||||||
|
rows: [["val": .integer(1)], ["val": .null], ["val": .integer(3)]]
|
||||||
|
)
|
||||||
|
let table = DataTable(result)
|
||||||
|
#expect(table.columns[0].inferredType == .integer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Missing Values
|
||||||
|
|
||||||
|
@Test("Missing dictionary keys become .null")
|
||||||
|
func missingKeysBecomNull() {
|
||||||
|
let result = makeQueryResult(
|
||||||
|
columns: ["a", "b"],
|
||||||
|
rows: [["a": .integer(1)]] // "b" is missing
|
||||||
|
)
|
||||||
|
|
||||||
|
let table = DataTable(result)
|
||||||
|
|
||||||
|
#expect(table[row: 0, column: 0] == .integer(1))
|
||||||
|
#expect(table[row: 0, column: 1] == .null)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Row Identity
|
||||||
|
|
||||||
|
@Test("Rows have sequential IDs")
|
||||||
|
func rowIdentity() {
|
||||||
|
let result = makeQueryResult(
|
||||||
|
columns: ["x"],
|
||||||
|
rows: [["x": .integer(1)], ["x": .integer(2)], ["x": .integer(3)]]
|
||||||
|
)
|
||||||
|
|
||||||
|
let table = DataTable(result)
|
||||||
|
|
||||||
|
#expect(table.rows[0].id == 0)
|
||||||
|
#expect(table.rows[1].id == 1)
|
||||||
|
#expect(table.rows[2].id == 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Column Identity
|
||||||
|
|
||||||
|
@Test("Columns are Identifiable by name")
|
||||||
|
func columnIdentity() {
|
||||||
|
let result = makeQueryResult(
|
||||||
|
columns: ["alpha", "beta"],
|
||||||
|
rows: [["alpha": .integer(1), "beta": .integer(2)]]
|
||||||
|
)
|
||||||
|
|
||||||
|
let table = DataTable(result)
|
||||||
|
|
||||||
|
#expect(table.columns[0].id == "alpha")
|
||||||
|
#expect(table.columns[1].id == "beta")
|
||||||
|
#expect(table.columns[0].index == 0)
|
||||||
|
#expect(table.columns[1].index == 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
745
Tests/SwiftDBAITests/DestructiveOperationTests.swift
Normal file
745
Tests/SwiftDBAITests/DestructiveOperationTests.swift
Normal file
@@ -0,0 +1,745 @@
|
|||||||
|
// DestructiveOperationTests.swift
|
||||||
|
// SwiftDBAITests
|
||||||
|
//
|
||||||
|
// Tests verifying that destructive operations are blocked without confirmation
|
||||||
|
// and allowed when the delegate approves.
|
||||||
|
|
||||||
|
import AnyLanguageModel
|
||||||
|
import Foundation
|
||||||
|
import GRDB
|
||||||
|
import Testing
|
||||||
|
|
||||||
|
@testable import SwiftDBAI
|
||||||
|
|
||||||
|
// MARK: - Test Delegates
|
||||||
|
|
||||||
|
/// A delegate that always rejects destructive operations and tracks calls.
|
||||||
|
private final class RejectingTrackingDelegate: SwiftDBAI.ToolExecutionDelegate, @unchecked Sendable {
|
||||||
|
private let lock = NSLock()
|
||||||
|
private var _confirmCalls: [DestructiveOperationContext] = []
|
||||||
|
private var _willExecuteCalls: [(sql: String, classification: DestructiveClassification)] = []
|
||||||
|
private var _didExecuteCalls: [(sql: String, success: Bool)] = []
|
||||||
|
|
||||||
|
var confirmCalls: [DestructiveOperationContext] {
|
||||||
|
lock.withLock { _confirmCalls }
|
||||||
|
}
|
||||||
|
|
||||||
|
var willExecuteCalls: [(sql: String, classification: DestructiveClassification)] {
|
||||||
|
lock.withLock { _willExecuteCalls }
|
||||||
|
}
|
||||||
|
|
||||||
|
var didExecuteCalls: [(sql: String, success: Bool)] {
|
||||||
|
lock.withLock { _didExecuteCalls }
|
||||||
|
}
|
||||||
|
|
||||||
|
func confirmDestructiveOperation(_ context: DestructiveOperationContext) async -> Bool {
|
||||||
|
lock.withLock { _confirmCalls.append(context) }
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func willExecuteSQL(_ sql: String, classification: DestructiveClassification) async {
|
||||||
|
lock.withLock { _willExecuteCalls.append((sql: sql, classification: classification)) }
|
||||||
|
}
|
||||||
|
|
||||||
|
func didExecuteSQL(_ sql: String, success: Bool) async {
|
||||||
|
lock.withLock { _didExecuteCalls.append((sql: sql, success: success)) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A delegate that always approves destructive operations and tracks calls.
|
||||||
|
private final class ApprovingTrackingDelegate: SwiftDBAI.ToolExecutionDelegate, @unchecked Sendable {
|
||||||
|
private let lock = NSLock()
|
||||||
|
private var _confirmCalls: [DestructiveOperationContext] = []
|
||||||
|
private var _willExecuteCalls: [(sql: String, classification: DestructiveClassification)] = []
|
||||||
|
private var _didExecuteCalls: [(sql: String, success: Bool)] = []
|
||||||
|
|
||||||
|
var confirmCalls: [DestructiveOperationContext] {
|
||||||
|
lock.withLock { _confirmCalls }
|
||||||
|
}
|
||||||
|
|
||||||
|
var willExecuteCalls: [(sql: String, classification: DestructiveClassification)] {
|
||||||
|
lock.withLock { _willExecuteCalls }
|
||||||
|
}
|
||||||
|
|
||||||
|
var didExecuteCalls: [(sql: String, success: Bool)] {
|
||||||
|
lock.withLock { _didExecuteCalls }
|
||||||
|
}
|
||||||
|
|
||||||
|
func confirmDestructiveOperation(_ context: DestructiveOperationContext) async -> Bool {
|
||||||
|
lock.withLock { _confirmCalls.append(context) }
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func willExecuteSQL(_ sql: String, classification: DestructiveClassification) async {
|
||||||
|
lock.withLock { _willExecuteCalls.append((sql: sql, classification: classification)) }
|
||||||
|
}
|
||||||
|
|
||||||
|
func didExecuteSQL(_ sql: String, success: Bool) async {
|
||||||
|
lock.withLock { _didExecuteCalls.append((sql: sql, success: success)) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Helpers
|
||||||
|
|
||||||
|
/// Creates an in-memory database with test data for destructive operation tests.
|
||||||
|
/// Users 1 and 2 have orders; user 3 has no orders (safe to delete).
|
||||||
|
private func makeTestDatabase() throws -> DatabaseQueue {
|
||||||
|
let db = try DatabaseQueue(path: ":memory:")
|
||||||
|
try db.write { db in
|
||||||
|
// Disable FK enforcement for test flexibility, then re-enable
|
||||||
|
try db.execute(sql: "PRAGMA foreign_keys = OFF")
|
||||||
|
try db.execute(sql: """
|
||||||
|
CREATE TABLE users (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
email TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
try db.execute(sql: """
|
||||||
|
INSERT INTO users (name, email) VALUES
|
||||||
|
('Alice', 'alice@example.com'),
|
||||||
|
('Bob', 'bob@example.com'),
|
||||||
|
('Charlie', 'charlie@example.com')
|
||||||
|
""")
|
||||||
|
try db.execute(sql: """
|
||||||
|
CREATE TABLE orders (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
user_id INTEGER NOT NULL,
|
||||||
|
amount REAL NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
try db.execute(sql: """
|
||||||
|
INSERT INTO orders (user_id, amount) VALUES
|
||||||
|
(1, 99.99),
|
||||||
|
(2, 150.00),
|
||||||
|
(3, 25.50)
|
||||||
|
""")
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A sequential mock model for tests. Returns responses in order.
|
||||||
|
private struct TestSequentialModel: LanguageModel {
|
||||||
|
typealias UnavailableReason = Never
|
||||||
|
|
||||||
|
let responses: [String]
|
||||||
|
private let callCounter = CallCounter()
|
||||||
|
|
||||||
|
private final class CallCounter: @unchecked Sendable {
|
||||||
|
var count = 0
|
||||||
|
let lock = NSLock()
|
||||||
|
|
||||||
|
func next() -> Int {
|
||||||
|
lock.lock()
|
||||||
|
defer { lock.unlock() }
|
||||||
|
let c = count
|
||||||
|
count += 1
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
init(responses: [String]) {
|
||||||
|
self.responses = responses
|
||||||
|
}
|
||||||
|
|
||||||
|
func respond<Content>(
|
||||||
|
within session: LanguageModelSession,
|
||||||
|
to prompt: Prompt,
|
||||||
|
generating type: Content.Type,
|
||||||
|
includeSchemaInPrompt: Bool,
|
||||||
|
options: GenerationOptions
|
||||||
|
) async throws -> LanguageModelSession.Response<Content> where Content: Generable {
|
||||||
|
let idx = callCounter.next()
|
||||||
|
let text = idx < responses.count ? responses[idx] : "fallback response"
|
||||||
|
let rawContent = GeneratedContent(kind: .string(text))
|
||||||
|
let content = try Content(rawContent)
|
||||||
|
return LanguageModelSession.Response(
|
||||||
|
content: content,
|
||||||
|
rawContent: rawContent,
|
||||||
|
transcriptEntries: [][...]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamResponse<Content>(
|
||||||
|
within session: LanguageModelSession,
|
||||||
|
to prompt: Prompt,
|
||||||
|
generating type: Content.Type,
|
||||||
|
includeSchemaInPrompt: Bool,
|
||||||
|
options: GenerationOptions
|
||||||
|
) -> sending LanguageModelSession.ResponseStream<Content> where Content: Generable {
|
||||||
|
let idx = callCounter.next()
|
||||||
|
let text = idx < responses.count ? responses[idx] : "fallback response"
|
||||||
|
let rawContent = GeneratedContent(kind: .string(text))
|
||||||
|
let content = try! Content(rawContent)
|
||||||
|
return LanguageModelSession.ResponseStream(content: content, rawContent: rawContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Tests: Destructive Operations Blocked Without Confirmation
|
||||||
|
|
||||||
|
@Suite("Destructive Operations - Blocked Without Confirmation")
|
||||||
|
struct DestructiveOperationsBlockedTests {
|
||||||
|
|
||||||
|
@Test("DELETE is blocked when no delegate is provided")
|
||||||
|
func deleteBlockedWithoutDelegate() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let model = TestSequentialModel(responses: [
|
||||||
|
"DELETE FROM users WHERE id = 1"
|
||||||
|
])
|
||||||
|
|
||||||
|
// Unrestricted allowlist permits DELETE, but no delegate to confirm
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: db,
|
||||||
|
model: model,
|
||||||
|
allowlist: .unrestricted
|
||||||
|
)
|
||||||
|
|
||||||
|
do {
|
||||||
|
_ = try await engine.send("Delete user 1")
|
||||||
|
Issue.record("Expected confirmationRequired error but send succeeded")
|
||||||
|
} catch let error as SwiftDBAIError {
|
||||||
|
guard case .confirmationRequired(let sql, let operation) = error else {
|
||||||
|
Issue.record("Expected confirmationRequired, got: \(error)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
#expect(sql.uppercased().contains("DELETE"))
|
||||||
|
#expect(operation == "delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the user was NOT deleted (data remains intact)
|
||||||
|
let count = try await db.read { db in
|
||||||
|
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 1")
|
||||||
|
}
|
||||||
|
#expect(count == 1, "User should NOT have been deleted")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("DELETE is blocked when delegate rejects")
|
||||||
|
func deleteBlockedWhenDelegateRejects() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let delegate = RejectingTrackingDelegate()
|
||||||
|
let model = TestSequentialModel(responses: [
|
||||||
|
"DELETE FROM users WHERE id = 2"
|
||||||
|
])
|
||||||
|
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: db,
|
||||||
|
model: model,
|
||||||
|
allowlist: .unrestricted,
|
||||||
|
delegate: delegate
|
||||||
|
)
|
||||||
|
|
||||||
|
do {
|
||||||
|
_ = try await engine.send("Delete user 2")
|
||||||
|
Issue.record("Expected confirmationRequired error but send succeeded")
|
||||||
|
} catch let error as SwiftDBAIError {
|
||||||
|
guard case .confirmationRequired(let sql, let operation) = error else {
|
||||||
|
Issue.record("Expected confirmationRequired, got: \(error)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
#expect(sql.uppercased().contains("DELETE"))
|
||||||
|
#expect(operation == "delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify delegate was consulted
|
||||||
|
#expect(delegate.confirmCalls.count == 1)
|
||||||
|
#expect(delegate.confirmCalls[0].statementKind == .delete)
|
||||||
|
#expect(delegate.confirmCalls[0].sql.uppercased().contains("DELETE"))
|
||||||
|
|
||||||
|
// Verify the data was NOT modified
|
||||||
|
let count = try await db.read { db in
|
||||||
|
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 2")
|
||||||
|
}
|
||||||
|
#expect(count == 1, "User should NOT have been deleted")
|
||||||
|
|
||||||
|
// Verify no SQL was actually executed (no willExecute/didExecute calls)
|
||||||
|
#expect(delegate.willExecuteCalls.isEmpty, "No SQL should have been executed")
|
||||||
|
#expect(delegate.didExecuteCalls.isEmpty, "No SQL should have been executed")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("DELETE is blocked with MutationPolicy and no delegate")
|
||||||
|
func deleteBlockedWithMutationPolicyNoDelegate() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let model = TestSequentialModel(responses: [
|
||||||
|
"DELETE FROM users WHERE id = 3"
|
||||||
|
])
|
||||||
|
|
||||||
|
let policy = MutationPolicy(
|
||||||
|
allowedOperations: [.insert, .update, .delete],
|
||||||
|
requiresDestructiveConfirmation: true
|
||||||
|
)
|
||||||
|
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: db,
|
||||||
|
model: model,
|
||||||
|
mutationPolicy: policy
|
||||||
|
)
|
||||||
|
|
||||||
|
do {
|
||||||
|
_ = try await engine.send("Delete user 3")
|
||||||
|
Issue.record("Expected confirmationRequired error but send succeeded")
|
||||||
|
} catch let error as SwiftDBAIError {
|
||||||
|
guard case .confirmationRequired(let sql, let operation) = error else {
|
||||||
|
Issue.record("Expected confirmationRequired, got: \(error)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
#expect(sql.uppercased().contains("DELETE"))
|
||||||
|
#expect(operation == "delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Data intact
|
||||||
|
let count = try await db.read { db in
|
||||||
|
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 3")
|
||||||
|
}
|
||||||
|
#expect(count == 1, "User should NOT have been deleted")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("DELETE is blocked with MutationPolicy and rejecting delegate")
|
||||||
|
func deleteBlockedWithMutationPolicyRejectingDelegate() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let delegate = RejectingTrackingDelegate()
|
||||||
|
let model = TestSequentialModel(responses: [
|
||||||
|
"DELETE FROM orders WHERE user_id = 1"
|
||||||
|
])
|
||||||
|
|
||||||
|
let policy = MutationPolicy(
|
||||||
|
allowedOperations: [.insert, .update, .delete],
|
||||||
|
requiresDestructiveConfirmation: true
|
||||||
|
)
|
||||||
|
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: db,
|
||||||
|
model: model,
|
||||||
|
mutationPolicy: policy,
|
||||||
|
delegate: delegate
|
||||||
|
)
|
||||||
|
|
||||||
|
do {
|
||||||
|
_ = try await engine.send("Delete all orders for user 1")
|
||||||
|
Issue.record("Expected confirmationRequired error but send succeeded")
|
||||||
|
} catch let error as SwiftDBAIError {
|
||||||
|
guard case .confirmationRequired = error else {
|
||||||
|
Issue.record("Expected confirmationRequired, got: \(error)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delegate was consulted and rejected
|
||||||
|
#expect(delegate.confirmCalls.count == 1)
|
||||||
|
#expect(delegate.confirmCalls[0].statementKind == .delete)
|
||||||
|
|
||||||
|
// Orders remain
|
||||||
|
let count = try await db.read { db in
|
||||||
|
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM orders WHERE user_id = 1")
|
||||||
|
}
|
||||||
|
#expect(count == 1, "Orders should NOT have been deleted")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Default delegate implementation rejects destructive operations")
|
||||||
|
func defaultDelegateRejectsDestructive() async {
|
||||||
|
struct DefaultDelegate: SwiftDBAI.ToolExecutionDelegate {}
|
||||||
|
let delegate = DefaultDelegate()
|
||||||
|
|
||||||
|
let context = DestructiveOperationContext(
|
||||||
|
sql: "DELETE FROM users WHERE id = 1",
|
||||||
|
statementKind: .delete,
|
||||||
|
classification: .destructive(.delete),
|
||||||
|
description: "Delete from users"
|
||||||
|
)
|
||||||
|
|
||||||
|
let approved = await delegate.confirmDestructiveOperation(context)
|
||||||
|
#expect(approved == false, "Default delegate should reject destructive operations")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("DELETE not in readOnly allowlist is rejected before delegate is consulted")
|
||||||
|
func deleteNotInAllowlistRejectedEarly() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let delegate = ApprovingTrackingDelegate()
|
||||||
|
let model = TestSequentialModel(responses: [
|
||||||
|
"DELETE FROM users WHERE id = 1"
|
||||||
|
])
|
||||||
|
|
||||||
|
// Read-only allowlist does NOT include DELETE
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: db,
|
||||||
|
model: model,
|
||||||
|
allowlist: .readOnly,
|
||||||
|
delegate: delegate
|
||||||
|
)
|
||||||
|
|
||||||
|
do {
|
||||||
|
_ = try await engine.send("Delete user 1")
|
||||||
|
Issue.record("Expected operationNotAllowed error")
|
||||||
|
} catch let error as SwiftDBAIError {
|
||||||
|
guard case .operationNotAllowed(let operation) = error else {
|
||||||
|
Issue.record("Expected operationNotAllowed, got: \(error)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
#expect(operation == "delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delegate should NOT have been consulted — the allowlist rejects before delegation
|
||||||
|
#expect(delegate.confirmCalls.isEmpty, "Delegate should not be consulted when op is not in allowlist")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Tests: Destructive Operations Allowed When Delegate Approves
|
||||||
|
|
||||||
|
@Suite("Destructive Operations - Allowed When Delegate Approves")
|
||||||
|
struct DestructiveOperationsAllowedTests {
|
||||||
|
|
||||||
|
@Test("DELETE succeeds when delegate approves")
|
||||||
|
func deleteSucceedsWithApprovingDelegate() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let delegate = ApprovingTrackingDelegate()
|
||||||
|
let model = TestSequentialModel(responses: [
|
||||||
|
"DELETE FROM users WHERE id = 1",
|
||||||
|
"Successfully deleted 1 user."
|
||||||
|
])
|
||||||
|
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: db,
|
||||||
|
model: model,
|
||||||
|
allowlist: .unrestricted,
|
||||||
|
delegate: delegate
|
||||||
|
)
|
||||||
|
|
||||||
|
let response = try await engine.send("Delete user 1")
|
||||||
|
|
||||||
|
// Delegate was consulted and approved
|
||||||
|
#expect(delegate.confirmCalls.count == 1)
|
||||||
|
#expect(delegate.confirmCalls[0].statementKind == .delete)
|
||||||
|
#expect(delegate.confirmCalls[0].sql.uppercased().contains("DELETE"))
|
||||||
|
#expect(delegate.confirmCalls[0].targetTable == "users")
|
||||||
|
|
||||||
|
// SQL was executed
|
||||||
|
#expect(delegate.willExecuteCalls.count == 1)
|
||||||
|
#expect(delegate.didExecuteCalls.count == 1)
|
||||||
|
#expect(delegate.didExecuteCalls[0].success == true)
|
||||||
|
|
||||||
|
// Verify the data was actually deleted
|
||||||
|
let count = try await db.read { db in
|
||||||
|
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 1")
|
||||||
|
}
|
||||||
|
#expect(count == 0, "User should have been deleted")
|
||||||
|
|
||||||
|
// Response should contain meaningful content
|
||||||
|
#expect(response.sql?.uppercased().contains("DELETE") == true)
|
||||||
|
#expect(response.queryResult != nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("DELETE with MutationPolicy succeeds when delegate approves")
|
||||||
|
func deleteWithPolicySucceedsWhenApproved() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let delegate = ApprovingTrackingDelegate()
|
||||||
|
let model = TestSequentialModel(responses: [
|
||||||
|
"DELETE FROM orders WHERE user_id = 2",
|
||||||
|
"Deleted 1 order."
|
||||||
|
])
|
||||||
|
|
||||||
|
let policy = MutationPolicy(
|
||||||
|
allowedOperations: [.insert, .update, .delete],
|
||||||
|
requiresDestructiveConfirmation: true
|
||||||
|
)
|
||||||
|
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: db,
|
||||||
|
model: model,
|
||||||
|
mutationPolicy: policy,
|
||||||
|
delegate: delegate
|
||||||
|
)
|
||||||
|
|
||||||
|
let response = try await engine.send("Delete all orders for user 2")
|
||||||
|
|
||||||
|
// Delegate approved
|
||||||
|
#expect(delegate.confirmCalls.count == 1)
|
||||||
|
|
||||||
|
// Data was actually deleted
|
||||||
|
let count = try await db.read { db in
|
||||||
|
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM orders WHERE user_id = 2")
|
||||||
|
}
|
||||||
|
#expect(count == 0, "Orders should have been deleted")
|
||||||
|
#expect(response.sql?.uppercased().contains("DELETE") == true)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("AutoApproveDelegate allows DELETE without user interaction")
|
||||||
|
func autoApproveDelegateAllowsDelete() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let delegate = AutoApproveDelegate()
|
||||||
|
let model = TestSequentialModel(responses: [
|
||||||
|
"DELETE FROM users WHERE id = 3",
|
||||||
|
"Deleted 1 user."
|
||||||
|
])
|
||||||
|
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: db,
|
||||||
|
model: model,
|
||||||
|
allowlist: .unrestricted,
|
||||||
|
delegate: delegate
|
||||||
|
)
|
||||||
|
|
||||||
|
let response = try await engine.send("Delete user 3")
|
||||||
|
|
||||||
|
// Should succeed without error
|
||||||
|
#expect(response.sql?.uppercased().contains("DELETE") == true)
|
||||||
|
|
||||||
|
let count = try await db.read { db in
|
||||||
|
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 3")
|
||||||
|
}
|
||||||
|
#expect(count == 0, "User should have been deleted")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("sendConfirmed bypasses delegate and executes directly")
|
||||||
|
func sendConfirmedBypassesDelegate() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let delegate = RejectingTrackingDelegate()
|
||||||
|
let model = TestSequentialModel(responses: [
|
||||||
|
"Deleted 1 user."
|
||||||
|
])
|
||||||
|
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: db,
|
||||||
|
model: model,
|
||||||
|
allowlist: .unrestricted,
|
||||||
|
delegate: delegate
|
||||||
|
)
|
||||||
|
|
||||||
|
// sendConfirmed should execute directly without consulting the delegate for confirmation
|
||||||
|
let response = try await engine.sendConfirmed(
|
||||||
|
"Delete user 1",
|
||||||
|
confirmedSQL: "DELETE FROM users WHERE id = 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Delegate was NOT asked to confirm (sendConfirmed skips confirmation)
|
||||||
|
#expect(delegate.confirmCalls.isEmpty)
|
||||||
|
|
||||||
|
// But willExecute/didExecute hooks were still called
|
||||||
|
#expect(delegate.willExecuteCalls.count == 1)
|
||||||
|
#expect(delegate.didExecuteCalls.count == 1)
|
||||||
|
#expect(delegate.didExecuteCalls[0].success == true)
|
||||||
|
|
||||||
|
// Data was deleted
|
||||||
|
let count = try await db.read { db in
|
||||||
|
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 1")
|
||||||
|
}
|
||||||
|
#expect(count == 0)
|
||||||
|
#expect(response.summary.contains("deleted") || response.summary.contains("Deleted") || response.summary.contains("1"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Tests: Delegate Context Correctness
|
||||||
|
|
||||||
|
@Suite("Destructive Operations - Delegate Context")
|
||||||
|
struct DestructiveOperationContextTests {
|
||||||
|
|
||||||
|
@Test("Delegate receives correct context for DELETE on specific table")
|
||||||
|
func delegateReceivesCorrectContext() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let delegate = RejectingTrackingDelegate()
|
||||||
|
let model = TestSequentialModel(responses: [
|
||||||
|
"DELETE FROM orders WHERE amount < 50"
|
||||||
|
])
|
||||||
|
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: db,
|
||||||
|
model: model,
|
||||||
|
allowlist: .unrestricted,
|
||||||
|
delegate: delegate
|
||||||
|
)
|
||||||
|
|
||||||
|
do {
|
||||||
|
_ = try await engine.send("Delete cheap orders")
|
||||||
|
Issue.record("Expected confirmationRequired error")
|
||||||
|
} catch is SwiftDBAIError {
|
||||||
|
// Expected
|
||||||
|
}
|
||||||
|
|
||||||
|
#expect(delegate.confirmCalls.count == 1)
|
||||||
|
let ctx = delegate.confirmCalls[0]
|
||||||
|
#expect(ctx.statementKind == .delete)
|
||||||
|
#expect(ctx.classification == .destructive(.delete))
|
||||||
|
#expect(ctx.classification.requiresConfirmation == true)
|
||||||
|
#expect(ctx.sql.uppercased().contains("DELETE FROM ORDERS"))
|
||||||
|
#expect(ctx.targetTable == "orders")
|
||||||
|
#expect(!ctx.description.isEmpty)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Non-destructive operations do not consult delegate")
|
||||||
|
func selectDoesNotConsultDelegate() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let delegate = ApprovingTrackingDelegate()
|
||||||
|
let model = TestSequentialModel(responses: [
|
||||||
|
"SELECT COUNT(*) FROM users",
|
||||||
|
"There are 3 users."
|
||||||
|
])
|
||||||
|
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: db,
|
||||||
|
model: model,
|
||||||
|
allowlist: .unrestricted,
|
||||||
|
delegate: delegate
|
||||||
|
)
|
||||||
|
|
||||||
|
_ = try await engine.send("How many users?")
|
||||||
|
|
||||||
|
// Delegate should NOT have been asked to confirm (SELECT is not destructive)
|
||||||
|
#expect(delegate.confirmCalls.isEmpty)
|
||||||
|
|
||||||
|
// But willExecute/didExecute should still be called (observation hooks)
|
||||||
|
#expect(delegate.willExecuteCalls.count == 1)
|
||||||
|
#expect(delegate.didExecuteCalls.count == 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("INSERT does not require confirmation even with delegate")
|
||||||
|
func insertDoesNotRequireConfirmation() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let delegate = RejectingTrackingDelegate()
|
||||||
|
let model = TestSequentialModel(responses: [
|
||||||
|
"INSERT INTO users (name, email) VALUES ('Dave', 'dave@example.com')",
|
||||||
|
"Inserted 1 row."
|
||||||
|
])
|
||||||
|
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: db,
|
||||||
|
model: model,
|
||||||
|
allowlist: .standard,
|
||||||
|
delegate: delegate
|
||||||
|
)
|
||||||
|
|
||||||
|
let response = try await engine.send("Add user Dave")
|
||||||
|
|
||||||
|
// No confirmation needed for INSERT
|
||||||
|
#expect(delegate.confirmCalls.isEmpty)
|
||||||
|
#expect(response.sql?.uppercased().contains("INSERT") == true)
|
||||||
|
|
||||||
|
// Verify the insert happened
|
||||||
|
let count = try await db.read { db in
|
||||||
|
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE name = 'Dave'")
|
||||||
|
}
|
||||||
|
#expect(count == 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("UPDATE does not require confirmation even with delegate")
|
||||||
|
func updateDoesNotRequireConfirmation() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let delegate = RejectingTrackingDelegate()
|
||||||
|
let model = TestSequentialModel(responses: [
|
||||||
|
"UPDATE users SET email = 'alice-new@example.com' WHERE id = 1",
|
||||||
|
"Updated 1 row."
|
||||||
|
])
|
||||||
|
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: db,
|
||||||
|
model: model,
|
||||||
|
allowlist: .standard,
|
||||||
|
delegate: delegate
|
||||||
|
)
|
||||||
|
|
||||||
|
let response = try await engine.send("Update Alice's email")
|
||||||
|
|
||||||
|
// No confirmation needed for UPDATE
|
||||||
|
#expect(delegate.confirmCalls.isEmpty)
|
||||||
|
#expect(response.sql?.uppercased().contains("UPDATE") == true)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Tests: MutationPolicy Confirmation Flag
|
||||||
|
|
||||||
|
@Suite("Destructive Operations - MutationPolicy Confirmation Control")
|
||||||
|
struct MutationPolicyConfirmationTests {
|
||||||
|
|
||||||
|
@Test("DELETE skips confirmation when requiresDestructiveConfirmation is false")
|
||||||
|
func deleteSkipsConfirmationWhenDisabled() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let delegate = RejectingTrackingDelegate()
|
||||||
|
let model = TestSequentialModel(responses: [
|
||||||
|
"DELETE FROM users WHERE id = 1",
|
||||||
|
"Deleted 1 user."
|
||||||
|
])
|
||||||
|
|
||||||
|
let policy = MutationPolicy(
|
||||||
|
allowedOperations: [.insert, .update, .delete],
|
||||||
|
requiresDestructiveConfirmation: false // Explicitly disabled
|
||||||
|
)
|
||||||
|
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: db,
|
||||||
|
model: model,
|
||||||
|
mutationPolicy: policy,
|
||||||
|
delegate: delegate
|
||||||
|
)
|
||||||
|
|
||||||
|
// Should succeed without confirmation since the policy disables it
|
||||||
|
let response = try await engine.send("Delete user 1")
|
||||||
|
|
||||||
|
// Delegate should NOT have been consulted for confirmation
|
||||||
|
#expect(delegate.confirmCalls.isEmpty)
|
||||||
|
|
||||||
|
// But the SQL should have executed
|
||||||
|
#expect(response.sql?.uppercased().contains("DELETE") == true)
|
||||||
|
|
||||||
|
let count = try await db.read { db in
|
||||||
|
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 1")
|
||||||
|
}
|
||||||
|
#expect(count == 0, "User should have been deleted without confirmation")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("MutationPolicy.requiresConfirmation only triggers for DELETE")
|
||||||
|
func requiresConfirmationOnlyForDelete() {
|
||||||
|
let policy = MutationPolicy(
|
||||||
|
allowedOperations: [.insert, .update, .delete],
|
||||||
|
requiresDestructiveConfirmation: true
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(policy.requiresConfirmation(for: .delete) == true)
|
||||||
|
#expect(policy.requiresConfirmation(for: .select) == false)
|
||||||
|
#expect(policy.requiresConfirmation(for: .insert) == false)
|
||||||
|
#expect(policy.requiresConfirmation(for: .update) == false)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("MutationPolicy.readOnly never requires confirmation (no delete allowed)")
|
||||||
|
func readOnlyNeverRequiresConfirmation() {
|
||||||
|
let policy = MutationPolicy.readOnly
|
||||||
|
|
||||||
|
#expect(policy.requiresConfirmation(for: .select) == false)
|
||||||
|
#expect(policy.requiresConfirmation(for: .delete) == true) // Would require confirmation IF allowed
|
||||||
|
#expect(policy.isOperationAllowed(.delete) == false) // But it's not allowed at all
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Table-restricted DELETE is blocked for disallowed tables")
|
||||||
|
func tableRestrictedDeleteBlocked() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let model = TestSequentialModel(responses: [
|
||||||
|
"DELETE FROM users WHERE id = 1"
|
||||||
|
])
|
||||||
|
|
||||||
|
let policy = MutationPolicy(
|
||||||
|
allowedOperations: [.insert, .update, .delete],
|
||||||
|
allowedTables: ["orders"], // Only orders, NOT users
|
||||||
|
requiresDestructiveConfirmation: true
|
||||||
|
)
|
||||||
|
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: db,
|
||||||
|
model: model,
|
||||||
|
mutationPolicy: policy
|
||||||
|
)
|
||||||
|
|
||||||
|
do {
|
||||||
|
_ = try await engine.send("Delete user 1")
|
||||||
|
Issue.record("Expected tableNotAllowedForMutation error")
|
||||||
|
} catch let error as SwiftDBAIError {
|
||||||
|
guard case .tableNotAllowedForMutation(let tableName, let operation) = error else {
|
||||||
|
Issue.record("Expected tableNotAllowedForMutation, got: \(error)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
#expect(tableName == "users")
|
||||||
|
#expect(operation == "delete")
|
||||||
|
}
|
||||||
|
|
||||||
|
// User was not deleted
|
||||||
|
let count = try await db.read { db in
|
||||||
|
try Int.fetchOne(db, sql: "SELECT COUNT(*) FROM users WHERE id = 1")
|
||||||
|
}
|
||||||
|
#expect(count == 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
49
Tests/SwiftDBAITests/Helpers/MockLanguageModel.swift
Normal file
49
Tests/SwiftDBAITests/Helpers/MockLanguageModel.swift
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
// MockLanguageModel.swift
|
||||||
|
// SwiftDBAI Tests
|
||||||
|
//
|
||||||
|
// A mock LanguageModel for unit tests that returns canned responses.
|
||||||
|
|
||||||
|
import AnyLanguageModel
|
||||||
|
import Foundation
|
||||||
|
|
||||||
|
/// A mock language model that returns a configurable canned response.
|
||||||
|
///
|
||||||
|
/// Used in tests to avoid hitting a real LLM provider.
|
||||||
|
struct MockLanguageModel: LanguageModel {
|
||||||
|
typealias UnavailableReason = Never
|
||||||
|
|
||||||
|
/// The text the mock will return from `respond(...)`.
|
||||||
|
let responseText: String
|
||||||
|
|
||||||
|
init(responseText: String = "Mock summary response.") {
|
||||||
|
self.responseText = responseText
|
||||||
|
}
|
||||||
|
|
||||||
|
func respond<Content>(
|
||||||
|
within session: LanguageModelSession,
|
||||||
|
to prompt: Prompt,
|
||||||
|
generating type: Content.Type,
|
||||||
|
includeSchemaInPrompt: Bool,
|
||||||
|
options: GenerationOptions
|
||||||
|
) async throws -> LanguageModelSession.Response<Content> where Content: Generable {
|
||||||
|
let rawContent = GeneratedContent(kind: .string(responseText))
|
||||||
|
let content = try Content(rawContent)
|
||||||
|
return LanguageModelSession.Response(
|
||||||
|
content: content,
|
||||||
|
rawContent: rawContent,
|
||||||
|
transcriptEntries: [][...]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamResponse<Content>(
|
||||||
|
within session: LanguageModelSession,
|
||||||
|
to prompt: Prompt,
|
||||||
|
generating type: Content.Type,
|
||||||
|
includeSchemaInPrompt: Bool,
|
||||||
|
options: GenerationOptions
|
||||||
|
) -> sending LanguageModelSession.ResponseStream<Content> where Content: Generable {
|
||||||
|
let rawContent = GeneratedContent(kind: .string(responseText))
|
||||||
|
let content = try! Content(rawContent)
|
||||||
|
return LanguageModelSession.ResponseStream(content: content, rawContent: rawContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
337
Tests/SwiftDBAITests/LocalProviderConfigurationTests.swift
Normal file
337
Tests/SwiftDBAITests/LocalProviderConfigurationTests.swift
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
// LocalProviderConfigurationTests.swift
|
||||||
|
// SwiftDBAI Tests
|
||||||
|
//
|
||||||
|
// Tests for local/self-hosted provider configurations (Ollama, llama.cpp):
|
||||||
|
// factory methods, endpoint discovery, connection handling, and model creation.
|
||||||
|
|
||||||
|
import AnyLanguageModel
|
||||||
|
import Foundation
|
||||||
|
import GRDB
|
||||||
|
@testable import SwiftDBAI
|
||||||
|
import Testing
|
||||||
|
|
||||||
|
@Suite("Local Provider Configuration")
|
||||||
|
struct LocalProviderConfigurationTests {
|
||||||
|
|
||||||
|
// MARK: - Ollama Configuration
|
||||||
|
|
||||||
|
@Test("Ollama configuration stores provider and model")
|
||||||
|
func ollamaBasicConfiguration() {
|
||||||
|
let config = ProviderConfiguration.ollama(model: "llama3.2")
|
||||||
|
|
||||||
|
#expect(config.provider == .ollama)
|
||||||
|
#expect(config.model == "llama3.2")
|
||||||
|
#expect(config.baseURL == OllamaLanguageModel.defaultBaseURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Ollama configuration produces OllamaLanguageModel")
|
||||||
|
func ollamaMakeModel() {
|
||||||
|
let config = ProviderConfiguration.ollama(model: "qwen2.5")
|
||||||
|
|
||||||
|
let model = config.makeModel()
|
||||||
|
#expect(model is OllamaLanguageModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Ollama with custom base URL for remote instance")
|
||||||
|
func ollamaCustomBaseURL() {
|
||||||
|
let remoteURL = URL(string: "http://192.168.1.100:11434")!
|
||||||
|
let config = ProviderConfiguration.ollama(
|
||||||
|
model: "mistral",
|
||||||
|
baseURL: remoteURL
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(config.baseURL == remoteURL)
|
||||||
|
#expect(config.provider == .ollama)
|
||||||
|
let model = config.makeModel()
|
||||||
|
#expect(model is OllamaLanguageModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Ollama does not require an API key")
|
||||||
|
func ollamaNoAPIKey() {
|
||||||
|
let config = ProviderConfiguration.ollama(model: "llama3.2")
|
||||||
|
|
||||||
|
// Ollama doesn't need an API key, so the key is empty
|
||||||
|
#expect(config.apiKey == "")
|
||||||
|
// hasValidAPIKey returns false because key is empty, but that's expected
|
||||||
|
// for local providers — they don't need authentication
|
||||||
|
#expect(!config.hasValidAPIKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Ollama model is available without API key")
|
||||||
|
func ollamaModelAvailable() {
|
||||||
|
let config = ProviderConfiguration.ollama(model: "llama3.2")
|
||||||
|
let model = config.makeModel()
|
||||||
|
#expect(model.isAvailable)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - llama.cpp Configuration
|
||||||
|
|
||||||
|
@Test("llama.cpp configuration stores provider and model")
|
||||||
|
func llamaCppBasicConfiguration() {
|
||||||
|
let config = ProviderConfiguration.llamaCpp(model: "my-model")
|
||||||
|
|
||||||
|
#expect(config.provider == .llamaCpp)
|
||||||
|
#expect(config.model == "my-model")
|
||||||
|
#expect(config.baseURL == LocalProviderDiscovery.defaultLlamaCppURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("llama.cpp uses 'default' model name by default")
|
||||||
|
func llamaCppDefaultModel() {
|
||||||
|
let config = ProviderConfiguration.llamaCpp()
|
||||||
|
|
||||||
|
#expect(config.model == "default")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("llama.cpp configuration produces OpenAILanguageModel (compatible API)")
|
||||||
|
func llamaCppMakeModel() {
|
||||||
|
let config = ProviderConfiguration.llamaCpp(model: "my-gguf")
|
||||||
|
|
||||||
|
let model = config.makeModel()
|
||||||
|
// llama.cpp uses OpenAI-compatible API
|
||||||
|
#expect(model is OpenAILanguageModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("llama.cpp with custom base URL")
|
||||||
|
func llamaCppCustomBaseURL() {
|
||||||
|
let customURL = URL(string: "http://localhost:9090")!
|
||||||
|
let config = ProviderConfiguration.llamaCpp(
|
||||||
|
model: "custom-model",
|
||||||
|
baseURL: customURL
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(config.baseURL == customURL)
|
||||||
|
let model = config.makeModel()
|
||||||
|
#expect(model is OpenAILanguageModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("llama.cpp with API key authentication")
|
||||||
|
func llamaCppWithAPIKey() {
|
||||||
|
let config = ProviderConfiguration.llamaCpp(
|
||||||
|
model: "secured-model",
|
||||||
|
apiKey: "my-secret-key"
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(config.apiKey == "my-secret-key")
|
||||||
|
#expect(config.hasValidAPIKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("llama.cpp without API key")
|
||||||
|
func llamaCppNoAPIKey() {
|
||||||
|
let config = ProviderConfiguration.llamaCpp(model: "open-model")
|
||||||
|
|
||||||
|
#expect(config.apiKey == "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Provider Enum
|
||||||
|
|
||||||
|
@Test("Provider enum includes ollama and llamaCpp cases")
|
||||||
|
func providerEnumHasLocalCases() {
|
||||||
|
let cases = ProviderConfiguration.Provider.allCases
|
||||||
|
#expect(cases.contains(.ollama))
|
||||||
|
#expect(cases.contains(.llamaCpp))
|
||||||
|
// Total: openAI, anthropic, gemini, openAICompatible, ollama, llamaCpp
|
||||||
|
#expect(cases.count == 6)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - fromEnvironment
|
||||||
|
|
||||||
|
@Test("fromEnvironment creates Ollama configuration")
|
||||||
|
func fromEnvironmentOllama() {
|
||||||
|
let config = ProviderConfiguration.fromEnvironment(
|
||||||
|
provider: .ollama,
|
||||||
|
environmentVariable: "NONEXISTENT_OLLAMA_KEY",
|
||||||
|
model: "llama3.2"
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(config.provider == .ollama)
|
||||||
|
#expect(config.model == "llama3.2")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("fromEnvironment creates llama.cpp configuration")
|
||||||
|
func fromEnvironmentLlamaCpp() {
|
||||||
|
let config = ProviderConfiguration.fromEnvironment(
|
||||||
|
provider: .llamaCpp,
|
||||||
|
environmentVariable: "NONEXISTENT_LLAMACPP_KEY",
|
||||||
|
model: "default"
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(config.provider == .llamaCpp)
|
||||||
|
#expect(config.model == "default")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - ChatEngine Convenience Init with Local Providers
|
||||||
|
|
||||||
|
@Test("ChatEngine can be created with Ollama provider")
|
||||||
|
func chatEngineWithOllama() throws {
|
||||||
|
let dbQueue = try GRDB.DatabaseQueue()
|
||||||
|
let config = ProviderConfiguration.ollama(model: "llama3.2")
|
||||||
|
let engine = ChatEngine(database: dbQueue, provider: config)
|
||||||
|
|
||||||
|
#expect(engine.tableCount == nil) // schema not yet introspected
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("ChatEngine can be created with llama.cpp provider")
|
||||||
|
func chatEngineWithLlamaCpp() throws {
|
||||||
|
let dbQueue = try GRDB.DatabaseQueue()
|
||||||
|
let config = ProviderConfiguration.llamaCpp()
|
||||||
|
let engine = ChatEngine(database: dbQueue, provider: config)
|
||||||
|
|
||||||
|
#expect(engine.tableCount == nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - LocalProviderType
|
||||||
|
|
||||||
|
@Test("LocalProviderType has expected raw values")
|
||||||
|
func localProviderTypeRawValues() {
|
||||||
|
#expect(LocalProviderType.ollama.rawValue == "ollama")
|
||||||
|
#expect(LocalProviderType.llamaCpp.rawValue == "llama.cpp")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("LocalProviderType CaseIterable includes both cases")
|
||||||
|
func localProviderTypeCases() {
|
||||||
|
let cases = LocalProviderType.allCases
|
||||||
|
#expect(cases.count == 2)
|
||||||
|
#expect(cases.contains(.ollama))
|
||||||
|
#expect(cases.contains(.llamaCpp))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - LocalProviderEndpoint
|
||||||
|
|
||||||
|
@Test("LocalProviderEndpoint description includes status and model count")
|
||||||
|
func endpointDescription() {
|
||||||
|
let endpoint = LocalProviderEndpoint(
|
||||||
|
baseURL: URL(string: "http://localhost:11434")!,
|
||||||
|
providerType: .ollama,
|
||||||
|
isReachable: true,
|
||||||
|
availableModels: ["llama3.2", "qwen2.5"]
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(endpoint.description.contains("ollama"))
|
||||||
|
#expect(endpoint.description.contains("reachable"))
|
||||||
|
#expect(endpoint.description.contains("2 models"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("LocalProviderEndpoint shows unreachable when not connected")
|
||||||
|
func endpointUnreachableDescription() {
|
||||||
|
let endpoint = LocalProviderEndpoint(
|
||||||
|
baseURL: URL(string: "http://localhost:8080")!,
|
||||||
|
providerType: .llamaCpp,
|
||||||
|
isReachable: false,
|
||||||
|
availableModels: []
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(endpoint.description.contains("unreachable"))
|
||||||
|
#expect(endpoint.description.contains("0 models"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("LocalProviderEndpoint equality works correctly")
|
||||||
|
func endpointEquality() {
|
||||||
|
let a = LocalProviderEndpoint(
|
||||||
|
baseURL: URL(string: "http://localhost:11434")!,
|
||||||
|
providerType: .ollama,
|
||||||
|
isReachable: true,
|
||||||
|
availableModels: ["llama3.2"]
|
||||||
|
)
|
||||||
|
let b = LocalProviderEndpoint(
|
||||||
|
baseURL: URL(string: "http://localhost:11434")!,
|
||||||
|
providerType: .ollama,
|
||||||
|
isReachable: true,
|
||||||
|
availableModels: ["llama3.2"]
|
||||||
|
)
|
||||||
|
let c = LocalProviderEndpoint(
|
||||||
|
baseURL: URL(string: "http://localhost:11434")!,
|
||||||
|
providerType: .ollama,
|
||||||
|
isReachable: false,
|
||||||
|
availableModels: []
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(a == b)
|
||||||
|
#expect(a != c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Discovery (No Local Server Running)
|
||||||
|
|
||||||
|
@Test("Discovery returns unreachable when no server is running")
|
||||||
|
func discoveryUnreachableEndpoint() async {
|
||||||
|
// Use a port that's almost certainly not running anything
|
||||||
|
let endpoint = await LocalProviderDiscovery.discover(
|
||||||
|
providerType: .ollama,
|
||||||
|
host: "127.0.0.1",
|
||||||
|
port: 59999,
|
||||||
|
timeout: 1
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(!endpoint.isReachable)
|
||||||
|
#expect(endpoint.availableModels.isEmpty)
|
||||||
|
#expect(endpoint.providerType == .ollama)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("isOllamaRunning returns false for unreachable endpoint")
|
||||||
|
func ollamaNotRunning() async {
|
||||||
|
let unreachableURL = URL(string: "http://127.0.0.1:59998")!
|
||||||
|
let running = await LocalProviderDiscovery.isOllamaRunning(
|
||||||
|
at: unreachableURL,
|
||||||
|
timeout: 1
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(!running)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("isLlamaCppRunning returns false for unreachable endpoint")
|
||||||
|
func llamaCppNotRunning() async {
|
||||||
|
let unreachableURL = URL(string: "http://127.0.0.1:59997")!
|
||||||
|
let running = await LocalProviderDiscovery.isLlamaCppRunning(
|
||||||
|
at: unreachableURL,
|
||||||
|
timeout: 1
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(!running)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("listOllamaModels returns empty for unreachable endpoint")
|
||||||
|
func ollamaModelsUnreachable() async {
|
||||||
|
let unreachableURL = URL(string: "http://127.0.0.1:59996")!
|
||||||
|
let models = await LocalProviderDiscovery.listOllamaModels(
|
||||||
|
at: unreachableURL,
|
||||||
|
timeout: 1
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(models.isEmpty)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("listLlamaCppModels returns empty for unreachable endpoint")
|
||||||
|
func llamaCppModelsUnreachable() async {
|
||||||
|
let unreachableURL = URL(string: "http://127.0.0.1:59995")!
|
||||||
|
let models = await LocalProviderDiscovery.listLlamaCppModels(
|
||||||
|
at: unreachableURL,
|
||||||
|
timeout: 1
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(models.isEmpty)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("discoverAll returns endpoints for both provider types")
|
||||||
|
func discoverAllReturnsAllProviders() async {
|
||||||
|
// Use very short timeout since we likely don't have servers running
|
||||||
|
let endpoints = await LocalProviderDiscovery.discoverAll(timeout: 0.5)
|
||||||
|
|
||||||
|
// Should return exactly 2 endpoints (one per well-known provider)
|
||||||
|
#expect(endpoints.count == 2)
|
||||||
|
|
||||||
|
let types = Set(endpoints.map(\.providerType))
|
||||||
|
#expect(types.contains(.ollama))
|
||||||
|
#expect(types.contains(.llamaCpp))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Default URLs
|
||||||
|
|
||||||
|
@Test("Default Ollama URL is correct")
|
||||||
|
func defaultOllamaURL() {
|
||||||
|
#expect(LocalProviderDiscovery.defaultOllamaURL.absoluteString == "http://localhost:11434")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Default llama.cpp URL is correct")
|
||||||
|
func defaultLlamaCppURL() {
|
||||||
|
#expect(LocalProviderDiscovery.defaultLlamaCppURL.absoluteString == "http://localhost:8080")
|
||||||
|
}
|
||||||
|
}
|
||||||
363
Tests/SwiftDBAITests/MultiTurnContextTests.swift
Normal file
363
Tests/SwiftDBAITests/MultiTurnContextTests.swift
Normal file
@@ -0,0 +1,363 @@
|
|||||||
|
// MultiTurnContextTests.swift
|
||||||
|
// SwiftDBAI Tests
|
||||||
|
//
|
||||||
|
// Tests verifying multi-turn conversation context — follow-up queries
|
||||||
|
// correctly reference the prior query's table, columns, and results.
|
||||||
|
|
||||||
|
import AnyLanguageModel
|
||||||
|
import Foundation
|
||||||
|
import GRDB
|
||||||
|
import Testing
|
||||||
|
|
||||||
|
@testable import SwiftDBAI
|
||||||
|
|
||||||
|
@Suite("Multi-Turn Context Tests")
|
||||||
|
struct MultiTurnContextTests {
|
||||||
|
|
||||||
|
// MARK: - Test Database Setup
|
||||||
|
|
||||||
|
/// Creates an in-memory database with users (including age) and orders.
|
||||||
|
private func makeTestDatabase() throws -> DatabaseQueue {
|
||||||
|
let db = try DatabaseQueue(path: ":memory:")
|
||||||
|
try db.write { db in
|
||||||
|
try db.execute(sql: """
|
||||||
|
CREATE TABLE users (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
age INTEGER NOT NULL,
|
||||||
|
email TEXT NOT NULL,
|
||||||
|
city TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
try db.execute(sql: """
|
||||||
|
INSERT INTO users (name, age, email, city) VALUES
|
||||||
|
('Alice', 25, 'alice@example.com', 'New York'),
|
||||||
|
('Bob', 35, 'bob@example.com', 'San Francisco'),
|
||||||
|
('Charlie', 42, 'charlie@example.com', 'New York'),
|
||||||
|
('Diana', 28, 'diana@example.com', 'Chicago'),
|
||||||
|
('Eve', 55, 'eve@example.com', 'San Francisco')
|
||||||
|
""")
|
||||||
|
try db.execute(sql: """
|
||||||
|
CREATE TABLE orders (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
user_id INTEGER NOT NULL,
|
||||||
|
amount REAL NOT NULL,
|
||||||
|
status TEXT NOT NULL,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
try db.execute(sql: """
|
||||||
|
INSERT INTO orders (user_id, amount, status, created_at) VALUES
|
||||||
|
(1, 99.99, 'completed', '2024-01-15'),
|
||||||
|
(1, 49.50, 'pending', '2024-02-20'),
|
||||||
|
(2, 150.00, 'completed', '2024-01-10'),
|
||||||
|
(3, 200.00, 'completed', '2024-03-01'),
|
||||||
|
(4, 75.00, 'cancelled', '2024-02-05')
|
||||||
|
""")
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Multi-Turn Context Tests
|
||||||
|
|
||||||
|
@Test("Follow-up 'filter those by age > 30' references prior 'show all users' context")
|
||||||
|
func followUpFilterReferencesUsersTable() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
|
||||||
|
// Turn 1: "show all users" → SELECT * FROM users (returns 5 rows, LLM summary needed)
|
||||||
|
// Turn 2: "filter those by age > 30" → should reference users table from context
|
||||||
|
let mock = PromptCapturingMockModel(responses: [
|
||||||
|
"SELECT * FROM users",
|
||||||
|
"Here are all 5 users in the database.",
|
||||||
|
"SELECT * FROM users WHERE age > 30",
|
||||||
|
"Found 3 users over 30: Bob (35), Charlie (42), and Eve (55)."
|
||||||
|
])
|
||||||
|
|
||||||
|
let engine = ChatEngine(database: db, model: mock)
|
||||||
|
|
||||||
|
// First turn: show all users
|
||||||
|
let response1 = try await engine.send("show all users")
|
||||||
|
#expect(response1.sql == "SELECT * FROM users")
|
||||||
|
#expect(response1.queryResult?.rowCount == 5)
|
||||||
|
|
||||||
|
// Second turn: follow-up with implicit reference
|
||||||
|
let response2 = try await engine.send("filter those by age > 30")
|
||||||
|
#expect(response2.sql == "SELECT * FROM users WHERE age > 30")
|
||||||
|
#expect(response2.queryResult?.rowCount == 3)
|
||||||
|
|
||||||
|
// Verify the follow-up prompt includes conversation history
|
||||||
|
let prompts = mock.capturedPrompts
|
||||||
|
// Find the prompt for the second SQL generation (skip summary prompts)
|
||||||
|
let followUpSQLPrompt = prompts.first { prompt in
|
||||||
|
prompt.contains("filter those by age > 30") && prompt.contains("CONVERSATION HISTORY")
|
||||||
|
}
|
||||||
|
#expect(followUpSQLPrompt != nil, "Follow-up prompt should contain CONVERSATION HISTORY")
|
||||||
|
|
||||||
|
// The conversation history should include the prior query and its SQL
|
||||||
|
if let prompt = followUpSQLPrompt {
|
||||||
|
#expect(prompt.contains("show all users"), "History should contain prior user message")
|
||||||
|
#expect(prompt.contains("SELECT * FROM users"), "History should contain prior SQL")
|
||||||
|
#expect(prompt.contains("filter those by age > 30"), "Prompt should contain current question")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Follow-up correctly inherits table context across multiple turns")
|
||||||
|
func multipleFollowUpsInheritContext() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
|
||||||
|
// 3-turn conversation narrowing down results
|
||||||
|
let mock = PromptCapturingMockModel(responses: [
|
||||||
|
"SELECT * FROM users",
|
||||||
|
"Here are all 5 users.",
|
||||||
|
"SELECT * FROM users WHERE city = 'New York'",
|
||||||
|
"Found 2 users in New York: Alice and Charlie.",
|
||||||
|
"SELECT * FROM users WHERE city = 'New York' AND age > 30",
|
||||||
|
"Charlie (42) is the only New York user over 30."
|
||||||
|
])
|
||||||
|
|
||||||
|
let engine = ChatEngine(database: db, model: mock)
|
||||||
|
|
||||||
|
// Turn 1
|
||||||
|
_ = try await engine.send("show all users")
|
||||||
|
|
||||||
|
// Turn 2 — narrows by city
|
||||||
|
let response2 = try await engine.send("only those in New York")
|
||||||
|
#expect(response2.sql == "SELECT * FROM users WHERE city = 'New York'")
|
||||||
|
#expect(response2.queryResult?.rowCount == 2)
|
||||||
|
|
||||||
|
// Turn 3 — further narrows by age
|
||||||
|
let response3 = try await engine.send("now filter by age over 30")
|
||||||
|
#expect(response3.sql == "SELECT * FROM users WHERE city = 'New York' AND age > 30")
|
||||||
|
#expect(response3.queryResult?.rowCount == 1)
|
||||||
|
|
||||||
|
// Verify third turn's prompt includes the full conversation history
|
||||||
|
let prompts = mock.capturedPrompts
|
||||||
|
let thirdTurnPrompt = prompts.last { prompt in
|
||||||
|
prompt.contains("now filter by age over 30") && prompt.contains("CONVERSATION HISTORY")
|
||||||
|
}
|
||||||
|
#expect(thirdTurnPrompt != nil)
|
||||||
|
|
||||||
|
if let prompt = thirdTurnPrompt {
|
||||||
|
// Should include both prior user messages
|
||||||
|
#expect(prompt.contains("show all users"))
|
||||||
|
#expect(prompt.contains("only those in New York"))
|
||||||
|
// Should include prior SQL
|
||||||
|
#expect(prompt.contains("SELECT * FROM users"))
|
||||||
|
#expect(prompt.contains("SELECT * FROM users WHERE city = 'New York'"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Follow-up switching tables preserves cross-table context")
|
||||||
|
func followUpSwitchesTableWithContext() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
|
||||||
|
// Turn 1: query users, Turn 2: ask about their orders
|
||||||
|
let mock = PromptCapturingMockModel(responses: [
|
||||||
|
"SELECT name, age FROM users WHERE age > 30",
|
||||||
|
"Found 3 users over 30.",
|
||||||
|
"SELECT o.id, u.name, o.amount, o.status FROM orders o JOIN users u ON o.user_id = u.id WHERE u.age > 30",
|
||||||
|
"Bob has a $150 completed order, Charlie has a $200 completed order."
|
||||||
|
])
|
||||||
|
|
||||||
|
let engine = ChatEngine(database: db, model: mock)
|
||||||
|
|
||||||
|
// Turn 1: users over 30
|
||||||
|
let response1 = try await engine.send("show users over 30")
|
||||||
|
#expect(response1.queryResult?.rowCount == 3)
|
||||||
|
|
||||||
|
// Turn 2: their orders — references the previous result context
|
||||||
|
let response2 = try await engine.send("show their orders")
|
||||||
|
#expect(response2.sql?.contains("JOIN") == true)
|
||||||
|
|
||||||
|
// Verify the follow-up prompt contains the users context
|
||||||
|
let prompts = mock.capturedPrompts
|
||||||
|
let orderPrompt = prompts.first { prompt in
|
||||||
|
prompt.contains("show their orders") && prompt.contains("CONVERSATION HISTORY")
|
||||||
|
}
|
||||||
|
#expect(orderPrompt != nil)
|
||||||
|
|
||||||
|
if let prompt = orderPrompt {
|
||||||
|
#expect(prompt.contains("show users over 30"), "Should contain prior user message")
|
||||||
|
#expect(prompt.contains("age > 30"), "Should contain prior SQL context for table reference")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Conversation history includes SQL from prior turns for context")
|
||||||
|
func historyIncludesSQLFromPriorTurns() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
|
||||||
|
// Both queries are aggregates → no LLM summarization needed
|
||||||
|
let mock = PromptCapturingMockModel(responses: [
|
||||||
|
"SELECT COUNT(*) FROM users",
|
||||||
|
"SELECT COUNT(*) FROM users WHERE age > 30",
|
||||||
|
])
|
||||||
|
|
||||||
|
let engine = ChatEngine(database: db, model: mock)
|
||||||
|
|
||||||
|
// Turn 1
|
||||||
|
let r1 = try await engine.send("how many users are there?")
|
||||||
|
#expect(r1.sql == "SELECT COUNT(*) FROM users")
|
||||||
|
|
||||||
|
// Turn 2 — references "those" implicitly
|
||||||
|
let r2 = try await engine.send("how many of those are over 30?")
|
||||||
|
#expect(r2.sql == "SELECT COUNT(*) FROM users WHERE age > 30")
|
||||||
|
|
||||||
|
// Verify engine history has all 4 messages (2 user + 2 assistant)
|
||||||
|
let messages = engine.messages
|
||||||
|
#expect(messages.count == 4)
|
||||||
|
#expect(messages[0].role == .user)
|
||||||
|
#expect(messages[0].content == "how many users are there?")
|
||||||
|
#expect(messages[1].role == .assistant)
|
||||||
|
#expect(messages[1].sql == "SELECT COUNT(*) FROM users")
|
||||||
|
#expect(messages[2].role == .user)
|
||||||
|
#expect(messages[2].content == "how many of those are over 30?")
|
||||||
|
#expect(messages[3].role == .assistant)
|
||||||
|
#expect(messages[3].sql == "SELECT COUNT(*) FROM users WHERE age > 30")
|
||||||
|
|
||||||
|
// The second prompt should reference the first query SQL
|
||||||
|
let prompts = mock.capturedPrompts
|
||||||
|
#expect(prompts.count >= 2)
|
||||||
|
let secondPrompt = prompts[1]
|
||||||
|
#expect(secondPrompt.contains("CONVERSATION HISTORY"))
|
||||||
|
#expect(secondPrompt.contains("SELECT COUNT(*) FROM users"))
|
||||||
|
#expect(secondPrompt.contains("how many users are there?"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Follow-up after aggregate uses prior table context")
|
||||||
|
func followUpAfterAggregateUsesTableContext() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
|
||||||
|
// Turn 1: aggregate (no LLM summary needed)
|
||||||
|
// Turn 2: follow-up referencing "those"
|
||||||
|
let mock = PromptCapturingMockModel(responses: [
|
||||||
|
"SELECT AVG(age) FROM users",
|
||||||
|
"SELECT name, age FROM users WHERE age > 35",
|
||||||
|
"Charlie (42) and Eve (55) are older than average."
|
||||||
|
])
|
||||||
|
|
||||||
|
let engine = ChatEngine(database: db, model: mock)
|
||||||
|
|
||||||
|
// Turn 1: average age → aggregate, template summary
|
||||||
|
let r1 = try await engine.send("what is the average age of users?")
|
||||||
|
#expect(r1.sql == "SELECT AVG(age) FROM users")
|
||||||
|
|
||||||
|
// Turn 2: "who is above that?" — needs the avg context
|
||||||
|
let r2 = try await engine.send("who is above average?")
|
||||||
|
#expect(r2.queryResult?.rowCount == 2)
|
||||||
|
|
||||||
|
// Verify context passed
|
||||||
|
let prompts = mock.capturedPrompts
|
||||||
|
let followUp = prompts.first { prompt in
|
||||||
|
prompt.contains("who is above average?") && prompt.contains("CONVERSATION HISTORY")
|
||||||
|
}
|
||||||
|
#expect(followUp != nil)
|
||||||
|
if let prompt = followUp {
|
||||||
|
#expect(prompt.contains("AVG(age)"), "Should include prior aggregate SQL for context")
|
||||||
|
#expect(prompt.contains("users"), "Should include table reference from prior turn")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Context window limits how much history is visible in follow-ups")
|
||||||
|
func contextWindowLimitsHistoryInFollowUps() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
|
||||||
|
// 3 turns, but context window of 2 messages
|
||||||
|
let mock = PromptCapturingMockModel(responses: [
|
||||||
|
"SELECT COUNT(*) FROM users",
|
||||||
|
"SELECT COUNT(*) FROM orders",
|
||||||
|
"SELECT COUNT(*) FROM users WHERE age > 30",
|
||||||
|
])
|
||||||
|
|
||||||
|
let config = ChatEngineConfiguration(
|
||||||
|
queryTimeout: nil,
|
||||||
|
contextWindowSize: 2
|
||||||
|
)
|
||||||
|
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: db,
|
||||||
|
model: mock,
|
||||||
|
configuration: config
|
||||||
|
)
|
||||||
|
|
||||||
|
_ = try await engine.send("how many users?")
|
||||||
|
_ = try await engine.send("how many orders?")
|
||||||
|
_ = try await engine.send("how many users over 30?")
|
||||||
|
|
||||||
|
// The third prompt should only have the last 2 messages from turn 2
|
||||||
|
let prompts = mock.capturedPrompts
|
||||||
|
#expect(prompts.count >= 3)
|
||||||
|
|
||||||
|
let thirdPrompt = prompts[2]
|
||||||
|
#expect(thirdPrompt.contains("CONVERSATION HISTORY"))
|
||||||
|
// Turn 2 context should be present
|
||||||
|
#expect(thirdPrompt.contains("how many orders?"))
|
||||||
|
#expect(thirdPrompt.contains("SELECT COUNT(*) FROM orders"))
|
||||||
|
// Turn 1 context should be trimmed (window=2 means last 2 messages)
|
||||||
|
#expect(!thirdPrompt.contains("how many users?\n"), "First turn should be trimmed from context window")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("clearHistory resets context so follow-ups have no prior history")
|
||||||
|
func clearHistoryResetsFollowUpContext() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
|
||||||
|
let mock = PromptCapturingMockModel(responses: [
|
||||||
|
"SELECT * FROM users",
|
||||||
|
"Here are the 5 users.",
|
||||||
|
"SELECT COUNT(*) FROM users",
|
||||||
|
])
|
||||||
|
|
||||||
|
let engine = ChatEngine(database: db, model: mock)
|
||||||
|
|
||||||
|
// Turn 1
|
||||||
|
_ = try await engine.send("show all users")
|
||||||
|
#expect(engine.messages.count == 2)
|
||||||
|
|
||||||
|
// Clear history
|
||||||
|
engine.clearHistory()
|
||||||
|
#expect(engine.messages.isEmpty)
|
||||||
|
|
||||||
|
// Turn 2 after clear — should NOT have conversation history
|
||||||
|
_ = try await engine.send("count all users")
|
||||||
|
|
||||||
|
let prompts = mock.capturedPrompts
|
||||||
|
let lastPrompt = prompts.last!
|
||||||
|
// After clearing, the prompt should NOT contain conversation history
|
||||||
|
#expect(!lastPrompt.contains("CONVERSATION HISTORY"),
|
||||||
|
"After clearHistory(), follow-up should not have prior context")
|
||||||
|
#expect(!lastPrompt.contains("show all users"),
|
||||||
|
"After clearHistory(), prior messages should be gone")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Multi-turn with result data in context enables informed follow-ups")
|
||||||
|
func resultDataInContextEnablesInformedFollowUps() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
|
||||||
|
// Turn 1: list users → multi-row result, LLM summarizes
|
||||||
|
// Turn 2: "sort those by age" → references same table
|
||||||
|
let mock = PromptCapturingMockModel(responses: [
|
||||||
|
"SELECT name, age, city FROM users",
|
||||||
|
"Found 5 users: Alice (25, NY), Bob (35, SF), Charlie (42, NY), Diana (28, Chicago), Eve (55, SF).",
|
||||||
|
"SELECT name, age, city FROM users ORDER BY age DESC",
|
||||||
|
"Users sorted by age: Eve (55), Charlie (42), Bob (35), Diana (28), Alice (25)."
|
||||||
|
])
|
||||||
|
|
||||||
|
let engine = ChatEngine(database: db, model: mock)
|
||||||
|
|
||||||
|
let r1 = try await engine.send("list all users with their age and city")
|
||||||
|
#expect(r1.queryResult?.rowCount == 5)
|
||||||
|
#expect(r1.queryResult?.columns.contains("age") == true)
|
||||||
|
#expect(r1.queryResult?.columns.contains("city") == true)
|
||||||
|
|
||||||
|
let r2 = try await engine.send("sort those by age descending")
|
||||||
|
#expect(r2.sql == "SELECT name, age, city FROM users ORDER BY age DESC")
|
||||||
|
|
||||||
|
// Verify the assistant message in history includes the SQL
|
||||||
|
let messages = engine.messages
|
||||||
|
#expect(messages.count == 4)
|
||||||
|
// First assistant message should have the SQL recorded
|
||||||
|
#expect(messages[1].sql == "SELECT name, age, city FROM users")
|
||||||
|
// Second assistant should have the sorted SQL
|
||||||
|
#expect(messages[3].sql == "SELECT name, age, city FROM users ORDER BY age DESC")
|
||||||
|
}
|
||||||
|
}
|
||||||
508
Tests/SwiftDBAITests/OnDeviceProviderConfigurationTests.swift
Normal file
508
Tests/SwiftDBAITests/OnDeviceProviderConfigurationTests.swift
Normal file
@@ -0,0 +1,508 @@
|
|||||||
|
// OnDeviceProviderConfigurationTests.swift
|
||||||
|
// SwiftDBAI Tests
|
||||||
|
//
|
||||||
|
// Tests for on-device provider configurations (CoreML, MLX) including
|
||||||
|
// configuration validation, inference pipeline setup, and system readiness.
|
||||||
|
|
||||||
|
import AnyLanguageModel
|
||||||
|
import Foundation
|
||||||
|
@testable import SwiftDBAI
|
||||||
|
import Testing
|
||||||
|
|
||||||
|
@Suite("OnDeviceProviderConfiguration")
|
||||||
|
struct OnDeviceProviderConfigurationTests {
|
||||||
|
|
||||||
|
// MARK: - OnDeviceProviderType
|
||||||
|
|
||||||
|
@Test("OnDeviceProviderType has CoreML and MLX cases")
|
||||||
|
func providerTypeCases() {
|
||||||
|
let cases = OnDeviceProviderType.allCases
|
||||||
|
#expect(cases.count == 2)
|
||||||
|
#expect(cases.contains(.coreML))
|
||||||
|
#expect(cases.contains(.mlx))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("OnDeviceProviderType raw values are descriptive")
|
||||||
|
func providerTypeRawValues() {
|
||||||
|
#expect(OnDeviceProviderType.coreML.rawValue == "coreML")
|
||||||
|
#expect(OnDeviceProviderType.mlx.rawValue == "mlx")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - CoreML Configuration
|
||||||
|
|
||||||
|
@Test("CoreML configuration stores all properties")
|
||||||
|
func coreMLBasicConfiguration() {
|
||||||
|
let url = URL(fileURLWithPath: "/tmp/TestModel.mlmodelc")
|
||||||
|
let config = CoreMLProviderConfiguration(
|
||||||
|
modelURL: url,
|
||||||
|
computeUnits: .cpuAndGPU,
|
||||||
|
maxResponseTokens: 1024,
|
||||||
|
useSampling: true,
|
||||||
|
temperature: 0.3
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(config.modelURL == url)
|
||||||
|
#expect(config.computeUnits == .cpuAndGPU)
|
||||||
|
#expect(config.maxResponseTokens == 1024)
|
||||||
|
#expect(config.useSampling == true)
|
||||||
|
#expect(config.temperature == 0.3)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("CoreML configuration uses sensible defaults")
|
||||||
|
func coreMLDefaultConfiguration() {
|
||||||
|
let url = URL(fileURLWithPath: "/tmp/TestModel.mlmodelc")
|
||||||
|
let config = CoreMLProviderConfiguration(modelURL: url)
|
||||||
|
|
||||||
|
#expect(config.computeUnits == .all)
|
||||||
|
#expect(config.maxResponseTokens == 2048)
|
||||||
|
#expect(config.useSampling == false)
|
||||||
|
#expect(config.temperature == 0.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("CoreML validation fails for non-mlmodelc extension")
|
||||||
|
func coreMLValidateWrongExtension() {
|
||||||
|
let url = URL(fileURLWithPath: "/tmp/TestModel.onnx")
|
||||||
|
let config = CoreMLProviderConfiguration(modelURL: url)
|
||||||
|
|
||||||
|
#expect(throws: OnDeviceProviderError.self) {
|
||||||
|
try config.validate()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("CoreML validation fails for missing model file")
|
||||||
|
func coreMLValidateMissingFile() {
|
||||||
|
let url = URL(fileURLWithPath: "/nonexistent/path/Model.mlmodelc")
|
||||||
|
let config = CoreMLProviderConfiguration(modelURL: url)
|
||||||
|
|
||||||
|
#expect(throws: OnDeviceProviderError.self) {
|
||||||
|
try config.validate()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("CoreML configuration is Equatable")
|
||||||
|
func coreMLEquatable() {
|
||||||
|
let url = URL(fileURLWithPath: "/tmp/TestModel.mlmodelc")
|
||||||
|
let a = CoreMLProviderConfiguration(modelURL: url, computeUnits: .all)
|
||||||
|
let b = CoreMLProviderConfiguration(modelURL: url, computeUnits: .all)
|
||||||
|
let c = CoreMLProviderConfiguration(modelURL: url, computeUnits: .cpuOnly)
|
||||||
|
|
||||||
|
#expect(a == b)
|
||||||
|
#expect(a != c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - ComputeUnitPreference
|
||||||
|
|
||||||
|
@Test("ComputeUnitPreference has all expected cases")
|
||||||
|
func computeUnitCases() {
|
||||||
|
let cases = ComputeUnitPreference.allCases
|
||||||
|
#expect(cases.count == 4)
|
||||||
|
#expect(cases.contains(.all))
|
||||||
|
#expect(cases.contains(.cpuOnly))
|
||||||
|
#expect(cases.contains(.cpuAndGPU))
|
||||||
|
#expect(cases.contains(.cpuAndNeuralEngine))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - MLX Configuration
|
||||||
|
|
||||||
|
@Test("MLX configuration stores all properties")
|
||||||
|
func mlxBasicConfiguration() {
|
||||||
|
let dir = URL(fileURLWithPath: "/tmp/models/my-model")
|
||||||
|
let config = MLXProviderConfiguration(
|
||||||
|
modelId: "mlx-community/Test-Model-4bit",
|
||||||
|
localDirectory: dir,
|
||||||
|
gpuMemory: .minimal,
|
||||||
|
maxResponseTokens: 512,
|
||||||
|
temperature: 0.2,
|
||||||
|
topP: 0.9,
|
||||||
|
repetitionPenalty: 1.2
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(config.modelId == "mlx-community/Test-Model-4bit")
|
||||||
|
#expect(config.localDirectory == dir)
|
||||||
|
#expect(config.gpuMemory == .minimal)
|
||||||
|
#expect(config.maxResponseTokens == 512)
|
||||||
|
#expect(config.temperature == 0.2)
|
||||||
|
#expect(config.topP == 0.9)
|
||||||
|
#expect(config.repetitionPenalty == 1.2)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("MLX configuration uses sensible defaults")
|
||||||
|
func mlxDefaultConfiguration() {
|
||||||
|
let config = MLXProviderConfiguration(modelId: "test-model")
|
||||||
|
|
||||||
|
#expect(config.localDirectory == nil)
|
||||||
|
#expect(config.gpuMemory == .automatic)
|
||||||
|
#expect(config.maxResponseTokens == 2048)
|
||||||
|
#expect(config.temperature == 0.1)
|
||||||
|
#expect(config.topP == 0.95)
|
||||||
|
#expect(config.repetitionPenalty == 1.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("MLX validation fails for empty model ID")
|
||||||
|
func mlxValidateEmptyModelId() {
|
||||||
|
let config = MLXProviderConfiguration(modelId: "")
|
||||||
|
|
||||||
|
#expect(throws: OnDeviceProviderError.self) {
|
||||||
|
try config.validate()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("MLX validation fails for nonexistent local directory")
|
||||||
|
func mlxValidateMissingDirectory() {
|
||||||
|
let config = MLXProviderConfiguration(
|
||||||
|
modelId: "test-model",
|
||||||
|
localDirectory: URL(fileURLWithPath: "/nonexistent/directory")
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(throws: OnDeviceProviderError.self) {
|
||||||
|
try config.validate()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("MLX validation fails for negative temperature")
|
||||||
|
func mlxValidateNegativeTemperature() {
|
||||||
|
let config = MLXProviderConfiguration(
|
||||||
|
modelId: "test-model",
|
||||||
|
temperature: -0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(throws: OnDeviceProviderError.self) {
|
||||||
|
try config.validate()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("MLX validation fails for topP out of range")
|
||||||
|
func mlxValidateInvalidTopP() {
|
||||||
|
let configZero = MLXProviderConfiguration(
|
||||||
|
modelId: "test-model",
|
||||||
|
topP: 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(throws: OnDeviceProviderError.self) {
|
||||||
|
try configZero.validate()
|
||||||
|
}
|
||||||
|
|
||||||
|
let configOver = MLXProviderConfiguration(
|
||||||
|
modelId: "test-model",
|
||||||
|
topP: 1.5
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(throws: OnDeviceProviderError.self) {
|
||||||
|
try configOver.validate()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("MLX validation fails for zero repetition penalty")
|
||||||
|
func mlxValidateInvalidRepetitionPenalty() {
|
||||||
|
let config = MLXProviderConfiguration(
|
||||||
|
modelId: "test-model",
|
||||||
|
repetitionPenalty: 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(throws: OnDeviceProviderError.self) {
|
||||||
|
try config.validate()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("MLX validation succeeds for valid configuration")
|
||||||
|
func mlxValidateSuccess() throws {
|
||||||
|
let config = MLXProviderConfiguration(modelId: "test-model")
|
||||||
|
// Should not throw (no local directory set, model ID is non-empty)
|
||||||
|
try config.validate()
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("MLX configuration is Equatable")
|
||||||
|
func mlxEquatable() {
|
||||||
|
let a = MLXProviderConfiguration(modelId: "model-a")
|
||||||
|
let b = MLXProviderConfiguration(modelId: "model-a")
|
||||||
|
let c = MLXProviderConfiguration(modelId: "model-b")
|
||||||
|
|
||||||
|
#expect(a == b)
|
||||||
|
#expect(a != c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Well-Known MLX Models
|
||||||
|
|
||||||
|
@Test("Llama 3.2 3B preset has correct model ID")
|
||||||
|
func llama3_2_3BPreset() {
|
||||||
|
let config = MLXProviderConfiguration.llama3_2_3B()
|
||||||
|
#expect(config.modelId == "mlx-community/Llama-3.2-3B-Instruct-4bit")
|
||||||
|
#expect(config.temperature == 0.1)
|
||||||
|
#expect(config.maxResponseTokens == 2048)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Qwen 2.5 Coder 3B preset has correct model ID")
|
||||||
|
func qwen2_5_coder3BPreset() {
|
||||||
|
let config = MLXProviderConfiguration.qwen2_5_coder_3B()
|
||||||
|
#expect(config.modelId == "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit")
|
||||||
|
#expect(config.temperature == 0.05)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Phi 3.5 Mini preset has correct model ID")
|
||||||
|
func phi3_5_miniPreset() {
|
||||||
|
let config = MLXProviderConfiguration.phi3_5_mini()
|
||||||
|
#expect(config.modelId == "mlx-community/Phi-3.5-mini-instruct-4bit")
|
||||||
|
#expect(config.temperature == 0.1)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Well-known models accept custom GPU memory config")
|
||||||
|
func wellKnownModelsCustomGPU() {
|
||||||
|
let config = MLXProviderConfiguration.llama3_2_3B(
|
||||||
|
gpuMemory: .minimal
|
||||||
|
)
|
||||||
|
#expect(config.gpuMemory == .minimal)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - GPU Memory Configuration
|
||||||
|
|
||||||
|
@Test("Automatic GPU memory config scales with RAM")
|
||||||
|
func automaticGPUMemory() {
|
||||||
|
let config = MLXGPUMemoryConfig.automatic
|
||||||
|
#expect(config.activeCacheLimit > 0)
|
||||||
|
#expect(config.idleCacheLimit == 50_000_000)
|
||||||
|
#expect(config.clearCacheOnEviction == true)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Minimal GPU memory config is conservative")
|
||||||
|
func minimalGPUMemory() {
|
||||||
|
let config = MLXGPUMemoryConfig.minimal
|
||||||
|
#expect(config.activeCacheLimit == 64_000_000)
|
||||||
|
#expect(config.idleCacheLimit == 16_000_000)
|
||||||
|
#expect(config.clearCacheOnEviction == true)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Unconstrained GPU memory config uses max values")
|
||||||
|
func unconstrainedGPUMemory() {
|
||||||
|
let config = MLXGPUMemoryConfig.unconstrained
|
||||||
|
#expect(config.activeCacheLimit == Int.max)
|
||||||
|
#expect(config.idleCacheLimit == Int.max)
|
||||||
|
#expect(config.clearCacheOnEviction == false)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("GPU memory config is Equatable")
|
||||||
|
func gpuMemoryEquatable() {
|
||||||
|
#expect(MLXGPUMemoryConfig.minimal == MLXGPUMemoryConfig.minimal)
|
||||||
|
#expect(MLXGPUMemoryConfig.minimal != MLXGPUMemoryConfig.unconstrained)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - On-Device Provider Errors
|
||||||
|
|
||||||
|
@Test("OnDeviceProviderError has descriptive messages")
|
||||||
|
func errorDescriptions() {
|
||||||
|
let errors: [OnDeviceProviderError] = [
|
||||||
|
.modelNotFound(URL(fileURLWithPath: "/tmp/model")),
|
||||||
|
.invalidModelFormat(expected: ".mlmodelc", actual: ".onnx"),
|
||||||
|
.emptyModelId,
|
||||||
|
.invalidParameter(name: "temperature", value: "-1", reason: "Must be non-negative"),
|
||||||
|
.providerUnavailable(.mlx, reason: "MLX build flag not enabled"),
|
||||||
|
.modelLoadFailed(reason: "Out of memory"),
|
||||||
|
.inferenceFailed(reason: "Token limit exceeded"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for error in errors {
|
||||||
|
#expect(error.errorDescription != nil)
|
||||||
|
#expect(!error.errorDescription!.isEmpty)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("OnDeviceProviderError is Equatable")
|
||||||
|
func errorEquatable() {
|
||||||
|
let a = OnDeviceProviderError.emptyModelId
|
||||||
|
let b = OnDeviceProviderError.emptyModelId
|
||||||
|
let c = OnDeviceProviderError.modelLoadFailed(reason: "test")
|
||||||
|
|
||||||
|
#expect(a == b)
|
||||||
|
#expect(a != c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Inference Pipeline
|
||||||
|
|
||||||
|
@Test("MLX inference pipeline initializes with correct type")
|
||||||
|
func mlxPipelineInit() {
|
||||||
|
let config = MLXProviderConfiguration.llama3_2_3B()
|
||||||
|
let pipeline = OnDeviceInferencePipeline(mlxConfiguration: config)
|
||||||
|
|
||||||
|
#expect(pipeline.providerType == .mlx)
|
||||||
|
#expect(pipeline.mlxConfiguration != nil)
|
||||||
|
#expect(pipeline.coreMLConfiguration == nil)
|
||||||
|
#expect(pipeline.status == .notLoaded)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("CoreML inference pipeline initializes with correct type")
|
||||||
|
func coreMLPipelineInit() {
|
||||||
|
let url = URL(fileURLWithPath: "/tmp/TestModel.mlmodelc")
|
||||||
|
let config = CoreMLProviderConfiguration(modelURL: url)
|
||||||
|
let pipeline = OnDeviceInferencePipeline(coreMLConfiguration: config)
|
||||||
|
|
||||||
|
#expect(pipeline.providerType == .coreML)
|
||||||
|
#expect(pipeline.coreMLConfiguration != nil)
|
||||||
|
#expect(pipeline.mlxConfiguration == nil)
|
||||||
|
#expect(pipeline.status == .notLoaded)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Pipeline validates MLX configuration")
|
||||||
|
func pipelineValidatesMLX() throws {
|
||||||
|
let validConfig = MLXProviderConfiguration(modelId: "test-model")
|
||||||
|
let pipeline = OnDeviceInferencePipeline(mlxConfiguration: validConfig)
|
||||||
|
try pipeline.validateConfiguration()
|
||||||
|
|
||||||
|
let invalidConfig = MLXProviderConfiguration(modelId: "")
|
||||||
|
let invalidPipeline = OnDeviceInferencePipeline(mlxConfiguration: invalidConfig)
|
||||||
|
#expect(throws: OnDeviceProviderError.self) {
|
||||||
|
try invalidPipeline.validateConfiguration()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Pipeline validates CoreML configuration")
|
||||||
|
func pipelineValidatesCoreML() {
|
||||||
|
let url = URL(fileURLWithPath: "/tmp/TestModel.onnx")
|
||||||
|
let config = CoreMLProviderConfiguration(modelURL: url)
|
||||||
|
let pipeline = OnDeviceInferencePipeline(coreMLConfiguration: config)
|
||||||
|
|
||||||
|
#expect(throws: OnDeviceProviderError.self) {
|
||||||
|
try pipeline.validateConfiguration()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Pipeline provides SQL generation hints for MLX")
|
||||||
|
func mlxSQLHints() {
|
||||||
|
let config = MLXProviderConfiguration(
|
||||||
|
modelId: "test-model",
|
||||||
|
maxResponseTokens: 512,
|
||||||
|
temperature: 0.2
|
||||||
|
)
|
||||||
|
let pipeline = OnDeviceInferencePipeline(mlxConfiguration: config)
|
||||||
|
let hints = pipeline.recommendedSQLGenerationHints
|
||||||
|
|
||||||
|
#expect(hints.maxTokens == 512)
|
||||||
|
#expect(hints.temperature == 0.2)
|
||||||
|
#expect(hints.useSampling == true)
|
||||||
|
#expect(hints.systemPromptSuffix.contains("MLX"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Pipeline provides SQL generation hints for CoreML")
|
||||||
|
func coreMLSQLHints() {
|
||||||
|
let url = URL(fileURLWithPath: "/tmp/TestModel.mlmodelc")
|
||||||
|
let config = CoreMLProviderConfiguration(
|
||||||
|
modelURL: url,
|
||||||
|
maxResponseTokens: 1024,
|
||||||
|
useSampling: false,
|
||||||
|
temperature: 0.05
|
||||||
|
)
|
||||||
|
let pipeline = OnDeviceInferencePipeline(coreMLConfiguration: config)
|
||||||
|
let hints = pipeline.recommendedSQLGenerationHints
|
||||||
|
|
||||||
|
#expect(hints.maxTokens == 1024)
|
||||||
|
#expect(hints.temperature == 0.05)
|
||||||
|
#expect(hints.useSampling == false)
|
||||||
|
#expect(hints.systemPromptSuffix.contains("SQL"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - System Readiness
|
||||||
|
|
||||||
|
@Test("System capability check returns valid data")
|
||||||
|
func systemCapability() {
|
||||||
|
let capability = OnDeviceModelReadiness.checkSystemCapability()
|
||||||
|
|
||||||
|
#expect(capability.totalRAM > 0)
|
||||||
|
// On any modern test machine, we should have at least some RAM
|
||||||
|
#expect(capability.totalRAM > 1024 * 1024 * 1024) // > 1GB
|
||||||
|
|
||||||
|
// On Apple silicon Macs, this should be true
|
||||||
|
#if arch(arm64)
|
||||||
|
#expect(capability.hasNeuralEngine == true)
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Suggested MLX model returns a valid configuration")
|
||||||
|
func suggestedMLXModel() {
|
||||||
|
let config = OnDeviceModelReadiness.suggestedMLXModel()
|
||||||
|
#expect(!config.modelId.isEmpty)
|
||||||
|
#expect(config.temperature >= 0)
|
||||||
|
#expect(config.maxResponseTokens > 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Recommended model size enum has correct raw values")
|
||||||
|
func recommendedModelSizeRawValues() {
|
||||||
|
#expect(OnDeviceModelReadiness.RecommendedModelSize.small.rawValue == "small")
|
||||||
|
#expect(OnDeviceModelReadiness.RecommendedModelSize.medium.rawValue == "medium")
|
||||||
|
#expect(OnDeviceModelReadiness.RecommendedModelSize.large.rawValue == "large")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - ProviderConfiguration Integration
|
||||||
|
|
||||||
|
@Test("onDeviceMLX creates a ProviderConfiguration")
|
||||||
|
func onDeviceMLXProviderConfig() {
|
||||||
|
let mlxConfig = MLXProviderConfiguration.llama3_2_3B()
|
||||||
|
let providerConfig = ProviderConfiguration.onDeviceMLX(mlxConfig)
|
||||||
|
|
||||||
|
#expect(providerConfig.model == mlxConfig.modelId)
|
||||||
|
#expect(!providerConfig.hasValidAPIKey) // No API key needed for on-device
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("onDeviceCoreML creates a ProviderConfiguration")
|
||||||
|
func onDeviceCoreMLProviderConfig() {
|
||||||
|
let url = URL(fileURLWithPath: "/tmp/SQLModel.mlmodelc")
|
||||||
|
let coreMLConfig = CoreMLProviderConfiguration(modelURL: url)
|
||||||
|
let providerConfig = ProviderConfiguration.onDeviceCoreML(coreMLConfig)
|
||||||
|
|
||||||
|
#expect(providerConfig.model == "SQLModel.mlmodelc")
|
||||||
|
#expect(!providerConfig.hasValidAPIKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Pipeline Status
|
||||||
|
|
||||||
|
@Test("Pipeline status transitions")
|
||||||
|
func pipelineStatusTransitions() {
|
||||||
|
let config = MLXProviderConfiguration(modelId: "test-model")
|
||||||
|
let pipeline = OnDeviceInferencePipeline(mlxConfiguration: config)
|
||||||
|
|
||||||
|
#expect(pipeline.status == .notLoaded)
|
||||||
|
|
||||||
|
pipeline.setStatus(.loading)
|
||||||
|
#expect(pipeline.status == .loading)
|
||||||
|
|
||||||
|
pipeline.setStatus(.ready)
|
||||||
|
#expect(pipeline.status == .ready)
|
||||||
|
|
||||||
|
pipeline.setStatus(.failed("Out of memory"))
|
||||||
|
#expect(pipeline.status == .failed("Out of memory"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Pipeline Status is Equatable")
|
||||||
|
func pipelineStatusEquatable() {
|
||||||
|
#expect(OnDeviceInferencePipeline.Status.notLoaded == .notLoaded)
|
||||||
|
#expect(OnDeviceInferencePipeline.Status.loading == .loading)
|
||||||
|
#expect(OnDeviceInferencePipeline.Status.ready == .ready)
|
||||||
|
#expect(OnDeviceInferencePipeline.Status.failed("a") == .failed("a"))
|
||||||
|
#expect(OnDeviceInferencePipeline.Status.failed("a") != .failed("b"))
|
||||||
|
#expect(OnDeviceInferencePipeline.Status.notLoaded != .ready)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - SQL Generation Hints
|
||||||
|
|
||||||
|
@Test("SQL generation hints are Equatable")
|
||||||
|
func sqlHintsEquatable() {
|
||||||
|
let a = OnDeviceSQLGenerationHints(
|
||||||
|
maxTokens: 512,
|
||||||
|
temperature: 0.1,
|
||||||
|
systemPromptSuffix: "test",
|
||||||
|
useSampling: true
|
||||||
|
)
|
||||||
|
let b = OnDeviceSQLGenerationHints(
|
||||||
|
maxTokens: 512,
|
||||||
|
temperature: 0.1,
|
||||||
|
systemPromptSuffix: "test",
|
||||||
|
useSampling: true
|
||||||
|
)
|
||||||
|
let c = OnDeviceSQLGenerationHints(
|
||||||
|
maxTokens: 1024,
|
||||||
|
temperature: 0.1,
|
||||||
|
systemPromptSuffix: "test",
|
||||||
|
useSampling: true
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(a == b)
|
||||||
|
#expect(a != c)
|
||||||
|
}
|
||||||
|
}
|
||||||
254
Tests/SwiftDBAITests/PromptBuilderTests.swift
Normal file
254
Tests/SwiftDBAITests/PromptBuilderTests.swift
Normal file
@@ -0,0 +1,254 @@
|
|||||||
|
// PromptBuilderTests.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
|
||||||
|
import Testing
|
||||||
|
@testable import SwiftDBAI
|
||||||
|
|
||||||
|
@Suite("PromptBuilder")
|
||||||
|
struct PromptBuilderTests {
|
||||||
|
|
||||||
|
// MARK: - Helpers
|
||||||
|
|
||||||
|
/// Creates a sample schema for testing.
|
||||||
|
private func makeSampleSchema() -> DatabaseSchema {
|
||||||
|
let usersTable = TableSchema(
|
||||||
|
name: "users",
|
||||||
|
columns: [
|
||||||
|
ColumnSchema(cid: 0, name: "id", type: "INTEGER", isNotNull: true, defaultValue: nil, isPrimaryKey: true),
|
||||||
|
ColumnSchema(cid: 1, name: "name", type: "TEXT", isNotNull: true, defaultValue: nil, isPrimaryKey: false),
|
||||||
|
ColumnSchema(cid: 2, name: "email", type: "TEXT", isNotNull: false, defaultValue: nil, isPrimaryKey: false),
|
||||||
|
ColumnSchema(cid: 3, name: "created_at", type: "TEXT", isNotNull: false, defaultValue: "CURRENT_TIMESTAMP", isPrimaryKey: false),
|
||||||
|
],
|
||||||
|
primaryKey: ["id"],
|
||||||
|
foreignKeys: [],
|
||||||
|
indexes: [
|
||||||
|
IndexSchema(name: "idx_users_email", isUnique: true, columns: ["email"])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
let ordersTable = TableSchema(
|
||||||
|
name: "orders",
|
||||||
|
columns: [
|
||||||
|
ColumnSchema(cid: 0, name: "id", type: "INTEGER", isNotNull: true, defaultValue: nil, isPrimaryKey: true),
|
||||||
|
ColumnSchema(cid: 1, name: "user_id", type: "INTEGER", isNotNull: true, defaultValue: nil, isPrimaryKey: false),
|
||||||
|
ColumnSchema(cid: 2, name: "total", type: "REAL", isNotNull: true, defaultValue: nil, isPrimaryKey: false),
|
||||||
|
ColumnSchema(cid: 3, name: "status", type: "TEXT", isNotNull: true, defaultValue: "'pending'", isPrimaryKey: false),
|
||||||
|
],
|
||||||
|
primaryKey: ["id"],
|
||||||
|
foreignKeys: [
|
||||||
|
ForeignKeySchema(fromColumn: "user_id", toTable: "users", toColumn: "id", onUpdate: "NO ACTION", onDelete: "CASCADE")
|
||||||
|
],
|
||||||
|
indexes: []
|
||||||
|
)
|
||||||
|
|
||||||
|
return DatabaseSchema(
|
||||||
|
tables: ["users": usersTable, "orders": ordersTable],
|
||||||
|
tableNames: ["users", "orders"]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func makeEmptySchema() -> DatabaseSchema {
|
||||||
|
DatabaseSchema(tables: [:], tableNames: [])
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - System Instructions Tests
|
||||||
|
|
||||||
|
@Test("System instructions contain role section")
|
||||||
|
func systemInstructionsContainRole() {
|
||||||
|
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||||
|
let instructions = builder.buildSystemInstructions()
|
||||||
|
|
||||||
|
#expect(instructions.contains("ROLE"))
|
||||||
|
#expect(instructions.contains("SQL assistant"))
|
||||||
|
#expect(instructions.contains("SQLite database"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("System instructions contain schema")
|
||||||
|
func systemInstructionsContainSchema() {
|
||||||
|
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||||
|
let instructions = builder.buildSystemInstructions()
|
||||||
|
|
||||||
|
#expect(instructions.contains("DATABASE SCHEMA"))
|
||||||
|
#expect(instructions.contains("TABLE users"))
|
||||||
|
#expect(instructions.contains("TABLE orders"))
|
||||||
|
#expect(instructions.contains("name TEXT"))
|
||||||
|
#expect(instructions.contains("email TEXT"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("System instructions contain foreign keys from schema")
|
||||||
|
func systemInstructionsContainForeignKeys() {
|
||||||
|
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||||
|
let instructions = builder.buildSystemInstructions()
|
||||||
|
|
||||||
|
#expect(instructions.contains("FOREIGN KEY"))
|
||||||
|
#expect(instructions.contains("REFERENCES users(id)"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("System instructions contain SQL generation rules")
|
||||||
|
func systemInstructionsContainRules() {
|
||||||
|
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||||
|
let instructions = builder.buildSystemInstructions()
|
||||||
|
|
||||||
|
#expect(instructions.contains("SQL GENERATION RULES"))
|
||||||
|
#expect(instructions.contains("Use ONLY the tables and columns"))
|
||||||
|
#expect(instructions.contains("Never generate DDL"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("System instructions contain output format section")
|
||||||
|
func systemInstructionsContainOutputFormat() {
|
||||||
|
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||||
|
let instructions = builder.buildSystemInstructions()
|
||||||
|
|
||||||
|
#expect(instructions.contains("OUTPUT FORMAT"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Default allowlist is read-only")
|
||||||
|
func defaultAllowlistIsReadOnly() {
|
||||||
|
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||||
|
let instructions = builder.buildSystemInstructions()
|
||||||
|
|
||||||
|
#expect(instructions.contains("ONLY generate SELECT queries"))
|
||||||
|
#expect(instructions.contains("No data modifications"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Standard allowlist shows correct operations")
|
||||||
|
func standardAllowlistInstructions() {
|
||||||
|
let builder = PromptBuilder(schema: makeSampleSchema(), allowlist: .standard)
|
||||||
|
let instructions = builder.buildSystemInstructions()
|
||||||
|
|
||||||
|
#expect(instructions.contains("INSERT"))
|
||||||
|
#expect(instructions.contains("SELECT"))
|
||||||
|
#expect(instructions.contains("UPDATE"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Unrestricted allowlist warns about DELETE")
|
||||||
|
func unrestrictedAllowlistWarnsAboutDelete() {
|
||||||
|
let builder = PromptBuilder(schema: makeSampleSchema(), allowlist: .unrestricted)
|
||||||
|
let instructions = builder.buildSystemInstructions()
|
||||||
|
|
||||||
|
#expect(instructions.contains("DELETE"))
|
||||||
|
#expect(instructions.contains("destructive"))
|
||||||
|
#expect(instructions.contains("confirmation"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Additional context is appended")
|
||||||
|
func additionalContextAppended() {
|
||||||
|
let builder = PromptBuilder(
|
||||||
|
schema: makeSampleSchema(),
|
||||||
|
additionalContext: "All dates are stored in ISO 8601 format."
|
||||||
|
)
|
||||||
|
let instructions = builder.buildSystemInstructions()
|
||||||
|
|
||||||
|
#expect(instructions.contains("ADDITIONAL CONTEXT"))
|
||||||
|
#expect(instructions.contains("ISO 8601"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("No additional context section when nil")
|
||||||
|
func noAdditionalContextWhenNil() {
|
||||||
|
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||||
|
let instructions = builder.buildSystemInstructions()
|
||||||
|
|
||||||
|
#expect(!instructions.contains("ADDITIONAL CONTEXT"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("No additional context section when empty string")
|
||||||
|
func noAdditionalContextWhenEmpty() {
|
||||||
|
let builder = PromptBuilder(schema: makeSampleSchema(), additionalContext: "")
|
||||||
|
let instructions = builder.buildSystemInstructions()
|
||||||
|
|
||||||
|
#expect(!instructions.contains("ADDITIONAL CONTEXT"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Empty schema produces valid instructions")
|
||||||
|
func emptySchemaProducesValidInstructions() {
|
||||||
|
let builder = PromptBuilder(schema: makeEmptySchema())
|
||||||
|
let instructions = builder.buildSystemInstructions()
|
||||||
|
|
||||||
|
#expect(instructions.contains("ROLE"))
|
||||||
|
#expect(instructions.contains("SQL GENERATION RULES"))
|
||||||
|
// Schema section should still be present, just empty
|
||||||
|
#expect(instructions.contains("DATABASE SCHEMA"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - User Prompt Tests
|
||||||
|
|
||||||
|
@Test("User prompt passes through question directly")
|
||||||
|
func userPromptPassesThrough() {
|
||||||
|
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||||
|
let prompt = builder.buildUserPrompt("How many users signed up this week?")
|
||||||
|
|
||||||
|
#expect(prompt == "How many users signed up this week?")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Follow-up Prompt Tests
|
||||||
|
|
||||||
|
@Test("Follow-up prompt includes previous context")
|
||||||
|
func followUpPromptIncludesPreviousContext() {
|
||||||
|
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||||
|
let prompt = builder.buildFollowUpPrompt(
|
||||||
|
"Now sort them by name",
|
||||||
|
previousSQL: "SELECT * FROM users WHERE created_at > date('now', '-7 days')",
|
||||||
|
previousResultSummary: "Found 42 users who signed up this week"
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(prompt.contains("Previous query:"))
|
||||||
|
#expect(prompt.contains("SELECT * FROM users"))
|
||||||
|
#expect(prompt.contains("Previous result:"))
|
||||||
|
#expect(prompt.contains("42 users"))
|
||||||
|
#expect(prompt.contains("Follow-up question:"))
|
||||||
|
#expect(prompt.contains("sort them by name"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Schema Description Quality
|
||||||
|
|
||||||
|
@Test("Schema includes column types and constraints")
|
||||||
|
func schemaIncludesColumnDetails() {
|
||||||
|
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||||
|
let instructions = builder.buildSystemInstructions()
|
||||||
|
|
||||||
|
// Should include type info
|
||||||
|
#expect(instructions.contains("INTEGER"))
|
||||||
|
#expect(instructions.contains("TEXT"))
|
||||||
|
#expect(instructions.contains("REAL"))
|
||||||
|
|
||||||
|
// Should include constraints
|
||||||
|
#expect(instructions.contains("NOT NULL"))
|
||||||
|
#expect(instructions.contains("PRIMARY KEY"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Schema includes index information")
|
||||||
|
func schemaIncludesIndexes() {
|
||||||
|
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||||
|
let instructions = builder.buildSystemInstructions()
|
||||||
|
|
||||||
|
#expect(instructions.contains("INDEX"))
|
||||||
|
#expect(instructions.contains("idx_users_email"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Sendable Conformance
|
||||||
|
|
||||||
|
@Test("PromptBuilder is Sendable")
|
||||||
|
func promptBuilderIsSendable() async {
|
||||||
|
let builder = PromptBuilder(schema: makeSampleSchema())
|
||||||
|
|
||||||
|
// Verify it can be sent across concurrency boundaries
|
||||||
|
let instructions = await Task.detached {
|
||||||
|
builder.buildSystemInstructions()
|
||||||
|
}.value
|
||||||
|
|
||||||
|
#expect(instructions.contains("ROLE"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Custom Allowlist
|
||||||
|
|
||||||
|
@Test("Custom allowlist with select and delete only")
|
||||||
|
func customAllowlist() {
|
||||||
|
let allowlist = OperationAllowlist([.select, .delete])
|
||||||
|
let builder = PromptBuilder(schema: makeSampleSchema(), allowlist: allowlist)
|
||||||
|
let instructions = builder.buildSystemInstructions()
|
||||||
|
|
||||||
|
#expect(instructions.contains("DELETE"))
|
||||||
|
#expect(instructions.contains("SELECT"))
|
||||||
|
#expect(instructions.contains("destructive"))
|
||||||
|
}
|
||||||
|
}
|
||||||
325
Tests/SwiftDBAITests/ProviderConfigurationTests.swift
Normal file
325
Tests/SwiftDBAITests/ProviderConfigurationTests.swift
Normal file
@@ -0,0 +1,325 @@
|
|||||||
|
// ProviderConfigurationTests.swift
|
||||||
|
// SwiftDBAI Tests
|
||||||
|
//
|
||||||
|
// Tests for ProviderConfiguration — verifying all cloud provider configurations
|
||||||
|
// produce valid LanguageModel instances with correct settings.
|
||||||
|
|
||||||
|
import AnyLanguageModel
|
||||||
|
import Foundation
|
||||||
|
@testable import SwiftDBAI
|
||||||
|
import Testing
|
||||||
|
|
||||||
|
@Suite("ProviderConfiguration")
|
||||||
|
struct ProviderConfigurationTests {
|
||||||
|
|
||||||
|
// MARK: - OpenAI Configuration
|
||||||
|
|
||||||
|
@Test("OpenAI configuration stores provider and model")
|
||||||
|
func openAIBasicConfiguration() {
|
||||||
|
let config = ProviderConfiguration.openAI(
|
||||||
|
apiKey: "sk-test-key-123",
|
||||||
|
model: "gpt-4o"
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(config.provider == .openAI)
|
||||||
|
#expect(config.model == "gpt-4o")
|
||||||
|
#expect(config.apiKey == "sk-test-key-123")
|
||||||
|
#expect(config.hasValidAPIKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("OpenAI configuration produces a valid LanguageModel")
|
||||||
|
func openAIMakeModel() {
|
||||||
|
let config = ProviderConfiguration.openAI(
|
||||||
|
apiKey: "sk-test-key",
|
||||||
|
model: "gpt-4o-mini"
|
||||||
|
)
|
||||||
|
|
||||||
|
let model = config.makeModel()
|
||||||
|
#expect(model is OpenAILanguageModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("OpenAI with custom base URL for compatible services")
|
||||||
|
func openAICustomBaseURL() {
|
||||||
|
let customURL = URL(string: "https://my-proxy.example.com/v1/")!
|
||||||
|
let config = ProviderConfiguration.openAI(
|
||||||
|
apiKey: "sk-proxy-key",
|
||||||
|
model: "gpt-4o",
|
||||||
|
baseURL: customURL
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(config.baseURL == customURL)
|
||||||
|
let model = config.makeModel()
|
||||||
|
#expect(model is OpenAILanguageModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("OpenAI with Responses API variant")
|
||||||
|
func openAIResponsesVariant() {
|
||||||
|
let config = ProviderConfiguration.openAI(
|
||||||
|
apiKey: "sk-test",
|
||||||
|
model: "gpt-4o",
|
||||||
|
variant: .responses
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(config.openAIVariant == .responses)
|
||||||
|
let model = config.makeModel()
|
||||||
|
#expect(model is OpenAILanguageModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("OpenAI with dynamic key provider captures key by reference")
|
||||||
|
func openAIDynamicKeyProvider() {
|
||||||
|
nonisolated(unsafe) var currentKey = "sk-initial"
|
||||||
|
let config = ProviderConfiguration.openAI(
|
||||||
|
apiKeyProvider: { currentKey },
|
||||||
|
model: "gpt-4o"
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(config.apiKey == "sk-initial")
|
||||||
|
currentKey = "sk-rotated"
|
||||||
|
#expect(config.apiKey == "sk-rotated")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Anthropic Configuration
|
||||||
|
|
||||||
|
@Test("Anthropic configuration stores provider and model")
|
||||||
|
func anthropicBasicConfiguration() {
|
||||||
|
let config = ProviderConfiguration.anthropic(
|
||||||
|
apiKey: "sk-ant-test-key",
|
||||||
|
model: "claude-sonnet-4-20250514"
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(config.provider == .anthropic)
|
||||||
|
#expect(config.model == "claude-sonnet-4-20250514")
|
||||||
|
#expect(config.apiKey == "sk-ant-test-key")
|
||||||
|
#expect(config.hasValidAPIKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Anthropic configuration produces a valid LanguageModel")
|
||||||
|
func anthropicMakeModel() {
|
||||||
|
let config = ProviderConfiguration.anthropic(
|
||||||
|
apiKey: "sk-ant-test",
|
||||||
|
model: "claude-sonnet-4-20250514"
|
||||||
|
)
|
||||||
|
|
||||||
|
let model = config.makeModel()
|
||||||
|
#expect(model is AnthropicLanguageModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Anthropic with API version and betas")
|
||||||
|
func anthropicWithVersionAndBetas() {
|
||||||
|
let config = ProviderConfiguration.anthropic(
|
||||||
|
apiKey: "sk-ant-test",
|
||||||
|
model: "claude-sonnet-4-20250514",
|
||||||
|
apiVersion: "2024-01-01",
|
||||||
|
betas: ["computer-use"]
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(config.apiVersion == "2024-01-01")
|
||||||
|
#expect(config.betas == ["computer-use"])
|
||||||
|
let model = config.makeModel()
|
||||||
|
#expect(model is AnthropicLanguageModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Anthropic with dynamic key provider captures key by reference")
|
||||||
|
func anthropicDynamicKeyProvider() {
|
||||||
|
nonisolated(unsafe) var currentKey = "sk-ant-initial"
|
||||||
|
let config = ProviderConfiguration.anthropic(
|
||||||
|
apiKeyProvider: { currentKey },
|
||||||
|
model: "claude-sonnet-4-20250514"
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(config.apiKey == "sk-ant-initial")
|
||||||
|
currentKey = "sk-ant-rotated"
|
||||||
|
#expect(config.apiKey == "sk-ant-rotated")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Gemini Configuration
|
||||||
|
|
||||||
|
@Test("Gemini configuration stores provider and model")
|
||||||
|
func geminiBasicConfiguration() {
|
||||||
|
let config = ProviderConfiguration.gemini(
|
||||||
|
apiKey: "AIzaSyTest123",
|
||||||
|
model: "gemini-2.0-flash"
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(config.provider == .gemini)
|
||||||
|
#expect(config.model == "gemini-2.0-flash")
|
||||||
|
#expect(config.apiKey == "AIzaSyTest123")
|
||||||
|
#expect(config.hasValidAPIKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Gemini configuration produces a valid LanguageModel")
|
||||||
|
func geminiMakeModel() {
|
||||||
|
let config = ProviderConfiguration.gemini(
|
||||||
|
apiKey: "AIzaSyTest",
|
||||||
|
model: "gemini-2.0-flash"
|
||||||
|
)
|
||||||
|
|
||||||
|
let model = config.makeModel()
|
||||||
|
#expect(model is GeminiLanguageModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Gemini with custom API version")
|
||||||
|
func geminiCustomVersion() {
|
||||||
|
let config = ProviderConfiguration.gemini(
|
||||||
|
apiKey: "AIzaSyTest",
|
||||||
|
model: "gemini-2.0-flash",
|
||||||
|
apiVersion: "v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(config.apiVersion == "v1")
|
||||||
|
let model = config.makeModel()
|
||||||
|
#expect(model is GeminiLanguageModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Gemini with dynamic key provider captures key by reference")
|
||||||
|
func geminiDynamicKeyProvider() {
|
||||||
|
nonisolated(unsafe) var currentKey = "AIza-initial"
|
||||||
|
let config = ProviderConfiguration.gemini(
|
||||||
|
apiKeyProvider: { currentKey },
|
||||||
|
model: "gemini-2.0-flash"
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(config.apiKey == "AIza-initial")
|
||||||
|
currentKey = "AIza-rotated"
|
||||||
|
#expect(config.apiKey == "AIza-rotated")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - OpenAI-Compatible Configuration
|
||||||
|
|
||||||
|
@Test("OpenAI-compatible configuration with custom base URL")
|
||||||
|
func openAICompatibleConfiguration() {
|
||||||
|
let baseURL = URL(string: "https://api.together.xyz/v1/")!
|
||||||
|
let config = ProviderConfiguration.openAICompatible(
|
||||||
|
apiKey: "together-key",
|
||||||
|
model: "meta-llama/Llama-3.1-70B",
|
||||||
|
baseURL: baseURL
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(config.provider == .openAICompatible)
|
||||||
|
#expect(config.model == "meta-llama/Llama-3.1-70B")
|
||||||
|
#expect(config.baseURL == baseURL)
|
||||||
|
let model = config.makeModel()
|
||||||
|
#expect(model is OpenAILanguageModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("OpenAI-compatible with dynamic key provider")
|
||||||
|
func openAICompatibleDynamicKey() {
|
||||||
|
let baseURL = URL(string: "http://localhost:1234/v1/")!
|
||||||
|
nonisolated(unsafe) var currentKey = "local-key"
|
||||||
|
let config = ProviderConfiguration.openAICompatible(
|
||||||
|
apiKeyProvider: { currentKey },
|
||||||
|
model: "local-model",
|
||||||
|
baseURL: baseURL
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(config.apiKey == "local-key")
|
||||||
|
currentKey = "new-local-key"
|
||||||
|
#expect(config.apiKey == "new-local-key")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - API Key Validation
|
||||||
|
|
||||||
|
@Test("Empty API key reports invalid")
|
||||||
|
func emptyAPIKeyInvalid() {
|
||||||
|
let config = ProviderConfiguration.openAI(
|
||||||
|
apiKey: "",
|
||||||
|
model: "gpt-4o"
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(!config.hasValidAPIKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Whitespace-only API key reports invalid")
|
||||||
|
func whitespaceAPIKeyInvalid() {
|
||||||
|
let config = ProviderConfiguration.openAI(
|
||||||
|
apiKey: " \n\t ",
|
||||||
|
model: "gpt-4o"
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(!config.hasValidAPIKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Non-empty API key reports valid")
|
||||||
|
func nonEmptyAPIKeyValid() {
|
||||||
|
let config = ProviderConfiguration.openAI(
|
||||||
|
apiKey: "x",
|
||||||
|
model: "gpt-4o"
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(config.hasValidAPIKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Environment Variable Configuration
|
||||||
|
|
||||||
|
@Test("fromEnvironment creates configuration for each provider")
|
||||||
|
func fromEnvironmentCreatesConfig() {
|
||||||
|
let openAI = ProviderConfiguration.fromEnvironment(
|
||||||
|
provider: .openAI,
|
||||||
|
environmentVariable: "SWIFTDAI_TEST_OPENAI_KEY",
|
||||||
|
model: "gpt-4o"
|
||||||
|
)
|
||||||
|
#expect(openAI.provider == .openAI)
|
||||||
|
#expect(openAI.model == "gpt-4o")
|
||||||
|
|
||||||
|
let anthropic = ProviderConfiguration.fromEnvironment(
|
||||||
|
provider: .anthropic,
|
||||||
|
environmentVariable: "SWIFTDAI_TEST_ANTHROPIC_KEY",
|
||||||
|
model: "claude-sonnet-4-20250514"
|
||||||
|
)
|
||||||
|
#expect(anthropic.provider == .anthropic)
|
||||||
|
|
||||||
|
let gemini = ProviderConfiguration.fromEnvironment(
|
||||||
|
provider: .gemini,
|
||||||
|
environmentVariable: "SWIFTDAI_TEST_GEMINI_KEY",
|
||||||
|
model: "gemini-2.0-flash"
|
||||||
|
)
|
||||||
|
#expect(gemini.provider == .gemini)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("fromEnvironment returns empty key when variable not set")
|
||||||
|
func fromEnvironmentMissingVariable() {
|
||||||
|
let config = ProviderConfiguration.fromEnvironment(
|
||||||
|
provider: .openAI,
|
||||||
|
environmentVariable: "NONEXISTENT_KEY_VAR_SWIFTDBAI_TEST",
|
||||||
|
model: "gpt-4o"
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(!config.hasValidAPIKey)
|
||||||
|
#expect(config.apiKey == "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Provider Enum
|
||||||
|
|
||||||
|
@Test("Provider enum has all expected cases")
|
||||||
|
func providerCases() {
|
||||||
|
let cases = ProviderConfiguration.Provider.allCases
|
||||||
|
#expect(cases.count == 6)
|
||||||
|
#expect(cases.contains(.openAI))
|
||||||
|
#expect(cases.contains(.anthropic))
|
||||||
|
#expect(cases.contains(.gemini))
|
||||||
|
#expect(cases.contains(.openAICompatible))
|
||||||
|
#expect(cases.contains(.ollama))
|
||||||
|
#expect(cases.contains(.llamaCpp))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Cross-Provider Model Creation
|
||||||
|
|
||||||
|
@Test("All providers produce available models")
|
||||||
|
func allProvidersCreateAvailableModels() {
|
||||||
|
let configs: [ProviderConfiguration] = [
|
||||||
|
.openAI(apiKey: "test", model: "gpt-4o"),
|
||||||
|
.anthropic(apiKey: "test", model: "claude-sonnet-4-20250514"),
|
||||||
|
.gemini(apiKey: "test", model: "gemini-2.0-flash"),
|
||||||
|
.openAICompatible(
|
||||||
|
apiKey: "test",
|
||||||
|
model: "local",
|
||||||
|
baseURL: URL(string: "http://localhost:8080/v1/")!
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
for config in configs {
|
||||||
|
let model = config.makeModel()
|
||||||
|
#expect(model.isAvailable, "Model for \(config.provider) should be available")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
397
Tests/SwiftDBAITests/SQLQueryParserTests.swift
Normal file
397
Tests/SwiftDBAITests/SQLQueryParserTests.swift
Normal file
@@ -0,0 +1,397 @@
|
|||||||
|
// SQLQueryParserTests.swift
|
||||||
|
// SwiftDBAITests
|
||||||
|
|
||||||
|
import Testing
|
||||||
|
@testable import SwiftDBAI
|
||||||
|
|
||||||
|
@Suite("SQLQueryParser")
|
||||||
|
struct SQLQueryParserTests {
|
||||||
|
|
||||||
|
let readOnlyParser = SQLQueryParser(allowlist: .readOnly)
|
||||||
|
let standardParser = SQLQueryParser(allowlist: .standard)
|
||||||
|
let unrestrictedParser = SQLQueryParser(allowlist: .unrestricted)
|
||||||
|
|
||||||
|
// MARK: - Extraction from code blocks
|
||||||
|
|
||||||
|
@Test("Extracts SQL from markdown sql code block")
|
||||||
|
func extractFromSQLCodeBlock() throws {
|
||||||
|
let text = """
|
||||||
|
Here's the query to find the top users:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
SELECT name, COUNT(*) as count FROM users GROUP BY name ORDER BY count DESC
|
||||||
|
```
|
||||||
|
|
||||||
|
This will give you the results.
|
||||||
|
"""
|
||||||
|
let result = try readOnlyParser.parse(text)
|
||||||
|
#expect(result.sql == "SELECT name, COUNT(*) as count FROM users GROUP BY name ORDER BY count DESC")
|
||||||
|
#expect(result.operation == .select)
|
||||||
|
#expect(result.requiresConfirmation == false)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Extracts SQL from generic code block")
|
||||||
|
func extractFromGenericCodeBlock() throws {
|
||||||
|
let text = """
|
||||||
|
Here you go:
|
||||||
|
|
||||||
|
```
|
||||||
|
SELECT * FROM products WHERE price > 100
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
let result = try readOnlyParser.parse(text)
|
||||||
|
#expect(result.sql == "SELECT * FROM products WHERE price > 100")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Extracts SQL from labeled text")
|
||||||
|
func extractFromLabel() throws {
|
||||||
|
let text = """
|
||||||
|
I can help with that.
|
||||||
|
SQL: SELECT id, name FROM categories WHERE active = 1
|
||||||
|
That should work.
|
||||||
|
"""
|
||||||
|
let result = try readOnlyParser.parse(text)
|
||||||
|
#expect(result.sql == "SELECT id, name FROM categories WHERE active = 1")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Extracts direct SQL from plain text")
|
||||||
|
func extractDirectSQL() throws {
|
||||||
|
let text = "SELECT COUNT(*) FROM orders WHERE status = 'shipped'"
|
||||||
|
let result = try readOnlyParser.parse(text)
|
||||||
|
#expect(result.sql == "SELECT COUNT(*) FROM orders WHERE status = 'shipped'")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Handles SQL with trailing semicolons")
|
||||||
|
func trailingSemicolon() throws {
|
||||||
|
let text = "```sql\nSELECT * FROM users;\n```"
|
||||||
|
let result = try readOnlyParser.parse(text)
|
||||||
|
#expect(result.sql == "SELECT * FROM users")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Handles multiline SQL in code block")
|
||||||
|
func multilineSQL() throws {
|
||||||
|
let text = """
|
||||||
|
```sql
|
||||||
|
SELECT u.name, COUNT(o.id) as order_count
|
||||||
|
FROM users u
|
||||||
|
JOIN orders o ON u.id = o.user_id
|
||||||
|
GROUP BY u.name
|
||||||
|
ORDER BY order_count DESC
|
||||||
|
LIMIT 10
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
let result = try readOnlyParser.parse(text)
|
||||||
|
#expect(result.sql.contains("SELECT u.name"))
|
||||||
|
#expect(result.sql.contains("LIMIT 10"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Handles WITH (CTE) queries as SELECT")
|
||||||
|
func cteQuery() throws {
|
||||||
|
let text = """
|
||||||
|
```sql
|
||||||
|
WITH top_users AS (
|
||||||
|
SELECT user_id, COUNT(*) as cnt FROM orders GROUP BY user_id
|
||||||
|
)
|
||||||
|
SELECT * FROM top_users WHERE cnt > 5
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
let result = try readOnlyParser.parse(text)
|
||||||
|
#expect(result.operation == .select)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - No SQL found
|
||||||
|
|
||||||
|
@Test("Throws noSQLFound for text without SQL")
|
||||||
|
func noSQLFound() throws {
|
||||||
|
let text = "I'm sorry, I can't help with that request."
|
||||||
|
#expect(throws: SQLParsingError.noSQLFound) {
|
||||||
|
try readOnlyParser.parse(text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Throws noSQLFound for empty input")
|
||||||
|
func emptyInput() throws {
|
||||||
|
#expect(throws: SQLParsingError.noSQLFound) {
|
||||||
|
try readOnlyParser.parse("")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Operation detection
|
||||||
|
|
||||||
|
@Test("Detects INSERT operation")
|
||||||
|
func detectInsert() throws {
|
||||||
|
let text = "```sql\nINSERT INTO users (name) VALUES ('Alice')\n```"
|
||||||
|
let result = try standardParser.parse(text)
|
||||||
|
#expect(result.operation == .insert)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Detects UPDATE operation")
|
||||||
|
func detectUpdate() throws {
|
||||||
|
let text = "```sql\nUPDATE users SET name = 'Bob' WHERE id = 1\n```"
|
||||||
|
let result = try standardParser.parse(text)
|
||||||
|
#expect(result.operation == .update)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Detects DELETE operation and requires confirmation")
|
||||||
|
func detectDeleteRequiresConfirmation() throws {
|
||||||
|
let text = "```sql\nDELETE FROM users WHERE id = 99\n```"
|
||||||
|
let result = try unrestrictedParser.parse(text)
|
||||||
|
#expect(result.operation == .delete)
|
||||||
|
#expect(result.requiresConfirmation == true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Allowlist enforcement
|
||||||
|
|
||||||
|
@Test("Rejects INSERT on read-only allowlist")
|
||||||
|
func rejectInsertOnReadOnly() throws {
|
||||||
|
let text = "```sql\nINSERT INTO users (name) VALUES ('Mallory')\n```"
|
||||||
|
#expect(throws: SQLParsingError.operationNotAllowed(.insert)) {
|
||||||
|
try readOnlyParser.parse(text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Rejects UPDATE on read-only allowlist")
|
||||||
|
func rejectUpdateOnReadOnly() {
|
||||||
|
let text = "```sql\nUPDATE users SET name = 'Eve' WHERE id = 1\n```"
|
||||||
|
#expect(throws: SQLParsingError.operationNotAllowed(.update)) {
|
||||||
|
try readOnlyParser.parse(text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Rejects DELETE on standard allowlist")
|
||||||
|
func rejectDeleteOnStandard() {
|
||||||
|
let text = "```sql\nDELETE FROM users WHERE id = 1\n```"
|
||||||
|
#expect(throws: SQLParsingError.operationNotAllowed(.delete)) {
|
||||||
|
try standardParser.parse(text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Dangerous operations
|
||||||
|
|
||||||
|
@Test("Rejects DROP TABLE")
|
||||||
|
func rejectDrop() {
|
||||||
|
let text = "```sql\nDROP TABLE users\n```"
|
||||||
|
#expect(throws: SQLParsingError.dangerousOperation("DROP")) {
|
||||||
|
try unrestrictedParser.parse(text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Rejects ALTER TABLE")
|
||||||
|
func rejectAlter() {
|
||||||
|
let text = "```sql\nALTER TABLE users ADD COLUMN age INTEGER\n```"
|
||||||
|
#expect(throws: SQLParsingError.dangerousOperation("ALTER")) {
|
||||||
|
try unrestrictedParser.parse(text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Rejects PRAGMA")
|
||||||
|
func rejectPragma() {
|
||||||
|
let text = "```sql\nPRAGMA table_info(users)\n```"
|
||||||
|
#expect(throws: SQLParsingError.dangerousOperation("PRAGMA")) {
|
||||||
|
try unrestrictedParser.parse(text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Does not match dangerous keywords inside identifiers")
|
||||||
|
func noFalsePositiveOnSubstring() throws {
|
||||||
|
// "DROPDOWN" contains "DROP" as substring but is not the keyword
|
||||||
|
let text = "SELECT dropdown_value FROM settings"
|
||||||
|
let result = try readOnlyParser.parse(text)
|
||||||
|
#expect(result.sql.contains("dropdown_value"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Multiple statements
|
||||||
|
|
||||||
|
@Test("Rejects multiple statements separated by semicolons")
|
||||||
|
func rejectMultipleStatements() {
|
||||||
|
let text = "```sql\nSELECT * FROM users; SELECT * FROM orders\n```"
|
||||||
|
#expect(throws: SQLParsingError.multipleStatements) {
|
||||||
|
try readOnlyParser.parse(text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Allows semicolons inside string literals")
|
||||||
|
func allowSemicolonInString() throws {
|
||||||
|
let text = "SELECT * FROM users WHERE bio = 'hello; world'"
|
||||||
|
let result = try readOnlyParser.parse(text)
|
||||||
|
#expect(result.sql.contains("hello; world"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - ParsedSQL equality
|
||||||
|
|
||||||
|
@Test("ParsedSQL equality works")
|
||||||
|
func parsedSQLEquality() {
|
||||||
|
let a = ParsedSQL(sql: "SELECT 1", operation: .select)
|
||||||
|
let b = ParsedSQL(sql: "SELECT 1", operation: .select)
|
||||||
|
#expect(a == b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Error descriptions
|
||||||
|
|
||||||
|
@Test("Error descriptions are meaningful")
|
||||||
|
func errorDescriptions() {
|
||||||
|
#expect(SQLParsingError.noSQLFound.description.contains("No SQL"))
|
||||||
|
#expect(SQLParsingError.operationNotAllowed(.insert).description.contains("INSERT"))
|
||||||
|
#expect(SQLParsingError.dangerousOperation("DROP").description.contains("DROP"))
|
||||||
|
#expect(SQLParsingError.multipleStatements.description.contains("single"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - MutationPolicy integration
|
||||||
|
|
||||||
|
@Test("MutationPolicy allows INSERT on permitted table")
|
||||||
|
func mutationPolicyAllowsInsertOnPermittedTable() throws {
|
||||||
|
let policy = MutationPolicy(
|
||||||
|
allowedOperations: [.insert, .update],
|
||||||
|
allowedTables: ["orders", "order_items"]
|
||||||
|
)
|
||||||
|
let parser = SQLQueryParser(mutationPolicy: policy)
|
||||||
|
let text = "```sql\nINSERT INTO orders (product, qty) VALUES ('Widget', 3)\n```"
|
||||||
|
let result = try parser.parse(text)
|
||||||
|
#expect(result.operation == .insert)
|
||||||
|
#expect(result.requiresConfirmation == false)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("MutationPolicy rejects INSERT on non-permitted table")
|
||||||
|
func mutationPolicyRejectsInsertOnForbiddenTable() {
|
||||||
|
let policy = MutationPolicy(
|
||||||
|
allowedOperations: [.insert, .update],
|
||||||
|
allowedTables: ["orders"]
|
||||||
|
)
|
||||||
|
let parser = SQLQueryParser(mutationPolicy: policy)
|
||||||
|
let text = "```sql\nINSERT INTO users (name) VALUES ('Alice')\n```"
|
||||||
|
#expect(throws: SQLParsingError.tableNotAllowed(table: "users", operation: .insert)) {
|
||||||
|
try parser.parse(text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("MutationPolicy rejects UPDATE on non-permitted table")
|
||||||
|
func mutationPolicyRejectsUpdateOnForbiddenTable() {
|
||||||
|
let policy = MutationPolicy(
|
||||||
|
allowedOperations: [.insert, .update],
|
||||||
|
allowedTables: ["orders"]
|
||||||
|
)
|
||||||
|
let parser = SQLQueryParser(mutationPolicy: policy)
|
||||||
|
let text = "```sql\nUPDATE users SET name = 'Bob' WHERE id = 1\n```"
|
||||||
|
#expect(throws: SQLParsingError.tableNotAllowed(table: "users", operation: .update)) {
|
||||||
|
try parser.parse(text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("MutationPolicy rejects DELETE on non-permitted table")
|
||||||
|
func mutationPolicyRejectsDeleteOnForbiddenTable() {
|
||||||
|
let policy = MutationPolicy(
|
||||||
|
allowedOperations: [.insert, .update, .delete],
|
||||||
|
allowedTables: ["temp_data"]
|
||||||
|
)
|
||||||
|
let parser = SQLQueryParser(mutationPolicy: policy)
|
||||||
|
let text = "```sql\nDELETE FROM users WHERE id = 99\n```"
|
||||||
|
#expect(throws: SQLParsingError.tableNotAllowed(table: "users", operation: .delete)) {
|
||||||
|
try parser.parse(text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("MutationPolicy allows mutation on any table when allowedTables is nil")
|
||||||
|
func mutationPolicyAllowsAllTablesWhenNil() throws {
|
||||||
|
let policy = MutationPolicy(allowedOperations: [.insert, .update])
|
||||||
|
let parser = SQLQueryParser(mutationPolicy: policy)
|
||||||
|
let text = "```sql\nINSERT INTO any_table (col) VALUES ('val')\n```"
|
||||||
|
let result = try parser.parse(text)
|
||||||
|
#expect(result.operation == .insert)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("MutationPolicy SELECT is never restricted by table allowlist")
|
||||||
|
func mutationPolicySelectIgnoresTableRestrictions() throws {
|
||||||
|
let policy = MutationPolicy(
|
||||||
|
allowedOperations: [.insert],
|
||||||
|
allowedTables: ["orders"]
|
||||||
|
)
|
||||||
|
let parser = SQLQueryParser(mutationPolicy: policy)
|
||||||
|
// SELECT from a table NOT in allowedTables — should still work
|
||||||
|
let text = "```sql\nSELECT * FROM users\n```"
|
||||||
|
let result = try parser.parse(text)
|
||||||
|
#expect(result.operation == .select)
|
||||||
|
#expect(result.requiresConfirmation == false)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("MutationPolicy DELETE requires confirmation by default")
|
||||||
|
func mutationPolicyDeleteRequiresConfirmation() throws {
|
||||||
|
let policy = MutationPolicy(allowedOperations: [.delete])
|
||||||
|
let parser = SQLQueryParser(mutationPolicy: policy)
|
||||||
|
let text = "```sql\nDELETE FROM users WHERE id = 1\n```"
|
||||||
|
let result = try parser.parse(text)
|
||||||
|
#expect(result.operation == .delete)
|
||||||
|
#expect(result.requiresConfirmation == true)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("MutationPolicy DELETE skips confirmation when configured")
|
||||||
|
func mutationPolicyDeleteNoConfirmation() throws {
|
||||||
|
let policy = MutationPolicy(
|
||||||
|
allowedOperations: [.delete],
|
||||||
|
requiresDestructiveConfirmation: false
|
||||||
|
)
|
||||||
|
let parser = SQLQueryParser(mutationPolicy: policy)
|
||||||
|
let text = "```sql\nDELETE FROM users WHERE id = 1\n```"
|
||||||
|
let result = try parser.parse(text)
|
||||||
|
#expect(result.operation == .delete)
|
||||||
|
#expect(result.requiresConfirmation == false)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("MutationPolicy readOnly preset rejects all mutations")
|
||||||
|
func mutationPolicyReadOnlyRejectsAll() {
|
||||||
|
let parser = SQLQueryParser(mutationPolicy: .readOnly)
|
||||||
|
#expect(throws: SQLParsingError.operationNotAllowed(.insert)) {
|
||||||
|
try parser.parse("INSERT INTO t (a) VALUES (1)")
|
||||||
|
}
|
||||||
|
#expect(throws: SQLParsingError.operationNotAllowed(.update)) {
|
||||||
|
try parser.parse("UPDATE t SET a = 1")
|
||||||
|
}
|
||||||
|
#expect(throws: SQLParsingError.operationNotAllowed(.delete)) {
|
||||||
|
try parser.parse("DELETE FROM t WHERE id = 1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("MutationPolicy table matching is case-insensitive")
|
||||||
|
func mutationPolicyTableCaseInsensitive() throws {
|
||||||
|
let policy = MutationPolicy(
|
||||||
|
allowedOperations: [.insert],
|
||||||
|
allowedTables: ["Orders"]
|
||||||
|
)
|
||||||
|
let parser = SQLQueryParser(mutationPolicy: policy)
|
||||||
|
let text = "INSERT INTO orders (product) VALUES ('Widget')"
|
||||||
|
let result = try parser.parse(text)
|
||||||
|
#expect(result.operation == .insert)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("MutationPolicy handles quoted table names")
|
||||||
|
func mutationPolicyQuotedTableNames() throws {
|
||||||
|
let policy = MutationPolicy(
|
||||||
|
allowedOperations: [.insert, .update],
|
||||||
|
allowedTables: ["order_items"]
|
||||||
|
)
|
||||||
|
let parser = SQLQueryParser(mutationPolicy: policy)
|
||||||
|
|
||||||
|
// Backtick-quoted
|
||||||
|
let backtick = "INSERT INTO `order_items` (qty) VALUES (5)"
|
||||||
|
let r1 = try parser.parse(backtick)
|
||||||
|
#expect(r1.operation == .insert)
|
||||||
|
|
||||||
|
// Double-quote-quoted
|
||||||
|
let doubleQuote = "UPDATE \"order_items\" SET qty = 10 WHERE id = 1"
|
||||||
|
let r2 = try parser.parse(doubleQuote)
|
||||||
|
#expect(r2.operation == .update)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Error description for tableNotAllowed is meaningful")
|
||||||
|
func tableNotAllowedDescription() {
|
||||||
|
let error = SQLParsingError.tableNotAllowed(table: "secret", operation: .delete)
|
||||||
|
#expect(error.description.contains("secret"))
|
||||||
|
#expect(error.description.contains("DELETE"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Error description for confirmationRequired is meaningful")
|
||||||
|
func confirmationRequiredDescription() {
|
||||||
|
let error = SQLParsingError.confirmationRequired(sql: "DELETE FROM x", operation: .delete)
|
||||||
|
#expect(error.description.contains("DELETE"))
|
||||||
|
#expect(error.description.contains("confirmation"))
|
||||||
|
}
|
||||||
|
}
|
||||||
234
Tests/SwiftDBAITests/SchemaIntrospectorTests.swift
Normal file
234
Tests/SwiftDBAITests/SchemaIntrospectorTests.swift
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
// SchemaIntrospectorTests.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
|
||||||
|
import Testing
|
||||||
|
import GRDB
|
||||||
|
@testable import SwiftDBAI
|
||||||
|
|
||||||
|
@Suite("SchemaIntrospector")
|
||||||
|
struct SchemaIntrospectorTests {
|
||||||
|
|
||||||
|
// MARK: - Helper
|
||||||
|
|
||||||
|
/// Creates an in-memory database with a sample schema for testing.
|
||||||
|
private func makeTestDatabase() throws -> DatabaseQueue {
|
||||||
|
let db = try DatabaseQueue(configuration: {
|
||||||
|
var config = Configuration()
|
||||||
|
config.foreignKeysEnabled = true
|
||||||
|
return config
|
||||||
|
}())
|
||||||
|
|
||||||
|
try db.write { db in
|
||||||
|
try db.execute(sql: """
|
||||||
|
CREATE TABLE authors (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
email TEXT UNIQUE
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
try db.execute(sql: """
|
||||||
|
CREATE TABLE books (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
title TEXT NOT NULL,
|
||||||
|
author_id INTEGER NOT NULL REFERENCES authors(id) ON DELETE CASCADE,
|
||||||
|
published_date TEXT,
|
||||||
|
price REAL DEFAULT 9.99
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
|
||||||
|
try db.execute(sql: """
|
||||||
|
CREATE INDEX idx_books_author ON books(author_id);
|
||||||
|
""")
|
||||||
|
|
||||||
|
try db.execute(sql: """
|
||||||
|
CREATE INDEX idx_books_title ON books(title);
|
||||||
|
""")
|
||||||
|
|
||||||
|
try db.execute(sql: """
|
||||||
|
CREATE TABLE reviews (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
book_id INTEGER NOT NULL REFERENCES books(id),
|
||||||
|
rating INTEGER NOT NULL,
|
||||||
|
comment TEXT
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
}
|
||||||
|
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Tests
|
||||||
|
|
||||||
|
@Test("Discovers all user tables")
|
||||||
|
func discoversAllTables() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||||
|
|
||||||
|
#expect(schema.tableNames.count == 3)
|
||||||
|
#expect(schema.tableNames.contains("authors"))
|
||||||
|
#expect(schema.tableNames.contains("books"))
|
||||||
|
#expect(schema.tableNames.contains("reviews"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Excludes sqlite_ internal tables")
|
||||||
|
func excludesInternalTables() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||||
|
|
||||||
|
for name in schema.tableNames {
|
||||||
|
#expect(!name.hasPrefix("sqlite_"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Introspects column names and types")
|
||||||
|
func introspectsColumns() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||||
|
|
||||||
|
let books = try #require(schema.tables["books"])
|
||||||
|
#expect(books.columns.count == 5)
|
||||||
|
|
||||||
|
let titleCol = try #require(books.columns.first { $0.name == "title" })
|
||||||
|
#expect(titleCol.type == "TEXT")
|
||||||
|
#expect(titleCol.isNotNull == true)
|
||||||
|
#expect(titleCol.isPrimaryKey == false)
|
||||||
|
|
||||||
|
let priceCol = try #require(books.columns.first { $0.name == "price" })
|
||||||
|
#expect(priceCol.type == "REAL")
|
||||||
|
#expect(priceCol.defaultValue == "9.99")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Detects primary keys")
|
||||||
|
func detectsPrimaryKeys() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||||
|
|
||||||
|
let authors = try #require(schema.tables["authors"])
|
||||||
|
#expect(authors.primaryKey == ["id"])
|
||||||
|
|
||||||
|
let idCol = try #require(authors.columns.first { $0.name == "id" })
|
||||||
|
#expect(idCol.isPrimaryKey == true)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Detects foreign keys")
|
||||||
|
func detectsForeignKeys() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||||
|
|
||||||
|
let books = try #require(schema.tables["books"])
|
||||||
|
#expect(books.foreignKeys.count == 1)
|
||||||
|
|
||||||
|
let fk = books.foreignKeys[0]
|
||||||
|
#expect(fk.fromColumn == "author_id")
|
||||||
|
#expect(fk.toTable == "authors")
|
||||||
|
#expect(fk.toColumn == "id")
|
||||||
|
#expect(fk.onDelete == "CASCADE")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Detects indexes")
|
||||||
|
func detectsIndexes() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||||
|
|
||||||
|
let books = try #require(schema.tables["books"])
|
||||||
|
let indexNames = books.indexes.map(\.name)
|
||||||
|
#expect(indexNames.contains("idx_books_author"))
|
||||||
|
#expect(indexNames.contains("idx_books_title"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Detects NOT NULL constraints")
|
||||||
|
func detectsNotNull() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||||
|
|
||||||
|
let reviews = try #require(schema.tables["reviews"])
|
||||||
|
let ratingCol = try #require(reviews.columns.first { $0.name == "rating" })
|
||||||
|
#expect(ratingCol.isNotNull == true)
|
||||||
|
|
||||||
|
let commentCol = try #require(reviews.columns.first { $0.name == "comment" })
|
||||||
|
#expect(commentCol.isNotNull == false)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Generates LLM-friendly schema description")
|
||||||
|
func generatesSchemaDescription() async throws {
|
||||||
|
let db = try makeTestDatabase()
|
||||||
|
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||||
|
|
||||||
|
let description = schema.schemaDescription
|
||||||
|
#expect(description.contains("TABLE authors"))
|
||||||
|
#expect(description.contains("TABLE books"))
|
||||||
|
#expect(description.contains("FOREIGN KEY"))
|
||||||
|
#expect(description.contains("REFERENCES authors(id)"))
|
||||||
|
#expect(description.contains("INDEX idx_books_author"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Handles empty database")
|
||||||
|
func handlesEmptyDatabase() async throws {
|
||||||
|
let db = try DatabaseQueue()
|
||||||
|
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||||
|
|
||||||
|
#expect(schema.tables.isEmpty)
|
||||||
|
#expect(schema.tableNames.isEmpty)
|
||||||
|
#expect(schema.schemaDescription.isEmpty)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Handles composite primary keys")
|
||||||
|
func handlesCompositePrimaryKey() async throws {
|
||||||
|
let db = try DatabaseQueue()
|
||||||
|
try await db.write { db in
|
||||||
|
try db.execute(sql: """
|
||||||
|
CREATE TABLE book_tags (
|
||||||
|
book_id INTEGER NOT NULL,
|
||||||
|
tag_id INTEGER NOT NULL,
|
||||||
|
PRIMARY KEY (book_id, tag_id)
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
}
|
||||||
|
|
||||||
|
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||||
|
let bookTags = try #require(schema.tables["book_tags"])
|
||||||
|
#expect(bookTags.primaryKey.count == 2)
|
||||||
|
#expect(bookTags.primaryKey.contains("book_id"))
|
||||||
|
#expect(bookTags.primaryKey.contains("tag_id"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Handles tables with no explicit types (SQLite dynamic typing)")
|
||||||
|
func handlesDynamicTyping() async throws {
|
||||||
|
let db = try DatabaseQueue()
|
||||||
|
try await db.write { db in
|
||||||
|
try db.execute(sql: """
|
||||||
|
CREATE TABLE flexible (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
data,
|
||||||
|
info BLOB
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
}
|
||||||
|
|
||||||
|
let schema = try await SchemaIntrospector.introspect(database: db)
|
||||||
|
let flexible = try #require(schema.tables["flexible"])
|
||||||
|
|
||||||
|
let dataCol = try #require(flexible.columns.first { $0.name == "data" })
|
||||||
|
#expect(dataCol.type == "") // No declared type
|
||||||
|
|
||||||
|
let infoCol = try #require(flexible.columns.first { $0.name == "info" })
|
||||||
|
#expect(infoCol.type == "BLOB")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Synchronous introspection works within database access")
|
||||||
|
func synchronousIntrospection() async throws {
|
||||||
|
let db = try DatabaseQueue()
|
||||||
|
try await db.write { db in
|
||||||
|
try db.execute(sql: "CREATE TABLE test (id INTEGER PRIMARY KEY, val TEXT);")
|
||||||
|
}
|
||||||
|
|
||||||
|
let schema = try await db.read { db in
|
||||||
|
try SchemaIntrospector.introspect(db: db)
|
||||||
|
}
|
||||||
|
|
||||||
|
#expect(schema.tableNames == ["test"])
|
||||||
|
let table = try #require(schema.tables["test"])
|
||||||
|
#expect(table.columns.count == 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
133
Tests/SwiftDBAITests/ScrollableDataTableViewTests.swift
Normal file
133
Tests/SwiftDBAITests/ScrollableDataTableViewTests.swift
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
// ScrollableDataTableViewTests.swift
|
||||||
|
// SwiftDBAITests
|
||||||
|
//
|
||||||
|
// Tests for the ScrollableDataTableView component.
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import Testing
|
||||||
|
@testable import SwiftDBAI
|
||||||
|
|
||||||
|
@Suite("ScrollableDataTableView")
|
||||||
|
@MainActor
|
||||||
|
struct ScrollableDataTableViewTests {
|
||||||
|
|
||||||
|
// MARK: - Test Helpers
|
||||||
|
|
||||||
|
private func makeDataTable(
|
||||||
|
columnNames: [String] = ["id", "name", "score"],
|
||||||
|
inferredTypes: [DataTable.InferredType] = [.integer, .text, .real],
|
||||||
|
rowCount: Int = 5
|
||||||
|
) -> DataTable {
|
||||||
|
let columns = columnNames.enumerated().map { idx, name in
|
||||||
|
DataTable.Column(name: name, index: idx, inferredType: inferredTypes[idx])
|
||||||
|
}
|
||||||
|
let rows = (0..<rowCount).map { i in
|
||||||
|
DataTable.Row(
|
||||||
|
id: i,
|
||||||
|
values: [
|
||||||
|
.integer(Int64(i + 1)),
|
||||||
|
.text("Item \(i + 1)"),
|
||||||
|
.real(Double(i) * 10.5),
|
||||||
|
],
|
||||||
|
columnNames: columnNames
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return DataTable(columns: columns, rows: rows, sql: "SELECT * FROM test", executionTime: 0.015)
|
||||||
|
}
|
||||||
|
|
||||||
|
private func makeEmptyDataTable() -> DataTable {
|
||||||
|
DataTable(columns: [], rows: [], sql: "", executionTime: 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Initialization Tests
|
||||||
|
|
||||||
|
@Test("Initializes with default parameters")
|
||||||
|
func initWithDefaults() {
|
||||||
|
let table = makeDataTable()
|
||||||
|
let view = ScrollableDataTableView(dataTable: table)
|
||||||
|
|
||||||
|
#expect(view.minimumColumnWidth == 80)
|
||||||
|
#expect(view.maximumColumnWidth == 250)
|
||||||
|
#expect(view.showAlternatingRows == true)
|
||||||
|
#expect(view.showFooter == true)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Initializes with custom parameters")
|
||||||
|
func initWithCustomParams() {
|
||||||
|
let table = makeDataTable()
|
||||||
|
let view = ScrollableDataTableView(
|
||||||
|
dataTable: table,
|
||||||
|
minimumColumnWidth: 100,
|
||||||
|
maximumColumnWidth: 300,
|
||||||
|
showAlternatingRows: false,
|
||||||
|
showFooter: false
|
||||||
|
)
|
||||||
|
|
||||||
|
#expect(view.minimumColumnWidth == 100)
|
||||||
|
#expect(view.maximumColumnWidth == 300)
|
||||||
|
#expect(view.showAlternatingRows == false)
|
||||||
|
#expect(view.showFooter == false)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Handles empty data table")
|
||||||
|
func handlesEmptyTable() {
|
||||||
|
let table = makeEmptyDataTable()
|
||||||
|
let view = ScrollableDataTableView(dataTable: table)
|
||||||
|
#expect(view.dataTable.isEmpty)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Handles single row table")
|
||||||
|
func handlesSingleRow() {
|
||||||
|
let table = makeDataTable(rowCount: 1)
|
||||||
|
let view = ScrollableDataTableView(dataTable: table)
|
||||||
|
#expect(view.dataTable.rowCount == 1)
|
||||||
|
#expect(view.dataTable.columnCount == 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Handles single column table")
|
||||||
|
func handlesSingleColumn() {
|
||||||
|
let columns = [DataTable.Column(name: "count", index: 0, inferredType: .integer)]
|
||||||
|
let rows = [
|
||||||
|
DataTable.Row(id: 0, values: [.integer(42)], columnNames: ["count"])
|
||||||
|
]
|
||||||
|
let table = DataTable(columns: columns, rows: rows, sql: "SELECT count(*) FROM t", executionTime: 0.001)
|
||||||
|
let view = ScrollableDataTableView(dataTable: table)
|
||||||
|
#expect(view.dataTable.columnCount == 1)
|
||||||
|
#expect(view.dataTable.rowCount == 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Handles large number of rows")
|
||||||
|
func handlesLargeRowCount() {
|
||||||
|
let table = makeDataTable(rowCount: 1000)
|
||||||
|
let view = ScrollableDataTableView(dataTable: table)
|
||||||
|
#expect(view.dataTable.rowCount == 1000)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Handles null values in cells")
|
||||||
|
func handlesNullValues() {
|
||||||
|
let columns = [
|
||||||
|
DataTable.Column(name: "name", index: 0, inferredType: .text),
|
||||||
|
DataTable.Column(name: "value", index: 1, inferredType: .null),
|
||||||
|
]
|
||||||
|
let rows = [
|
||||||
|
DataTable.Row(id: 0, values: [.text("test"), .null], columnNames: ["name", "value"])
|
||||||
|
]
|
||||||
|
let table = DataTable(columns: columns, rows: rows)
|
||||||
|
let view = ScrollableDataTableView(dataTable: table)
|
||||||
|
#expect(view.dataTable.rows[0][1] == .null)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Handles blob values in cells")
|
||||||
|
func handlesBlobValues() {
|
||||||
|
let columns = [
|
||||||
|
DataTable.Column(name: "data", index: 0, inferredType: .blob),
|
||||||
|
]
|
||||||
|
let blobData = Data([0x00, 0xFF, 0xAB])
|
||||||
|
let rows = [
|
||||||
|
DataTable.Row(id: 0, values: [.blob(blobData)], columnNames: ["data"])
|
||||||
|
]
|
||||||
|
let table = DataTable(columns: columns, rows: rows)
|
||||||
|
let view = ScrollableDataTableView(dataTable: table)
|
||||||
|
#expect(view.dataTable.rows[0][0] == QueryResult.Value.blob(blobData))
|
||||||
|
}
|
||||||
|
}
|
||||||
301
Tests/SwiftDBAITests/TextSummaryRendererTests.swift
Normal file
301
Tests/SwiftDBAITests/TextSummaryRendererTests.swift
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
// TextSummaryRendererTests.swift
|
||||||
|
// SwiftDBAI
|
||||||
|
|
||||||
|
import AnyLanguageModel
|
||||||
|
import Testing
|
||||||
|
import Foundation
|
||||||
|
@testable import SwiftDBAI
|
||||||
|
|
||||||
|
@Suite("TextSummaryRenderer")
|
||||||
|
struct TextSummaryRendererTests {
|
||||||
|
|
||||||
|
// MARK: - QueryResult.Value Tests
|
||||||
|
|
||||||
|
@Test("Value description renders correctly")
|
||||||
|
func valueDescriptions() {
|
||||||
|
#expect(QueryResult.Value.text("hello").description == "hello")
|
||||||
|
#expect(QueryResult.Value.integer(42).description == "42")
|
||||||
|
#expect(QueryResult.Value.real(3.14).description == "3.14")
|
||||||
|
#expect(QueryResult.Value.null.description == "NULL")
|
||||||
|
#expect(QueryResult.Value.blob(Data([0x01, 0x02])).description == "<2 bytes>")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Value doubleValue extracts numeric values")
|
||||||
|
func valueDoubleValues() {
|
||||||
|
#expect(QueryResult.Value.integer(42).doubleValue == 42.0)
|
||||||
|
#expect(QueryResult.Value.real(3.14).doubleValue == 3.14)
|
||||||
|
#expect(QueryResult.Value.text("100").doubleValue == 100.0)
|
||||||
|
#expect(QueryResult.Value.text("not a number").doubleValue == nil)
|
||||||
|
#expect(QueryResult.Value.null.doubleValue == nil)
|
||||||
|
#expect(QueryResult.Value.blob(Data()).doubleValue == nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Value isNull works correctly")
|
||||||
|
func valueIsNull() {
|
||||||
|
#expect(QueryResult.Value.null.isNull == true)
|
||||||
|
#expect(QueryResult.Value.text("").isNull == false)
|
||||||
|
#expect(QueryResult.Value.integer(0).isNull == false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - QueryResult Tests
|
||||||
|
|
||||||
|
@Test("Empty result has correct properties")
|
||||||
|
func emptyResult() {
|
||||||
|
let result = QueryResult(
|
||||||
|
columns: ["id", "name"],
|
||||||
|
rows: [],
|
||||||
|
sql: "SELECT id, name FROM users",
|
||||||
|
executionTime: 0.01
|
||||||
|
)
|
||||||
|
#expect(result.rowCount == 0)
|
||||||
|
#expect(result.isAggregate == false)
|
||||||
|
#expect(result.tabularDescription == "(empty result set)")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Single aggregate result is detected")
|
||||||
|
func aggregateDetection() {
|
||||||
|
let result = QueryResult(
|
||||||
|
columns: ["COUNT(*)"],
|
||||||
|
rows: [["COUNT(*)": .integer(42)]],
|
||||||
|
sql: "SELECT COUNT(*) FROM users",
|
||||||
|
executionTime: 0.01
|
||||||
|
)
|
||||||
|
#expect(result.isAggregate == true)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Multi-row result is not aggregate")
|
||||||
|
func nonAggregateDetection() {
|
||||||
|
let result = QueryResult(
|
||||||
|
columns: ["name"],
|
||||||
|
rows: [
|
||||||
|
["name": .text("Alice")],
|
||||||
|
["name": .text("Bob")],
|
||||||
|
],
|
||||||
|
sql: "SELECT name FROM users",
|
||||||
|
executionTime: 0.01
|
||||||
|
)
|
||||||
|
#expect(result.isAggregate == false)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Tabular description formats correctly")
|
||||||
|
func tabularDescription() {
|
||||||
|
let result = QueryResult(
|
||||||
|
columns: ["id", "name"],
|
||||||
|
rows: [
|
||||||
|
["id": .integer(1), "name": .text("Alice")],
|
||||||
|
["id": .integer(2), "name": .text("Bob")],
|
||||||
|
],
|
||||||
|
sql: "SELECT id, name FROM users",
|
||||||
|
executionTime: 0.01
|
||||||
|
)
|
||||||
|
let desc = result.tabularDescription
|
||||||
|
#expect(desc.contains("id | name"))
|
||||||
|
#expect(desc.contains("1 | Alice"))
|
||||||
|
#expect(desc.contains("2 | Bob"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("values(forColumn:) extracts column values")
|
||||||
|
func valuesForColumn() {
|
||||||
|
let result = QueryResult(
|
||||||
|
columns: ["name"],
|
||||||
|
rows: [
|
||||||
|
["name": .text("Alice")],
|
||||||
|
["name": .text("Bob")],
|
||||||
|
],
|
||||||
|
sql: "SELECT name FROM users",
|
||||||
|
executionTime: 0.01
|
||||||
|
)
|
||||||
|
let values = result.values(forColumn: "name")
|
||||||
|
#expect(values.count == 2)
|
||||||
|
#expect(values[0] == .text("Alice"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Local Summary Tests (no LLM required)
|
||||||
|
|
||||||
|
@Test("Local summary for empty result")
|
||||||
|
func localSummaryEmpty() {
|
||||||
|
let result = makeResult(columns: ["id"], rows: [])
|
||||||
|
let renderer = makeMockRenderer()
|
||||||
|
let summary = renderer.localSummary(result: result, userQuestion: "Any users?")
|
||||||
|
#expect(summary == "No results found for your query.")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Local summary for single aggregate")
|
||||||
|
func localSummarySingleAggregate() {
|
||||||
|
let result = makeResult(
|
||||||
|
columns: ["COUNT(*)"],
|
||||||
|
rows: [["COUNT(*)": .integer(42)]]
|
||||||
|
)
|
||||||
|
let renderer = makeMockRenderer()
|
||||||
|
let summary = renderer.localSummary(result: result, userQuestion: "How many?")
|
||||||
|
#expect(summary.contains("42"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Local summary for multiple aggregates")
|
||||||
|
func localSummaryMultipleAggregates() {
|
||||||
|
let result = makeResult(
|
||||||
|
columns: ["COUNT(*)", "AVG(price)"],
|
||||||
|
rows: [["COUNT(*)": .integer(10), "AVG(price)": .real(25.5)]]
|
||||||
|
)
|
||||||
|
let renderer = makeMockRenderer()
|
||||||
|
let summary = renderer.localSummary(result: result, userQuestion: "Stats?")
|
||||||
|
#expect(summary.contains("count"))
|
||||||
|
#expect(summary.contains("average price"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Local summary for single record")
|
||||||
|
func localSummarySingleRecord() {
|
||||||
|
let result = makeResult(
|
||||||
|
columns: ["name", "email"],
|
||||||
|
rows: [["name": .text("Alice"), "email": .text("alice@example.com")]]
|
||||||
|
)
|
||||||
|
let renderer = makeMockRenderer()
|
||||||
|
let summary = renderer.localSummary(result: result, userQuestion: "Who?")
|
||||||
|
#expect(summary.contains("1 result"))
|
||||||
|
#expect(summary.contains("Alice"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Local summary for multiple records with name column")
|
||||||
|
func localSummaryMultipleWithNames() {
|
||||||
|
let result = makeResult(
|
||||||
|
columns: ["name", "age"],
|
||||||
|
rows: [
|
||||||
|
["name": .text("Alice"), "age": .integer(30)],
|
||||||
|
["name": .text("Bob"), "age": .integer(25)],
|
||||||
|
["name": .text("Charlie"), "age": .integer(35)],
|
||||||
|
["name": .text("Diana"), "age": .integer(28)],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
let renderer = makeMockRenderer()
|
||||||
|
let summary = renderer.localSummary(result: result, userQuestion: "List users")
|
||||||
|
#expect(summary.contains("4 results"))
|
||||||
|
#expect(summary.contains("Alice"))
|
||||||
|
#expect(summary.contains("1 more"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Local summary for mutation result")
|
||||||
|
func localSummaryMutation() {
|
||||||
|
let result = QueryResult(
|
||||||
|
columns: [],
|
||||||
|
rows: [],
|
||||||
|
sql: "INSERT INTO users (name) VALUES ('Test')",
|
||||||
|
executionTime: 0.01,
|
||||||
|
rowsAffected: 1
|
||||||
|
)
|
||||||
|
let renderer = makeMockRenderer()
|
||||||
|
let summary = renderer.localSummary(result: result, userQuestion: "Add user")
|
||||||
|
#expect(summary == "Successfully inserted 1 row.")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Local summary for delete mutation")
|
||||||
|
func localSummaryDelete() {
|
||||||
|
let result = QueryResult(
|
||||||
|
columns: [],
|
||||||
|
rows: [],
|
||||||
|
sql: "DELETE FROM users WHERE id = 5",
|
||||||
|
executionTime: 0.01,
|
||||||
|
rowsAffected: 3
|
||||||
|
)
|
||||||
|
let renderer = makeMockRenderer()
|
||||||
|
let summary = renderer.localSummary(result: result, userQuestion: "Delete old users")
|
||||||
|
#expect(summary == "Successfully deleted 3 rows.")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Local summary for update mutation")
|
||||||
|
func localSummaryUpdate() {
|
||||||
|
let result = QueryResult(
|
||||||
|
columns: [],
|
||||||
|
rows: [],
|
||||||
|
sql: "UPDATE users SET active = 0 WHERE id = 1",
|
||||||
|
executionTime: 0.01,
|
||||||
|
rowsAffected: 1
|
||||||
|
)
|
||||||
|
let renderer = makeMockRenderer()
|
||||||
|
let summary = renderer.localSummary(result: result, userQuestion: "Deactivate user")
|
||||||
|
#expect(summary == "Successfully updated 1 row.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - LLM-based Summary Tests (using MockLanguageModel)
|
||||||
|
|
||||||
|
@Test("Summarize with LLM returns mock response for multi-row results")
|
||||||
|
func summarizeWithLLM() async throws {
|
||||||
|
let result = makeResult(
|
||||||
|
columns: ["name", "age"],
|
||||||
|
rows: [
|
||||||
|
["name": .text("Alice"), "age": .integer(30)],
|
||||||
|
["name": .text("Bob"), "age": .integer(25)],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
let mockModel = MockLanguageModel(responseText: "There are 2 users: Alice (30) and Bob (25).")
|
||||||
|
let renderer = TextSummaryRenderer(model: mockModel)
|
||||||
|
let summary = try await renderer.summarize(result: result, userQuestion: "List all users")
|
||||||
|
#expect(summary == "There are 2 users: Alice (30) and Bob (25).")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Summarize returns empty result message without calling LLM")
|
||||||
|
func summarizeEmptyResult() async throws {
|
||||||
|
let result = makeResult(columns: ["id"], rows: [])
|
||||||
|
let renderer = makeMockRenderer()
|
||||||
|
let summary = try await renderer.summarize(result: result, userQuestion: "Find users")
|
||||||
|
#expect(summary == "No results found for your query.")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Summarize returns direct aggregate without calling LLM")
|
||||||
|
func summarizeAggregate() async throws {
|
||||||
|
let result = makeResult(
|
||||||
|
columns: ["COUNT(*)"],
|
||||||
|
rows: [["COUNT(*)": .integer(42)]]
|
||||||
|
)
|
||||||
|
let renderer = makeMockRenderer()
|
||||||
|
let summary = try await renderer.summarize(result: result, userQuestion: "How many?")
|
||||||
|
#expect(summary.contains("42"))
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Summarize mutation returns template without calling LLM")
|
||||||
|
func summarizeMutation() async throws {
|
||||||
|
let result = QueryResult(
|
||||||
|
columns: [],
|
||||||
|
rows: [],
|
||||||
|
sql: "UPDATE users SET name = 'Test' WHERE id = 1",
|
||||||
|
executionTime: 0.01,
|
||||||
|
rowsAffected: 1
|
||||||
|
)
|
||||||
|
let renderer = makeMockRenderer()
|
||||||
|
let summary = try await renderer.summarize(result: result, userQuestion: "Update user")
|
||||||
|
#expect(summary == "Successfully updated 1 row.")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Summarize passes context to LLM prompt")
|
||||||
|
func summarizeWithContext() async throws {
|
||||||
|
let result = makeResult(
|
||||||
|
columns: ["total"],
|
||||||
|
rows: [
|
||||||
|
["total": .real(100.0)],
|
||||||
|
["total": .real(200.0)],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
let mockModel = MockLanguageModel(responseText: "The totals are 100 and 200.")
|
||||||
|
let renderer = TextSummaryRenderer(model: mockModel)
|
||||||
|
let summary = try await renderer.summarize(
|
||||||
|
result: result,
|
||||||
|
userQuestion: "Show totals",
|
||||||
|
context: "Amounts are in USD"
|
||||||
|
)
|
||||||
|
#expect(summary == "The totals are 100 and 200.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Helpers
|
||||||
|
|
||||||
|
private func makeResult(
|
||||||
|
columns: [String],
|
||||||
|
rows: [[String: QueryResult.Value]],
|
||||||
|
sql: String = "SELECT * FROM test"
|
||||||
|
) -> QueryResult {
|
||||||
|
QueryResult(columns: columns, rows: rows, sql: sql, executionTime: 0.01)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a renderer with a mock model (for localSummary tests that don't hit the LLM).
|
||||||
|
private func makeMockRenderer() -> TextSummaryRenderer {
|
||||||
|
TextSummaryRenderer(model: MockLanguageModel())
|
||||||
|
}
|
||||||
|
}
|
||||||
246
Tests/SwiftDBAITests/ToolExecutionDelegateTests.swift
Normal file
246
Tests/SwiftDBAITests/ToolExecutionDelegateTests.swift
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
// ToolExecutionDelegateTests.swift
|
||||||
|
// SwiftDBAITests
|
||||||
|
|
||||||
|
import Foundation
|
||||||
|
import Testing
|
||||||
|
@testable import SwiftDBAI
|
||||||
|
|
||||||
|
@Suite("DestructiveClassification")
|
||||||
|
struct DestructiveClassificationTests {
|
||||||
|
|
||||||
|
// MARK: - Safe statements
|
||||||
|
|
||||||
|
@Test("SELECT is classified as safe")
|
||||||
|
func selectIsSafe() {
|
||||||
|
let result = classifySQL("SELECT * FROM users")
|
||||||
|
#expect(result == .safe)
|
||||||
|
#expect(!result.requiresConfirmation)
|
||||||
|
#expect(!result.isMutating)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("WITH (CTE) is classified as safe")
|
||||||
|
func withIsSafe() {
|
||||||
|
let result = classifySQL("WITH cte AS (SELECT 1) SELECT * FROM cte")
|
||||||
|
#expect(result == .safe)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Mutation statements
|
||||||
|
|
||||||
|
@Test("INSERT is classified as mutation")
|
||||||
|
func insertIsMutation() {
|
||||||
|
let result = classifySQL("INSERT INTO users (name) VALUES ('Alice')")
|
||||||
|
#expect(result == .mutation(.insert))
|
||||||
|
#expect(!result.requiresConfirmation)
|
||||||
|
#expect(result.isMutating)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("UPDATE is classified as mutation")
|
||||||
|
func updateIsMutation() {
|
||||||
|
let result = classifySQL("UPDATE users SET name = 'Bob' WHERE id = 1")
|
||||||
|
#expect(result == .mutation(.update))
|
||||||
|
#expect(!result.requiresConfirmation)
|
||||||
|
#expect(result.isMutating)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Destructive statements
|
||||||
|
|
||||||
|
@Test("DELETE is classified as destructive")
|
||||||
|
func deleteIsDestructive() {
|
||||||
|
let result = classifySQL("DELETE FROM users WHERE id = 1")
|
||||||
|
#expect(result == .destructive(.delete))
|
||||||
|
#expect(result.requiresConfirmation)
|
||||||
|
#expect(result.isMutating)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("DROP is classified as destructive")
|
||||||
|
func dropIsDestructive() {
|
||||||
|
let result = classifySQL("DROP TABLE users")
|
||||||
|
#expect(result == .destructive(.drop))
|
||||||
|
#expect(result.requiresConfirmation)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("ALTER is classified as destructive")
|
||||||
|
func alterIsDestructive() {
|
||||||
|
let result = classifySQL("ALTER TABLE users ADD COLUMN age INTEGER")
|
||||||
|
#expect(result == .destructive(.alter))
|
||||||
|
#expect(result.requiresConfirmation)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("TRUNCATE is classified as destructive")
|
||||||
|
func truncateIsDestructive() {
|
||||||
|
let result = classifySQL("TRUNCATE TABLE users")
|
||||||
|
#expect(result == .destructive(.truncate))
|
||||||
|
#expect(result.requiresConfirmation)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Case insensitivity
|
||||||
|
|
||||||
|
@Test("Classification is case-insensitive")
|
||||||
|
func caseInsensitive() {
|
||||||
|
#expect(classifySQL("delete from users") == .destructive(.delete))
|
||||||
|
#expect(classifySQL("Drop Table foo") == .destructive(.drop))
|
||||||
|
#expect(classifySQL("select 1") == .safe)
|
||||||
|
#expect(classifySQL("INSERT into t values (1)") == .mutation(.insert))
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Leading whitespace
|
||||||
|
|
||||||
|
@Test("Classification ignores leading whitespace")
|
||||||
|
func leadingWhitespace() {
|
||||||
|
#expect(classifySQL(" \n DELETE FROM users") == .destructive(.delete))
|
||||||
|
#expect(classifySQL("\t SELECT 1") == .safe)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - SQLStatementKind
|
||||||
|
|
||||||
|
@Test("Destructive kinds are correct")
|
||||||
|
func destructiveKinds() {
|
||||||
|
#expect(SQLStatementKind.delete.isDestructive)
|
||||||
|
#expect(SQLStatementKind.drop.isDestructive)
|
||||||
|
#expect(SQLStatementKind.alter.isDestructive)
|
||||||
|
#expect(SQLStatementKind.truncate.isDestructive)
|
||||||
|
#expect(!SQLStatementKind.select.isDestructive)
|
||||||
|
#expect(!SQLStatementKind.insert.isDestructive)
|
||||||
|
#expect(!SQLStatementKind.update.isDestructive)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Mutation kinds are correct")
|
||||||
|
func mutationKinds() {
|
||||||
|
#expect(SQLStatementKind.insert.isMutation)
|
||||||
|
#expect(SQLStatementKind.update.isMutation)
|
||||||
|
#expect(!SQLStatementKind.select.isMutation)
|
||||||
|
#expect(!SQLStatementKind.delete.isMutation)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Suite("ToolExecutionDelegate")
|
||||||
|
struct ToolExecutionDelegateProtocolTests {
|
||||||
|
|
||||||
|
@Test("AutoApproveDelegate approves all operations")
|
||||||
|
func autoApprove() async {
|
||||||
|
let delegate = AutoApproveDelegate()
|
||||||
|
let context = DestructiveOperationContext(
|
||||||
|
sql: "DELETE FROM users",
|
||||||
|
statementKind: .delete,
|
||||||
|
classification: .destructive(.delete),
|
||||||
|
description: "Delete all rows from users"
|
||||||
|
)
|
||||||
|
let result = await delegate.confirmDestructiveOperation(context)
|
||||||
|
#expect(result == true)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("RejectAllDelegate rejects all operations")
|
||||||
|
func rejectAll() async {
|
||||||
|
let delegate = RejectAllDelegate()
|
||||||
|
let context = DestructiveOperationContext(
|
||||||
|
sql: "DROP TABLE users",
|
||||||
|
statementKind: .drop,
|
||||||
|
classification: .destructive(.drop),
|
||||||
|
description: "Drop the users table"
|
||||||
|
)
|
||||||
|
let result = await delegate.confirmDestructiveOperation(context)
|
||||||
|
#expect(result == false)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Default delegate implementation rejects destructive operations")
|
||||||
|
func defaultRejects() async {
|
||||||
|
struct EmptyDelegate: ToolExecutionDelegate {}
|
||||||
|
let delegate = EmptyDelegate()
|
||||||
|
let context = DestructiveOperationContext(
|
||||||
|
sql: "DELETE FROM users",
|
||||||
|
statementKind: .delete,
|
||||||
|
classification: .destructive(.delete),
|
||||||
|
description: "Delete rows"
|
||||||
|
)
|
||||||
|
let result = await delegate.confirmDestructiveOperation(context)
|
||||||
|
#expect(result == false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Tracking Delegate for Integration Tests
|
||||||
|
|
||||||
|
/// A delegate that records all calls for verification in tests.
|
||||||
|
private final class TrackingDelegate: ToolExecutionDelegate, @unchecked Sendable {
|
||||||
|
private let lock = NSLock()
|
||||||
|
|
||||||
|
private var _confirmCalls: [DestructiveOperationContext] = []
|
||||||
|
private var _willExecuteCalls: [(sql: String, classification: DestructiveClassification)] = []
|
||||||
|
private var _didExecuteCalls: [(sql: String, success: Bool)] = []
|
||||||
|
private var _confirmResult: Bool
|
||||||
|
|
||||||
|
var confirmCalls: [DestructiveOperationContext] {
|
||||||
|
lock.withLock { _confirmCalls }
|
||||||
|
}
|
||||||
|
|
||||||
|
var willExecuteCalls: [(sql: String, classification: DestructiveClassification)] {
|
||||||
|
lock.withLock { _willExecuteCalls }
|
||||||
|
}
|
||||||
|
|
||||||
|
var didExecuteCalls: [(sql: String, success: Bool)] {
|
||||||
|
lock.withLock { _didExecuteCalls }
|
||||||
|
}
|
||||||
|
|
||||||
|
init(confirmResult: Bool) {
|
||||||
|
self._confirmResult = confirmResult
|
||||||
|
}
|
||||||
|
|
||||||
|
func confirmDestructiveOperation(_ context: DestructiveOperationContext) async -> Bool {
|
||||||
|
lock.withLock { _confirmCalls.append(context) }
|
||||||
|
return _confirmResult
|
||||||
|
}
|
||||||
|
|
||||||
|
func willExecuteSQL(_ sql: String, classification: DestructiveClassification) async {
|
||||||
|
lock.withLock { _willExecuteCalls.append((sql: sql, classification: classification)) }
|
||||||
|
}
|
||||||
|
|
||||||
|
func didExecuteSQL(_ sql: String, success: Bool) async {
|
||||||
|
lock.withLock { _didExecuteCalls.append((sql: sql, success: success)) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Suite("ToolExecutionDelegate - ChatEngine Integration")
|
||||||
|
struct DelegateIntegrationTests {
|
||||||
|
|
||||||
|
@Test("DestructiveOperationContext captures target table")
|
||||||
|
func contextCapturesTable() {
|
||||||
|
let context = DestructiveOperationContext(
|
||||||
|
sql: "DELETE FROM users WHERE id = 1",
|
||||||
|
statementKind: .delete,
|
||||||
|
classification: .destructive(.delete),
|
||||||
|
description: "Delete from users",
|
||||||
|
targetTable: "users"
|
||||||
|
)
|
||||||
|
#expect(context.targetTable == "users")
|
||||||
|
#expect(context.statementKind == .delete)
|
||||||
|
#expect(context.classification.requiresConfirmation)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("classifySQL returns destructive for DELETE")
|
||||||
|
func classifySQLDestructive() {
|
||||||
|
let result = classifySQL("DELETE FROM orders WHERE id = 5")
|
||||||
|
#expect(result == .destructive(.delete))
|
||||||
|
#expect(result.requiresConfirmation)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("classifySQL returns safe for SELECT")
|
||||||
|
func classifySQLSafe() {
|
||||||
|
let result = classifySQL("SELECT * FROM users")
|
||||||
|
#expect(result == .safe)
|
||||||
|
#expect(!result.requiresConfirmation)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("classifySQL returns mutation for INSERT")
|
||||||
|
func classifySQLMutation() {
|
||||||
|
let result = classifySQL("INSERT INTO users (name) VALUES ('test')")
|
||||||
|
#expect(result == .mutation(.insert))
|
||||||
|
#expect(!result.requiresConfirmation)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("DestructiveClassification.isMutating is true for mutations and destructive")
|
||||||
|
func isMutatingCovers() {
|
||||||
|
#expect(DestructiveClassification.mutation(.insert).isMutating)
|
||||||
|
#expect(DestructiveClassification.mutation(.update).isMutating)
|
||||||
|
#expect(DestructiveClassification.destructive(.delete).isMutating)
|
||||||
|
#expect(!DestructiveClassification.safe.isMutating)
|
||||||
|
}
|
||||||
|
}
|
||||||
617
Tests/SwiftDBAITests/UnifiedProviderTestHarness.swift
Normal file
617
Tests/SwiftDBAITests/UnifiedProviderTestHarness.swift
Normal file
@@ -0,0 +1,617 @@
|
|||||||
|
// UnifiedProviderTestHarness.swift
|
||||||
|
// SwiftDBAI Tests
|
||||||
|
//
|
||||||
|
// A unified test harness that validates all seven provider types
|
||||||
|
// conform to the AnyLanguageModel protocol and produce consistent
|
||||||
|
// ChatEngine-compatible output. Covers: OpenAI, Anthropic, Gemini,
|
||||||
|
// OpenAI-Compatible, Ollama, llama.cpp, and on-device (MLX/CoreML).
|
||||||
|
|
||||||
|
import AnyLanguageModel
|
||||||
|
import Foundation
|
||||||
|
import GRDB
|
||||||
|
import Testing
|
||||||
|
|
||||||
|
@testable import SwiftDBAI
|
||||||
|
|
||||||
|
// MARK: - Provider-Simulating Mock Models
|
||||||
|
|
||||||
|
/// A mock that records which LanguageModel protocol methods were called,
|
||||||
|
/// the arguments passed, and returns configurable responses.
|
||||||
|
/// Used to validate that every provider path through ChatEngine
|
||||||
|
/// exercises the same protocol surface.
|
||||||
|
final class ProviderConformanceMock: LanguageModel, @unchecked Sendable {
|
||||||
|
typealias UnavailableReason = Never
|
||||||
|
|
||||||
|
/// Track calls to verify protocol conformance exercised fully.
|
||||||
|
struct CallRecord: Sendable {
|
||||||
|
let method: String
|
||||||
|
let promptDescription: String
|
||||||
|
let timestamp: Date
|
||||||
|
}
|
||||||
|
|
||||||
|
private let lock = NSLock()
|
||||||
|
private var _calls: [CallRecord] = []
|
||||||
|
private let _responses: [String]
|
||||||
|
private var _callIndex = 0
|
||||||
|
|
||||||
|
/// Label for diagnostics.
|
||||||
|
let providerName: String
|
||||||
|
|
||||||
|
var calls: [CallRecord] {
|
||||||
|
lock.lock()
|
||||||
|
defer { lock.unlock() }
|
||||||
|
return _calls
|
||||||
|
}
|
||||||
|
|
||||||
|
init(providerName: String, responses: [String]) {
|
||||||
|
self.providerName = providerName
|
||||||
|
self._responses = responses
|
||||||
|
}
|
||||||
|
|
||||||
|
private func nextResponse() -> String {
|
||||||
|
lock.lock()
|
||||||
|
defer { lock.unlock() }
|
||||||
|
let idx = _callIndex
|
||||||
|
_callIndex += 1
|
||||||
|
return idx < _responses.count ? _responses[idx] : "fallback response"
|
||||||
|
}
|
||||||
|
|
||||||
|
private func recordCall(method: String, prompt: String) {
|
||||||
|
lock.lock()
|
||||||
|
_calls.append(CallRecord(method: method, promptDescription: prompt, timestamp: Date()))
|
||||||
|
lock.unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func respond<Content>(
|
||||||
|
within session: LanguageModelSession,
|
||||||
|
to prompt: Prompt,
|
||||||
|
generating type: Content.Type,
|
||||||
|
includeSchemaInPrompt: Bool,
|
||||||
|
options: GenerationOptions
|
||||||
|
) async throws -> LanguageModelSession.Response<Content> where Content: Generable {
|
||||||
|
recordCall(method: "respond", prompt: prompt.description)
|
||||||
|
let text = nextResponse()
|
||||||
|
let rawContent = GeneratedContent(kind: .string(text))
|
||||||
|
let content = try Content(rawContent)
|
||||||
|
return LanguageModelSession.Response(
|
||||||
|
content: content,
|
||||||
|
rawContent: rawContent,
|
||||||
|
transcriptEntries: [][...]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamResponse<Content>(
|
||||||
|
within session: LanguageModelSession,
|
||||||
|
to prompt: Prompt,
|
||||||
|
generating type: Content.Type,
|
||||||
|
includeSchemaInPrompt: Bool,
|
||||||
|
options: GenerationOptions
|
||||||
|
) -> sending LanguageModelSession.ResponseStream<Content> where Content: Generable {
|
||||||
|
recordCall(method: "streamResponse", prompt: prompt.description)
|
||||||
|
let text = nextResponse()
|
||||||
|
let rawContent = GeneratedContent(kind: .string(text))
|
||||||
|
let content = try! Content(rawContent)
|
||||||
|
return LanguageModelSession.ResponseStream(content: content, rawContent: rawContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Test Database Helper
|
||||||
|
|
||||||
|
/// Creates a minimal in-memory database for provider integration tests.
|
||||||
|
private func makeProviderTestDatabase() throws -> DatabaseQueue {
|
||||||
|
let db = try DatabaseQueue(path: ":memory:")
|
||||||
|
try db.write { db in
|
||||||
|
try db.execute(sql: """
|
||||||
|
CREATE TABLE products (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
price REAL NOT NULL,
|
||||||
|
category TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
try db.execute(sql: """
|
||||||
|
INSERT INTO products (name, price, category) VALUES
|
||||||
|
('Widget', 9.99, 'tools'),
|
||||||
|
('Gadget', 24.99, 'electronics'),
|
||||||
|
('Doohickey', 4.50, 'tools')
|
||||||
|
""")
|
||||||
|
}
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - Unified Provider Test Harness
|
||||||
|
|
||||||
|
@Suite("Unified Provider Test Harness")
|
||||||
|
struct UnifiedProviderTestHarness {
|
||||||
|
|
||||||
|
// MARK: - Provider Configuration Enumeration
|
||||||
|
|
||||||
|
/// All seven provider types that SwiftDBAI supports.
|
||||||
|
enum TestedProvider: String, CaseIterable {
|
||||||
|
case openAI
|
||||||
|
case anthropic
|
||||||
|
case gemini
|
||||||
|
case openAICompatible
|
||||||
|
case ollama
|
||||||
|
case llamaCpp
|
||||||
|
case onDevice
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a ProviderConformanceMock simulating each provider type.
|
||||||
|
private func makeMock(for provider: TestedProvider, responses: [String]) -> ProviderConformanceMock {
|
||||||
|
ProviderConformanceMock(providerName: provider.rawValue, responses: responses)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - 1. Protocol Conformance — All Providers Are LanguageModel
|
||||||
|
|
||||||
|
@Test("All provider types produce instances conforming to LanguageModel protocol")
|
||||||
|
func allProvidersConformToLanguageModel() {
|
||||||
|
// Cloud providers via ProviderConfiguration.makeModel()
|
||||||
|
let openAI = ProviderConfiguration.openAI(apiKey: "test-key", model: "gpt-4o").makeModel()
|
||||||
|
let anthropic = ProviderConfiguration.anthropic(apiKey: "test-key", model: "claude-sonnet-4-20250514").makeModel()
|
||||||
|
let gemini = ProviderConfiguration.gemini(apiKey: "test-key", model: "gemini-2.0-flash").makeModel()
|
||||||
|
let openAICompatible = ProviderConfiguration.openAICompatible(
|
||||||
|
apiKey: "test-key",
|
||||||
|
model: "local-model",
|
||||||
|
baseURL: URL(string: "http://localhost:8080/v1/")!
|
||||||
|
).makeModel()
|
||||||
|
let ollama = ProviderConfiguration.ollama(model: "llama3.2").makeModel()
|
||||||
|
let llamaCpp = ProviderConfiguration.llamaCpp(model: "default").makeModel()
|
||||||
|
// On-device MLX (wraps as openAICompatible internally)
|
||||||
|
let onDeviceMLX = ProviderConfiguration.onDeviceMLX(
|
||||||
|
MLXProviderConfiguration(modelId: "test-model")
|
||||||
|
).makeModel()
|
||||||
|
|
||||||
|
// Verify all are LanguageModel
|
||||||
|
let models: [(String, any LanguageModel)] = [
|
||||||
|
("OpenAI", openAI),
|
||||||
|
("Anthropic", anthropic),
|
||||||
|
("Gemini", gemini),
|
||||||
|
("OpenAI-Compatible", openAICompatible),
|
||||||
|
("Ollama", ollama),
|
||||||
|
("llama.cpp", llamaCpp),
|
||||||
|
("On-Device MLX", onDeviceMLX),
|
||||||
|
]
|
||||||
|
|
||||||
|
for (name, model) in models {
|
||||||
|
// Protocol conformance is compile-time, but we verify isAvailable works
|
||||||
|
#expect(model.isAvailable, "\(name) model should report as available")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("All provider configurations produce correct concrete model types")
|
||||||
|
func providerConfigurationsProduceCorrectTypes() {
|
||||||
|
let openAI = ProviderConfiguration.openAI(apiKey: "k", model: "m").makeModel()
|
||||||
|
#expect(openAI is OpenAILanguageModel, "OpenAI config should produce OpenAILanguageModel")
|
||||||
|
|
||||||
|
let anthropic = ProviderConfiguration.anthropic(apiKey: "k", model: "m").makeModel()
|
||||||
|
#expect(anthropic is AnthropicLanguageModel, "Anthropic config should produce AnthropicLanguageModel")
|
||||||
|
|
||||||
|
let gemini = ProviderConfiguration.gemini(apiKey: "k", model: "m").makeModel()
|
||||||
|
#expect(gemini is GeminiLanguageModel, "Gemini config should produce GeminiLanguageModel")
|
||||||
|
|
||||||
|
let openAICompat = ProviderConfiguration.openAICompatible(
|
||||||
|
apiKey: "k", model: "m", baseURL: URL(string: "http://localhost:1234")!
|
||||||
|
).makeModel()
|
||||||
|
#expect(openAICompat is OpenAILanguageModel, "OpenAI-Compatible config should produce OpenAILanguageModel")
|
||||||
|
|
||||||
|
let ollama = ProviderConfiguration.ollama(model: "m").makeModel()
|
||||||
|
#expect(ollama is OllamaLanguageModel, "Ollama config should produce OllamaLanguageModel")
|
||||||
|
|
||||||
|
let llamaCpp = ProviderConfiguration.llamaCpp(model: "m").makeModel()
|
||||||
|
#expect(llamaCpp is OpenAILanguageModel, "llama.cpp config should produce OpenAILanguageModel (OpenAI-compatible)")
|
||||||
|
|
||||||
|
// On-device uses OpenAILanguageModel internally as a wrapper
|
||||||
|
let onDevice = ProviderConfiguration.onDeviceMLX(
|
||||||
|
MLXProviderConfiguration(modelId: "test")
|
||||||
|
).makeModel()
|
||||||
|
#expect(onDevice is OpenAILanguageModel, "On-device MLX config should produce OpenAILanguageModel wrapper")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - 2. Consistent ChatEngine-Compatible Output
|
||||||
|
|
||||||
|
@Test("Every provider mock produces valid ChatEngine responses for SELECT queries",
|
||||||
|
arguments: TestedProvider.allCases)
|
||||||
|
func providerProducesValidChatEngineResponse(provider: TestedProvider) async throws {
|
||||||
|
let db = try makeProviderTestDatabase()
|
||||||
|
let mock = makeMock(for: provider, responses: [
|
||||||
|
"SELECT COUNT(*) FROM products", // SQL generation
|
||||||
|
"There are 3 products in the database.", // Summary (fallback)
|
||||||
|
])
|
||||||
|
|
||||||
|
let engine = ChatEngine(database: db, model: mock)
|
||||||
|
let response = try await engine.send("How many products are there?")
|
||||||
|
|
||||||
|
// All providers must produce:
|
||||||
|
// 1. Non-empty summary
|
||||||
|
#expect(!response.summary.isEmpty, "\(provider.rawValue): summary must not be empty")
|
||||||
|
|
||||||
|
// 2. Valid SQL that was executed
|
||||||
|
#expect(response.sql == "SELECT COUNT(*) FROM products",
|
||||||
|
"\(provider.rawValue): SQL must match generated query")
|
||||||
|
|
||||||
|
// 3. A QueryResult with data
|
||||||
|
#expect(response.queryResult != nil, "\(provider.rawValue): queryResult must exist")
|
||||||
|
#expect(response.queryResult?.rowCount == 1, "\(provider.rawValue): should have 1 row for COUNT")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Every provider mock produces valid ChatEngine responses for multi-row SELECT",
|
||||||
|
arguments: TestedProvider.allCases)
|
||||||
|
func providerProducesMultiRowResponse(provider: TestedProvider) async throws {
|
||||||
|
let db = try makeProviderTestDatabase()
|
||||||
|
let mock = makeMock(for: provider, responses: [
|
||||||
|
"SELECT name, price FROM products ORDER BY price DESC",
|
||||||
|
"Here are the products sorted by price.",
|
||||||
|
])
|
||||||
|
|
||||||
|
let engine = ChatEngine(database: db, model: mock)
|
||||||
|
let response = try await engine.send("List products by price")
|
||||||
|
|
||||||
|
#expect(response.queryResult != nil, "\(provider.rawValue): queryResult must exist")
|
||||||
|
#expect(response.queryResult?.rowCount == 3, "\(provider.rawValue): should return all 3 products")
|
||||||
|
#expect(response.queryResult?.columns.contains("name") == true,
|
||||||
|
"\(provider.rawValue): columns must include 'name'")
|
||||||
|
#expect(response.queryResult?.columns.contains("price") == true,
|
||||||
|
"\(provider.rawValue): columns must include 'price'")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - 3. Consistent LanguageModelSession Integration
|
||||||
|
|
||||||
|
@Test("Every provider mock works through LanguageModelSession.respond(to:)",
|
||||||
|
arguments: TestedProvider.allCases)
|
||||||
|
func providerWorksWithSession(provider: TestedProvider) async throws {
|
||||||
|
let mock = makeMock(for: provider, responses: [
|
||||||
|
"SELECT 1 AS test",
|
||||||
|
])
|
||||||
|
|
||||||
|
let session = LanguageModelSession(
|
||||||
|
model: mock,
|
||||||
|
instructions: "You are a SQL assistant."
|
||||||
|
)
|
||||||
|
|
||||||
|
let response = try await session.respond(to: "Generate a test query")
|
||||||
|
|
||||||
|
// Verify the response content is the expected string
|
||||||
|
#expect(response.content == "SELECT 1 AS test",
|
||||||
|
"\(provider.rawValue): session response should match mock output")
|
||||||
|
|
||||||
|
// Verify the mock received the call
|
||||||
|
#expect(mock.calls.count == 1, "\(provider.rawValue): should have exactly 1 call")
|
||||||
|
#expect(mock.calls.first?.method == "respond",
|
||||||
|
"\(provider.rawValue): should call respond method")
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Every provider mock works through LanguageModelSession.streamResponse(to:)",
|
||||||
|
arguments: TestedProvider.allCases)
|
||||||
|
func providerWorksWithStreamSession(provider: TestedProvider) async throws {
|
||||||
|
let mock = makeMock(for: provider, responses: [
|
||||||
|
"SELECT 42 AS answer",
|
||||||
|
])
|
||||||
|
|
||||||
|
let session = LanguageModelSession(
|
||||||
|
model: mock,
|
||||||
|
instructions: "You are a SQL assistant."
|
||||||
|
)
|
||||||
|
|
||||||
|
let stream = session.streamResponse(to: "Give me a number")
|
||||||
|
let collected = try await stream.collect()
|
||||||
|
|
||||||
|
#expect(collected.content == "SELECT 42 AS answer",
|
||||||
|
"\(provider.rawValue): stream collected response should match mock output")
|
||||||
|
#expect(mock.calls.count == 1, "\(provider.rawValue): should have exactly 1 call")
|
||||||
|
#expect(mock.calls.first?.method == "streamResponse",
|
||||||
|
"\(provider.rawValue): should call streamResponse method")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - 4. Schema Introspection Works Identically Across Providers
|
||||||
|
|
||||||
|
@Test("Schema introspection returns same schema regardless of provider",
|
||||||
|
arguments: TestedProvider.allCases)
|
||||||
|
func schemaIntrospectionIsProviderAgnostic(provider: TestedProvider) async throws {
|
||||||
|
let db = try makeProviderTestDatabase()
|
||||||
|
let mock = makeMock(for: provider, responses: ["SELECT 1"])
|
||||||
|
|
||||||
|
let engine = ChatEngine(database: db, model: mock)
|
||||||
|
let schema = try await engine.prepareSchema()
|
||||||
|
|
||||||
|
#expect(schema.tableNames.contains("products"),
|
||||||
|
"\(provider.rawValue): schema must include 'products' table")
|
||||||
|
#expect(schema.tableNames.count == 1,
|
||||||
|
"\(provider.rawValue): should have exactly 1 table")
|
||||||
|
|
||||||
|
let table = schema.tables["products"]
|
||||||
|
#expect(table != nil, "\(provider.rawValue): must find products table")
|
||||||
|
#expect(table?.columns.count == 4,
|
||||||
|
"\(provider.rawValue): products table must have 4 columns")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - 5. Error Handling Consistency
|
||||||
|
|
||||||
|
@Test("All providers handle empty schema consistently",
|
||||||
|
arguments: TestedProvider.allCases)
|
||||||
|
func emptySchemaHandledConsistently(provider: TestedProvider) async throws {
|
||||||
|
let db = try DatabaseQueue(path: ":memory:")
|
||||||
|
let mock = makeMock(for: provider, responses: ["SELECT 1"])
|
||||||
|
|
||||||
|
let engine = ChatEngine(database: db, model: mock)
|
||||||
|
|
||||||
|
do {
|
||||||
|
_ = try await engine.send("Show me data")
|
||||||
|
Issue.record("\(provider.rawValue): should throw for empty schema")
|
||||||
|
} catch let error as SwiftDBAIError {
|
||||||
|
#expect(error == .emptySchema,
|
||||||
|
"\(provider.rawValue): must throw .emptySchema for database with no tables")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("All providers reject disallowed SQL operations consistently",
|
||||||
|
arguments: TestedProvider.allCases)
|
||||||
|
func disallowedSQLRejectedConsistently(provider: TestedProvider) async throws {
|
||||||
|
let db = try makeProviderTestDatabase()
|
||||||
|
let mock = makeMock(for: provider, responses: [
|
||||||
|
"DELETE FROM products WHERE id = 1",
|
||||||
|
])
|
||||||
|
|
||||||
|
// Default allowlist is readOnly (SELECT only)
|
||||||
|
let engine = ChatEngine(database: db, model: mock)
|
||||||
|
|
||||||
|
do {
|
||||||
|
_ = try await engine.send("Delete the first product")
|
||||||
|
Issue.record("\(provider.rawValue): should reject DELETE when allowlist is readOnly")
|
||||||
|
} catch {
|
||||||
|
// All providers must trigger the same error path for disallowed operations
|
||||||
|
#expect(error is SwiftDBAIError,
|
||||||
|
"\(provider.rawValue): error must be SwiftDBAIError")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - 6. Conversation History Consistency
|
||||||
|
|
||||||
|
@Test("Conversation history works identically for all providers",
|
||||||
|
arguments: TestedProvider.allCases)
|
||||||
|
func conversationHistoryConsistent(provider: TestedProvider) async throws {
|
||||||
|
let db = try makeProviderTestDatabase()
|
||||||
|
// ChatEngine calls LLM for SQL generation, then TextSummaryRenderer
|
||||||
|
// may call LLM for summarization. For aggregate queries (COUNT, AVG),
|
||||||
|
// TextSummaryRenderer uses a template and skips the LLM call.
|
||||||
|
// So the mock sequence is: SQL1, SQL2 (each followed by template summary).
|
||||||
|
let mock = makeMock(for: provider, responses: [
|
||||||
|
"SELECT COUNT(*) FROM products",
|
||||||
|
"SELECT AVG(price) FROM products",
|
||||||
|
])
|
||||||
|
|
||||||
|
let engine = ChatEngine(database: db, model: mock)
|
||||||
|
|
||||||
|
_ = try await engine.send("How many products?")
|
||||||
|
_ = try await engine.send("What is the average price?")
|
||||||
|
|
||||||
|
let messages = engine.messages
|
||||||
|
#expect(messages.count == 4,
|
||||||
|
"\(provider.rawValue): should have 4 messages (2 user + 2 assistant)")
|
||||||
|
#expect(messages[0].role == .user, "\(provider.rawValue): first message should be user")
|
||||||
|
#expect(messages[1].role == .assistant, "\(provider.rawValue): second message should be assistant")
|
||||||
|
#expect(messages[2].role == .user, "\(provider.rawValue): third message should be user")
|
||||||
|
#expect(messages[3].role == .assistant, "\(provider.rawValue): fourth message should be assistant")
|
||||||
|
|
||||||
|
// Both assistant messages must have SQL
|
||||||
|
#expect(messages[1].sql != nil, "\(provider.rawValue): first response must have SQL")
|
||||||
|
#expect(messages[3].sql != nil, "\(provider.rawValue): second response must have SQL")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - 7. ProviderConfiguration Roundtrip
|
||||||
|
|
||||||
|
@Test("All cloud provider configurations roundtrip through makeModel()")
|
||||||
|
func allCloudProvidersRoundtrip() {
|
||||||
|
let configs: [(String, ProviderConfiguration)] = [
|
||||||
|
("OpenAI", .openAI(apiKey: "sk-test", model: "gpt-4o")),
|
||||||
|
("OpenAI Responses", .openAI(apiKey: "sk-test", model: "gpt-4o", variant: .responses)),
|
||||||
|
("Anthropic", .anthropic(apiKey: "sk-ant-test", model: "claude-sonnet-4-20250514")),
|
||||||
|
("Anthropic+version", .anthropic(apiKey: "sk-ant-test", model: "claude-sonnet-4-20250514", apiVersion: "2024-01-01")),
|
||||||
|
("Anthropic+betas", .anthropic(apiKey: "sk-ant-test", model: "claude-sonnet-4-20250514", betas: ["computer-use"])),
|
||||||
|
("Gemini", .gemini(apiKey: "AIza-test", model: "gemini-2.0-flash")),
|
||||||
|
("Gemini+version", .gemini(apiKey: "AIza-test", model: "gemini-2.0-flash", apiVersion: "v1")),
|
||||||
|
("OpenAI-Compatible", .openAICompatible(
|
||||||
|
apiKey: "key", model: "model", baseURL: URL(string: "http://localhost:1234")!
|
||||||
|
)),
|
||||||
|
("Ollama", .ollama(model: "llama3.2")),
|
||||||
|
("Ollama+custom URL", .ollama(model: "qwen2.5", baseURL: URL(string: "http://192.168.1.100:11434")!)),
|
||||||
|
("llama.cpp", .llamaCpp(model: "default")),
|
||||||
|
("llama.cpp+custom", .llamaCpp(model: "my-model", baseURL: URL(string: "http://localhost:9090")!)),
|
||||||
|
]
|
||||||
|
|
||||||
|
for (name, config) in configs {
|
||||||
|
let model = config.makeModel()
|
||||||
|
#expect(model.isAvailable, "\(name): model must be available after makeModel()")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("On-device provider configurations produce valid models")
|
||||||
|
func onDeviceProvidersRoundtrip() {
|
||||||
|
let mlxConfigs: [MLXProviderConfiguration] = [
|
||||||
|
.llama3_2_3B(),
|
||||||
|
.qwen2_5_coder_3B(),
|
||||||
|
.phi3_5_mini(),
|
||||||
|
MLXProviderConfiguration(modelId: "custom-model", temperature: 0.2),
|
||||||
|
]
|
||||||
|
|
||||||
|
for mlxConfig in mlxConfigs {
|
||||||
|
let providerConfig = ProviderConfiguration.onDeviceMLX(mlxConfig)
|
||||||
|
let model = providerConfig.makeModel()
|
||||||
|
#expect(model.isAvailable, "MLX model '\(mlxConfig.modelId)' must be available")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - 8. Write Operation Allowlist Consistency
|
||||||
|
|
||||||
|
@Test("Write operations require explicit opt-in for all providers",
|
||||||
|
arguments: TestedProvider.allCases)
|
||||||
|
func writeOperationsRequireOptIn(provider: TestedProvider) async throws {
|
||||||
|
let db = try makeProviderTestDatabase()
|
||||||
|
|
||||||
|
// Mock returns an INSERT statement
|
||||||
|
let mock = makeMock(for: provider, responses: [
|
||||||
|
"INSERT INTO products (name, price, category) VALUES ('New', 1.00, 'misc')",
|
||||||
|
])
|
||||||
|
|
||||||
|
// readOnly allowlist (default)
|
||||||
|
let readOnlyEngine = ChatEngine(database: db, model: mock)
|
||||||
|
|
||||||
|
do {
|
||||||
|
_ = try await readOnlyEngine.send("Add a new product")
|
||||||
|
Issue.record("\(provider.rawValue): INSERT should be rejected with readOnly allowlist")
|
||||||
|
} catch {
|
||||||
|
#expect(error is SwiftDBAIError,
|
||||||
|
"\(provider.rawValue): must throw SwiftDBAIError for disallowed INSERT")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("Allowed write operations work for all providers",
|
||||||
|
arguments: TestedProvider.allCases)
|
||||||
|
func allowedWriteOperationsWork(provider: TestedProvider) async throws {
|
||||||
|
let db = try makeProviderTestDatabase()
|
||||||
|
|
||||||
|
let mock = makeMock(for: provider, responses: [
|
||||||
|
"INSERT INTO products (name, price, category) VALUES ('NewItem', 1.00, 'misc')",
|
||||||
|
"Successfully added 1 product.",
|
||||||
|
])
|
||||||
|
|
||||||
|
let engine = ChatEngine(
|
||||||
|
database: db,
|
||||||
|
model: mock,
|
||||||
|
allowlist: .standard
|
||||||
|
)
|
||||||
|
|
||||||
|
let response = try await engine.send("Add a product called NewItem")
|
||||||
|
#expect(response.sql?.uppercased().hasPrefix("INSERT") == true,
|
||||||
|
"\(provider.rawValue): SQL should be an INSERT")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - 9. Response Format Consistency
|
||||||
|
|
||||||
|
@Test("ChatResponse structure is identical regardless of provider",
|
||||||
|
arguments: TestedProvider.allCases)
|
||||||
|
func responseStructureConsistent(provider: TestedProvider) async throws {
|
||||||
|
let db = try makeProviderTestDatabase()
|
||||||
|
let mock = makeMock(for: provider, responses: [
|
||||||
|
"SELECT name, price, category FROM products",
|
||||||
|
"Found 3 products across 2 categories.",
|
||||||
|
])
|
||||||
|
|
||||||
|
let engine = ChatEngine(database: db, model: mock)
|
||||||
|
let response = try await engine.send("Show all products")
|
||||||
|
|
||||||
|
// ChatResponse must always have these properties populated
|
||||||
|
#expect(response.summary.count > 0,
|
||||||
|
"\(provider.rawValue): summary must be non-empty")
|
||||||
|
#expect(response.sql != nil,
|
||||||
|
"\(provider.rawValue): sql must be present")
|
||||||
|
#expect(response.queryResult != nil,
|
||||||
|
"\(provider.rawValue): queryResult must be present")
|
||||||
|
|
||||||
|
// QueryResult structure must match the query
|
||||||
|
let qr = response.queryResult!
|
||||||
|
#expect(qr.columns == ["name", "price", "category"],
|
||||||
|
"\(provider.rawValue): columns must match SELECT clause")
|
||||||
|
#expect(qr.rowCount == 3,
|
||||||
|
"\(provider.rawValue): must return all rows")
|
||||||
|
#expect(qr.sql == "SELECT name, price, category FROM products",
|
||||||
|
"\(provider.rawValue): QueryResult.sql must match executed SQL")
|
||||||
|
#expect(qr.executionTime >= 0,
|
||||||
|
"\(provider.rawValue): execution time must be non-negative")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - 10. Provider Enum Completeness
|
||||||
|
|
||||||
|
@Test("TestedProvider covers all ProviderConfiguration.Provider cases plus on-device")
|
||||||
|
func testedProviderCoversAllCases() {
|
||||||
|
// ProviderConfiguration.Provider has 6 cases
|
||||||
|
let configProviderCount = ProviderConfiguration.Provider.allCases.count
|
||||||
|
#expect(configProviderCount == 6, "ProviderConfiguration.Provider should have 6 cases")
|
||||||
|
|
||||||
|
// TestedProvider adds on-device for 7 total
|
||||||
|
#expect(TestedProvider.allCases.count == 7, "TestedProvider should cover all 7 provider types")
|
||||||
|
|
||||||
|
// Verify 1:1 mapping for the config providers
|
||||||
|
let configNames = Set(ProviderConfiguration.Provider.allCases.map(\.rawValue))
|
||||||
|
for tested in TestedProvider.allCases where tested != .onDevice {
|
||||||
|
#expect(configNames.contains(tested.rawValue),
|
||||||
|
"\(tested.rawValue) must map to a ProviderConfiguration.Provider case")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - 11. ChatEngine Convenience Init Consistency
|
||||||
|
|
||||||
|
@Test("ChatEngine convenience init with ProviderConfiguration works for all cloud providers")
|
||||||
|
func chatEngineConvenienceInitWorks() throws {
|
||||||
|
let db = try makeProviderTestDatabase()
|
||||||
|
|
||||||
|
let configs: [ProviderConfiguration] = [
|
||||||
|
.openAI(apiKey: "test", model: "gpt-4o"),
|
||||||
|
.anthropic(apiKey: "test", model: "claude-sonnet-4-20250514"),
|
||||||
|
.gemini(apiKey: "test", model: "gemini-2.0-flash"),
|
||||||
|
.openAICompatible(apiKey: "test", model: "m", baseURL: URL(string: "http://localhost:1234")!),
|
||||||
|
.ollama(model: "llama3.2"),
|
||||||
|
.llamaCpp(model: "default"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for config in configs {
|
||||||
|
// This should not throw — it only creates the engine, doesn't call the LLM
|
||||||
|
let engine = ChatEngine(database: db, provider: config)
|
||||||
|
#expect(engine.tableCount == nil, "tableCount should be nil before first query")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - 12. Availability Reporting
|
||||||
|
|
||||||
|
@Test("All real provider models report available by default")
|
||||||
|
func allModelsReportAvailable() {
|
||||||
|
let models: [(String, any LanguageModel)] = [
|
||||||
|
("OpenAI", OpenAILanguageModel(apiKey: "k", model: "m")),
|
||||||
|
("Anthropic", AnthropicLanguageModel(apiKey: "k", model: "m")),
|
||||||
|
("Gemini", GeminiLanguageModel(apiKey: "k", model: "m")),
|
||||||
|
("Ollama", OllamaLanguageModel(model: "m")),
|
||||||
|
]
|
||||||
|
|
||||||
|
for (name, model) in models {
|
||||||
|
#expect(model.isAvailable, "\(name) should be available by default")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MARK: - 13. On-Device Pipeline Status
|
||||||
|
|
||||||
|
@Test("On-device inference pipeline starts in notLoaded state")
|
||||||
|
func onDevicePipelineInitialState() {
|
||||||
|
let mlxPipeline = OnDeviceInferencePipeline(
|
||||||
|
mlxConfiguration: .llama3_2_3B()
|
||||||
|
)
|
||||||
|
#expect(mlxPipeline.status == .notLoaded)
|
||||||
|
#expect(mlxPipeline.providerType == .mlx)
|
||||||
|
|
||||||
|
let coreMLPipeline = OnDeviceInferencePipeline(
|
||||||
|
coreMLConfiguration: CoreMLProviderConfiguration(
|
||||||
|
modelURL: URL(fileURLWithPath: "/tmp/test.mlmodelc")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
#expect(coreMLPipeline.status == .notLoaded)
|
||||||
|
#expect(coreMLPipeline.providerType == .coreML)
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test("On-device SQL generation hints are populated for both provider types")
|
||||||
|
func onDeviceSQLHints() {
|
||||||
|
let mlxPipeline = OnDeviceInferencePipeline(mlxConfiguration: .llama3_2_3B())
|
||||||
|
let mlxHints = mlxPipeline.recommendedSQLGenerationHints
|
||||||
|
#expect(mlxHints.maxTokens > 0)
|
||||||
|
#expect(mlxHints.temperature >= 0)
|
||||||
|
#expect(!mlxHints.systemPromptSuffix.isEmpty)
|
||||||
|
|
||||||
|
let coreMLPipeline = OnDeviceInferencePipeline(
|
||||||
|
coreMLConfiguration: CoreMLProviderConfiguration(
|
||||||
|
modelURL: URL(fileURLWithPath: "/tmp/test.mlmodelc")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
let coreMLHints = coreMLPipeline.recommendedSQLGenerationHints
|
||||||
|
#expect(coreMLHints.maxTokens > 0)
|
||||||
|
#expect(coreMLHints.temperature >= 0)
|
||||||
|
#expect(!coreMLHints.systemPromptSuffix.isEmpty)
|
||||||
|
}
|
||||||
|
}
|
||||||
191
seed.yaml
Normal file
191
seed.yaml
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
###############################################################################
|
||||||
|
# SwiftDBAI — Seed Specification
|
||||||
|
# Generated from Socratic interview on 2026-04-03
|
||||||
|
###############################################################################
|
||||||
|
|
||||||
|
goal: >
|
||||||
|
Build SwiftDBAI — a Swift package that lets users chat with any SQLite
|
||||||
|
database using natural language. Built on AnyLanguageModel (HuggingFace)
|
||||||
|
for LLM-agnostic provider support and GRDB for SQLite access. Ships a
|
||||||
|
drop-in SwiftUI ChatView, a headless ChatEngine, auto-schema introspection,
|
||||||
|
and result rendering with text, tables, and Swift Charts.
|
||||||
|
|
||||||
|
constraints:
|
||||||
|
- Swift 6.1+ with strict concurrency checking
|
||||||
|
- Swift Package Manager only (no CocoaPods/Carthage)
|
||||||
|
- iOS 17.0+, macOS 14.0+, visionOS 1.0+
|
||||||
|
- GRDB for all SQLite access (no raw sqlite3 C API)
|
||||||
|
- AnyLanguageModel (HuggingFace) as the sole LLM abstraction layer
|
||||||
|
- Any SQLite database supported (not limited to SwiftData)
|
||||||
|
- No UIKit dependency — pure SwiftUI for the view layer
|
||||||
|
- No Apple Intelligence / Foundation Models dependency
|
||||||
|
- Developer provides their own GRDB DatabasePool or DatabaseQueue
|
||||||
|
- SQL operation allowlist configured by developer (SELECT, INSERT, UPDATE, DELETE)
|
||||||
|
- No SQL parser — allowlist check is sufficient for validation
|
||||||
|
- Zero telemetry — the package collects nothing
|
||||||
|
|
||||||
|
acceptance_criteria:
|
||||||
|
- Auto-introspects any SQLite database schema from sqlite_master (tables, columns, types, foreign keys)
|
||||||
|
- LLM generates valid SQL queries from natural language input
|
||||||
|
- Query results render as natural language text summaries
|
||||||
|
- Query results render as scrollable data tables
|
||||||
|
- Aggregate/numeric results render as Swift Charts
|
||||||
|
- Drop-in DataChatView works with minimal setup (< 10 lines of code)
|
||||||
|
- Headless ChatEngine supports programmatic query/response without UI
|
||||||
|
- Multi-turn conversation context maintained (follow-up queries work)
|
||||||
|
- Mutation operations (INSERT, UPDATE, DELETE) work when allowlisted
|
||||||
|
- Destructive operations require confirmation via ToolExecutionDelegate
|
||||||
|
- All AnyLanguageModel providers work (OpenAI, Anthropic, Gemini, Ollama, CoreML, MLX, llama.cpp)
|
||||||
|
- Error states (bad SQL, LLM failure, schema mismatch) handled gracefully in UI
|
||||||
|
- Package binary adds < 2 MB to app (excluding LLM model weights)
|
||||||
|
|
||||||
|
ontology_schema:
|
||||||
|
name: SwiftDBAI
|
||||||
|
description: >
|
||||||
|
Domain model for a conversational SQLite query engine that bridges
|
||||||
|
natural language, LLM tool calls, SQL generation, and result rendering.
|
||||||
|
fields:
|
||||||
|
- name: database_schema
|
||||||
|
type: object
|
||||||
|
description: >
|
||||||
|
Auto-introspected SQLite schema containing tables, columns
|
||||||
|
(name, type, nullable, default), primary keys, foreign keys,
|
||||||
|
and indexes. Derived from sqlite_master at init time.
|
||||||
|
|
||||||
|
- name: chat_engine
|
||||||
|
type: object
|
||||||
|
description: >
|
||||||
|
Orchestrator that manages the conversation loop: accepts user
|
||||||
|
messages, injects schema context into the LLM system prompt,
|
||||||
|
calls AnyLanguageModel LanguageModelSession with registered
|
||||||
|
SQL tools, executes resulting SQL, and formats responses.
|
||||||
|
|
||||||
|
- name: sql_tools
|
||||||
|
type: array
|
||||||
|
description: >
|
||||||
|
AnyLanguageModel Tool conformances: QueryTool (SELECT),
|
||||||
|
InsertTool (INSERT), UpdateTool (UPDATE), DeleteTool (DELETE),
|
||||||
|
AggregateTool (COUNT/SUM/AVG/MIN/MAX). Each uses @Generable
|
||||||
|
arguments. Developer's allowlist controls which tools are active.
|
||||||
|
|
||||||
|
- name: query_result
|
||||||
|
type: object
|
||||||
|
description: >
|
||||||
|
Structured result from SQL execution containing: raw rows
|
||||||
|
(array of dictionaries), column metadata, row count, execution
|
||||||
|
time, and the generated SQL string for transparency.
|
||||||
|
|
||||||
|
- name: chat_message
|
||||||
|
type: object
|
||||||
|
description: >
|
||||||
|
A single message in the conversation. Has a role (user/assistant/system),
|
||||||
|
text content, optional query_result attachment, optional chart_data
|
||||||
|
for Swift Charts rendering, and a timestamp.
|
||||||
|
|
||||||
|
- name: chat_view
|
||||||
|
type: object
|
||||||
|
description: >
|
||||||
|
Drop-in SwiftUI view that renders chat_messages with three
|
||||||
|
result modes: text summary, data table, and Swift Charts.
|
||||||
|
Handles input, loading states, error display, and confirmation
|
||||||
|
dialogs for mutations. Themeable.
|
||||||
|
|
||||||
|
- name: operation_allowlist
|
||||||
|
type: object
|
||||||
|
description: >
|
||||||
|
Developer-configured set of permitted SQL operations. Maps to
|
||||||
|
mutationPolicy: readOnly (SELECT only), standard (SELECT + INSERT +
|
||||||
|
UPDATE), unrestricted (all including DELETE), or custom set.
|
||||||
|
|
||||||
|
- name: llm_provider
|
||||||
|
type: object
|
||||||
|
description: >
|
||||||
|
Any AnyLanguageModel-compatible language model instance passed
|
||||||
|
by the developer. The kit never instantiates providers itself —
|
||||||
|
developer chooses and configures their preferred backend.
|
||||||
|
|
||||||
|
evaluation_principles:
|
||||||
|
- name: zero_config_reads
|
||||||
|
description: >
|
||||||
|
A developer with an existing SQLite database should be able to
|
||||||
|
chat with their data by providing only a GRDB connection and an
|
||||||
|
AnyLanguageModel instance. No schema files, no annotations, no setup.
|
||||||
|
weight: 0.25
|
||||||
|
|
||||||
|
- name: sql_correctness
|
||||||
|
description: >
|
||||||
|
Generated SQL must be valid for the target schema. Column names,
|
||||||
|
table names, and types must match the introspected schema. Queries
|
||||||
|
should not reference non-existent tables or columns.
|
||||||
|
weight: 0.25
|
||||||
|
|
||||||
|
- name: safety_by_default
|
||||||
|
description: >
|
||||||
|
The default configuration must be read-only (SELECT only). Write
|
||||||
|
operations require explicit opt-in via the allowlist. Destructive
|
||||||
|
operations (DELETE, DROP) must require confirmation even when allowed.
|
||||||
|
weight: 0.20
|
||||||
|
|
||||||
|
- name: provider_agnosticism
|
||||||
|
description: >
|
||||||
|
All features must work identically regardless of which AnyLanguageModel
|
||||||
|
provider is used. No provider-specific code paths in the core kit.
|
||||||
|
Tool definitions must use standard AnyLanguageModel Tool protocol.
|
||||||
|
weight: 0.15
|
||||||
|
|
||||||
|
- name: rendering_quality
|
||||||
|
description: >
|
||||||
|
Results must be presented clearly: text summaries are natural and
|
||||||
|
concise, data tables are scrollable and readable, charts are
|
||||||
|
auto-selected based on data shape (bar for categories, line for
|
||||||
|
time series, pie for proportions).
|
||||||
|
weight: 0.15
|
||||||
|
|
||||||
|
exit_conditions:
|
||||||
|
- name: core_query_loop
|
||||||
|
description: User can type a natural language question and get correct SQL results
|
||||||
|
criteria: >
|
||||||
|
End-to-end works: NL input → schema context → LLM tool call →
|
||||||
|
SQL generation → GRDB execution → formatted response in ChatView
|
||||||
|
|
||||||
|
- name: all_result_modes
|
||||||
|
description: All three rendering modes (text, table, chart) work
|
||||||
|
criteria: >
|
||||||
|
Text summaries render for all queries. Data tables render for
|
||||||
|
multi-row results. Swift Charts render for numeric/aggregate data.
|
||||||
|
|
||||||
|
- name: mutation_safety
|
||||||
|
description: Write operations are safe and controllable
|
||||||
|
criteria: >
|
||||||
|
Allowlist correctly blocks disallowed operations. Confirmation
|
||||||
|
dialog appears for destructive ops. Mutations execute correctly
|
||||||
|
when confirmed.
|
||||||
|
|
||||||
|
- name: multi_provider
|
||||||
|
description: Works with at least 3 different AnyLanguageModel providers
|
||||||
|
criteria: >
|
||||||
|
Tested with OpenAI, Anthropic, and Ollama. Same queries produce
|
||||||
|
equivalent results across providers.
|
||||||
|
|
||||||
|
- name: integration_time
|
||||||
|
description: New developer can integrate in under 5 minutes
|
||||||
|
criteria: >
|
||||||
|
README example compiles and runs. DataChatView renders and
|
||||||
|
responds to queries with < 10 lines of setup code.
|
||||||
|
|
||||||
|
metadata:
|
||||||
|
version: "1.0"
|
||||||
|
generated: "2026-04-03"
|
||||||
|
source: socratic_interview
|
||||||
|
ambiguity_score: 0.15
|
||||||
|
interview_decisions:
|
||||||
|
data_layer: "All SQL via GRDB (not SwiftData APIs)"
|
||||||
|
scope: "Any SQLite database"
|
||||||
|
schema_discovery: "Auto-introspect from sqlite_master"
|
||||||
|
safety_model: "Developer-configured operation allowlist"
|
||||||
|
sql_validation: "Allowlist check only (no SQL parser)"
|
||||||
|
connection_model: "Developer passes GRDB DatabasePool/DatabaseQueue"
|
||||||
|
ui_rendering: "Text + data table + Swift Charts"
|
||||||
|
target_audience: "Both app developers (ChatView) and debuggers (headless)"
|
||||||
|
timeline: "6-8 weeks, polished release"
|
||||||
|
llm_layer: "AnyLanguageModel (HuggingFace)"
|
||||||
Reference in New Issue
Block a user