Initial implementation of SwiftDBAI

Chat with any SQLite database using natural language. Built on
AnyLanguageModel (HuggingFace) for LLM-agnostic provider support
and GRDB for SQLite access.

Core features:
- Auto schema introspection from sqlite_master (zero config)
- NL → SQL generation via any AnyLanguageModel provider
- Three rendering modes: text summary, data table, Swift Charts
- Drop-in DataChatView (SwiftUI) and headless ChatEngine
- Operation allowlist with read-only default
- Mutation policy with per-table control
- ToolExecutionDelegate for destructive operation confirmation
- Multi-turn conversation context
- 352 tests across 24 suites, all passing

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Krishna Kumar
2026-04-04 09:30:56 -05:00
commit b1724fe7ca
55 changed files with 15506 additions and 0 deletions

View File

@@ -0,0 +1,113 @@
// ChatEngineConfiguration.swift
// SwiftDBAI
//
// Configurable settings for ChatEngine behavior timeouts, context window,
// summary limits, and custom query validation.
import Foundation
/// Configuration for ``ChatEngine`` behavior.
///
/// Use this to tune timeouts, conversation context windows, and attach
/// custom query validators.
///
/// ```swift
/// var config = ChatEngineConfiguration()
/// config.queryTimeout = 10 // 10-second SQL timeout
/// config.contextWindowSize = 20 // Keep last 20 messages for LLM context
/// config.maxSummaryRows = 100 // Summarize up to 100 rows
///
/// let engine = ChatEngine(
/// database: db,
/// model: model,
/// configuration: config
/// )
/// ```
public struct ChatEngineConfiguration: Sendable {
// MARK: - Query Execution
/// Maximum time (in seconds) to wait for a SQL query to execute.
///
/// If the query exceeds this duration, a ``ChatEngineError/queryTimedOut``
/// error is thrown. Set to `nil` to disable the timeout (not recommended
/// for user-facing apps). Defaults to 30 seconds.
public var queryTimeout: TimeInterval?
// MARK: - Conversation Context
/// Maximum number of conversation messages to include when building
/// LLM context for follow-up queries.
///
/// Only the most recent `contextWindowSize` messages are sent to the LLM.
/// Older messages are still retained in ``ChatEngine/messages`` for UI
/// display but do not consume LLM tokens.
///
/// Set to `nil` for unlimited context (all history is always sent).
/// Defaults to 50 messages.
public var contextWindowSize: Int?
// MARK: - Rendering
/// Maximum number of rows to include when generating text summaries.
/// Defaults to 50.
public var maxSummaryRows: Int
// MARK: - LLM Context
/// Optional extra instructions appended to the LLM system prompt.
///
/// Use this to provide business-specific terminology, query hints,
/// or domain constraints. For example:
/// ```swift
/// config.additionalContext = "The 'status' column uses: 'active', 'inactive', 'suspended'."
/// ```
public var additionalContext: String?
// MARK: - Validation
/// Custom query validators that run after the built-in allowlist check.
///
/// Use ``addValidator(_:)`` to add validators. They are executed in order;
/// the first validator to throw stops execution.
public private(set) var validators: [any QueryValidator] = []
// MARK: - Initialization
/// Creates a configuration with the given settings.
///
/// - Parameters:
/// - queryTimeout: SQL execution timeout in seconds. Defaults to 30.
/// - contextWindowSize: Max messages for LLM context. Defaults to 50.
/// - maxSummaryRows: Max rows for text summaries. Defaults to 50.
/// - additionalContext: Extra LLM system prompt instructions.
public init(
queryTimeout: TimeInterval? = 30,
contextWindowSize: Int? = 50,
maxSummaryRows: Int = 50,
additionalContext: String? = nil
) {
self.queryTimeout = queryTimeout
self.contextWindowSize = contextWindowSize
self.maxSummaryRows = maxSummaryRows
self.additionalContext = additionalContext
}
/// The default configuration: 30s timeout, 50-message context window,
/// 50-row summaries, no additional context, no custom validators.
public static let `default` = ChatEngineConfiguration()
// MARK: - Mutating Helpers
/// Appends a custom query validator.
///
/// Validators run after the built-in allowlist and dangerous-keyword checks.
/// They receive the parsed SQL and can throw to reject a query.
///
/// ```swift
/// config.addValidator(TableAllowlistValidator(allowedTables: ["users", "orders"]))
/// ```
public mutating func addValidator(_ validator: any QueryValidator) {
validators.append(validator)
}
}

View File

@@ -0,0 +1,336 @@
// LocalProviderConfiguration.swift
// SwiftDBAI
//
// Configuration and endpoint discovery for local/self-hosted LLM providers
// (Ollama, llama.cpp). Wraps AnyLanguageModel's OllamaLanguageModel and
// OpenAILanguageModel with convenient factory methods and health checking.
import AnyLanguageModel
import Foundation
#if canImport(FoundationNetworking)
import FoundationNetworking
#endif
// MARK: - Local Provider Endpoint
/// Represents a discovered local LLM endpoint with its connection status.
public struct LocalProviderEndpoint: Sendable, Equatable {
/// The base URL of the local provider.
public let baseURL: URL
/// The provider type (Ollama or llama.cpp).
public let providerType: LocalProviderType
/// Whether the endpoint was reachable at discovery time.
public let isReachable: Bool
/// The list of available models, if the endpoint supports model listing.
public let availableModels: [String]
/// Human-readable description of the endpoint.
public var description: String {
let status = isReachable ? "reachable" : "unreachable"
return "\(providerType.rawValue) at \(baseURL.absoluteString) (\(status), \(availableModels.count) models)"
}
}
/// The type of local LLM provider.
public enum LocalProviderType: String, Sendable, Hashable, CaseIterable {
/// Ollama runs models locally via `ollama serve`.
/// Default endpoint: http://localhost:11434
case ollama
/// llama.cpp server runs GGUF models via `llama-server`.
/// Default endpoint: http://localhost:8080
/// Exposes an OpenAI-compatible API.
case llamaCpp = "llama.cpp"
}
// MARK: - Local Provider Discovery
/// Discovers and validates local LLM provider endpoints.
///
/// Use `LocalProviderDiscovery` to automatically find running Ollama or llama.cpp
/// instances on the local machine, check their health, and list available models.
///
/// ```swift
/// // Check if Ollama is running
/// let isRunning = await LocalProviderDiscovery.isOllamaRunning()
///
/// // Discover all local providers
/// let endpoints = await LocalProviderDiscovery.discoverAll()
/// for endpoint in endpoints where endpoint.isReachable {
/// print("Found \(endpoint.description)")
/// }
///
/// // List models available on Ollama
/// let models = await LocalProviderDiscovery.listOllamaModels()
/// ```
public enum LocalProviderDiscovery {
/// Default Ollama endpoint.
public static let defaultOllamaURL = URL(string: "http://localhost:11434")!
/// Default llama.cpp server endpoint.
public static let defaultLlamaCppURL = URL(string: "http://localhost:8080")!
/// Well-known ports to probe for local providers.
/// Ollama: 11434, llama.cpp: 8080
private static let wellKnownEndpoints: [(URL, LocalProviderType)] = [
(defaultOllamaURL, .ollama),
(defaultLlamaCppURL, .llamaCpp),
]
// MARK: - Health Checks
/// Checks if an Ollama instance is reachable at the given URL.
///
/// Sends a GET request to the Ollama root endpoint and checks for a 200 response.
///
/// - Parameter baseURL: The Ollama base URL. Defaults to `http://localhost:11434`.
/// - Parameter timeout: Connection timeout in seconds. Defaults to 3.
/// - Returns: `true` if the Ollama server responded successfully.
public static func isOllamaRunning(
at baseURL: URL = defaultOllamaURL,
timeout: TimeInterval = 3
) async -> Bool {
await checkEndpointHealth(baseURL, timeout: timeout)
}
/// Checks if a llama.cpp server is reachable at the given URL.
///
/// Sends a GET request to the `/health` endpoint and checks for a 200 response.
///
/// - Parameter baseURL: The llama.cpp base URL. Defaults to `http://localhost:8080`.
/// - Parameter timeout: Connection timeout in seconds. Defaults to 3.
/// - Returns: `true` if the llama.cpp server responded successfully.
public static func isLlamaCppRunning(
at baseURL: URL = defaultLlamaCppURL,
timeout: TimeInterval = 3
) async -> Bool {
let healthURL = baseURL.appendingPathComponent("health")
return await checkEndpointHealth(healthURL, timeout: timeout)
}
/// Checks if any endpoint at the given URL responds to HTTP requests.
///
/// - Parameters:
/// - url: The URL to probe.
/// - timeout: Connection timeout in seconds.
/// - Returns: `true` if the endpoint returned an HTTP response with status 200.
private static func checkEndpointHealth(
_ url: URL,
timeout: TimeInterval
) async -> Bool {
let config = URLSessionConfiguration.ephemeral
config.timeoutIntervalForRequest = timeout
config.timeoutIntervalForResource = timeout
let session = URLSession(configuration: config)
do {
let (_, response) = try await session.data(from: url)
if let httpResponse = response as? HTTPURLResponse {
return httpResponse.statusCode == 200
}
return false
} catch {
return false
}
}
// MARK: - Model Listing
/// Lists models available on an Ollama instance.
///
/// Calls the Ollama `/api/tags` endpoint to retrieve the list of
/// locally installed models.
///
/// - Parameter baseURL: The Ollama base URL. Defaults to `http://localhost:11434`.
/// - Parameter timeout: Request timeout in seconds. Defaults to 5.
/// - Returns: An array of model name strings, or an empty array if unreachable.
public static func listOllamaModels(
at baseURL: URL = defaultOllamaURL,
timeout: TimeInterval = 5
) async -> [String] {
let tagsURL = baseURL.appendingPathComponent("api/tags")
let config = URLSessionConfiguration.ephemeral
config.timeoutIntervalForRequest = timeout
config.timeoutIntervalForResource = timeout
let session = URLSession(configuration: config)
do {
let (data, response) = try await session.data(from: tagsURL)
guard let httpResponse = response as? HTTPURLResponse,
httpResponse.statusCode == 200
else {
return []
}
let decoded = try JSONDecoder().decode(OllamaTagsResponse.self, from: data)
return decoded.models.map(\.name)
} catch {
return []
}
}
/// Lists models available on a llama.cpp server via its OpenAI-compatible endpoint.
///
/// Calls `/v1/models` which llama.cpp exposes when running with
/// `--api-key` or in default mode.
///
/// - Parameter baseURL: The llama.cpp base URL. Defaults to `http://localhost:8080`.
/// - Parameter timeout: Request timeout in seconds. Defaults to 5.
/// - Returns: An array of model ID strings, or an empty array if unreachable.
public static func listLlamaCppModels(
at baseURL: URL = defaultLlamaCppURL,
timeout: TimeInterval = 5
) async -> [String] {
let modelsURL = baseURL.appendingPathComponent("v1/models")
let config = URLSessionConfiguration.ephemeral
config.timeoutIntervalForRequest = timeout
config.timeoutIntervalForResource = timeout
let session = URLSession(configuration: config)
do {
let (data, response) = try await session.data(from: modelsURL)
guard let httpResponse = response as? HTTPURLResponse,
httpResponse.statusCode == 200
else {
return []
}
let decoded = try JSONDecoder().decode(OpenAIModelsResponse.self, from: data)
return decoded.data.map(\.id)
} catch {
return []
}
}
// MARK: - Full Discovery
/// Discovers all running local LLM providers by probing well-known endpoints.
///
/// Probes Ollama (port 11434) and llama.cpp (port 8080) concurrently,
/// returning their status and available models.
///
/// ```swift
/// let endpoints = await LocalProviderDiscovery.discoverAll()
/// for endpoint in endpoints where endpoint.isReachable {
/// print("Found: \(endpoint.description)")
/// }
/// ```
///
/// - Parameter timeout: Connection timeout per endpoint in seconds. Defaults to 3.
/// - Returns: An array of `LocalProviderEndpoint` for each probed location.
public static func discoverAll(
timeout: TimeInterval = 3
) async -> [LocalProviderEndpoint] {
await withTaskGroup(of: LocalProviderEndpoint.self, returning: [LocalProviderEndpoint].self) { group in
for (url, providerType) in wellKnownEndpoints {
group.addTask {
await discover(providerType: providerType, at: url, timeout: timeout)
}
}
var results: [LocalProviderEndpoint] = []
for await endpoint in group {
results.append(endpoint)
}
return results
}
}
/// Discovers a specific local provider at the given URL.
///
/// - Parameters:
/// - providerType: The type of provider to probe.
/// - baseURL: The base URL to check.
/// - timeout: Connection timeout in seconds.
/// - Returns: A `LocalProviderEndpoint` with reachability and model info.
public static func discover(
providerType: LocalProviderType,
at baseURL: URL,
timeout: TimeInterval = 3
) async -> LocalProviderEndpoint {
switch providerType {
case .ollama:
let reachable = await isOllamaRunning(at: baseURL, timeout: timeout)
let models = reachable ? await listOllamaModels(at: baseURL) : []
return LocalProviderEndpoint(
baseURL: baseURL,
providerType: .ollama,
isReachable: reachable,
availableModels: models
)
case .llamaCpp:
let reachable = await isLlamaCppRunning(at: baseURL, timeout: timeout)
let models = reachable ? await listLlamaCppModels(at: baseURL) : []
return LocalProviderEndpoint(
baseURL: baseURL,
providerType: .llamaCpp,
isReachable: reachable,
availableModels: models
)
}
}
/// Discovers a specific local provider at a custom URL and port.
///
/// Use this for non-standard configurations where Ollama or llama.cpp
/// is running on a custom host or port.
///
/// ```swift
/// let endpoint = await LocalProviderDiscovery.discover(
/// providerType: .ollama,
/// host: "192.168.1.100",
/// port: 11434
/// )
/// ```
///
/// - Parameters:
/// - providerType: The provider type.
/// - host: The hostname or IP address.
/// - port: The port number.
/// - timeout: Connection timeout in seconds. Defaults to 3.
/// - Returns: A `LocalProviderEndpoint` with reachability and model info.
public static func discover(
providerType: LocalProviderType,
host: String,
port: Int,
timeout: TimeInterval = 3
) async -> LocalProviderEndpoint {
guard let url = URL(string: "http://\(host):\(port)") else {
return LocalProviderEndpoint(
baseURL: URL(string: "http://\(host):\(port)")!,
providerType: providerType,
isReachable: false,
availableModels: []
)
}
return await discover(providerType: providerType, at: url, timeout: timeout)
}
}
// MARK: - JSON Response Types
/// Response from Ollama's `/api/tags` endpoint.
private struct OllamaTagsResponse: Decodable, Sendable {
let models: [OllamaModelInfo]
}
/// Individual model info from Ollama's tags endpoint.
private struct OllamaModelInfo: Decodable, Sendable {
let name: String
}
/// Response from the OpenAI-compatible `/v1/models` endpoint.
private struct OpenAIModelsResponse: Decodable, Sendable {
let data: [OpenAIModelInfo]
}
/// Individual model info from the OpenAI models endpoint.
private struct OpenAIModelInfo: Decodable, Sendable {
let id: String
}

View File

@@ -0,0 +1,148 @@
// MutationPolicy.swift
// SwiftDBAI
//
// Defines which mutation operations are permitted and optionally restricts
// them to specific tables. Wraps OperationAllowlist with table-level granularity.
import Foundation
/// Controls which SQL mutation operations the LLM may generate and,
/// optionally, which tables those mutations may target.
///
/// `MutationPolicy` builds on ``OperationAllowlist`` by adding per-table
/// restrictions. The default policy is **read-only** no mutations are
/// allowed on any table. Write operations require explicit opt-in.
///
/// ```swift
/// // Read-only (default) only SELECT is allowed
/// let readOnly = MutationPolicy.readOnly
///
/// // Allow INSERT and UPDATE on specific tables only
/// let restricted = MutationPolicy(
/// allowedOperations: [.insert, .update],
/// allowedTables: ["orders", "order_items"]
/// )
///
/// // Allow INSERT and UPDATE on all tables
/// let broad = MutationPolicy(allowedOperations: [.insert, .update])
///
/// // Full access including DELETE (requires confirmation)
/// let full = MutationPolicy.unrestricted
/// ```
public struct MutationPolicy: Sendable, Equatable {
// MARK: - Properties
/// The underlying operation allowlist (always includes SELECT).
public let operationAllowlist: OperationAllowlist
/// Optional set of table names that mutations may target.
///
/// When `nil`, mutations are allowed on all tables (subject to
/// ``operationAllowlist``). When non-nil, mutation operations
/// (INSERT, UPDATE, DELETE) are only permitted on the listed tables.
/// SELECT queries are never restricted by this property.
public let allowedMutationTables: Set<String>?
/// When `true`, destructive operations (DELETE) require explicit user
/// confirmation before execution, even when the operation is allowed.
/// Defaults to `true`.
public let requiresDestructiveConfirmation: Bool
// MARK: - Initialization
/// Creates a mutation policy with the given operations and optional table restrictions.
///
/// SELECT is always implicitly included you cannot create a policy
/// that disallows reads.
///
/// - Parameters:
/// - allowedOperations: The mutation operations to permit (INSERT, UPDATE, DELETE).
/// SELECT is always allowed regardless of this parameter.
/// - allowedTables: Optional set of table names mutations may target.
/// Pass `nil` to allow mutations on all tables. Defaults to `nil`.
/// - requiresDestructiveConfirmation: Whether DELETE requires user confirmation.
/// Defaults to `true`.
public init(
allowedOperations: Set<SQLOperation> = [],
allowedTables: Set<String>? = nil,
requiresDestructiveConfirmation: Bool = true
) {
// Always include SELECT
var ops = allowedOperations
ops.insert(.select)
self.operationAllowlist = OperationAllowlist(ops)
self.allowedMutationTables = allowedTables
self.requiresDestructiveConfirmation = requiresDestructiveConfirmation
}
// MARK: - Presets
/// Read-only policy: only SELECT queries are allowed. This is the default.
public static let readOnly = MutationPolicy()
/// Standard read-write: SELECT, INSERT, and UPDATE on all tables.
public static let readWrite = MutationPolicy(
allowedOperations: [.insert, .update]
)
/// Unrestricted: all operations including DELETE on all tables.
/// DELETE still requires confirmation by default.
public static let unrestricted = MutationPolicy(
allowedOperations: [.insert, .update, .delete]
)
// MARK: - Validation
/// Returns `true` if the given operation is permitted by this policy.
public func isOperationAllowed(_ operation: SQLOperation) -> Bool {
operationAllowlist.isAllowed(operation)
}
/// Returns `true` if the given mutation operation is permitted on the
/// specified table.
///
/// SELECT operations always return `true` regardless of table restrictions.
/// For mutation operations, this checks both the operation allowlist and
/// the table restrictions (if any).
///
/// - Parameters:
/// - operation: The SQL operation type.
/// - table: The target table name (case-insensitive comparison).
/// - Returns: Whether the operation is allowed on the given table.
public func isAllowed(operation: SQLOperation, on table: String) -> Bool {
// SELECT is always allowed
guard operation != .select else { return true }
// Check operation allowlist first
guard operationAllowlist.isAllowed(operation) else { return false }
// If no table restrictions, the operation is allowed
guard let allowedTables = allowedMutationTables else { return true }
// Case-insensitive table name check
let lowerTable = table.lowercased()
return allowedTables.contains { $0.lowercased() == lowerTable }
}
/// Returns `true` if the given operation requires user confirmation.
public func requiresConfirmation(for operation: SQLOperation) -> Bool {
operation == .delete && requiresDestructiveConfirmation
}
/// Returns a human-readable description for inclusion in the LLM system prompt.
func describeForLLM() -> String {
var desc = operationAllowlist.describeForLLM()
if let tables = allowedMutationTables, !tables.isEmpty {
let sorted = tables.sorted()
desc += " Mutations (INSERT/UPDATE/DELETE) are restricted to these tables only: \(sorted.joined(separator: ", "))."
}
if requiresDestructiveConfirmation && operationAllowlist.isAllowed(.delete) {
desc += " DELETE operations require user confirmation before execution."
}
return desc
}
}

View File

@@ -0,0 +1,866 @@
// OnDeviceProviderConfiguration.swift
// SwiftDBAI
//
// Configuration for on-device LLM providers (CoreML, MLX) that run models
// locally on Apple silicon. These providers enable fully offline,
// privacy-sensitive deployments where no data leaves the device.
//
// Both CoreML and MLX models are provided by AnyLanguageModel behind
// conditional compilation flags (#if CoreML, #if MLX). This configuration
// layer wraps their setup with convenient factory methods and integrates
// them into the SwiftDBAI ChatEngine pipeline.
import AnyLanguageModel
import Foundation
import GRDB
// MARK: - On-Device Provider Type
/// The type of on-device LLM provider.
public enum OnDeviceProviderType: String, Sendable, Hashable, CaseIterable {
/// CoreML runs compiled .mlmodelc models on-device using Apple's CoreML framework.
/// Requires pre-compiled models and supports CPU, GPU, and Neural Engine compute units.
case coreML
/// MLX runs HuggingFace models on Apple silicon using the MLX framework.
/// Models are automatically downloaded and cached. Supports quantized models
/// (e.g., 4-bit) for efficient memory usage.
case mlx
}
// MARK: - CoreML Configuration
/// Configuration for loading and running a CoreML language model on-device.
///
/// CoreML models must be pre-compiled to `.mlmodelc` format before use.
/// The model runs entirely on-device using CPU, GPU, and/or Neural Engine
/// depending on the `computeUnits` setting.
///
/// ```swift
/// let config = CoreMLProviderConfiguration(
/// modelURL: Bundle.main.url(forResource: "MyModel", withExtension: "mlmodelc")!,
/// computeUnits: .all
/// )
/// ```
///
/// - Note: CoreML models are available behind the `#if CoreML` flag in AnyLanguageModel.
/// Ensure your project enables the CoreML build condition.
public struct CoreMLProviderConfiguration: Sendable, Equatable {
/// The URL to the compiled CoreML model (`.mlmodelc`).
public let modelURL: URL
/// The compute units to use for inference.
///
/// - `.all`: Uses the best available hardware (Neural Engine, GPU, CPU).
/// - `.cpuOnly`: Forces CPU-only inference. Useful for debugging.
/// - `.cpuAndGPU`: Uses CPU and GPU but not the Neural Engine.
/// - `.cpuAndNeuralEngine`: Uses CPU and Neural Engine.
public let computeUnits: ComputeUnitPreference
/// Maximum number of tokens the model can generate per response.
/// Defaults to 2048.
public let maxResponseTokens: Int
/// Whether to use sampling (true) or greedy decoding (false).
/// Defaults to false (greedy) for more deterministic SQL generation.
public let useSampling: Bool
/// Temperature for sampling. Only used when `useSampling` is true.
/// Lower values produce more focused output. Defaults to 0.1.
public let temperature: Double
/// Creates a CoreML provider configuration.
///
/// - Parameters:
/// - modelURL: The URL to a compiled CoreML model (`.mlmodelc`).
/// - computeUnits: The compute units to use. Defaults to `.all`.
/// - maxResponseTokens: Maximum tokens per response. Defaults to 2048.
/// - useSampling: Whether to use sampling vs greedy decoding. Defaults to false.
/// - temperature: Sampling temperature. Defaults to 0.1.
public init(
modelURL: URL,
computeUnits: ComputeUnitPreference = .all,
maxResponseTokens: Int = 2048,
useSampling: Bool = false,
temperature: Double = 0.1
) {
self.modelURL = modelURL
self.computeUnits = computeUnits
self.maxResponseTokens = maxResponseTokens
self.useSampling = useSampling
self.temperature = temperature
}
/// Validates that the model URL points to a compiled CoreML model.
///
/// - Throws: ``OnDeviceProviderError`` if the URL is invalid.
public func validate() throws {
guard modelURL.pathExtension == "mlmodelc" else {
throw OnDeviceProviderError.invalidModelFormat(
expected: ".mlmodelc",
actual: modelURL.pathExtension
)
}
guard FileManager.default.fileExists(atPath: modelURL.path) else {
throw OnDeviceProviderError.modelNotFound(modelURL)
}
}
}
/// Compute unit preference for CoreML inference.
///
/// Maps to `MLComputeUnits` in the CoreML framework.
public enum ComputeUnitPreference: String, Sendable, Hashable, CaseIterable {
/// Use all available compute units (Neural Engine, GPU, CPU).
/// This is the recommended setting for production use.
case all
/// Force CPU-only execution. Useful for debugging or testing.
case cpuOnly
/// Use CPU and GPU, but not the Neural Engine.
case cpuAndGPU
/// Use CPU and Neural Engine, but not the GPU.
case cpuAndNeuralEngine
}
// MARK: - MLX Configuration
/// Configuration for loading and running an MLX language model on Apple silicon.
///
/// MLX models are loaded from HuggingFace Hub or a local directory. The MLX
/// framework provides efficient inference on Apple silicon with support for
/// quantized models (4-bit, 8-bit) for reduced memory usage.
///
/// ```swift
/// // From HuggingFace Hub (auto-downloaded)
/// let config = MLXProviderConfiguration(
/// modelId: "mlx-community/Llama-3.2-3B-Instruct-4bit"
/// )
///
/// // From a local directory
/// let config = MLXProviderConfiguration(
/// modelId: "my-local-model",
/// localDirectory: URL(fileURLWithPath: "/path/to/model")
/// )
/// ```
///
/// - Note: MLX models are available behind the `#if MLX` flag in AnyLanguageModel.
/// Ensure your project enables the MLX build condition.
public struct MLXProviderConfiguration: Sendable, Equatable {
/// The HuggingFace model identifier (e.g., "mlx-community/Llama-3.2-3B-Instruct-4bit").
public let modelId: String
/// Optional local directory containing the model files.
/// When set, the model is loaded from this directory instead of downloading from Hub.
public let localDirectory: URL?
/// GPU memory management configuration.
public let gpuMemory: MLXGPUMemoryConfig
/// Maximum number of tokens the model can generate per response.
/// Defaults to 2048.
public let maxResponseTokens: Int
/// Temperature for text generation. Lower values produce more deterministic output.
/// Defaults to 0.1 for SQL generation accuracy.
public let temperature: Double
/// Top-P (nucleus) sampling threshold. Only tokens with cumulative probability
/// below this threshold are considered. Defaults to 0.95.
public let topP: Double
/// Repetition penalty to reduce repetitive output. Defaults to 1.1.
public let repetitionPenalty: Double
/// Creates an MLX provider configuration.
///
/// - Parameters:
/// - modelId: The HuggingFace model ID or local identifier.
/// - localDirectory: Optional path to a local model directory.
/// - gpuMemory: GPU memory configuration. Defaults to `.automatic`.
/// - maxResponseTokens: Maximum tokens per response. Defaults to 2048.
/// - temperature: Generation temperature. Defaults to 0.1.
/// - topP: Top-P sampling threshold. Defaults to 0.95.
/// - repetitionPenalty: Repetition penalty. Defaults to 1.1.
public init(
modelId: String,
localDirectory: URL? = nil,
gpuMemory: MLXGPUMemoryConfig = .automatic,
maxResponseTokens: Int = 2048,
temperature: Double = 0.1,
topP: Double = 0.95,
repetitionPenalty: Double = 1.1
) {
self.modelId = modelId
self.localDirectory = localDirectory
self.gpuMemory = gpuMemory
self.maxResponseTokens = maxResponseTokens
self.temperature = temperature
self.topP = topP
self.repetitionPenalty = repetitionPenalty
}
/// Validates the configuration parameters.
///
/// - Throws: ``OnDeviceProviderError`` if the configuration is invalid.
public func validate() throws {
guard !modelId.isEmpty else {
throw OnDeviceProviderError.emptyModelId
}
if let dir = localDirectory {
guard FileManager.default.fileExists(atPath: dir.path) else {
throw OnDeviceProviderError.modelNotFound(dir)
}
}
guard temperature >= 0 else {
throw OnDeviceProviderError.invalidParameter(
name: "temperature",
value: "\(temperature)",
reason: "Must be non-negative"
)
}
guard topP > 0, topP <= 1.0 else {
throw OnDeviceProviderError.invalidParameter(
name: "topP",
value: "\(topP)",
reason: "Must be between 0 (exclusive) and 1.0 (inclusive)"
)
}
guard repetitionPenalty > 0 else {
throw OnDeviceProviderError.invalidParameter(
name: "repetitionPenalty",
value: "\(repetitionPenalty)",
reason: "Must be positive"
)
}
}
// MARK: - Well-Known Models
/// Pre-configured for Llama 3.2 3B Instruct (4-bit quantized).
/// Good balance of quality and memory usage (~2GB RAM).
public static func llama3_2_3B(
localDirectory: URL? = nil,
gpuMemory: MLXGPUMemoryConfig = .automatic
) -> MLXProviderConfiguration {
MLXProviderConfiguration(
modelId: "mlx-community/Llama-3.2-3B-Instruct-4bit",
localDirectory: localDirectory,
gpuMemory: gpuMemory,
maxResponseTokens: 2048,
temperature: 0.1
)
}
/// Pre-configured for Qwen 2.5 Coder 3B Instruct (4-bit quantized).
/// Optimized for code and SQL generation.
public static func qwen2_5_coder_3B(
localDirectory: URL? = nil,
gpuMemory: MLXGPUMemoryConfig = .automatic
) -> MLXProviderConfiguration {
MLXProviderConfiguration(
modelId: "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit",
localDirectory: localDirectory,
gpuMemory: gpuMemory,
maxResponseTokens: 2048,
temperature: 0.05
)
}
/// Pre-configured for Phi-3.5 Mini Instruct (4-bit quantized).
/// Compact model suitable for devices with limited memory (~1.5GB RAM).
public static func phi3_5_mini(
localDirectory: URL? = nil,
gpuMemory: MLXGPUMemoryConfig = .automatic
) -> MLXProviderConfiguration {
MLXProviderConfiguration(
modelId: "mlx-community/Phi-3.5-mini-instruct-4bit",
localDirectory: localDirectory,
gpuMemory: gpuMemory,
maxResponseTokens: 2048,
temperature: 0.1
)
}
}
/// GPU memory management configuration for MLX models.
///
/// Controls how aggressively the MLX runtime manages GPU buffer caches
/// during active generation and idle phases.
public struct MLXGPUMemoryConfig: Sendable, Equatable {
/// GPU cache limit (in bytes) during active generation.
public let activeCacheLimit: Int
/// GPU cache limit (in bytes) when idle.
public let idleCacheLimit: Int
/// Whether to clear cached GPU buffers when eviction is safe.
public let clearCacheOnEviction: Bool
/// Creates a GPU memory configuration.
///
/// - Parameters:
/// - activeCacheLimit: Cache limit during active generation (bytes).
/// - idleCacheLimit: Cache limit when idle (bytes).
/// - clearCacheOnEviction: Whether to clear cache on eviction.
public init(
activeCacheLimit: Int,
idleCacheLimit: Int,
clearCacheOnEviction: Bool = true
) {
self.activeCacheLimit = activeCacheLimit
self.idleCacheLimit = idleCacheLimit
self.clearCacheOnEviction = clearCacheOnEviction
}
/// Automatically determined based on device physical memory.
///
/// - Devices with <4GB RAM: 128MB active cache
/// - Devices with <6GB RAM: 256MB active cache
/// - Devices with <8GB RAM: 512MB active cache
/// - Devices with 8GB+ RAM: 768MB active cache
/// - Idle cache: 50MB for all devices
public static var automatic: MLXGPUMemoryConfig {
let ramBytes = ProcessInfo.processInfo.physicalMemory
let ramGB = ramBytes / (1024 * 1024 * 1024)
let active: Int
switch ramGB {
case ..<4:
active = 128_000_000
case ..<6:
active = 256_000_000
case ..<8:
active = 512_000_000
default:
active = 768_000_000
}
return .init(
activeCacheLimit: active,
idleCacheLimit: 50_000_000,
clearCacheOnEviction: true
)
}
/// Minimal memory configuration for constrained devices.
/// Uses 64MB active cache and 16MB idle cache.
public static var minimal: MLXGPUMemoryConfig {
.init(
activeCacheLimit: 64_000_000,
idleCacheLimit: 16_000_000,
clearCacheOnEviction: true
)
}
/// Unconstrained configuration for maximum performance.
/// Leaves GPU cache effectively unlimited. Use when your app
/// can afford maximum memory usage.
public static var unconstrained: MLXGPUMemoryConfig {
.init(
activeCacheLimit: Int.max,
idleCacheLimit: Int.max,
clearCacheOnEviction: false
)
}
}
// MARK: - On-Device Provider Errors
/// Errors specific to on-device provider configuration and model loading.
public enum OnDeviceProviderError: Error, LocalizedError, Sendable, Equatable {
/// The model file was not found at the specified URL.
case modelNotFound(URL)
/// The model file format is not what was expected.
case invalidModelFormat(expected: String, actual: String)
/// The model ID is empty.
case emptyModelId
/// A configuration parameter is invalid.
case invalidParameter(name: String, value: String, reason: String)
/// The on-device provider is not available on this platform.
/// CoreML requires macOS 15+ / iOS 18+. MLX requires the MLX build flag.
case providerUnavailable(OnDeviceProviderType, reason: String)
/// Model loading failed with an underlying error.
case modelLoadFailed(reason: String)
/// Model inference failed.
case inferenceFailed(reason: String)
public var errorDescription: String? {
switch self {
case .modelNotFound(let url):
return "On-device model not found at: \(url.path)"
case .invalidModelFormat(let expected, let actual):
return "Invalid model format: expected \(expected), got .\(actual)"
case .emptyModelId:
return "Model ID must not be empty"
case .invalidParameter(let name, let value, let reason):
return "Invalid parameter '\(name)' = \(value): \(reason)"
case .providerUnavailable(let type, let reason):
return "\(type.rawValue) provider unavailable: \(reason)"
case .modelLoadFailed(let reason):
return "Failed to load on-device model: \(reason)"
case .inferenceFailed(let reason):
return "On-device inference failed: \(reason)"
}
}
}
// MARK: - On-Device Inference Pipeline
/// Manages the on-device model inference pipeline.
///
/// `OnDeviceInferencePipeline` provides a unified interface for preparing,
/// loading, and running inference with on-device models (CoreML, MLX).
/// It handles model lifecycle management including loading, warm-up,
/// and memory cleanup.
///
/// ```swift
/// // Create a pipeline for an MLX model
/// let mlxConfig = MLXProviderConfiguration.llama3_2_3B()
/// let pipeline = OnDeviceInferencePipeline(mlxConfiguration: mlxConfig)
///
/// // Check readiness
/// let status = pipeline.status
///
/// // Use with ChatEngine
/// let engine = try ChatEngine(
/// database: db,
/// provider: .onDevice(mlx: mlxConfig)
/// )
/// ```
public final class OnDeviceInferencePipeline: @unchecked Sendable {
/// The current status of the on-device inference pipeline.
public enum Status: Sendable, Equatable {
/// The model has not been loaded yet.
case notLoaded
/// The model is currently being loaded/downloaded.
case loading
/// The model is loaded and ready for inference.
case ready
/// The model failed to load.
case failed(String)
}
/// The type of on-device provider this pipeline uses.
public let providerType: OnDeviceProviderType
/// The MLX configuration, if this is an MLX pipeline.
public let mlxConfiguration: MLXProviderConfiguration?
/// The CoreML configuration, if this is a CoreML pipeline.
public let coreMLConfiguration: CoreMLProviderConfiguration?
/// The current status of the pipeline.
private let _statusLock = NSLock()
private var _status: Status = .notLoaded
/// The current pipeline status.
public var status: Status {
_statusLock.lock()
defer { _statusLock.unlock() }
return _status
}
/// Creates an MLX inference pipeline.
///
/// - Parameter configuration: The MLX model configuration.
public init(mlxConfiguration: MLXProviderConfiguration) {
self.providerType = .mlx
self.mlxConfiguration = mlxConfiguration
self.coreMLConfiguration = nil
}
/// Creates a CoreML inference pipeline.
///
/// - Parameter configuration: The CoreML model configuration.
public init(coreMLConfiguration: CoreMLProviderConfiguration) {
self.providerType = .coreML
self.coreMLConfiguration = coreMLConfiguration
self.mlxConfiguration = nil
}
/// Validates the configuration before attempting to load.
///
/// Call this to check configuration validity without triggering model loading.
///
/// - Throws: ``OnDeviceProviderError`` if the configuration is invalid.
public func validateConfiguration() throws {
switch providerType {
case .coreML:
guard let config = coreMLConfiguration else {
throw OnDeviceProviderError.providerUnavailable(
.coreML,
reason: "No CoreML configuration provided"
)
}
try config.validate()
case .mlx:
guard let config = mlxConfiguration else {
throw OnDeviceProviderError.providerUnavailable(
.mlx,
reason: "No MLX configuration provided"
)
}
try config.validate()
}
}
/// Updates the pipeline status.
internal func setStatus(_ newStatus: Status) {
_statusLock.lock()
_status = newStatus
_statusLock.unlock()
}
/// Provides recommended generation options optimized for SQL generation
/// based on the pipeline's configuration.
///
/// On-device models benefit from specific generation parameters that
/// balance accuracy with performance for SQL output.
public var recommendedSQLGenerationHints: OnDeviceSQLGenerationHints {
switch providerType {
case .coreML:
let config = coreMLConfiguration ?? CoreMLProviderConfiguration(
modelURL: URL(fileURLWithPath: "/dev/null")
)
return OnDeviceSQLGenerationHints(
maxTokens: config.maxResponseTokens,
temperature: config.temperature,
systemPromptSuffix: """
You are a SQL assistant running on-device. Generate only valid SQLite SQL.
Be concise — output ONLY the SQL query with no explanation.
""",
useSampling: config.useSampling
)
case .mlx:
let config = mlxConfiguration ?? .llama3_2_3B()
return OnDeviceSQLGenerationHints(
maxTokens: config.maxResponseTokens,
temperature: config.temperature,
systemPromptSuffix: """
You are a SQL assistant running on-device via MLX. Generate only valid SQLite SQL.
Be concise — output ONLY the SQL query with no explanation.
""",
useSampling: true
)
}
}
}
/// Hints for optimizing SQL generation with on-device models.
///
/// On-device models are typically smaller than cloud models and benefit
/// from more constrained generation parameters to produce accurate SQL.
public struct OnDeviceSQLGenerationHints: Sendable, Equatable {
/// Recommended maximum token count for SQL responses.
public let maxTokens: Int
/// Recommended temperature for SQL generation.
public let temperature: Double
/// Additional system prompt text optimized for on-device SQL generation.
public let systemPromptSuffix: String
/// Whether to use sampling or greedy decoding.
public let useSampling: Bool
}
// MARK: - ProviderConfiguration Extension
extension ProviderConfiguration {
/// Creates a configuration for an on-device MLX model.
///
/// MLX models run entirely on Apple silicon using the MLX framework.
/// Models are automatically downloaded from HuggingFace Hub on first use.
///
/// ```swift
/// // Using a pre-configured model
/// let config = ProviderConfiguration.onDeviceMLX(.llama3_2_3B())
///
/// // Using a custom model
/// let config = ProviderConfiguration.onDeviceMLX(
/// MLXProviderConfiguration(
/// modelId: "mlx-community/Qwen2.5-7B-Instruct-4bit",
/// temperature: 0.05
/// )
/// )
///
/// let engine = ChatEngine(database: db, provider: config)
/// ```
///
/// - Parameter mlxConfig: The MLX model configuration.
/// - Returns: A configured `ProviderConfiguration` that wraps the MLX model.
///
/// - Note: The returned configuration uses `.openAICompatible` as the provider
/// type internally. The actual model is created via MLX APIs when `#if MLX` is
/// available. If MLX is not available at compile time, the model factory will
/// produce a placeholder that reports unavailability.
public static func onDeviceMLX(
_ mlxConfig: MLXProviderConfiguration
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .openAICompatible,
model: mlxConfig.modelId,
apiKeyProvider: { "" },
baseURL: nil,
apiVersion: nil,
betas: nil,
openAIVariant: nil
)
}
/// Creates a configuration for an on-device CoreML model.
///
/// CoreML models must be pre-compiled to `.mlmodelc` format.
/// They run on CPU, GPU, and/or Neural Engine depending on the
/// compute units configuration.
///
/// ```swift
/// let modelURL = Bundle.main.url(forResource: "SQLModel", withExtension: "mlmodelc")!
/// let config = ProviderConfiguration.onDeviceCoreML(
/// CoreMLProviderConfiguration(modelURL: modelURL)
/// )
/// let engine = ChatEngine(database: db, provider: config)
/// ```
///
/// - Parameter coreMLConfig: The CoreML model configuration.
/// - Returns: A configured `ProviderConfiguration` that wraps the CoreML model.
///
/// - Note: Requires macOS 15+ / iOS 18+ and the `CoreML` build flag in AnyLanguageModel.
public static func onDeviceCoreML(
_ coreMLConfig: CoreMLProviderConfiguration
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .openAICompatible,
model: coreMLConfig.modelURL.lastPathComponent,
apiKeyProvider: { "" },
baseURL: nil,
apiVersion: nil,
betas: nil,
openAIVariant: nil
)
}
}
// MARK: - ChatEngine On-Device Convenience
extension ChatEngine {
/// Creates a ChatEngine with an on-device MLX model.
///
/// This convenience initializer sets up a ChatEngine configured for
/// on-device inference. It validates the MLX configuration and creates
/// an inference pipeline.
///
/// ```swift
/// let engine = try ChatEngine.onDevice(
/// database: db,
/// mlx: .llama3_2_3B()
/// )
/// let response = try await engine.send("How many users are there?")
/// ```
///
/// - Parameters:
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
/// - mlx: The MLX model configuration.
/// - allowlist: SQL operations allowed. Defaults to read-only.
/// - configuration: Engine configuration.
/// - Returns: A configured `ChatEngine` instance.
/// - Throws: ``OnDeviceProviderError`` if the configuration is invalid.
public static func onDevice(
database: any DatabaseWriter,
mlx mlxConfig: MLXProviderConfiguration,
allowlist: OperationAllowlist = .readOnly,
configuration: ChatEngineConfiguration = .default
) throws -> ChatEngine {
// Validate configuration
try mlxConfig.validate()
let pipeline = OnDeviceInferencePipeline(mlxConfiguration: mlxConfig)
// Build a ChatEngineConfiguration that includes on-device hints
var engineConfig = configuration
let hints = pipeline.recommendedSQLGenerationHints
if engineConfig.additionalContext == nil {
engineConfig.additionalContext = hints.systemPromptSuffix
} else {
engineConfig.additionalContext! += "\n\n" + hints.systemPromptSuffix
}
let providerConfig = ProviderConfiguration.onDeviceMLX(mlxConfig)
return ChatEngine(
database: database,
provider: providerConfig,
allowlist: allowlist,
configuration: engineConfig
)
}
/// Creates a ChatEngine with an on-device CoreML model.
///
/// ```swift
/// let modelURL = Bundle.main.url(forResource: "SQLModel", withExtension: "mlmodelc")!
/// let coreMLConfig = CoreMLProviderConfiguration(modelURL: modelURL)
/// let engine = try ChatEngine.onDevice(
/// database: db,
/// coreML: coreMLConfig
/// )
/// ```
///
/// - Parameters:
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
/// - coreML: The CoreML model configuration.
/// - allowlist: SQL operations allowed. Defaults to read-only.
/// - configuration: Engine configuration.
/// - Returns: A configured `ChatEngine` instance.
/// - Throws: ``OnDeviceProviderError`` if the configuration is invalid.
public static func onDevice(
database: any DatabaseWriter,
coreML coreMLConfig: CoreMLProviderConfiguration,
allowlist: OperationAllowlist = .readOnly,
configuration: ChatEngineConfiguration = .default
) throws -> ChatEngine {
// Validate configuration
try coreMLConfig.validate()
let pipeline = OnDeviceInferencePipeline(coreMLConfiguration: coreMLConfig)
var engineConfig = configuration
let hints = pipeline.recommendedSQLGenerationHints
if engineConfig.additionalContext == nil {
engineConfig.additionalContext = hints.systemPromptSuffix
} else {
engineConfig.additionalContext! += "\n\n" + hints.systemPromptSuffix
}
let providerConfig = ProviderConfiguration.onDeviceCoreML(coreMLConfig)
return ChatEngine(
database: database,
provider: providerConfig,
allowlist: allowlist,
configuration: engineConfig
)
}
}
// MARK: - Model Readiness Checker
/// Utility for checking on-device model availability and system capability.
public enum OnDeviceModelReadiness {
/// System capability information for on-device inference.
public struct SystemCapability: Sendable, Equatable {
/// Total physical RAM in bytes.
public let totalRAM: UInt64
/// Whether the device has sufficient RAM for typical on-device models.
/// Generally requires at least 4GB for 3B parameter models.
public let hasSufficientRAM: Bool
/// Whether Apple Neural Engine is likely available.
/// True on devices with Apple silicon.
public let hasNeuralEngine: Bool
/// Recommended model size category based on available RAM.
public let recommendedModelSize: RecommendedModelSize
}
/// Recommended model size based on device capabilities.
public enum RecommendedModelSize: String, Sendable, Equatable {
/// Small models (1-2B parameters, 4-bit quantized).
/// Suitable for devices with 4GB RAM.
case small
/// Medium models (3-4B parameters, 4-bit quantized).
/// Suitable for devices with 6-8GB RAM.
case medium
/// Large models (7-8B parameters, 4-bit quantized).
/// Suitable for devices with 16GB+ RAM.
case large
}
/// Checks the current device's capability for on-device inference.
///
/// ```swift
/// let capability = OnDeviceModelReadiness.checkSystemCapability()
/// if capability.hasSufficientRAM {
/// print("Recommended size: \(capability.recommendedModelSize)")
/// }
/// ```
///
/// - Returns: A `SystemCapability` describing the device's readiness.
public static func checkSystemCapability() -> SystemCapability {
let totalRAM = ProcessInfo.processInfo.physicalMemory
let ramGB = totalRAM / (1024 * 1024 * 1024)
let recommendedSize: RecommendedModelSize
switch ramGB {
case ..<4:
recommendedSize = .small
case ..<8:
recommendedSize = .medium
default:
recommendedSize = .large
}
return SystemCapability(
totalRAM: totalRAM,
hasSufficientRAM: ramGB >= 4,
hasNeuralEngine: hasAppleSilicon(),
recommendedModelSize: recommendedSize
)
}
/// Suggests an MLX model configuration based on system capabilities.
///
/// ```swift
/// let config = OnDeviceModelReadiness.suggestedMLXModel()
/// let engine = try ChatEngine.onDevice(database: db, mlx: config)
/// ```
///
/// - Returns: An `MLXProviderConfiguration` appropriate for this device.
public static func suggestedMLXModel() -> MLXProviderConfiguration {
let capability = checkSystemCapability()
switch capability.recommendedModelSize {
case .small:
return .phi3_5_mini()
case .medium:
return .llama3_2_3B()
case .large:
return .qwen2_5_coder_3B()
}
}
/// Checks if the current device uses Apple silicon.
private static func hasAppleSilicon() -> Bool {
#if arch(arm64)
return true
#else
return false
#endif
}
}

View File

@@ -0,0 +1,54 @@
/// Defines which SQL operations the LLM is permitted to generate.
///
/// The default is ``readOnly`` (SELECT only). Write operations require
/// explicit opt-in. This is the safety-by-default principle.
public struct OperationAllowlist: Sendable, Equatable {
/// The set of permitted SQL operation types.
public let allowedOperations: Set<SQLOperation>
/// Creates an allowlist from the given set of operations.
public init(_ operations: Set<SQLOperation>) {
self.allowedOperations = operations
}
/// Read-only: only SELECT queries are permitted. This is the default.
public static let readOnly = OperationAllowlist([.select])
/// Standard read-write: SELECT, INSERT, and UPDATE are permitted.
public static let standard = OperationAllowlist([.select, .insert, .update])
/// Unrestricted: all operations including DELETE are permitted.
/// DELETE still requires confirmation via `ToolExecutionDelegate`.
public static let unrestricted = OperationAllowlist([.select, .insert, .update, .delete])
/// Returns true if the given operation is allowed.
public func isAllowed(_ operation: SQLOperation) -> Bool {
allowedOperations.contains(operation)
}
/// Returns a human-readable description of what's allowed, for inclusion
/// in the LLM system prompt.
func describeForLLM() -> String {
if allowedOperations == [.select] {
return "You may ONLY generate SELECT queries. No data modifications are allowed."
}
let sorted = allowedOperations.sorted { $0.rawValue < $1.rawValue }
let names = sorted.map { $0.rawValue.uppercased() }
var desc = "Allowed SQL operations: \(names.joined(separator: ", "))."
if allowedOperations.contains(.delete) {
desc += " DELETE operations are destructive and require user confirmation before execution."
}
return desc
}
}
/// The types of SQL operations that can be controlled via the allowlist.
public enum SQLOperation: String, Sendable, Hashable, CaseIterable {
case select
case insert
case update
case delete
}

View File

@@ -0,0 +1,609 @@
// ProviderConfiguration.swift
// SwiftDBAI
//
// Unified provider configuration for cloud-based LLM providers.
// Wraps AnyLanguageModel provider types with convenient factory methods.
import AnyLanguageModel
import Foundation
import GRDB
/// Configuration for connecting to a cloud-based LLM provider.
///
/// `ProviderConfiguration` provides a unified way to configure any supported
/// LLM provider (OpenAI, Anthropic, Gemini, or OpenAI-compatible services).
/// Each configuration produces a properly configured `LanguageModel` instance
/// that works with ``ChatEngine`` and ``TextSummaryRenderer``.
///
/// ## Quick Start
///
/// ```swift
/// // OpenAI
/// let config = ProviderConfiguration.openAI(apiKey: "sk-...", model: "gpt-4o")
///
/// // Anthropic
/// let config = ProviderConfiguration.anthropic(apiKey: "sk-ant-...", model: "claude-sonnet-4-20250514")
///
/// // Gemini
/// let config = ProviderConfiguration.gemini(apiKey: "AIza...", model: "gemini-2.0-flash")
///
/// // Use with ChatEngine
/// let engine = ChatEngine(database: db, model: config.makeModel())
/// ```
///
/// ## API Key Handling
///
/// API keys are stored as closures to support both static strings and
/// dynamic retrieval from keychains, environment variables, or secure storage:
///
/// ```swift
/// // Static key
/// let config = ProviderConfiguration.openAI(apiKey: "sk-...", model: "gpt-4o")
///
/// // Dynamic key from environment
/// let config = ProviderConfiguration.openAI(
/// apiKeyProvider: { ProcessInfo.processInfo.environment["OPENAI_API_KEY"] ?? "" },
/// model: "gpt-4o"
/// )
/// ```
public struct ProviderConfiguration: Sendable {
/// The supported LLM provider types.
public enum Provider: String, Sendable, Hashable, CaseIterable {
/// OpenAI's GPT models via the Chat Completions or Responses API.
case openAI
/// Anthropic's Claude models.
case anthropic
/// Google's Gemini models.
case gemini
/// Any OpenAI-compatible API (e.g., local servers, third-party providers).
case openAICompatible
/// Ollama local models via `ollama serve`.
/// Default endpoint: http://localhost:11434
case ollama
/// llama.cpp server local GGUF models via `llama-server`.
/// Default endpoint: http://localhost:8080
/// Uses the OpenAI-compatible API.
case llamaCpp
}
/// The provider type for this configuration.
public let provider: Provider
/// The model identifier (e.g., "gpt-4o", "claude-sonnet-4-20250514", "gemini-2.0-flash").
public let model: String
/// A closure that provides the API key on demand.
///
/// Using a closure allows lazy evaluation and integration with secure
/// storage systems (Keychain, environment variables, etc.).
private let apiKeyProvider: @Sendable () -> String
/// Optional custom base URL for OpenAI-compatible providers.
public let baseURL: URL?
/// Optional API version override (used by Anthropic and Gemini).
public let apiVersion: String?
/// Optional beta headers (used by Anthropic).
public let betas: [String]?
/// The OpenAI API variant to use (Chat Completions or Responses).
public let openAIVariant: OpenAILanguageModel.APIVariant?
// MARK: - Internal Init
/// Internal memberwise initializer used by factory methods.
internal init(
provider: Provider,
model: String,
apiKeyProvider: @escaping @Sendable () -> String,
baseURL: URL?,
apiVersion: String?,
betas: [String]?,
openAIVariant: OpenAILanguageModel.APIVariant?
) {
self.provider = provider
self.model = model
self.apiKeyProvider = apiKeyProvider
self.baseURL = baseURL
self.apiVersion = apiVersion
self.betas = betas
self.openAIVariant = openAIVariant
}
// MARK: - Factory Methods
/// Creates a configuration for OpenAI's API.
///
/// - Parameters:
/// - apiKey: Your OpenAI API key (e.g., "sk-...").
/// - model: The model identifier (e.g., "gpt-4o", "gpt-4o-mini").
/// - variant: The API variant to use. Defaults to `.chatCompletions`.
/// - baseURL: Optional custom base URL. Defaults to OpenAI's API.
/// - Returns: A configured `ProviderConfiguration`.
public static func openAI(
apiKey: String,
model: String,
variant: OpenAILanguageModel.APIVariant = .chatCompletions,
baseURL: URL? = nil
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .openAI,
model: model,
apiKeyProvider: { apiKey },
baseURL: baseURL,
apiVersion: nil,
betas: nil,
openAIVariant: variant
)
}
/// Creates a configuration for OpenAI's API with a dynamic key provider.
///
/// Use this when the API key comes from a keychain, environment variable,
/// or other dynamic source.
///
/// - Parameters:
/// - apiKeyProvider: A closure that returns the API key.
/// - model: The model identifier.
/// - variant: The API variant to use. Defaults to `.chatCompletions`.
/// - baseURL: Optional custom base URL.
/// - Returns: A configured `ProviderConfiguration`.
public static func openAI(
apiKeyProvider: @escaping @Sendable () -> String,
model: String,
variant: OpenAILanguageModel.APIVariant = .chatCompletions,
baseURL: URL? = nil
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .openAI,
model: model,
apiKeyProvider: apiKeyProvider,
baseURL: baseURL,
apiVersion: nil,
betas: nil,
openAIVariant: variant
)
}
/// Creates a configuration for Anthropic's Claude API.
///
/// - Parameters:
/// - apiKey: Your Anthropic API key (e.g., "sk-ant-...").
/// - model: The model identifier (e.g., "claude-sonnet-4-20250514").
/// - apiVersion: Optional API version override.
/// - betas: Optional beta feature headers.
/// - Returns: A configured `ProviderConfiguration`.
public static func anthropic(
apiKey: String,
model: String,
apiVersion: String? = nil,
betas: [String]? = nil
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .anthropic,
model: model,
apiKeyProvider: { apiKey },
baseURL: nil,
apiVersion: apiVersion,
betas: betas,
openAIVariant: nil
)
}
/// Creates a configuration for Anthropic's Claude API with a dynamic key provider.
///
/// - Parameters:
/// - apiKeyProvider: A closure that returns the API key.
/// - model: The model identifier.
/// - apiVersion: Optional API version override.
/// - betas: Optional beta feature headers.
/// - Returns: A configured `ProviderConfiguration`.
public static func anthropic(
apiKeyProvider: @escaping @Sendable () -> String,
model: String,
apiVersion: String? = nil,
betas: [String]? = nil
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .anthropic,
model: model,
apiKeyProvider: apiKeyProvider,
baseURL: nil,
apiVersion: apiVersion,
betas: betas,
openAIVariant: nil
)
}
/// Creates a configuration for Google's Gemini API.
///
/// - Parameters:
/// - apiKey: Your Gemini API key (e.g., "AIza...").
/// - model: The model identifier (e.g., "gemini-2.0-flash").
/// - apiVersion: Optional API version override (defaults to "v1beta").
/// - Returns: A configured `ProviderConfiguration`.
public static func gemini(
apiKey: String,
model: String,
apiVersion: String? = nil
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .gemini,
model: model,
apiKeyProvider: { apiKey },
baseURL: nil,
apiVersion: apiVersion,
betas: nil,
openAIVariant: nil
)
}
/// Creates a configuration for Google's Gemini API with a dynamic key provider.
///
/// - Parameters:
/// - apiKeyProvider: A closure that returns the API key.
/// - model: The model identifier.
/// - apiVersion: Optional API version override.
/// - Returns: A configured `ProviderConfiguration`.
public static func gemini(
apiKeyProvider: @escaping @Sendable () -> String,
model: String,
apiVersion: String? = nil
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .gemini,
model: model,
apiKeyProvider: apiKeyProvider,
baseURL: nil,
apiVersion: apiVersion,
betas: nil,
openAIVariant: nil
)
}
/// Creates a configuration for any OpenAI-compatible API.
///
/// Use this for third-party services that implement the OpenAI Chat Completions
/// API (e.g., local LLM servers, Groq, Together AI, etc.).
///
/// ```swift
/// let config = ProviderConfiguration.openAICompatible(
/// apiKey: "your-key",
/// model: "llama-3.1-70b",
/// baseURL: URL(string: "https://api.together.xyz/v1/")!
/// )
/// ```
///
/// - Parameters:
/// - apiKey: The API key for the service.
/// - model: The model identifier.
/// - baseURL: The base URL of the compatible API.
/// - variant: The API variant. Defaults to `.chatCompletions`.
/// - Returns: A configured `ProviderConfiguration`.
public static func openAICompatible(
apiKey: String,
model: String,
baseURL: URL,
variant: OpenAILanguageModel.APIVariant = .chatCompletions
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .openAICompatible,
model: model,
apiKeyProvider: { apiKey },
baseURL: baseURL,
apiVersion: nil,
betas: nil,
openAIVariant: variant
)
}
/// Creates a configuration for any OpenAI-compatible API with a dynamic key provider.
///
/// - Parameters:
/// - apiKeyProvider: A closure that returns the API key.
/// - model: The model identifier.
/// - baseURL: The base URL of the compatible API.
/// - variant: The API variant. Defaults to `.chatCompletions`.
/// - Returns: A configured `ProviderConfiguration`.
public static func openAICompatible(
apiKeyProvider: @escaping @Sendable () -> String,
model: String,
baseURL: URL,
variant: OpenAILanguageModel.APIVariant = .chatCompletions
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .openAICompatible,
model: model,
apiKeyProvider: apiKeyProvider,
baseURL: baseURL,
apiVersion: nil,
betas: nil,
openAIVariant: variant
)
}
// MARK: - Local Provider Factory Methods
/// Creates a configuration for a local Ollama instance.
///
/// Ollama runs models locally and exposes a native API on port 11434.
/// No API key is required by default.
///
/// ```swift
/// // Default local Ollama
/// let config = ProviderConfiguration.ollama(model: "llama3.2")
///
/// // Ollama on a remote machine
/// let config = ProviderConfiguration.ollama(
/// model: "qwen2.5",
/// baseURL: URL(string: "http://192.168.1.100:11434")!
/// )
///
/// // Use with ChatEngine
/// let engine = ChatEngine(database: db, provider: config)
/// ```
///
/// - Parameters:
/// - model: The Ollama model name (e.g., "llama3.2", "qwen2.5", "mistral").
/// - baseURL: The Ollama server URL. Defaults to `http://localhost:11434`.
/// - Returns: A configured `ProviderConfiguration`.
public static func ollama(
model: String,
baseURL: URL = OllamaLanguageModel.defaultBaseURL
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .ollama,
model: model,
apiKeyProvider: { "" },
baseURL: baseURL,
apiVersion: nil,
betas: nil,
openAIVariant: nil
)
}
/// Creates a configuration for a local llama.cpp server.
///
/// llama.cpp's `llama-server` exposes an OpenAI-compatible Chat Completions
/// API, typically on port 8080. No API key is required by default.
///
/// ```swift
/// // Default local llama.cpp
/// let config = ProviderConfiguration.llamaCpp(model: "default")
///
/// // llama.cpp on a custom port with API key
/// let config = ProviderConfiguration.llamaCpp(
/// model: "my-model",
/// baseURL: URL(string: "http://localhost:9090")!,
/// apiKey: "my-secret-key"
/// )
/// ```
///
/// - Parameters:
/// - model: The model identifier. Use "default" if llama-server loads a single model.
/// - baseURL: The llama.cpp server URL. Defaults to `http://localhost:8080`.
/// - apiKey: Optional API key if the server requires authentication.
/// - Returns: A configured `ProviderConfiguration`.
public static func llamaCpp(
model: String = "default",
baseURL: URL = LocalProviderDiscovery.defaultLlamaCppURL,
apiKey: String = ""
) -> ProviderConfiguration {
ProviderConfiguration(
provider: .llamaCpp,
model: model,
apiKeyProvider: { apiKey },
baseURL: baseURL,
apiVersion: nil,
betas: nil,
openAIVariant: .chatCompletions
)
}
// MARK: - Model Construction
/// Creates a configured `LanguageModel` instance for this provider.
///
/// This is the primary way to get a model from a configuration.
/// The returned model is ready to use with ``ChatEngine`` or
/// ``TextSummaryRenderer``.
///
/// ```swift
/// let config = ProviderConfiguration.openAI(apiKey: "sk-...", model: "gpt-4o")
/// let engine = ChatEngine(database: db, model: config.makeModel())
/// ```
///
/// - Returns: A configured `LanguageModel` instance.
public func makeModel() -> any LanguageModel {
let key = apiKeyProvider
switch provider {
case .openAI:
if let baseURL {
return OpenAILanguageModel(
baseURL: baseURL,
apiKey: key(),
model: model,
apiVariant: openAIVariant ?? .chatCompletions
)
}
return OpenAILanguageModel(
apiKey: key(),
model: model,
apiVariant: openAIVariant ?? .chatCompletions
)
case .anthropic:
if let apiVersion {
return AnthropicLanguageModel(
apiKey: key(),
apiVersion: apiVersion,
betas: betas,
model: model
)
}
if let betas {
return AnthropicLanguageModel(
apiKey: key(),
betas: betas,
model: model
)
}
return AnthropicLanguageModel(
apiKey: key(),
model: model
)
case .gemini:
if let apiVersion {
return GeminiLanguageModel(
apiKey: key(),
apiVersion: apiVersion,
model: model
)
}
return GeminiLanguageModel(
apiKey: key(),
model: model
)
case .openAICompatible:
return OpenAILanguageModel(
baseURL: baseURL ?? OpenAILanguageModel.defaultBaseURL,
apiKey: key(),
model: model,
apiVariant: openAIVariant ?? .chatCompletions
)
case .ollama:
return OllamaLanguageModel(
baseURL: baseURL ?? OllamaLanguageModel.defaultBaseURL,
model: model
)
case .llamaCpp:
// llama.cpp exposes an OpenAI-compatible API
return OpenAILanguageModel(
baseURL: baseURL ?? LocalProviderDiscovery.defaultLlamaCppURL,
apiKey: key(),
model: model,
apiVariant: openAIVariant ?? .chatCompletions
)
}
}
// MARK: - API Key Access
/// Returns the current API key.
///
/// Useful for validation or debugging. In production, prefer using
/// ``makeModel()`` which handles key injection automatically.
public var apiKey: String {
apiKeyProvider()
}
/// Returns `true` if the API key is non-empty.
///
/// Use this to check configuration validity before creating an engine:
/// ```swift
/// guard config.hasValidAPIKey else {
/// // Show API key setup UI
/// return
/// }
/// ```
public var hasValidAPIKey: Bool {
!apiKeyProvider().trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
}
// MARK: - Environment Variable Helpers
/// Creates a configuration using an API key from an environment variable.
///
/// Falls back to an empty string if the environment variable is not set,
/// which will cause API calls to fail with an authentication error.
///
/// ```swift
/// let config = ProviderConfiguration.fromEnvironment(
/// provider: .openAI,
/// environmentVariable: "OPENAI_API_KEY",
/// model: "gpt-4o"
/// )
/// ```
///
/// - Parameters:
/// - provider: The LLM provider.
/// - environmentVariable: The name of the environment variable holding the API key.
/// - model: The model identifier.
/// - Returns: A configured `ProviderConfiguration`.
public static func fromEnvironment(
provider: Provider,
environmentVariable: String,
model: String
) -> ProviderConfiguration {
let keyProvider: @Sendable () -> String = {
ProcessInfo.processInfo.environment[environmentVariable] ?? ""
}
switch provider {
case .openAI:
return .openAI(apiKeyProvider: keyProvider, model: model)
case .anthropic:
return .anthropic(apiKeyProvider: keyProvider, model: model)
case .gemini:
return .gemini(apiKeyProvider: keyProvider, model: model)
case .openAICompatible:
return .openAICompatible(
apiKeyProvider: keyProvider,
model: model,
baseURL: OpenAILanguageModel.defaultBaseURL
)
case .ollama:
return .ollama(model: model)
case .llamaCpp:
return .llamaCpp(model: model)
}
}
}
// MARK: - ChatEngine Convenience Init
extension ChatEngine {
/// Creates a ChatEngine using a ``ProviderConfiguration``.
///
/// This is the most convenient way to set up a ChatEngine with a
/// cloud provider:
///
/// ```swift
/// let engine = ChatEngine(
/// database: myDB,
/// provider: .openAI(apiKey: "sk-...", model: "gpt-4o")
/// )
/// ```
///
/// - Parameters:
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
/// - provider: The provider configuration.
/// - allowlist: SQL operations the LLM may generate. Defaults to read-only.
/// - configuration: Engine configuration for timeouts, context window, validators, etc.
public convenience init(
database: any DatabaseWriter,
provider: ProviderConfiguration,
allowlist: OperationAllowlist = .readOnly,
configuration: ChatEngineConfiguration = .default
) {
self.init(
database: database,
model: provider.makeModel(),
allowlist: allowlist,
configuration: configuration
)
}
}

View File

@@ -0,0 +1,114 @@
// QueryValidator.swift
// SwiftDBAI
//
// Extensible query validation protocol for custom pre-execution checks.
import Foundation
/// A protocol for custom SQL query validation.
///
/// Implement this protocol to add domain-specific validation rules that run
/// after the built-in allowlist and safety checks. Validators receive the
/// parsed SQL string and its detected operation type.
///
/// Example restrict queries to specific tables:
/// ```swift
/// struct TableAllowlistValidator: QueryValidator {
/// let allowedTables: Set<String>
///
/// func validate(sql: String, operation: SQLOperation) throws {
/// let upper = sql.uppercased()
/// for table in allowedTables {
/// // Simple check real implementation might parse FROM/JOIN clauses
/// if upper.contains(table.uppercased()) { return }
/// }
/// throw QueryValidationError.rejected("Query references tables outside the allowlist.")
/// }
/// }
/// ```
public protocol QueryValidator: Sendable {
/// Validates a SQL query before execution.
///
/// - Parameters:
/// - sql: The cleaned SQL statement about to be executed.
/// - operation: The detected operation type (SELECT, INSERT, etc.).
/// - Throws: ``QueryValidationError`` or any `Error` to reject the query.
func validate(sql: String, operation: SQLOperation) throws
}
/// Errors thrown by custom ``QueryValidator`` implementations.
public enum QueryValidationError: Error, LocalizedError, Sendable, Equatable {
/// The query was rejected by a custom validator with the given reason.
case rejected(String)
public var errorDescription: String? {
switch self {
case .rejected(let reason):
return "Query rejected: \(reason)"
}
}
}
// MARK: - Built-in Validators
/// A validator that restricts queries to a specific set of table names.
///
/// This performs a simple keyword check it verifies that the SQL references
/// at least one of the allowed tables. This is a best-effort check, not a
/// full SQL parser.
public struct TableAllowlistValidator: QueryValidator {
/// The set of table names queries are allowed to reference.
public let allowedTables: Set<String>
/// Creates a validator with the given allowed table names.
public init(allowedTables: Set<String>) {
self.allowedTables = allowedTables
}
public func validate(sql: String, operation: SQLOperation) throws {
let upper = sql.uppercased()
let found = allowedTables.contains { table in
let pattern = table.uppercased()
return upper.contains(pattern)
}
guard found else {
throw QueryValidationError.rejected(
"Query does not reference any allowed tables: \(allowedTables.sorted().joined(separator: ", "))"
)
}
}
}
/// A validator that enforces a maximum row limit on SELECT queries
/// by checking for a LIMIT clause.
public struct MaxRowLimitValidator: QueryValidator {
/// The maximum number of rows allowed.
public let maxRows: Int
/// Creates a validator that requires SELECT queries to include a LIMIT clause
/// not exceeding `maxRows`.
public init(maxRows: Int) {
self.maxRows = maxRows
}
public func validate(sql: String, operation: SQLOperation) throws {
guard operation == .select else { return }
let upper = sql.uppercased()
// Check if LIMIT is present
guard let limitRange = upper.range(of: #"LIMIT\s+(\d+)"#, options: .regularExpression) else {
throw QueryValidationError.rejected(
"SELECT queries must include a LIMIT clause (max \(maxRows) rows)."
)
}
// Extract the limit value
let limitSubstring = upper[limitRange]
let digits = limitSubstring.components(separatedBy: .decimalDigits.inverted).joined()
if let value = Int(digits), value > maxRows {
throw QueryValidationError.rejected(
"LIMIT \(value) exceeds the maximum allowed (\(maxRows))."
)
}
}
}

View File

@@ -0,0 +1,677 @@
// ChatEngine.swift
// SwiftDBAI
//
// Orchestrates the conversation loop: user message SQL generation query
// execution result summarization response.
import AnyLanguageModel
import Foundation
import GRDB
/// A message in the chat conversation.
public struct ChatMessage: Sendable, Identifiable, Equatable {
public let id: UUID
public let role: Role
public let content: String
public let queryResult: QueryResult?
public let sql: String?
public let timestamp: Date
/// The typed error, if this is an error message.
public let error: SwiftDBAIError?
public enum Role: String, Sendable, Equatable {
case user
case assistant
case error
}
public init(
id: UUID = UUID(),
role: Role,
content: String,
queryResult: QueryResult? = nil,
sql: String? = nil,
timestamp: Date = Date(),
error: SwiftDBAIError? = nil
) {
self.id = id
self.role = role
self.content = content
self.queryResult = queryResult
self.sql = sql
self.timestamp = timestamp
self.error = error
}
}
/// The response returned by `ChatEngine.send(_:)`.
public struct ChatResponse: Sendable {
/// The natural language summary of the result.
public let summary: String
/// The SQL that was generated and executed, if any.
public let sql: String?
/// The raw query result, if a query was executed.
public let queryResult: QueryResult?
}
/// Headless engine that orchestrates the full chat-with-database pipeline.
///
/// The engine:
/// 1. Introspects the database schema (once, lazily)
/// 2. Builds a system prompt with schema context
/// 3. Sends the user's question to the LLM to generate SQL
/// 4. Validates the SQL against the operation allowlist
/// 5. Executes the SQL via GRDB
/// 6. Summarizes results using `TextSummaryRenderer`
/// 7. Returns the summary (and raw data) to the caller
///
/// Usage:
/// ```swift
/// let engine = ChatEngine(
/// database: myDatabasePool,
/// model: myLanguageModel
/// )
/// let response = try await engine.send("How many users signed up this week?")
/// print(response.summary) // "There were 42 new signups this week."
/// ```
public final class ChatEngine: @unchecked Sendable {
// MARK: - Dependencies
private let database: any DatabaseWriter
private let model: any LanguageModel
private let allowlist: OperationAllowlist
private let mutationPolicy: MutationPolicy?
private let configuration: ChatEngineConfiguration
private let summaryRenderer: TextSummaryRenderer
private let sqlParser: SQLQueryParser
/// Optional delegate for intercepting destructive operations and observing SQL execution.
private let delegate: (any ToolExecutionDelegate)?
// MARK: - State
private var schema: DatabaseSchema?
private var conversationHistory: [ChatMessage] = []
private let lock = NSLock()
// MARK: - Initialization
/// Creates a new ChatEngine with a full configuration object.
///
/// - Parameters:
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
/// - model: Any `AnyLanguageModel`-compatible language model.
/// - allowlist: SQL operations the LLM may generate. Defaults to read-only (SELECT only).
/// - configuration: Engine configuration for timeouts, context window, validators, etc.
/// - delegate: Optional delegate for confirming destructive operations and observing SQL execution.
public init(
database: any DatabaseWriter,
model: any LanguageModel,
allowlist: OperationAllowlist = .readOnly,
configuration: ChatEngineConfiguration = .default,
delegate: (any ToolExecutionDelegate)? = nil
) {
self.database = database
self.model = model
self.allowlist = allowlist
self.mutationPolicy = nil
self.configuration = configuration
self.delegate = delegate
self.summaryRenderer = TextSummaryRenderer(
model: model,
maxRowsInPrompt: configuration.maxSummaryRows
)
self.sqlParser = SQLQueryParser(allowlist: allowlist)
}
/// Creates a new ChatEngine with a `MutationPolicy` for table-level control.
///
/// This initializer provides fine-grained control over which mutations are
/// allowed on which tables. The policy's operation allowlist is used for
/// SQL validation, and table-level restrictions are enforced during parsing.
///
/// - Parameters:
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
/// - model: Any `AnyLanguageModel`-compatible language model.
/// - mutationPolicy: Controls which operations are allowed on which tables.
/// - configuration: Engine configuration for timeouts, context window, validators, etc.
/// - delegate: Optional delegate for confirming destructive operations and observing SQL execution.
public init(
database: any DatabaseWriter,
model: any LanguageModel,
mutationPolicy: MutationPolicy,
configuration: ChatEngineConfiguration = .default,
delegate: (any ToolExecutionDelegate)? = nil
) {
self.database = database
self.model = model
self.allowlist = mutationPolicy.operationAllowlist
self.mutationPolicy = mutationPolicy
self.configuration = configuration
self.delegate = delegate
self.summaryRenderer = TextSummaryRenderer(
model: model,
maxRowsInPrompt: configuration.maxSummaryRows
)
self.sqlParser = SQLQueryParser(mutationPolicy: mutationPolicy)
}
/// Creates a new ChatEngine with individual parameters (convenience).
///
/// - Parameters:
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
/// - model: Any `AnyLanguageModel`-compatible language model.
/// - allowlist: SQL operations the LLM may generate. Defaults to read-only (SELECT only).
/// - additionalContext: Optional extra instructions for the LLM system prompt.
/// - maxSummaryRows: Maximum rows to include when summarizing results (default: 50).
public convenience init(
database: any DatabaseWriter,
model: any LanguageModel,
allowlist: OperationAllowlist,
additionalContext: String?,
maxSummaryRows: Int = 50
) {
let config = ChatEngineConfiguration(
maxSummaryRows: maxSummaryRows,
additionalContext: additionalContext
)
self.init(
database: database,
model: model,
allowlist: allowlist,
configuration: config
)
}
// MARK: - Public API
/// Sends a natural language message and returns a summarized response.
///
/// This is the primary entry point. The engine will:
/// 1. Introspect the schema if not yet cached
/// 2. Ask the LLM to generate SQL
/// 3. Validate the SQL against the allowlist and custom validators
/// 4. Execute the SQL (with timeout if configured)
/// 5. Summarize the results using `TextSummaryRenderer`
///
/// All errors are caught and mapped to a distinct ``SwiftDBAIError`` case
/// so callers always receive a typed, user-friendly error with a localized
/// description suitable for display in a chat UI.
///
/// - Parameter message: The user's natural language question or command.
/// - Returns: A `ChatResponse` containing the summary, SQL, and raw result.
/// - Throws: ``SwiftDBAIError`` for every failure mode.
public func send(_ message: String) async throws -> ChatResponse {
// 1. Ensure schema is introspected
let schema: DatabaseSchema
do {
schema = try await ensureSchema()
} catch let error as SwiftDBAIError {
throw error
} catch {
throw SwiftDBAIError.schemaIntrospectionFailed(reason: error.localizedDescription)
}
// Check for empty schema
if schema.tableNames.isEmpty {
throw SwiftDBAIError.emptySchema
}
// 2. Build prompt and get raw LLM response
let promptBuilder = PromptBuilder(
schema: schema,
allowlist: allowlist,
additionalContext: configuration.additionalContext
)
let rawLLMResponse: String
do {
rawLLMResponse = try await generateRawResponse(
question: message,
promptBuilder: promptBuilder
)
} catch let error as SwiftDBAIError {
throw error
} catch {
throw SwiftDBAIError.llmFailure(reason: error.localizedDescription)
}
// 3. Parse and validate SQL through SQLQueryParser
let parsed: ParsedSQL
do {
parsed = try sqlParser.parse(rawLLMResponse)
} catch let error as SQLParsingError {
throw error.toSwiftDBAIError(rawResponse: rawLLMResponse)
} catch let error as SwiftDBAIError {
throw error
} catch {
throw SwiftDBAIError.invalidSQL(sql: rawLLMResponse, reason: error.localizedDescription)
}
// 4. Run custom validators
do {
try runCustomValidators(parsed: parsed)
} catch let error as QueryValidationError {
throw error
} catch let error as SwiftDBAIError {
throw error
} catch {
throw SwiftDBAIError.queryRejected(reason: error.localizedDescription)
}
// 5. Handle confirmation-required operations (DELETE, DROP, etc.)
if parsed.requiresConfirmation {
if let delegate = self.delegate {
// Build context for the delegate
let classification = classifySQL(parsed.sql)
let context = DestructiveOperationContext(
sql: parsed.sql,
statementKind: detectStatementKind(parsed.sql) ?? .delete,
classification: classification,
description: "Execute \(parsed.operation.rawValue.uppercased()) operation: \(parsed.sql)",
targetTable: extractTargetTableForDelegate(from: parsed.sql, operation: parsed.operation)
)
// Ask the delegate for approval
let approved = await delegate.confirmDestructiveOperation(context)
if !approved {
throw SwiftDBAIError.confirmationRequired(
sql: parsed.sql,
operation: parsed.operation.rawValue
)
}
// Delegate approved fall through to execution
} else {
// No delegate throw confirmation required so caller can handle it
throw SwiftDBAIError.confirmationRequired(
sql: parsed.sql,
operation: parsed.operation.rawValue
)
}
}
// 6. Execute the SQL (with timeout if configured)
let result: QueryResult
do {
let classification = classifySQL(parsed.sql)
await delegate?.willExecuteSQL(parsed.sql, classification: classification)
result = try await executeSQLWithTimeout(parsed.sql)
await delegate?.didExecuteSQL(parsed.sql, success: true)
} catch let error as SwiftDBAIError {
await delegate?.didExecuteSQL(parsed.sql, success: false)
throw error
} catch let error as ChatEngineError {
await delegate?.didExecuteSQL(parsed.sql, success: false)
// Map internal ChatEngineError (e.g. from timeout) to SwiftDBAIError
throw error.toSwiftDBAIError()
} catch {
await delegate?.didExecuteSQL(parsed.sql, success: false)
throw SwiftDBAIError.databaseError(reason: error.localizedDescription)
}
// 7. Summarize the result using TextSummaryRenderer
let summary: String
do {
summary = try await summaryRenderer.summarize(
result: result,
userQuestion: message
)
} catch let error as SwiftDBAIError {
throw error
} catch {
throw SwiftDBAIError.llmFailure(reason: "Summarization failed: \(error.localizedDescription)")
}
// 8. Record conversation history
let userMessage = ChatMessage(role: .user, content: message)
let assistantMessage = ChatMessage(
role: .assistant,
content: summary,
queryResult: result,
sql: parsed.sql
)
lock.withLock {
conversationHistory.append(userMessage)
conversationHistory.append(assistantMessage)
}
return ChatResponse(
summary: summary,
sql: parsed.sql,
queryResult: result
)
}
/// Sends a natural language message, executing a previously confirmed destructive operation.
///
/// Call this after receiving a `confirmationRequired` error and the user has confirmed.
///
/// - Parameters:
/// - message: The original user message (for history recording).
/// - confirmedSQL: The SQL that was confirmed by the user.
/// - Returns: A `ChatResponse` with the result.
public func sendConfirmed(_ message: String, confirmedSQL: String) async throws -> ChatResponse {
let result: QueryResult
do {
let classification = classifySQL(confirmedSQL)
await delegate?.willExecuteSQL(confirmedSQL, classification: classification)
result = try await executeSQLWithTimeout(confirmedSQL)
await delegate?.didExecuteSQL(confirmedSQL, success: true)
} catch let error as SwiftDBAIError {
await delegate?.didExecuteSQL(confirmedSQL, success: false)
throw error
} catch let error as ChatEngineError {
await delegate?.didExecuteSQL(confirmedSQL, success: false)
throw error.toSwiftDBAIError()
} catch {
await delegate?.didExecuteSQL(confirmedSQL, success: false)
throw SwiftDBAIError.databaseError(reason: error.localizedDescription)
}
let summary: String
do {
summary = try await summaryRenderer.summarize(
result: result,
userQuestion: message
)
} catch let error as SwiftDBAIError {
throw error
} catch {
throw SwiftDBAIError.llmFailure(reason: "Summarization failed: \(error.localizedDescription)")
}
let userMessage = ChatMessage(role: .user, content: message)
let assistantMessage = ChatMessage(
role: .assistant,
content: summary,
queryResult: result,
sql: confirmedSQL
)
lock.withLock {
conversationHistory.append(userMessage)
conversationHistory.append(assistantMessage)
}
return ChatResponse(
summary: summary,
sql: confirmedSQL,
queryResult: result
)
}
/// Returns the current conversation history.
public var messages: [ChatMessage] {
lock.lock()
defer { lock.unlock() }
return conversationHistory
}
/// Eagerly introspects the database schema so it's ready before the first query.
///
/// Call this at view-appear time to pre-warm the schema cache. If the schema
/// is already cached, this returns immediately. The returned `DatabaseSchema`
/// can be used to display table/column info in the UI.
///
/// - Returns: The introspected `DatabaseSchema`.
@discardableResult
public func prepareSchema() async throws -> DatabaseSchema {
try await ensureSchema()
}
/// The number of tables discovered during schema introspection.
/// Returns `nil` if the schema has not been introspected yet.
public var tableCount: Int? {
lock.withLock { schema?.tableNames.count }
}
/// The cached schema, if introspection has completed.
public var cachedSchema: DatabaseSchema? {
lock.withLock { schema }
}
/// Clears the conversation history and cached schema.
///
/// After calling this, the next `send(_:)` call will re-introspect the
/// schema. Use ``clearHistory()`` if you only want to reset the conversation
/// while keeping the cached schema.
public func reset() {
lock.withLock {
conversationHistory.removeAll()
schema = nil
}
}
/// Clears only the conversation history, keeping the cached schema.
///
/// This is useful when you want to start a fresh conversation thread
/// without re-introspecting the database. The schema cache remains valid
/// as long as the database structure hasn't changed.
public func clearHistory() {
lock.withLock {
conversationHistory.removeAll()
}
}
/// The current engine configuration.
public var currentConfiguration: ChatEngineConfiguration {
configuration
}
// MARK: - Internal Helpers (visible for testing)
/// Ensures the database schema is introspected and cached.
func ensureSchema() async throws -> DatabaseSchema {
if let cached = lock.withLock({ schema }) {
return cached
}
let introspected = try await SchemaIntrospector.introspect(database: database)
lock.withLock { schema = introspected }
return introspected
}
/// Asks the LLM to generate SQL from a natural language question.
/// Returns the raw LLM response text (before parsing).
///
/// Uses the configured ``ChatEngineConfiguration/contextWindowSize`` to limit
/// how many conversation messages are included as context for the LLM.
private func generateRawResponse(
question: String,
promptBuilder: PromptBuilder
) async throws -> String {
let instructions = promptBuilder.buildSystemInstructions()
// Build user prompt include full conversation history for follow-ups
// Respect context window: only use recent messages for context
let userPrompt: String
let historySlice = lock.withLock { () -> [ChatMessage] in
Array(contextWindowSlice())
}
if historySlice.isEmpty {
userPrompt = promptBuilder.buildUserPrompt(question)
} else {
userPrompt = promptBuilder.buildConversationPrompt(
question,
history: historySlice
)
}
let session = LanguageModelSession(
model: model,
instructions: instructions + "\n\nRespond with ONLY the SQL query. No explanations, no markdown, no code fences."
)
let response = try await session.respond(to: userPrompt)
return response.content.trimmingCharacters(in: .whitespacesAndNewlines)
}
/// Returns the conversation history slice within the configured context window.
/// Must be called within a `lock.withLock` closure.
private func contextWindowSlice() -> ArraySlice<ChatMessage> {
guard let windowSize = configuration.contextWindowSize else {
return conversationHistory[...]
}
let count = conversationHistory.count
let start = max(0, count - windowSize)
return conversationHistory[start...]
}
/// Runs all custom validators from the configuration against the parsed SQL.
private func runCustomValidators(parsed: ParsedSQL) throws {
for validator in configuration.validators {
try validator.validate(sql: parsed.sql, operation: parsed.operation)
}
}
/// Extracts the target table name from a SQL statement for delegate context.
private func extractTargetTableForDelegate(from sql: String, operation: SQLOperation) -> String? {
let pattern: String
switch operation {
case .insert:
pattern = #"INSERT\s+INTO\s+[`"\[]?(\w+)[`"\]]?"#
case .update:
pattern = #"UPDATE\s+[`"\[]?(\w+)[`"\]]?"#
case .delete:
pattern = #"DELETE\s+FROM\s+[`"\[]?(\w+)[`"\]]?"#
case .select:
return nil
}
guard let regex = try? NSRegularExpression(pattern: pattern, options: .caseInsensitive) else {
return nil
}
let range = NSRange(sql.startIndex..., in: sql)
guard let match = regex.firstMatch(in: sql, range: range),
match.numberOfRanges > 1,
let groupRange = Range(match.range(at: 1), in: sql) else {
return nil
}
return String(sql[groupRange])
}
/// Executes SQL with the configured timeout, if any.
private func executeSQLWithTimeout(_ sql: String) async throws -> QueryResult {
guard let timeout = configuration.queryTimeout else {
return try await executeSQL(sql)
}
return try await withThrowingTaskGroup(of: QueryResult.self) { group in
group.addTask {
try await self.executeSQL(sql)
}
group.addTask {
try await Task.sleep(for: .seconds(timeout))
throw ChatEngineError.queryTimedOut(seconds: timeout)
}
// Return whichever finishes first
let result = try await group.next()!
group.cancelAll()
return result
}
}
/// Executes SQL against the database and returns a `QueryResult`.
private func executeSQL(_ sql: String) async throws -> QueryResult {
let trimmed = sql.trimmingCharacters(in: .whitespacesAndNewlines).uppercased()
let isSelect = trimmed.hasPrefix("SELECT") || trimmed.hasPrefix("WITH")
let startTime = CFAbsoluteTimeGetCurrent()
if isSelect {
let result = try await database.read { db -> (columns: [String], rows: [[String: QueryResult.Value]]) in
let statement = try db.makeStatement(sql: sql)
let columnNames = statement.columnNames
var rows: [[String: QueryResult.Value]] = []
let cursor = try Row.fetchCursor(statement)
while let row = try cursor.next() {
var dict: [String: QueryResult.Value] = [:]
for col in columnNames {
dict[col] = Self.extractValue(row: row, column: col)
}
rows.append(dict)
}
return (columns: columnNames, rows: rows)
}
let elapsed = CFAbsoluteTimeGetCurrent() - startTime
return QueryResult(
columns: result.columns,
rows: result.rows,
sql: sql,
executionTime: elapsed
)
} else {
// Mutation query
let affected = try await database.write { db -> Int in
try db.execute(sql: sql)
return db.changesCount
}
let elapsed = CFAbsoluteTimeGetCurrent() - startTime
return QueryResult(
columns: [],
rows: [],
sql: sql,
executionTime: elapsed,
rowsAffected: affected
)
}
}
/// Extracts a `QueryResult.Value` from a GRDB `Row` for the given column.
private static func extractValue(row: Row, column: String) -> QueryResult.Value {
let dbValue: DatabaseValue = row[column]
switch dbValue.storage {
case .null:
return .null
case .int64(let i):
return .integer(i)
case .double(let d):
return .real(d)
case .string(let s):
return .text(s)
case .blob(let data):
return .blob(data)
}
}
}
// MARK: - Errors
/// Errors that can occur during ChatEngine operations.
public enum ChatEngineError: Error, LocalizedError, Sendable {
/// SQL parsing/extraction from LLM response failed.
case sqlParsingFailed(SQLParsingError)
/// A destructive operation requires user confirmation before execution.
case confirmationRequired(sql: String, operation: SQLOperation)
/// Schema introspection failed.
case schemaIntrospectionFailed(String)
/// The SQL query exceeded the configured timeout.
case queryTimedOut(seconds: TimeInterval)
/// A custom query validator rejected the query.
case validationFailed(String)
public var errorDescription: String? {
switch self {
case .sqlParsingFailed(let parsingError):
return "SQL parsing failed: \(parsingError.description)"
case .confirmationRequired(let sql, let op):
return "The \(op.rawValue.uppercased()) operation requires confirmation: \(sql)"
case .schemaIntrospectionFailed(let reason):
return "Failed to introspect database schema: \(reason)"
case .queryTimedOut(let seconds):
return "Query timed out after \(Int(seconds)) seconds."
case .validationFailed(let reason):
return "Query validation failed: \(reason)"
}
}
}

View File

@@ -0,0 +1,288 @@
// ToolExecutionDelegate.swift
// SwiftDBAI
//
// Delegate protocol for controlling SQL tool execution, including
// confirmation of destructive operations before they reach the database.
import Foundation
// MARK: - Destructive SQL Classification
/// Classifies SQL statements by their destructive potential.
///
/// A statement is considered **destructive** if it modifies or removes data
/// or schema objects. The classification drives the confirmation flow:
/// destructive statements require explicit user approval via
/// ``ToolExecutionDelegate/confirmDestructiveOperation(_:)``.
public enum DestructiveClassification: Sendable, Equatable {
/// The statement is read-only (e.g. SELECT). No confirmation needed.
case safe
/// The statement modifies existing data (INSERT, UPDATE).
case mutation(SQLStatementKind)
/// The statement deletes data or alters/drops schema objects.
/// These always require confirmation, even when the operation is allowed.
case destructive(SQLStatementKind)
/// Returns `true` when the statement requires user confirmation.
public var requiresConfirmation: Bool {
switch self {
case .safe:
return false
case .mutation:
return false
case .destructive:
return true
}
}
/// Returns `true` when the statement modifies data or schema in any way.
public var isMutating: Bool {
switch self {
case .safe:
return false
case .mutation, .destructive:
return true
}
}
}
/// The kind of SQL statement, used for classification and display.
public enum SQLStatementKind: String, Sendable, Hashable, CaseIterable {
case select = "SELECT"
case insert = "INSERT"
case update = "UPDATE"
case delete = "DELETE"
case drop = "DROP"
case alter = "ALTER"
case truncate = "TRUNCATE"
/// All kinds that are classified as destructive.
public static let destructiveKinds: Set<SQLStatementKind> = [
.delete, .drop, .alter, .truncate
]
/// All kinds that are classified as mutations (data-modifying but not destructive).
public static let mutationKinds: Set<SQLStatementKind> = [
.insert, .update
]
/// Whether this kind of statement is destructive.
public var isDestructive: Bool {
Self.destructiveKinds.contains(self)
}
/// Whether this kind of statement is a mutation (INSERT/UPDATE).
public var isMutation: Bool {
Self.mutationKinds.contains(self)
}
}
// MARK: - Classification Function
/// Classifies a SQL statement string by its destructive potential.
///
/// The classifier inspects the first keyword token of the statement
/// (case-insensitive) to determine the statement kind, then maps it
/// to a ``DestructiveClassification``.
///
/// - Parameter sql: The SQL statement to classify.
/// - Returns: The classification for the statement.
public func classifySQL(_ sql: String) -> DestructiveClassification {
guard let kind = detectStatementKind(sql) else {
return .safe
}
if kind.isDestructive {
return .destructive(kind)
} else if kind.isMutation {
return .mutation(kind)
} else {
return .safe
}
}
/// Detects the ``SQLStatementKind`` from the leading keyword of a SQL string.
///
/// - Parameter sql: The SQL statement to inspect.
/// - Returns: The detected kind, or `nil` if unrecognized.
public func detectStatementKind(_ sql: String) -> SQLStatementKind? {
let trimmed = sql.trimmingCharacters(in: .whitespacesAndNewlines).uppercased()
// Check each known statement kind against the first token
if trimmed.hasPrefix("SELECT") || trimmed.hasPrefix("WITH") {
return .select
} else if trimmed.hasPrefix("INSERT") {
return .insert
} else if trimmed.hasPrefix("UPDATE") {
return .update
} else if trimmed.hasPrefix("DELETE") {
return .delete
} else if trimmed.hasPrefix("DROP") {
return .drop
} else if trimmed.hasPrefix("ALTER") {
return .alter
} else if trimmed.hasPrefix("TRUNCATE") {
return .truncate
}
return nil
}
// MARK: - Destructive Operation Context
/// Context provided to the delegate when a destructive operation needs confirmation.
///
/// Contains all the information a UI or programmatic handler needs to
/// decide whether to allow the operation.
public struct DestructiveOperationContext: Sendable {
/// The SQL statement that would be executed.
public let sql: String
/// The detected kind of statement (DELETE, DROP, ALTER, TRUNCATE).
public let statementKind: SQLStatementKind
/// The classification result.
public let classification: DestructiveClassification
/// A human-readable description of what the operation will do.
public let description: String
/// The target table name, if detected.
public let targetTable: String?
public init(
sql: String,
statementKind: SQLStatementKind,
classification: DestructiveClassification,
description: String,
targetTable: String? = nil
) {
self.sql = sql
self.statementKind = statementKind
self.classification = classification
self.description = description
self.targetTable = targetTable
}
}
// MARK: - ToolExecutionDelegate Protocol
/// A delegate that controls execution of SQL operations, providing
/// confirmation gates for destructive statements.
///
/// Implement this protocol to intercept destructive SQL operations
/// (DELETE, DROP, ALTER, TRUNCATE) before they are executed. The
/// ``ChatEngine`` consults the delegate whenever it encounters a
/// statement classified as ``DestructiveClassification/destructive(_:)``.
///
/// ## Example
///
/// ```swift
/// struct MyDelegate: ToolExecutionDelegate {
/// func confirmDestructiveOperation(
/// _ context: DestructiveOperationContext
/// ) async -> Bool {
/// // Show a confirmation dialog to the user
/// return await showAlert(
/// "Confirm \(context.statementKind.rawValue)",
/// message: context.description
/// )
/// }
/// }
///
/// let engine = ChatEngine(
/// database: pool,
/// model: model,
/// delegate: MyDelegate()
/// )
/// ```
public protocol ToolExecutionDelegate: Sendable {
/// Called when a destructive SQL operation is about to be executed.
///
/// The delegate should present the operation details to the user and
/// return `true` to proceed or `false` to cancel.
///
/// - Parameter context: Details about the destructive operation.
/// - Returns: `true` to allow execution, `false` to reject it.
func confirmDestructiveOperation(
_ context: DestructiveOperationContext
) async -> Bool
/// Called before any SQL statement is executed.
///
/// This is an observation hook the engine does not wait for a
/// decision. Override to log, audit, or instrument queries.
///
/// - Parameters:
/// - sql: The SQL about to be executed.
/// - classification: The destructive classification of the statement.
func willExecuteSQL(
_ sql: String,
classification: DestructiveClassification
) async
/// Called after a SQL statement completes execution.
///
/// - Parameters:
/// - sql: The SQL that was executed.
/// - success: Whether execution succeeded.
func didExecuteSQL(
_ sql: String,
success: Bool
) async
}
// MARK: - Default Implementations
extension ToolExecutionDelegate {
/// Default: rejects all destructive operations.
public func confirmDestructiveOperation(
_ context: DestructiveOperationContext
) async -> Bool {
false
}
/// Default: no-op.
public func willExecuteSQL(
_ sql: String,
classification: DestructiveClassification
) async {}
/// Default: no-op.
public func didExecuteSQL(
_ sql: String,
success: Bool
) async {}
}
// MARK: - Built-in Delegates
/// A delegate that automatically approves all destructive operations.
///
/// Use this only in testing or trusted environments where confirmation
/// is not needed.
public struct AutoApproveDelegate: ToolExecutionDelegate {
public init() {}
public func confirmDestructiveOperation(
_ context: DestructiveOperationContext
) async -> Bool {
true
}
}
/// A delegate that always rejects destructive operations.
///
/// This is the safest option and matches the default behavior.
public struct RejectAllDelegate: ToolExecutionDelegate {
public init() {}
public func confirmDestructiveOperation(
_ context: DestructiveOperationContext
) async -> Bool {
false
}
}

View File

@@ -0,0 +1,143 @@
// ConversationHistory.swift
// SwiftDBAI
//
// Ordered chat message history with configurable context window.
import Foundation
/// Stores an ordered sequence of ``ChatMessage`` instances with a configurable
/// context window limit.
///
/// When the number of messages exceeds ``maxMessages``, the oldest messages are
/// trimmed to keep the history within budget. This prevents unbounded token
/// growth when building LLM prompts from conversation history.
///
/// Usage:
/// ```swift
/// var history = ConversationHistory(maxMessages: 20)
/// history.append(ChatMessage(role: .user, content: "How many users?"))
/// history.append(ChatMessage(role: .assistant, content: "42", sql: "SELECT COUNT(*) FROM users"))
/// print(history.promptText) // formatted for LLM context
/// ```
public struct ConversationHistory: Sendable {
/// The maximum number of messages to retain. `nil` means unlimited.
public let maxMessages: Int?
/// All messages in chronological order.
public private(set) var messages: [ChatMessage] = []
/// Creates a new conversation history.
///
/// - Parameter maxMessages: Maximum number of messages to keep in the
/// context window. Pass `nil` for unlimited history. Defaults to 50.
public init(maxMessages: Int? = 50) {
precondition(maxMessages == nil || maxMessages! > 0,
"maxMessages must be positive or nil")
self.maxMessages = maxMessages
}
/// The number of messages currently stored.
public var count: Int { messages.count }
/// Whether the history is empty.
public var isEmpty: Bool { messages.isEmpty }
// MARK: - Mutating Operations
/// Appends a message and trims the history if it exceeds the context window.
public mutating func append(_ message: ChatMessage) {
messages.append(message)
trimIfNeeded()
}
/// Appends multiple messages and trims once afterward.
public mutating func append(contentsOf newMessages: [ChatMessage]) {
messages.append(contentsOf: newMessages)
trimIfNeeded()
}
/// Removes all messages from the history.
public mutating func clear() {
messages.removeAll()
}
// MARK: - Context Window
/// Returns the most recent messages formatted for inclusion in an LLM prompt.
///
/// Each message is formatted as `[role] content`, with SQL and query results
/// included inline for assistant messages.
///
/// - Parameter limit: Optional override to further restrict the number of
/// messages returned. When `nil`, uses the full retained history.
/// - Returns: An array of prompt-formatted strings, one per message.
public func promptMessages(limit: Int? = nil) -> [String] {
let slice: ArraySlice<ChatMessage>
if let limit {
slice = messages.suffix(limit)
} else {
slice = messages[...]
}
return slice.map { message in
Self.formatForPrompt(message)
}
}
/// Returns the combined prompt text for all retained messages, separated by
/// double newlines.
public var promptText: String {
promptMessages().joined(separator: "\n\n")
}
// MARK: - Queries
/// Returns only user messages.
public var userMessages: [ChatMessage] {
messages.filter { $0.role == .user }
}
/// Returns only assistant messages.
public var assistantMessages: [ChatMessage] {
messages.filter { $0.role == .assistant }
}
/// Returns the last message, if any.
public var lastMessage: ChatMessage? {
messages.last
}
/// Returns the most recent user query text, if any.
public var lastUserQuery: String? {
messages.last(where: { $0.role == .user })?.content
}
/// Returns the most recent assistant message, if any.
public var lastAssistantMessage: ChatMessage? {
messages.last(where: { $0.role == .assistant })
}
// MARK: - Private
/// Formats a ``ChatMessage`` into a string suitable for LLM prompt context.
private static func formatForPrompt(_ message: ChatMessage) -> String {
var parts: [String] = ["[\(message.role.rawValue)] \(message.content)"]
if let sql = message.sql {
parts.append("SQL: \(sql)")
}
if let result = message.queryResult {
parts.append("Result:\n\(result.tabularDescription)")
}
return parts.joined(separator: "\n")
}
/// Trims the oldest messages to stay within the context window.
private mutating func trimIfNeeded() {
guard let max = maxMessages, messages.count > max else { return }
let overflow = messages.count - max
messages.removeFirst(overflow)
}
}

View File

@@ -0,0 +1,136 @@
// QueryResult.swift
// SwiftDBAI
//
// Structured result from SQL query execution.
import Foundation
/// Represents the result of executing a SQL query against the database.
///
/// Contains raw row data as dictionaries, column metadata, row count,
/// the original SQL string, and execution timing.
public struct QueryResult: Sendable, Equatable {
/// A single cell value from a query result.
///
/// Wraps SQLite's dynamic value types into a type-safe, Sendable enum.
public enum Value: Sendable, Equatable, CustomStringConvertible {
case text(String)
case integer(Int64)
case real(Double)
case blob(Data)
case null
public var description: String {
switch self {
case .text(let s): return s
case .integer(let i): return String(i)
case .real(let d):
if d == d.rounded() && abs(d) < 1e15 {
return String(format: "%.0f", d)
}
return String(d)
case .blob(let data): return "<\(data.count) bytes>"
case .null: return "NULL"
}
}
/// Returns the value as a `Double` if it is numeric, nil otherwise.
public var doubleValue: Double? {
switch self {
case .integer(let i): return Double(i)
case .real(let d): return d
case .text(let s): return Double(s)
default: return nil
}
}
/// Returns the value as a `String` (non-nil for all cases).
public var stringValue: String { description }
/// Returns `true` if this value is `.null`.
public var isNull: Bool {
if case .null = self { return true }
return false
}
}
/// Column names in the order they appear in the result set.
public let columns: [String]
/// Row data as an array of dictionaries mapping column name to value.
public let rows: [[String: Value]]
/// Total number of rows returned.
public var rowCount: Int { rows.count }
/// The SQL statement that was executed.
public let sql: String
/// Time taken to execute the query, in seconds.
public let executionTime: TimeInterval
/// Number of rows affected (for INSERT/UPDATE/DELETE). Nil for SELECT.
public let rowsAffected: Int?
public init(
columns: [String],
rows: [[String: Value]],
sql: String,
executionTime: TimeInterval,
rowsAffected: Int? = nil
) {
self.columns = columns
self.rows = rows
self.sql = sql
self.executionTime = executionTime
self.rowsAffected = rowsAffected
}
// MARK: - Convenience Accessors
/// Returns all values for a given column, in row order.
public func values(forColumn column: String) -> [Value] {
rows.compactMap { $0[column] }
}
/// Returns a compact tabular string representation of the results.
///
/// Useful for embedding query results into LLM prompts.
public var tabularDescription: String {
guard !rows.isEmpty else {
return "(empty result set)"
}
var lines: [String] = []
// Header
lines.append(columns.joined(separator: " | "))
lines.append(String(repeating: "-", count: lines[0].count))
// Rows (cap at 50 for prompt size)
let displayRows = rows.prefix(50)
for row in displayRows {
let vals = columns.map { col in
row[col]?.description ?? "NULL"
}
lines.append(vals.joined(separator: " | "))
}
if rows.count > 50 {
lines.append("... and \(rows.count - 50) more rows")
}
return lines.joined(separator: "\n")
}
/// Returns true if the result looks like a single aggregate value
/// (1 row, 1-3 columns, all numeric).
public var isAggregate: Bool {
guard rowCount == 1, columns.count <= 3 else { return false }
let firstRow = rows[0]
return columns.allSatisfy { col in
firstRow[col]?.doubleValue != nil
}
}
}

View File

@@ -0,0 +1,380 @@
// SQLQueryParser.swift
// SwiftDBAI
//
// Extracts and validates SQL statements from raw LLM response text.
import Foundation
/// Errors that can occur during SQL parsing and validation.
public enum SQLParsingError: Error, Sendable, Equatable, CustomStringConvertible {
/// No SQL statement could be found in the LLM response.
case noSQLFound
/// The SQL statement uses an operation not in the allowlist.
case operationNotAllowed(SQLOperation)
/// A destructive operation (DELETE) requires user confirmation.
case confirmationRequired(sql: String, operation: SQLOperation)
/// The mutation targets a table not in the allowed mutation tables.
case tableNotAllowed(table: String, operation: SQLOperation)
/// The SQL contains a disallowed keyword (e.g., DROP, ALTER, TRUNCATE).
case dangerousOperation(String)
/// Multiple SQL statements were found but only single-statement execution is supported.
case multipleStatements
public var description: String {
switch self {
case .noSQLFound:
return "No SQL statement found in the response."
case .operationNotAllowed(let op):
return "Operation '\(op.rawValue.uppercased())' is not allowed by the current configuration."
case .confirmationRequired(let sql, let op):
return "The \(op.rawValue.uppercased()) operation requires confirmation: \(sql)"
case .tableNotAllowed(let table, let op):
return "The \(op.rawValue.uppercased()) operation is not allowed on table '\(table)'."
case .dangerousOperation(let keyword):
return "Dangerous SQL operation '\(keyword)' is never allowed."
case .multipleStatements:
return "Only single SQL statements are supported."
}
}
}
/// Result of successfully parsing SQL from an LLM response.
public struct ParsedSQL: Sendable, Equatable {
/// The cleaned SQL statement ready for execution.
public let sql: String
/// The detected operation type.
public let operation: SQLOperation
/// Whether this operation requires user confirmation before execution.
public let requiresConfirmation: Bool
public init(sql: String, operation: SQLOperation, requiresConfirmation: Bool = false) {
self.sql = sql
self.operation = operation
self.requiresConfirmation = requiresConfirmation
}
}
/// Extracts SQL statements from raw LLM response text and validates them
/// against the configured ``OperationAllowlist``.
///
/// The parser handles common LLM output patterns:
/// - SQL in markdown code blocks (```sql ... ```)
/// - SQL in generic code blocks (``` ... ```)
/// - Raw SQL statements in plain text
/// - SQL prefixed with labels like "SQL:" or "Query:"
public struct SQLQueryParser: Sendable {
/// Keywords that are never allowed regardless of allowlist configuration.
private static let dangerousKeywords: Set<String> = [
"DROP", "ALTER", "TRUNCATE", "CREATE", "GRANT", "REVOKE",
"ATTACH", "DETACH", "PRAGMA", "VACUUM", "REINDEX"
]
/// The operation allowlist to validate against.
private let allowlist: OperationAllowlist
/// The mutation policy for table-level restrictions.
private let mutationPolicy: MutationPolicy?
/// Creates a parser with the given operation allowlist.
/// - Parameter allowlist: The set of permitted operations. Defaults to read-only.
public init(allowlist: OperationAllowlist = .readOnly) {
self.allowlist = allowlist
self.mutationPolicy = nil
}
/// Creates a parser with a mutation policy (preferred initializer).
/// - Parameter mutationPolicy: The mutation policy controlling operations and table access.
public init(mutationPolicy: MutationPolicy) {
self.allowlist = mutationPolicy.operationAllowlist
self.mutationPolicy = mutationPolicy
}
/// Extracts and validates a SQL statement from raw LLM response text.
///
/// - Parameter text: The raw text from the LLM response.
/// - Returns: A ``ParsedSQL`` containing the validated statement.
/// - Throws: ``SQLParsingError`` if extraction or validation fails.
public func parse(_ text: String) throws -> ParsedSQL {
let sql = try extractSQL(from: text)
return try validate(sql)
}
// MARK: - Extraction
/// Attempts to extract a SQL statement from the LLM response text.
/// Tries multiple strategies in order of confidence.
func extractSQL(from text: String) throws -> String {
// Strategy 1: SQL in markdown fenced code block with sql language tag
if let sql = extractFromSQLCodeBlock(text) {
return sql
}
// Strategy 2: SQL in generic fenced code block
if let sql = extractFromGenericCodeBlock(text) {
return sql
}
// Strategy 3: SQL after a label like "SQL:" or "Query:"
if let sql = extractFromLabel(text) {
return sql
}
// Strategy 4: Direct SQL detection in plain text
if let sql = extractDirectSQL(text) {
return sql
}
throw SQLParsingError.noSQLFound
}
/// Extracts SQL from a ```sql ... ``` code block.
private func extractFromSQLCodeBlock(_ text: String) -> String? {
let pattern = #"```sql\s*\n([\s\S]*?)```"#
return firstMatch(pattern: pattern, in: text, group: 1)?
.trimmingCharacters(in: .whitespacesAndNewlines)
.nonEmptyOrNil
}
/// Extracts SQL from a generic ``` ... ``` code block.
private func extractFromGenericCodeBlock(_ text: String) -> String? {
let pattern = #"```\s*\n([\s\S]*?)```"#
guard let content = firstMatch(pattern: pattern, in: text, group: 1)?
.trimmingCharacters(in: .whitespacesAndNewlines) else {
return nil
}
// Only accept if it looks like SQL
guard looksLikeSQL(content) else { return nil }
return content.nonEmptyOrNil
}
/// Extracts SQL after labels like "SQL:", "Query:", "Here's the query:"
private func extractFromLabel(_ text: String) -> String? {
// Match the SQL keyword up to end-of-line (handling multi-line SQL with indentation)
let pattern = #"(?:SQL|Query|Statement)\s*:\s*\n?\s*((?:SELECT|INSERT|UPDATE|DELETE|WITH)\b.+?)(?:\n(?!\s)|$)"#
guard let content = firstMatch(pattern: pattern, in: text, group: 1, options: [.caseInsensitive, .dotMatchesLineSeparators])?
.trimmingCharacters(in: .whitespacesAndNewlines) else {
return nil
}
guard looksLikeSQL(content) else { return nil }
return content.nonEmptyOrNil
}
/// Detects SQL directly in the text by matching known statement patterns.
private func extractDirectSQL(_ text: String) -> String? {
// Match SQL statement, allowing semicolons inside single-quoted string literals
let pattern = #"(?:^|\n)\s*((?:SELECT|INSERT|UPDATE|DELETE)\b(?:[^;']|'[^']*')*;?)"#
guard let content = firstMatch(pattern: pattern, in: text, group: 1, options: .caseInsensitive)?
.trimmingCharacters(in: .whitespacesAndNewlines) else {
return nil
}
return content.nonEmptyOrNil
}
// MARK: - Validation
/// Validates a SQL string against the allowlist and safety rules.
func validate(_ sql: String) throws -> ParsedSQL {
let cleaned = cleanSQL(sql)
guard !cleaned.isEmpty else {
throw SQLParsingError.noSQLFound
}
// Check for multiple statements (semicolons in non-trivial positions)
if containsMultipleStatements(cleaned) {
throw SQLParsingError.multipleStatements
}
// Check for dangerous operations first (before allowlist)
try checkDangerousKeywords(cleaned)
// Detect the operation type
let operation = detectOperation(cleaned)
// Check against the allowlist
guard allowlist.isAllowed(operation) else {
throw SQLParsingError.operationNotAllowed(operation)
}
// Check table-level restrictions for mutation operations
if let policy = mutationPolicy, operation != .select,
let targetTable = extractTargetTable(from: cleaned, operation: operation) {
guard policy.isAllowed(operation: operation, on: targetTable) else {
throw SQLParsingError.tableNotAllowed(table: targetTable, operation: operation)
}
}
// DELETE requires confirmation when policy says so, or always by default
let requiresConfirmation: Bool
if let policy = mutationPolicy {
requiresConfirmation = policy.requiresConfirmation(for: operation)
} else {
requiresConfirmation = operation == .delete
}
return ParsedSQL(
sql: cleaned,
operation: operation,
requiresConfirmation: requiresConfirmation
)
}
// MARK: - Helpers
/// Cleans a SQL string by removing trailing semicolons (outside string literals) and excess whitespace.
private func cleanSQL(_ sql: String) -> String {
var cleaned = sql.trimmingCharacters(in: .whitespacesAndNewlines)
// Remove trailing semicolons only if they're outside string literals
while cleaned.hasSuffix(";") && !isInsideStringLiteral(sql: cleaned, position: cleaned.index(before: cleaned.endIndex)) {
cleaned = String(cleaned.dropLast()).trimmingCharacters(in: .whitespacesAndNewlines)
}
// Collapse internal whitespace outside string literals
cleaned = collapseWhitespace(cleaned)
return cleaned
}
/// Collapses whitespace while preserving string literal contents.
private func collapseWhitespace(_ sql: String) -> String {
var result = ""
var inString = false
var prevWasSpace = false
for ch in sql {
if ch == "'" {
inString.toggle()
prevWasSpace = false
result.append(ch)
} else if inString {
result.append(ch)
} else if ch.isWhitespace {
if !prevWasSpace {
result.append(" ")
prevWasSpace = true
}
} else {
prevWasSpace = false
result.append(ch)
}
}
return result
}
/// Returns true if the character at the given position is inside a single-quoted string literal.
private func isInsideStringLiteral(sql: String, position: String.Index) -> Bool {
var inString = false
for idx in sql.indices {
if idx == position { return inString }
if sql[idx] == "'" { inString.toggle() }
}
return false
}
/// Checks whether cleaned SQL contains multiple statements.
private func containsMultipleStatements(_ sql: String) -> Bool {
// Remove string literals before checking for semicolons
var inString = false
for ch in sql {
if ch == "'" {
inString.toggle()
} else if ch == ";" && !inString {
return true
}
}
return false
}
/// Checks for dangerous SQL keywords that are never allowed.
private func checkDangerousKeywords(_ sql: String) throws {
let upper = sql.uppercased()
// Tokenize to avoid partial matches (e.g., "DROPDOWN" matching "DROP")
let tokens = upper.components(separatedBy: .alphanumerics.inverted)
.filter { !$0.isEmpty }
for keyword in Self.dangerousKeywords {
if tokens.contains(keyword) {
throw SQLParsingError.dangerousOperation(keyword)
}
}
}
/// Detects the SQL operation type from the first keyword.
private func detectOperation(_ sql: String) -> SQLOperation {
let upper = sql.uppercased().trimmingCharacters(in: .whitespaces)
if upper.hasPrefix("SELECT") || upper.hasPrefix("WITH") {
return .select
} else if upper.hasPrefix("INSERT") {
return .insert
} else if upper.hasPrefix("UPDATE") {
return .update
} else if upper.hasPrefix("DELETE") {
return .delete
}
// Default to select for unrecognized patterns (e.g. EXPLAIN)
return .select
}
/// Extracts the target table name from a mutation SQL statement.
///
/// Handles common patterns:
/// - `INSERT INTO table_name ...`
/// - `UPDATE table_name SET ...`
/// - `DELETE FROM table_name ...`
private func extractTargetTable(from sql: String, operation: SQLOperation) -> String? {
let pattern: String
switch operation {
case .insert:
pattern = #"INSERT\s+INTO\s+[`"\[]?(\w+)[`"\]]?"#
case .update:
pattern = #"UPDATE\s+[`"\[]?(\w+)[`"\]]?"#
case .delete:
pattern = #"DELETE\s+FROM\s+[`"\[]?(\w+)[`"\]]?"#
case .select:
return nil
}
return firstMatch(pattern: pattern, in: sql, group: 1, options: .caseInsensitive)
}
/// Returns true if the text looks like a SQL statement.
private func looksLikeSQL(_ text: String) -> Bool {
let upper = text.uppercased().trimmingCharacters(in: .whitespaces)
let sqlPrefixes = ["SELECT", "INSERT", "UPDATE", "DELETE", "WITH"]
return sqlPrefixes.contains { upper.hasPrefix($0) }
}
/// Extracts the first regex match group from the text.
private func firstMatch(
pattern: String,
in text: String,
group: Int,
options: NSRegularExpression.Options = []
) -> String? {
guard let regex = try? NSRegularExpression(pattern: pattern, options: options) else {
return nil
}
let range = NSRange(text.startIndex..., in: text)
guard let match = regex.firstMatch(in: text, range: range),
match.numberOfRanges > group,
let groupRange = Range(match.range(at: group), in: text) else {
return nil
}
return String(text[groupRange])
}
}
// MARK: - String Extension
private extension String {
/// Returns nil if the string is empty, otherwise returns self.
var nonEmptyOrNil: String? {
isEmpty ? nil : self
}
}

View File

@@ -0,0 +1,211 @@
/// Builds structured LLM prompts for SQL generation from a database schema
/// and natural language input.
///
/// `PromptBuilder` is the bridge between the introspected database schema and
/// the LLM. It produces two things:
/// 1. A **system instructions** string containing schema context and behavioral rules
/// 2. A **user prompt** string wrapping the natural language question
///
/// Usage:
/// ```swift
/// let builder = PromptBuilder(schema: mySchema, allowlist: .readOnly)
/// let instructions = builder.buildSystemInstructions()
/// let prompt = builder.buildUserPrompt("How many users signed up this week?")
/// ```
public struct PromptBuilder: Sendable {
/// The database schema to include as context.
public let schema: DatabaseSchema
/// Which SQL operations the LLM may generate.
public let allowlist: OperationAllowlist
/// Optional additional context to append to the system instructions
/// (e.g., business-specific terminology or query hints).
public let additionalContext: String?
/// Creates a prompt builder for the given schema and allowlist.
///
/// - Parameters:
/// - schema: The introspected database schema.
/// - allowlist: Permitted SQL operations. Defaults to ``OperationAllowlist/readOnly``.
/// - additionalContext: Extra instructions appended to the system prompt.
public init(
schema: DatabaseSchema,
allowlist: OperationAllowlist = .readOnly,
additionalContext: String? = nil
) {
self.schema = schema
self.allowlist = allowlist
self.additionalContext = additionalContext
}
// MARK: - System Instructions
/// Builds the system instructions string that should be passed as the
/// `instructions` parameter when creating a `LanguageModelSession`.
///
/// The instructions include:
/// - Role definition
/// - The full database schema
/// - SQL generation rules and constraints
/// - The operation allowlist
/// - Output format requirements
public func buildSystemInstructions() -> String {
var sections: [String] = []
// 1. Role
sections.append(Self.roleSection)
// 2. Schema
sections.append(buildSchemaSection())
// 3. Operation permissions
sections.append(buildPermissionsSection())
// 4. SQL generation rules
sections.append(Self.sqlRulesSection)
// 5. Output format
sections.append(Self.outputFormatSection)
// 6. Additional context
if let additionalContext, !additionalContext.isEmpty {
sections.append("ADDITIONAL CONTEXT\n=================\n\(additionalContext)")
}
return sections.joined(separator: "\n\n")
}
// MARK: - User Prompt
/// Wraps a natural language question into a user prompt string.
///
/// - Parameter question: The user's natural language question.
/// - Returns: A formatted prompt string for the LLM.
public func buildUserPrompt(_ question: String) -> String {
question
}
/// Builds a follow-up prompt that includes prior SQL context for
/// multi-turn conversations.
///
/// - Parameters:
/// - question: The user's follow-up question.
/// - previousSQL: The SQL from the previous turn, for context.
/// - previousResultSummary: A brief summary of what the previous query returned.
/// - Returns: A formatted prompt string.
public func buildFollowUpPrompt(
_ question: String,
previousSQL: String,
previousResultSummary: String
) -> String {
"""
Previous query: \(previousSQL)
Previous result: \(previousResultSummary)
Follow-up question: \(question)
"""
}
/// Builds a prompt that includes the full conversation history within the
/// configured context window, enabling the LLM to resolve follow-up
/// references (pronouns, implicit table/column references, etc.).
///
/// - Parameters:
/// - question: The user's current question.
/// - history: The conversation history messages within the context window.
/// - Returns: A formatted prompt string with conversation context.
public func buildConversationPrompt(
_ question: String,
history: [ChatMessage]
) -> String {
guard !history.isEmpty else {
return buildUserPrompt(question)
}
var lines: [String] = []
lines.append("CONVERSATION HISTORY")
lines.append("====================")
for message in history {
switch message.role {
case .user:
lines.append("User: \(message.content)")
case .assistant:
if let sql = message.sql {
lines.append("Assistant SQL: \(sql)")
}
lines.append("Assistant: \(message.content)")
case .error:
lines.append("Error: \(message.content)")
}
}
lines.append("")
lines.append("CURRENT QUESTION")
lines.append("================")
lines.append(question)
return lines.joined(separator: "\n")
}
// MARK: - Private Sections
private func buildSchemaSection() -> String {
var lines: [String] = []
lines.append("DATABASE SCHEMA")
lines.append("===============")
lines.append("")
lines.append(schema.schemaDescription)
return lines.joined(separator: "\n")
}
private func buildPermissionsSection() -> String {
var lines: [String] = []
lines.append("PERMISSIONS")
lines.append("===========")
lines.append(allowlist.describeForLLM())
return lines.joined(separator: "\n")
}
// MARK: - Static Content
static let roleSection = """
ROLE
====
You are a SQL assistant for a SQLite database. Your job is to translate \
natural language questions into valid SQLite SQL queries based on the \
database schema provided below. You must ONLY reference tables and columns \
that exist in the schema. Never fabricate table or column names.
"""
static let sqlRulesSection = """
SQL GENERATION RULES
====================
1. Use ONLY the tables and columns listed in the schema above.
2. Use SQLite-compatible syntax (e.g., || for string concatenation, \
IFNULL instead of COALESCE where needed).
3. Use appropriate JOINs when queries span multiple tables — reference \
the foreign key relationships in the schema.
4. For date/time operations, use SQLite date functions \
(date(), time(), datetime(), strftime()).
5. Use parameterized-style values where possible. For literal values \
from the user's question, embed them directly in the SQL.
6. Always include an ORDER BY clause when the user implies ordering.
7. Use LIMIT when the user asks for "top N" or "first N" results.
8. For aggregate queries (count, sum, average, min, max), use the \
appropriate SQL aggregate functions.
9. When the user's question is ambiguous, prefer the simplest valid \
interpretation.
10. Never generate DDL statements (CREATE, ALTER, DROP TABLE).
"""
static let outputFormatSection = """
OUTPUT FORMAT
=============
When generating SQL, call the appropriate tool with the SQL query. \
After receiving query results, provide a concise natural language \
summary of the data. Be specific with numbers and names from the results. \
If no rows are returned, say so clearly.
"""
}

View File

@@ -0,0 +1,423 @@
// ChartDataDetector.swift
// SwiftDBAI
//
// Analyzes query results to determine chart eligibility and
// recommends appropriate chart types based on data shape.
import Foundation
/// Detects whether a `DataTable` is suitable for charting and
/// recommends the best chart type based on data shape heuristics.
///
/// The detector examines column types, row counts, and value distributions
/// to produce a `ChartRecommendation` that the rendering layer can use
/// to auto-select an appropriate Swift Charts visualization.
///
/// Usage:
/// ```swift
/// let detector = ChartDataDetector()
/// if let recommendation = detector.detect(table) {
/// switch recommendation.chartType {
/// case .bar: // render bar chart
/// case .line: // render line chart
/// case .pie: // render pie chart
/// }
/// }
/// ```
public struct ChartDataDetector: Sendable {
// MARK: - Chart Types
/// The type of chart recommended for the data.
public enum ChartType: String, Sendable, Equatable, CaseIterable {
/// Vertical bar chart best for categorical comparisons.
case bar
/// Line chart best for time series or ordered sequences.
case line
/// Pie/donut chart best for proportional breakdowns with few categories.
case pie
}
/// A recommendation for how to chart a `DataTable`.
public struct ChartRecommendation: Sendable, Equatable {
/// The recommended chart type.
public let chartType: ChartType
/// The column to use for the category axis (x-axis / labels).
public let categoryColumn: String
/// The column to use for the value axis (y-axis / sizes).
public let valueColumn: String
/// Confidence score from 0.0 (guess) to 1.0 (strong match).
public let confidence: Double
/// Human-readable reason for this recommendation.
public let reason: String
public init(
chartType: ChartType,
categoryColumn: String,
valueColumn: String,
confidence: Double,
reason: String
) {
self.chartType = chartType
self.categoryColumn = categoryColumn
self.valueColumn = valueColumn
self.confidence = confidence
self.reason = reason
}
}
// MARK: - Configuration
/// Minimum rows required to consider chart-eligible.
public let minimumRows: Int
/// Maximum rows for a pie chart (too many slices becomes unreadable).
public let maxPieSlices: Int
/// Maximum rows for any chart before it becomes cluttered.
public let maximumRows: Int
// MARK: - Initialization
/// Creates a detector with configurable thresholds.
///
/// - Parameters:
/// - minimumRows: Minimum rows for chart eligibility (default: 2).
/// - maxPieSlices: Maximum categories for pie charts (default: 8).
/// - maximumRows: Maximum rows for any chart (default: 100).
public init(
minimumRows: Int = 2,
maxPieSlices: Int = 8,
maximumRows: Int = 100
) {
self.minimumRows = minimumRows
self.maxPieSlices = maxPieSlices
self.maximumRows = maximumRows
}
// MARK: - Detection
/// Analyzes a `DataTable` and returns a chart recommendation, or `nil`
/// if the data is not suitable for charting.
///
/// - Parameter table: The data table to analyze.
/// - Returns: A recommendation, or `nil` if no chart type fits.
public func detect(_ table: DataTable) -> ChartRecommendation? {
// Must have at least 2 columns (category + value) and enough rows
guard table.columnCount >= 2,
table.rowCount >= minimumRows,
table.rowCount <= maximumRows else {
return nil
}
// Find candidate category and value columns
guard let (categoryCol, valueCol) = findCategoryValuePair(in: table) else {
return nil
}
let chartType = recommendChartType(
table: table,
categoryColumn: categoryCol,
valueColumn: valueCol
)
let confidence = computeConfidence(
table: table,
categoryColumn: categoryCol,
valueColumn: valueCol,
chartType: chartType
)
let reason = describeReason(
chartType: chartType,
categoryColumn: categoryCol,
valueColumn: valueCol,
table: table
)
return ChartRecommendation(
chartType: chartType,
categoryColumn: categoryCol.name,
valueColumn: valueCol.name,
confidence: confidence,
reason: reason
)
}
/// Returns all viable chart recommendations, ranked by confidence.
///
/// - Parameter table: The data table to analyze.
/// - Returns: An array of recommendations sorted by confidence (highest first).
public func allRecommendations(for table: DataTable) -> [ChartRecommendation] {
guard table.columnCount >= 2,
table.rowCount >= minimumRows,
table.rowCount <= maximumRows else {
return []
}
guard let (categoryCol, valueCol) = findCategoryValuePair(in: table) else {
return []
}
return ChartType.allCases.compactMap { chartType in
guard isViable(chartType, table: table, categoryColumn: categoryCol) else {
return nil
}
let confidence = computeConfidence(
table: table,
categoryColumn: categoryCol,
valueColumn: valueCol,
chartType: chartType
)
let reason = describeReason(
chartType: chartType,
categoryColumn: categoryCol,
valueColumn: valueCol,
table: table
)
return ChartRecommendation(
chartType: chartType,
categoryColumn: categoryCol.name,
valueColumn: valueCol.name,
confidence: confidence,
reason: reason
)
}
.sorted { $0.confidence > $1.confidence }
}
// MARK: - Private Helpers
/// Finds the best (category, value) column pair from the table.
private func findCategoryValuePair(
in table: DataTable
) -> (category: DataTable.Column, value: DataTable.Column)? {
let numericColumns = table.columns.filter { isNumeric($0) }
let categoryColumns = table.columns.filter { isCategory($0) }
// Prefer: first text/category column + first numeric column
if let cat = categoryColumns.first, let val = numericColumns.first {
return (cat, val)
}
// Fallback: if all columns are numeric, use first as category, second as value
if numericColumns.count >= 2 {
return (numericColumns[0], numericColumns[1])
}
return nil
}
/// Recommends the single best chart type for the data shape.
private func recommendChartType(
table: DataTable,
categoryColumn: DataTable.Column,
valueColumn: DataTable.Column
) -> ChartType {
// Line: time series or sequential numeric categories (check first strongest signal)
if isTimeSeries(categoryColumn, in: table) || isSequential(categoryColumn, in: table) {
return .line
}
// Pie: small number of categories with all-positive values
// Only when clearly categorical (text labels) and few rows
if table.rowCount <= maxPieSlices,
isCategory(categoryColumn),
isPieCandidate(table: table, valueColumn: valueColumn),
looksProportional(table: table, valueColumn: valueColumn) {
return .pie
}
// Default: bar chart for categorical comparisons
return .bar
}
/// Checks if a chart type is viable for the given data.
private func isViable(
_ chartType: ChartType,
table: DataTable,
categoryColumn: DataTable.Column
) -> Bool {
switch chartType {
case .pie:
return table.rowCount <= maxPieSlices
case .line:
return table.rowCount >= minimumRows
case .bar:
return true
}
}
/// Determines if a column holds numeric data.
private func isNumeric(_ column: DataTable.Column) -> Bool {
switch column.inferredType {
case .integer, .real:
return true
default:
return false
}
}
/// Determines if a column holds categorical (label) data.
private func isCategory(_ column: DataTable.Column) -> Bool {
switch column.inferredType {
case .text, .mixed:
return true
default:
return false
}
}
/// Checks if the value column contains all non-negative values,
/// making it a candidate for pie charts.
private func isPieCandidate(
table: DataTable,
valueColumn: DataTable.Column
) -> Bool {
let values = table.numericValues(forColumn: valueColumn.name)
guard !values.isEmpty else { return false }
// All values must be positive for a meaningful pie chart
return values.allSatisfy { $0 > 0 }
}
/// Heuristic: do values look like they represent parts of a whole?
///
/// Checks for aggregate-like column names (count, total, sum, amount, pct, etc.)
/// or if values sum to a round number suggesting percentages/proportions.
private func looksProportional(
table: DataTable,
valueColumn: DataTable.Column
) -> Bool {
let proportionalNames: Set<String> = ["count", "total", "sum", "amount", "pct",
"percent", "percentage", "share", "proportion",
"quantity", "qty", "num", "number"]
// Split on common separators and check for exact word matches
let lowerName = valueColumn.name.lowercased()
let words = Set(lowerName.split { $0 == "_" || $0 == "-" || $0 == " " }.map(String.init))
if !words.isDisjoint(with: proportionalNames) {
return true
}
// Check if values sum to ~100 (percentages)
let values = table.numericValues(forColumn: valueColumn.name)
let sum = values.reduce(0, +)
if abs(sum - 100.0) < 1.0 {
return true
}
return false
}
/// Heuristic: does the category column look like time-series data?
///
/// Checks for date-like patterns (YYYY, YYYY-MM, YYYY-MM-DD)
/// or common time-related column names.
private func isTimeSeries(_ column: DataTable.Column, in table: DataTable) -> Bool {
let timeNames = ["date", "time", "timestamp", "year", "month", "day",
"week", "quarter", "period", "created_at", "updated_at"]
let lowerName = column.name.lowercased()
if timeNames.contains(where: { lowerName.contains($0) }) {
return true
}
// Check if text values look like dates
if column.inferredType == .text {
let values = table.stringValues(forColumn: column.name)
let datePattern = #/^\d{4}(-\d{2}){0,2}$/#
let matchCount = values.prefix(5).filter { (try? datePattern.wholeMatch(in: $0)) != nil }.count
if matchCount >= 3 {
return true
}
}
return false
}
/// Heuristic: does the category column contain sequential numeric values?
private func isSequential(_ column: DataTable.Column, in table: DataTable) -> Bool {
guard isNumeric(column) else { return false }
let values = table.numericValues(forColumn: column.name)
guard values.count >= 3 else { return false }
// Check if values are monotonically increasing
for i in 1..<values.count {
if values[i] <= values[i - 1] {
return false
}
}
return true
}
/// Computes a confidence score for a specific chart type + data combination.
private func computeConfidence(
table: DataTable,
categoryColumn: DataTable.Column,
valueColumn: DataTable.Column,
chartType: ChartType
) -> Double {
var score = 0.5 // baseline
// Bonus: clear category/value split (text + numeric)
if isCategory(categoryColumn) && isNumeric(valueColumn) {
score += 0.2
}
// Bonus: reasonable row count for the chart type
switch chartType {
case .bar:
if table.rowCount >= 2 && table.rowCount <= 20 {
score += 0.15
}
case .line:
if isTimeSeries(categoryColumn, in: table) {
score += 0.2
} else if isSequential(categoryColumn, in: table) {
score += 0.1
}
case .pie:
if table.rowCount <= maxPieSlices && isPieCandidate(table: table, valueColumn: valueColumn) {
score += 0.2
}
// Penalty: too many slices
if table.rowCount > 5 {
score -= 0.1
}
}
// Bonus: no null values in key columns
let categoryNulls = table.columnValues(named: categoryColumn.name).filter(\.isNull).count
let valueNulls = table.columnValues(named: valueColumn.name).filter(\.isNull).count
if categoryNulls == 0 && valueNulls == 0 {
score += 0.1
}
return min(max(score, 0.0), 1.0)
}
/// Generates a human-readable reason for the recommendation.
private func describeReason(
chartType: ChartType,
categoryColumn: DataTable.Column,
valueColumn: DataTable.Column,
table: DataTable
) -> String {
switch chartType {
case .bar:
return "\(table.rowCount) categories comparing \(valueColumn.name) by \(categoryColumn.name)"
case .line:
if isTimeSeries(categoryColumn, in: table) {
return "\(valueColumn.name) over time (\(categoryColumn.name))"
}
return "\(valueColumn.name) trend across \(table.rowCount) points"
case .pie:
return "Proportional breakdown of \(valueColumn.name) by \(categoryColumn.name)"
}
}
}

View File

@@ -0,0 +1,255 @@
// DataTable.swift
// SwiftDBAI
//
// Structured table representation for rendering query results
// in SwiftUI table views and charts.
import Foundation
/// A structured, row-column table built from a `QueryResult`.
///
/// `DataTable` provides indexed access to rows and columns, typed column
/// metadata, and convenience methods for extracting data suitable for
/// SwiftUI `Table` views and Swift Charts.
///
/// Usage:
/// ```swift
/// let table = DataTable(queryResult)
/// print(table.columnCount) // 3
/// print(table[row: 0, column: 1]) // .text("Alice")
/// ```
public struct DataTable: Sendable, Equatable {
// MARK: - Column Metadata
/// Metadata for a single column in the data table.
public struct Column: Sendable, Equatable, Identifiable {
/// Stable identifier for the column (same as `name`).
public var id: String { name }
/// Column name from the query result set.
public let name: String
/// Index of this column in the table (0-based).
public let index: Int
/// Inferred data type based on the values in this column.
public let inferredType: InferredType
public init(name: String, index: Int, inferredType: InferredType) {
self.name = name
self.index = index
self.inferredType = inferredType
}
}
/// The inferred data type for a column, determined by inspecting its values.
public enum InferredType: Sendable, Equatable {
/// All non-null values are integers.
case integer
/// All non-null values are numeric (mix of integer and real).
case real
/// All non-null values are text.
case text
/// Values contain blob data.
case blob
/// Column contains only null values or is empty.
case null
/// Values are a mix of incompatible types.
case mixed
}
// MARK: - Row Type
/// A single row in the data table, providing indexed and named access.
public struct Row: Sendable, Equatable, Identifiable {
/// Row index (0-based), used as stable identity.
public let id: Int
/// Values in column order.
public let values: [QueryResult.Value]
/// Column names for named access.
private let columnNames: [String]
public init(id: Int, values: [QueryResult.Value], columnNames: [String]) {
self.id = id
self.values = values
self.columnNames = columnNames
}
/// Access a value by column index.
public subscript(columnIndex: Int) -> QueryResult.Value {
values[columnIndex]
}
/// Access a value by column name. Returns `.null` if the column doesn't exist.
public subscript(columnName: String) -> QueryResult.Value {
guard let idx = columnNames.firstIndex(of: columnName) else {
return .null
}
return values[idx]
}
}
// MARK: - Properties
/// Column metadata in order.
public let columns: [Column]
/// All rows in order.
public let rows: [Row]
/// The SQL that produced this table.
public let sql: String
/// Execution time of the underlying query.
public let executionTime: TimeInterval
/// Number of columns.
public var columnCount: Int { columns.count }
/// Number of rows.
public var rowCount: Int { rows.count }
/// Whether the table has no rows.
public var isEmpty: Bool { rows.isEmpty }
/// Column names in order.
public var columnNames: [String] { columns.map(\.name) }
// MARK: - Initialization
/// Creates a `DataTable` from a `QueryResult`.
///
/// Converts the dictionary-based row representation into an indexed
/// array representation and infers column types from the data.
///
/// - Parameter queryResult: The raw query result to convert.
public init(_ queryResult: QueryResult) {
let colNames = queryResult.columns
// Build indexed rows
let indexedRows: [Row] = queryResult.rows.enumerated().map { idx, rowDict in
let values = colNames.map { col in
rowDict[col] ?? .null
}
return Row(id: idx, values: values, columnNames: colNames)
}
// Infer column types
let inferredColumns: [Column] = colNames.enumerated().map { colIdx, name in
let type = Self.inferType(
from: indexedRows.map { $0.values[colIdx] }
)
return Column(name: name, index: colIdx, inferredType: type)
}
self.columns = inferredColumns
self.rows = indexedRows
self.sql = queryResult.sql
self.executionTime = queryResult.executionTime
}
/// Creates a `DataTable` directly from components (useful for testing).
public init(
columns: [Column],
rows: [Row],
sql: String = "",
executionTime: TimeInterval = 0
) {
self.columns = columns
self.rows = rows
self.sql = sql
self.executionTime = executionTime
}
// MARK: - Subscript Access
/// Access a cell by row and column index.
public subscript(row rowIndex: Int, column columnIndex: Int) -> QueryResult.Value {
rows[rowIndex].values[columnIndex]
}
/// Access a cell by row index and column name.
public subscript(row rowIndex: Int, column columnName: String) -> QueryResult.Value {
rows[rowIndex][columnName]
}
// MARK: - Column Data Extraction
/// Returns all values for a column by index, in row order.
public func columnValues(at index: Int) -> [QueryResult.Value] {
rows.map { $0.values[index] }
}
/// Returns all values for a column by name, in row order.
public func columnValues(named name: String) -> [QueryResult.Value] {
guard let col = columns.first(where: { $0.name == name }) else {
return []
}
return columnValues(at: col.index)
}
/// Returns all non-null `Double` values for a column (useful for charting).
public func numericValues(forColumn name: String) -> [Double] {
columnValues(named: name).compactMap(\.doubleValue)
}
/// Returns all non-null `String` values for a column (useful for labels).
public func stringValues(forColumn name: String) -> [String] {
columnValues(named: name).compactMap { value in
if case .null = value { return nil }
return value.stringValue
}
}
// MARK: - Type Inference
/// Infers the predominant type from an array of values.
static func inferType(from values: [QueryResult.Value]) -> InferredType {
var hasInteger = false
var hasReal = false
var hasText = false
var hasBlob = false
var hasNonNull = false
for value in values {
switch value {
case .integer:
hasInteger = true
hasNonNull = true
case .real:
hasReal = true
hasNonNull = true
case .text:
hasText = true
hasNonNull = true
case .blob:
hasBlob = true
hasNonNull = true
case .null:
break
}
}
guard hasNonNull else { return .null }
// Count how many distinct types are present
let typeCount = [hasInteger, hasReal, hasText, hasBlob].filter { $0 }.count
if typeCount == 1 {
if hasInteger { return .integer }
if hasReal { return .real }
if hasText { return .text }
if hasBlob { return .blob }
}
// Integer + real treat as real (numeric promotion)
if typeCount == 2, hasInteger, hasReal {
return .real
}
return .mixed
}
}

View File

@@ -0,0 +1,301 @@
// TextSummaryRenderer.swift
// SwiftDBAI
//
// Converts raw SQL query results into natural language text summaries
// using the LLM via AnyLanguageModel.
import AnyLanguageModel
import Foundation
/// Renders SQL query results as natural language text summaries.
///
/// The renderer takes a `QueryResult` and the user's original question,
/// sends them to the LLM for summarization, and returns a concise,
/// human-readable response.
///
/// Usage:
/// ```swift
/// let renderer = TextSummaryRenderer(model: myModel)
/// let summary = try await renderer.summarize(
/// result: queryResult,
/// userQuestion: "How many orders were placed last month?"
/// )
/// print(summary) // "There were 42 orders placed last month."
/// ```
public struct TextSummaryRenderer: Sendable {
/// The language model used to generate summaries.
private let model: any LanguageModel
/// Maximum number of rows to include in the LLM prompt.
///
/// Results larger than this are truncated with a note about total count.
public let maxRowsInPrompt: Int
/// Creates a new text summary renderer.
///
/// - Parameters:
/// - model: Any `AnyLanguageModel`-compatible language model.
/// - maxRowsInPrompt: Maximum rows to send to the LLM for summarization (default: 50).
public init(model: any LanguageModel, maxRowsInPrompt: Int = 50) {
self.model = model
self.maxRowsInPrompt = maxRowsInPrompt
}
/// Generates a natural language summary of query results.
///
/// - Parameters:
/// - result: The raw `QueryResult` from SQL execution.
/// - userQuestion: The original natural language question from the user.
/// - context: Optional additional context (e.g., table descriptions) to help the LLM.
/// - Returns: A natural language text summary of the results.
public func summarize(
result: QueryResult,
userQuestion: String,
context: String? = nil
) async throws -> String {
// For mutation results (INSERT/UPDATE/DELETE), use a simple template
if let affected = result.rowsAffected {
return summarizeMutation(result: result, affected: affected)
}
// For empty results, no need to call the LLM
if result.rows.isEmpty {
return "No results found for your query."
}
// For simple aggregates, produce a direct answer without LLM
if let directAnswer = tryDirectAggregateSummary(result: result, userQuestion: userQuestion) {
return directAnswer
}
// Build the prompt and ask the LLM to summarize
let prompt = buildSummarizationPrompt(
result: result,
userQuestion: userQuestion,
context: context
)
let session = LanguageModelSession(
model: model,
instructions: summaryInstructions
)
let response = try await session.respond(to: prompt)
return response.content.trimmingCharacters(in: .whitespacesAndNewlines)
}
/// Generates a summary without calling the LLM, using simple templates.
///
/// Useful when LLM access is unavailable, or for fast local rendering.
///
/// - Parameters:
/// - result: The raw `QueryResult` from SQL execution.
/// - userQuestion: The original natural language question.
/// - Returns: A template-based text summary.
public func localSummary(result: QueryResult, userQuestion: String) -> String {
if let affected = result.rowsAffected {
return summarizeMutation(result: result, affected: affected)
}
if result.rows.isEmpty {
return "No results found for your query."
}
if let directAnswer = tryDirectAggregateSummary(result: result, userQuestion: userQuestion) {
return directAnswer
}
return buildTemplateSummary(result: result)
}
// MARK: - Private Helpers
/// System instructions for the summarization session.
private var summaryInstructions: String {
"""
You are a data assistant that summarizes SQL query results in natural language.
Rules:
- Be concise and direct. Answer the user's question first, then add detail if helpful.
- Use natural language, not SQL or code.
- For numeric results, include the exact numbers.
- For lists of records, summarize the count and highlight notable items.
- If the data contains dates, format them in a readable way.
- Do not mention SQL, databases, tables, columns, or queries in your response.
- Do not include markdown formatting.
- Keep your response under 3 sentences for simple results, under 5 for complex ones.
"""
}
/// Builds the prompt sent to the LLM for summarization.
private func buildSummarizationPrompt(
result: QueryResult,
userQuestion: String,
context: String?
) -> String {
var parts: [String] = []
parts.append("User's question: \(userQuestion)")
if let context {
parts.append("Context: \(context)")
}
parts.append("Query returned \(result.rowCount) row(s) with columns: \(result.columns.joined(separator: ", "))")
// Include the result data (truncated if large)
let dataStr = formatResultData(result)
parts.append("Data:\n\(dataStr)")
parts.append("Summarize these results in natural language, directly answering the user's question.")
return parts.joined(separator: "\n\n")
}
/// Formats the query result data as a compact table for the LLM prompt.
private func formatResultData(_ result: QueryResult) -> String {
let rowsToInclude = Array(result.rows.prefix(maxRowsInPrompt))
var lines: [String] = []
// Header
lines.append(result.columns.joined(separator: " | "))
// Rows
for row in rowsToInclude {
let values = result.columns.map { col in
row[col]?.description ?? "NULL"
}
lines.append(values.joined(separator: " | "))
}
if result.rowCount > maxRowsInPrompt {
lines.append("(\(result.rowCount - maxRowsInPrompt) additional rows not shown)")
}
return lines.joined(separator: "\n")
}
/// Produces a direct answer for simple aggregate queries (1 row, few columns).
private func tryDirectAggregateSummary(result: QueryResult, userQuestion: String) -> String? {
guard result.isAggregate else { return nil }
let row = result.rows[0]
// Single numeric column e.g., "COUNT(*)" "42"
if result.columns.count == 1 {
let col = result.columns[0]
guard let value = row[col] else { return nil }
let formatted = formatNumber(value)
return "The result is \(formatted)."
}
// Multiple aggregate columns e.g., COUNT, AVG, SUM
let parts = result.columns.compactMap { col -> String? in
guard let value = row[col] else { return nil }
let label = humanizeColumnName(col)
let formatted = formatNumber(value)
return "\(label): \(formatted)"
}
return parts.joined(separator: ", ") + "."
}
/// Formats a numeric Value for display.
private func formatNumber(_ value: QueryResult.Value) -> String {
switch value {
case .integer(let i):
return NumberFormatter.localizedString(from: NSNumber(value: i), number: .decimal)
case .real(let d):
if d == d.rounded() && abs(d) < 1e12 {
return NumberFormatter.localizedString(from: NSNumber(value: Int64(d)), number: .decimal)
}
let formatter = NumberFormatter()
formatter.numberStyle = .decimal
formatter.maximumFractionDigits = 2
return formatter.string(from: NSNumber(value: d)) ?? String(d)
default:
return value.description
}
}
/// Converts a column name like "total_count" or "AVG(price)" into a readable label.
private func humanizeColumnName(_ name: String) -> String {
// Handle SQL function names: "COUNT(*)" "count", "AVG(price)" "average price"
let functionPatterns: [(pattern: String, label: String)] = [
("COUNT", "count"),
("SUM", "total"),
("AVG", "average"),
("MIN", "minimum"),
("MAX", "maximum"),
]
let upper = name.uppercased()
for (pattern, label) in functionPatterns {
if upper.hasPrefix(pattern + "(") {
// Extract the inner column name
let start = name.index(name.startIndex, offsetBy: pattern.count + 1)
let end = name.index(before: name.endIndex)
if start < end {
let inner = String(name[start..<end])
if inner == "*" { return label }
return "\(label) \(humanizeColumnName(inner))"
}
return label
}
}
// snake_case space-separated
return name
.replacingOccurrences(of: "_", with: " ")
.lowercased()
}
/// Produces a template-based summary without calling the LLM.
private func buildTemplateSummary(result: QueryResult) -> String {
let count = result.rowCount
if count == 1 {
// Single record list field values
let row = result.rows[0]
let details = result.columns.prefix(5).compactMap { col -> String? in
guard let val = row[col], !val.isNull else { return nil }
return "\(humanizeColumnName(col)): \(val.description)"
}
return "Found 1 result. \(details.joined(separator: ", "))."
}
// Multiple records
var summary = "Found \(count) results"
// If there's a clear "name" or "title" column, list first few
let nameColumns = ["name", "title", "label", "description"]
if let nameCol = result.columns.first(where: { nameColumns.contains($0.lowercased()) }) {
let names = result.rows.prefix(3).compactMap { $0[nameCol]?.description }
if !names.isEmpty {
summary += " including \(names.joined(separator: ", "))"
if count > 3 { summary += ", and \(count - 3) more" }
}
}
return summary + "."
}
/// Summarizes a mutation (INSERT/UPDATE/DELETE) result.
private func summarizeMutation(result: QueryResult, affected: Int) -> String {
let sql = result.sql.trimmingCharacters(in: .whitespacesAndNewlines).uppercased()
let operation: String
if sql.hasPrefix("INSERT") {
operation = "inserted"
} else if sql.hasPrefix("UPDATE") {
operation = "updated"
} else if sql.hasPrefix("DELETE") {
operation = "deleted"
} else {
operation = "affected"
}
let noun = affected == 1 ? "row" : "rows"
return "Successfully \(operation) \(affected) \(noun)."
}
}

View File

@@ -0,0 +1,164 @@
// DatabaseSchema.swift
// SwiftDBAI
//
// Auto-introspected SQLite schema model types.
import Foundation
/// Complete schema representation of an SQLite database.
public struct DatabaseSchema: Sendable, Equatable {
/// All tables in the database, keyed by table name.
public let tables: [String: TableSchema]
/// Ordered table names (preserves discovery order).
public let tableNames: [String]
/// Returns a compact text description suitable for LLM system prompts.
public var schemaDescription: String {
var lines: [String] = []
for name in tableNames {
guard let table = tables[name] else { continue }
lines.append(table.descriptionForLLM)
}
return lines.joined(separator: "\n\n")
}
/// Returns a description suitable for LLM system prompts.
/// Alias for `schemaDescription` for API compatibility.
public func describeForLLM() -> String {
schemaDescription
}
public init(tables: [String: TableSchema], tableNames: [String]) {
self.tables = tables
self.tableNames = tableNames
}
}
/// Schema for a single SQLite table.
public struct TableSchema: Sendable, Equatable {
public let name: String
public let columns: [ColumnSchema]
public let primaryKey: [String]
public let foreignKeys: [ForeignKeySchema]
public let indexes: [IndexSchema]
/// Text description for embedding in LLM prompts.
public var descriptionForLLM: String {
var parts: [String] = []
let colDefs = columns.map { col in
var def = " \(col.name) \(col.type)"
if col.isPrimaryKey { def += " PRIMARY KEY" }
if col.isNotNull { def += " NOT NULL" }
if let defaultValue = col.defaultValue { def += " DEFAULT \(defaultValue)" }
return def
}
parts.append("TABLE \(name) (\n\(colDefs.joined(separator: ",\n"))\n)")
if !foreignKeys.isEmpty {
let fkDescs = foreignKeys.map {
" FOREIGN KEY (\($0.fromColumn)) REFERENCES \($0.toTable)(\($0.toColumn))"
}
parts.append("FOREIGN KEYS:\n\(fkDescs.joined(separator: "\n"))")
}
if !indexes.isEmpty {
let idxDescs = indexes.map {
" INDEX \($0.name) ON (\($0.columns.joined(separator: ", ")))\($0.isUnique ? " UNIQUE" : "")"
}
parts.append("INDEXES:\n\(idxDescs.joined(separator: "\n"))")
}
return parts.joined(separator: "\n")
}
public init(
name: String,
columns: [ColumnSchema],
primaryKey: [String],
foreignKeys: [ForeignKeySchema],
indexes: [IndexSchema]
) {
self.name = name
self.columns = columns
self.primaryKey = primaryKey
self.foreignKeys = foreignKeys
self.indexes = indexes
}
}
/// Schema for a single column.
public struct ColumnSchema: Sendable, Equatable {
/// Column position (0-based).
public let cid: Int
/// Column name.
public let name: String
/// Declared SQLite type (e.g. "TEXT", "INTEGER", "REAL", "BLOB").
public let type: String
/// Whether the column has a NOT NULL constraint.
public let isNotNull: Bool
/// Default value expression, if any.
public let defaultValue: String?
/// Whether this column is part of the primary key.
public let isPrimaryKey: Bool
public init(
cid: Int,
name: String,
type: String,
isNotNull: Bool,
defaultValue: String?,
isPrimaryKey: Bool
) {
self.cid = cid
self.name = name
self.type = type
self.isNotNull = isNotNull
self.defaultValue = defaultValue
self.isPrimaryKey = isPrimaryKey
}
}
/// Schema for a foreign key relationship.
public struct ForeignKeySchema: Sendable, Equatable {
/// Column in the source table.
public let fromColumn: String
/// Referenced table name.
public let toTable: String
/// Referenced column name.
public let toColumn: String
/// ON UPDATE action (e.g. "CASCADE", "NO ACTION").
public let onUpdate: String
/// ON DELETE action.
public let onDelete: String
public init(
fromColumn: String,
toTable: String,
toColumn: String,
onUpdate: String,
onDelete: String
) {
self.fromColumn = fromColumn
self.toTable = toTable
self.toColumn = toColumn
self.onUpdate = onUpdate
self.onDelete = onDelete
}
}
/// Schema for a database index.
public struct IndexSchema: Sendable, Equatable {
/// Index name.
public let name: String
/// Whether the index enforces uniqueness.
public let isUnique: Bool
/// Columns included in the index, in order.
public let columns: [String]
public init(name: String, isUnique: Bool, columns: [String]) {
self.name = name
self.isUnique = isUnique
self.columns = columns
}
}

View File

@@ -0,0 +1,153 @@
// SchemaIntrospector.swift
// SwiftDBAI
//
// Auto-introspects SQLite database schema using GRDB.
import GRDB
/// Introspects an SQLite database schema by querying sqlite_master and PRAGMA statements.
///
/// Usage:
/// ```swift
/// let dbPool = try DatabasePool(path: "path/to/db.sqlite")
/// let schema = try await SchemaIntrospector.introspect(database: dbPool)
/// print(schema.schemaDescription)
/// ```
public struct SchemaIntrospector: Sendable {
// MARK: - Public API
/// Introspects the full schema of the given database.
///
/// Discovers all user tables (excluding sqlite_ internal tables),
/// their columns, primary keys, foreign keys, and indexes.
///
/// - Parameter database: A GRDB `DatabaseReader` (DatabasePool or DatabaseQueue).
/// - Returns: A complete `DatabaseSchema` representation.
public static func introspect(database: any DatabaseReader) async throws -> DatabaseSchema {
try await database.read { db in
try introspect(db: db)
}
}
/// Synchronous introspection within an existing database access context.
///
/// - Parameter db: A GRDB `Database` instance from within a read/write block.
/// - Returns: A complete `DatabaseSchema` representation.
public static func introspect(db: Database) throws -> DatabaseSchema {
let tableNames = try fetchTableNames(db: db)
var tables: [String: TableSchema] = [:]
for tableName in tableNames {
let columns = try fetchColumns(db: db, table: tableName)
let primaryKey = try fetchPrimaryKey(db: db, table: tableName)
let foreignKeys = try fetchForeignKeys(db: db, table: tableName)
let indexes = try fetchIndexes(db: db, table: tableName)
// Mark columns that are part of the primary key
let pkSet = Set(primaryKey)
let annotatedColumns = columns.map { col in
ColumnSchema(
cid: col.cid,
name: col.name,
type: col.type,
isNotNull: col.isNotNull,
defaultValue: col.defaultValue,
isPrimaryKey: pkSet.contains(col.name)
)
}
tables[tableName] = TableSchema(
name: tableName,
columns: annotatedColumns,
primaryKey: primaryKey,
foreignKeys: foreignKeys,
indexes: indexes
)
}
return DatabaseSchema(tables: tables, tableNames: tableNames)
}
// MARK: - Private Helpers
/// Fetches all user table names from sqlite_master.
private static func fetchTableNames(db: Database) throws -> [String] {
let sql = """
SELECT name FROM sqlite_master
WHERE type = 'table'
AND name NOT LIKE 'sqlite_%'
ORDER BY name
"""
return try String.fetchAll(db, sql: sql)
}
/// Fetches column metadata for a table using PRAGMA table_info.
private static func fetchColumns(db: Database, table: String) throws -> [ColumnSchema] {
let sql = "PRAGMA table_info(\(table.quotedDatabaseIdentifier))"
let rows = try Row.fetchAll(db, sql: sql)
return rows.map { row in
ColumnSchema(
cid: row["cid"],
name: row["name"],
type: (row["type"] as String?) ?? "",
isNotNull: row["notnull"] == 1,
defaultValue: row["dflt_value"],
isPrimaryKey: row["pk"] != 0
)
}
}
/// Fetches primary key columns for a table.
private static func fetchPrimaryKey(db: Database, table: String) throws -> [String] {
let sql = "PRAGMA table_info(\(table.quotedDatabaseIdentifier))"
let rows = try Row.fetchAll(db, sql: sql)
return rows
.filter { ($0["pk"] as Int) > 0 }
.sorted { ($0["pk"] as Int) < ($1["pk"] as Int) }
.map { $0["name"] }
}
/// Fetches foreign key relationships for a table.
private static func fetchForeignKeys(db: Database, table: String) throws -> [ForeignKeySchema] {
let sql = "PRAGMA foreign_key_list(\(table.quotedDatabaseIdentifier))"
let rows = try Row.fetchAll(db, sql: sql)
return rows.map { row in
ForeignKeySchema(
fromColumn: row["from"],
toTable: row["table"],
toColumn: row["to"],
onUpdate: row["on_update"] ?? "NO ACTION",
onDelete: row["on_delete"] ?? "NO ACTION"
)
}
}
/// Fetches indexes and their columns for a table.
private static func fetchIndexes(db: Database, table: String) throws -> [IndexSchema] {
let indexListSQL = "PRAGMA index_list(\(table.quotedDatabaseIdentifier))"
let indexRows = try Row.fetchAll(db, sql: indexListSQL)
var indexes: [IndexSchema] = []
for indexRow in indexRows {
let indexName: String = indexRow["name"]
let isUnique: Bool = indexRow["unique"] == 1
// Skip auto-generated indexes for primary keys
if indexName.hasPrefix("sqlite_autoindex_") { continue }
let infoSQL = "PRAGMA index_info(\(indexName.quotedDatabaseIdentifier))"
let infoRows = try Row.fetchAll(db, sql: infoSQL)
let columns: [String] = infoRows
.sorted { ($0["seqno"] as Int) < ($1["seqno"] as Int) }
.map { $0["name"] }
indexes.append(IndexSchema(
name: indexName,
isUnique: isUnique,
columns: columns
))
}
return indexes
}
}

View File

@@ -0,0 +1,215 @@
// SwiftDBAIError.swift
// SwiftDBAI
//
// Unified error type for the SwiftDBAI package.
import Foundation
/// The top-level error type for SwiftDBAI operations.
///
/// `SwiftDBAIError` provides a single, typed error surface that covers
/// every failure mode a consumer of SwiftDBAI may encounter from invalid
/// SQL and LLM failures to schema mismatches and safety violations.
///
/// Every case includes a user-friendly `localizedDescription` suitable for
/// displaying directly in a chat interface.
public enum SwiftDBAIError: Error, LocalizedError, Sendable, Equatable {
// MARK: - SQL Errors
/// No SQL statement could be extracted from the LLM response.
case noSQLGenerated
/// The generated SQL is syntactically invalid or failed execution.
case invalidSQL(sql: String, reason: String)
/// The SQL uses an operation (e.g. DELETE) not in the developer's allowlist.
case operationNotAllowed(operation: String)
/// Multiple SQL statements were generated but only single-statement execution is supported.
case multipleStatementsNotSupported
/// A dangerous SQL keyword (DROP, ALTER, TRUNCATE) was detected.
case dangerousOperationBlocked(keyword: String)
// MARK: - LLM Errors
/// The LLM failed to produce a response.
case llmFailure(reason: String)
/// The LLM response could not be parsed into an actionable result.
case llmResponseUnparseable(response: String)
/// The LLM request timed out.
case llmTimeout(seconds: TimeInterval)
// MARK: - Schema Errors
/// Schema introspection of the database failed.
case schemaIntrospectionFailed(reason: String)
/// The generated SQL references a table that does not exist in the schema.
case tableNotFound(tableName: String)
/// The generated SQL references a column that does not exist on the given table.
case columnNotFound(columnName: String, tableName: String)
/// The database schema is empty (no user tables found).
case emptySchema
// MARK: - Safety & Validation Errors
/// A destructive operation requires explicit user confirmation before execution.
case confirmationRequired(sql: String, operation: String)
/// A mutation targets a table not in the allowed mutation tables.
case tableNotAllowedForMutation(tableName: String, operation: String)
/// A custom query validator rejected the query.
case queryRejected(reason: String)
// MARK: - Database Errors
/// The underlying database operation failed.
case databaseError(reason: String)
/// The query exceeded the configured execution timeout.
case queryTimedOut(seconds: TimeInterval)
// MARK: - Configuration Errors
/// The engine has not been configured correctly.
case configurationError(reason: String)
// MARK: - Error Classification
/// Whether this error represents a safety/permissions issue (not a bug).
public var isSafetyError: Bool {
switch self {
case .operationNotAllowed, .dangerousOperationBlocked,
.confirmationRequired, .tableNotAllowedForMutation, .queryRejected:
return true
default:
return false
}
}
/// Whether this error is recoverable by rephrasing the user's question.
public var isRecoverable: Bool {
switch self {
case .noSQLGenerated, .llmResponseUnparseable, .invalidSQL,
.tableNotFound, .columnNotFound:
return true
default:
return false
}
}
/// Whether this error requires user action (e.g. confirmation).
public var requiresUserAction: Bool {
if case .confirmationRequired = self { return true }
return false
}
// MARK: - LocalizedError
public var errorDescription: String? {
switch self {
// SQL
case .noSQLGenerated:
return "I couldn't generate a SQL query from your request. Could you rephrase your question?"
case .invalidSQL(let sql, let reason):
return "The generated query is invalid — \(reason). Query: \(sql)"
case .operationNotAllowed(let operation):
return "The \(operation.uppercased()) operation is not allowed by the current configuration."
case .multipleStatementsNotSupported:
return "Only single SQL statements are supported. Please ask one question at a time."
case .dangerousOperationBlocked(let keyword):
return "The \(keyword.uppercased()) operation is blocked for safety. This operation is never allowed."
// LLM
case .llmFailure(let reason):
return "The language model encountered an error: \(reason)"
case .llmResponseUnparseable(let response):
return "I received a response but couldn't understand it. Raw response: \(response.prefix(200))"
case .llmTimeout(let seconds):
return "The language model did not respond within \(Int(seconds)) seconds. Please try again."
// Schema
case .schemaIntrospectionFailed(let reason):
return "Failed to read the database schema: \(reason)"
case .tableNotFound(let tableName):
return "The table '\(tableName)' does not exist in this database."
case .columnNotFound(let columnName, let tableName):
return "The column '\(columnName)' does not exist on table '\(tableName)'."
case .emptySchema:
return "This database has no tables. There's nothing to query yet."
// Safety
case .confirmationRequired(let sql, let operation):
return "The \(operation.uppercased()) operation requires your confirmation before running: \(sql)"
case .tableNotAllowedForMutation(let tableName, let operation):
return "The \(operation.uppercased()) operation is not allowed on table '\(tableName)'."
case .queryRejected(let reason):
return "Query rejected: \(reason)"
// Database
case .databaseError(let reason):
return "A database error occurred: \(reason)"
case .queryTimedOut(let seconds):
return "The query timed out after \(Int(seconds)) seconds. Try a simpler query."
// Configuration
case .configurationError(let reason):
return "Configuration error: \(reason)"
}
}
}
// MARK: - Conversion from SQLParsingError
extension SQLParsingError {
/// Maps a ``SQLParsingError`` to the corresponding ``SwiftDBAIError`` case.
///
/// - Parameter rawResponse: The raw LLM response text (used for context in `.noSQLFound`).
/// - Returns: A ``SwiftDBAIError`` with the same semantic meaning.
func toSwiftDBAIError(rawResponse: String = "") -> SwiftDBAIError {
switch self {
case .noSQLFound:
if rawResponse.isEmpty {
return .noSQLGenerated
}
return .llmResponseUnparseable(response: rawResponse)
case .operationNotAllowed(let op):
return .operationNotAllowed(operation: op.rawValue)
case .confirmationRequired(let sql, let op):
return .confirmationRequired(sql: sql, operation: op.rawValue)
case .tableNotAllowed(let table, let op):
return .tableNotAllowedForMutation(tableName: table, operation: op.rawValue)
case .dangerousOperation(let keyword):
return .dangerousOperationBlocked(keyword: keyword)
case .multipleStatements:
return .multipleStatementsNotSupported
}
}
}
// MARK: - Conversion from ChatEngineError
extension ChatEngineError {
/// Maps a ``ChatEngineError`` to the corresponding ``SwiftDBAIError`` case.
func toSwiftDBAIError() -> SwiftDBAIError {
switch self {
case .sqlParsingFailed(let parsingError):
return parsingError.toSwiftDBAIError()
case .confirmationRequired(let sql, let operation):
return .confirmationRequired(sql: sql, operation: operation.rawValue)
case .schemaIntrospectionFailed(let reason):
return .schemaIntrospectionFailed(reason: reason)
case .queryTimedOut(let seconds):
return .queryTimedOut(seconds: seconds)
case .validationFailed(let reason):
return .queryRejected(reason: reason)
}
}
}

View File

@@ -0,0 +1,182 @@
// BarChartView.swift
// SwiftDBAI
//
// A SwiftUI bar chart that renders DataTable values using Swift Charts.
// Best for categorical comparisons (e.g., sales by region, counts by status).
import SwiftUI
import Charts
/// A bar chart view that renders a `DataTable` column pair using Swift Charts.
///
/// Displays vertical bars with category labels on the x-axis and numeric
/// values on the y-axis. Automatically colors bars using the accent gradient
/// and supports scrolling when many categories are present.
///
/// Usage:
/// ```swift
/// BarChartView(
/// dataTable: table,
/// categoryColumn: "department",
/// valueColumn: "total_sales"
/// )
/// ```
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
public struct BarChartView: View {
/// The data to chart.
public let dataTable: DataTable
/// Column name for category labels (x-axis).
public let categoryColumn: String
/// Column name for numeric values (y-axis).
public let valueColumn: String
/// Optional chart title.
public var title: String?
/// Maximum number of bars to display before truncating.
public var maxBars: Int
public init(
dataTable: DataTable,
categoryColumn: String,
valueColumn: String,
title: String? = nil,
maxBars: Int = 30
) {
self.dataTable = dataTable
self.categoryColumn = categoryColumn
self.valueColumn = valueColumn
self.title = title
self.maxBars = maxBars
}
public var body: some View {
VStack(alignment: .leading, spacing: 8) {
if let title {
Text(title)
.font(.caption.weight(.semibold))
.foregroundStyle(.secondary)
}
if chartData.isEmpty {
emptyChartView
} else {
chartContent
}
}
}
// MARK: - Chart Content
@ViewBuilder
private var chartContent: some View {
Chart(chartData, id: \.label) { item in
BarMark(
x: .value(categoryColumn, item.label),
y: .value(valueColumn, item.value)
)
.foregroundStyle(
.linearGradient(
colors: [.accentColor, .accentColor.opacity(0.7)],
startPoint: .bottom,
endPoint: .top
)
)
.cornerRadius(4)
}
.chartXAxis {
AxisMarks(values: .automatic) { _ in
AxisValueLabel()
.font(.caption2)
}
}
.chartYAxis {
AxisMarks(position: .leading) { _ in
AxisGridLine(stroke: StrokeStyle(lineWidth: 0.5, dash: [4, 4]))
.foregroundStyle(.secondary.opacity(0.3))
AxisValueLabel()
.font(.caption2)
}
}
.frame(minHeight: 200)
if isTruncated {
truncationNotice
}
}
// MARK: - Empty State
@ViewBuilder
private var emptyChartView: some View {
VStack(spacing: 8) {
Image(systemName: "chart.bar")
.font(.title2)
.foregroundStyle(.secondary)
Text("No chartable data")
.font(.caption)
.foregroundStyle(.secondary)
}
.frame(maxWidth: .infinity, minHeight: 100)
}
// MARK: - Truncation Notice
@ViewBuilder
private var truncationNotice: some View {
Text("Showing \(maxBars) of \(dataTable.rowCount) categories")
.font(.caption2)
.foregroundStyle(.secondary)
}
// MARK: - Data Extraction
private var isTruncated: Bool {
dataTable.rowCount > maxBars
}
private var chartData: [ChartDataPoint] {
let labels = dataTable.stringValues(forColumn: categoryColumn)
let values = dataTable.numericValues(forColumn: valueColumn)
let count = min(labels.count, values.count, maxBars)
guard count > 0 else { return [] }
return (0..<count).map { i in
ChartDataPoint(label: labels[i], value: values[i])
}
}
}
// MARK: - Preview
#if DEBUG
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
#Preview("Bar Chart") {
let columns: [DataTable.Column] = [
.init(name: "department", index: 0, inferredType: .text),
.init(name: "revenue", index: 1, inferredType: .real),
]
let departments = ["Engineering", "Sales", "Marketing", "Support", "Design"]
let rows: [DataTable.Row] = departments.enumerated().map { i, dept in
DataTable.Row(
id: i,
values: [.text(dept), .real(Double.random(in: 50_000...200_000))],
columnNames: ["department", "revenue"]
)
}
let table = DataTable(columns: columns, rows: rows)
BarChartView(
dataTable: table,
categoryColumn: "department",
valueColumn: "revenue",
title: "Revenue by Department"
)
.padding()
.frame(height: 300)
}
#endif

View File

@@ -0,0 +1,21 @@
// ChartDataPoint.swift
// SwiftDBAI
//
// Shared data model used by all chart views.
import Foundation
/// A single data point for chart rendering.
///
/// Pairs a string label (category) with a numeric value.
/// Used as the common data format across BarChartView,
/// LineChartView, and PieChartView.
struct ChartDataPoint: Sendable, Identifiable {
var id: String { label }
/// The category label (x-axis or slice label).
let label: String
/// The numeric value (y-axis or slice size).
let value: Double
}

View File

@@ -0,0 +1,135 @@
// ChartResultView.swift
// SwiftDBAI
//
// Auto-selecting chart view that uses ChartDataDetector to pick the
// best chart type for a given DataTable.
import SwiftUI
import Charts
/// A chart view that automatically selects the best chart type for a `DataTable`.
///
/// Uses `ChartDataDetector` to analyze the data shape and renders the
/// appropriate chart (bar, line, or pie). If the data isn't suitable for
/// charting, the view renders nothing.
///
/// Usage:
/// ```swift
/// ChartResultView(dataTable: myTable)
/// ```
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
public struct ChartResultView: View {
/// The data table to chart.
public let dataTable: DataTable
/// Optional override: force a specific chart type.
public var chartType: ChartDataDetector.ChartType?
/// The detector used to analyze the data.
private let detector: ChartDataDetector
public init(
dataTable: DataTable,
chartType: ChartDataDetector.ChartType? = nil,
detector: ChartDataDetector = ChartDataDetector()
) {
self.dataTable = dataTable
self.chartType = chartType
self.detector = detector
}
public var body: some View {
if let recommendation = resolvedRecommendation {
VStack(alignment: .leading, spacing: 4) {
chartView(for: recommendation)
Text(recommendation.reason)
.font(.caption2)
.foregroundStyle(.tertiary)
}
}
}
// MARK: - Chart Selection
@ViewBuilder
private func chartView(
for recommendation: ChartDataDetector.ChartRecommendation
) -> some View {
switch recommendation.chartType {
case .bar:
BarChartView(
dataTable: dataTable,
categoryColumn: recommendation.categoryColumn,
valueColumn: recommendation.valueColumn
)
case .line:
LineChartView(
dataTable: dataTable,
categoryColumn: recommendation.categoryColumn,
valueColumn: recommendation.valueColumn
)
case .pie:
PieChartView(
dataTable: dataTable,
categoryColumn: recommendation.categoryColumn,
valueColumn: recommendation.valueColumn
)
}
}
// MARK: - Resolution
/// Resolves the chart recommendation, using the override type if provided.
private var resolvedRecommendation: ChartDataDetector.ChartRecommendation? {
if let override = chartType {
// Use forced chart type still need column pair from detector
let all = detector.allRecommendations(for: dataTable)
// Try to find recommendation for the forced type
if let match = all.first(where: { $0.chartType == override }) {
return match
}
// Fallback: use first recommendation and override its type
if let first = all.first {
return ChartDataDetector.ChartRecommendation(
chartType: override,
categoryColumn: first.categoryColumn,
valueColumn: first.valueColumn,
confidence: first.confidence * 0.8,
reason: first.reason
)
}
return nil
}
// Auto-detect best chart type
return detector.detect(dataTable)
}
}
// MARK: - Preview
#if DEBUG
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
#Preview("Auto Chart — Bar") {
let columns: [DataTable.Column] = [
.init(name: "city", index: 0, inferredType: .text),
.init(name: "population", index: 1, inferredType: .integer),
]
let cities = ["NYC", "LA", "Chicago", "Houston", "Phoenix"]
let pops: [Int64] = [8_336_817, 3_979_576, 2_693_976, 2_320_268, 1_680_992]
let rows: [DataTable.Row] = cities.enumerated().map { i, city in
DataTable.Row(
id: i,
values: [.text(city), .integer(pops[i])],
columnNames: ["city", "population"]
)
}
let table = DataTable(columns: columns, rows: rows)
ChartResultView(dataTable: table)
.padding()
.frame(height: 300)
}
#endif

View File

@@ -0,0 +1,206 @@
// LineChartView.swift
// SwiftDBAI
//
// A SwiftUI line chart that renders DataTable values using Swift Charts.
// Best for time series or sequential data (e.g., revenue over months).
import SwiftUI
import Charts
/// A line chart view that renders a `DataTable` column pair using Swift Charts.
///
/// Displays a connected line with optional area fill, point markers,
/// and smooth interpolation. Best suited for time series or sequential data.
///
/// Usage:
/// ```swift
/// LineChartView(
/// dataTable: table,
/// categoryColumn: "month",
/// valueColumn: "revenue"
/// )
/// ```
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
public struct LineChartView: View {
/// The data to chart.
public let dataTable: DataTable
/// Column name for category/time labels (x-axis).
public let categoryColumn: String
/// Column name for numeric values (y-axis).
public let valueColumn: String
/// Optional chart title.
public var title: String?
/// Whether to show an area fill below the line.
public var showAreaFill: Bool
/// Whether to show point markers at each data point.
public var showPoints: Bool
/// Maximum data points to display.
public var maxPoints: Int
public init(
dataTable: DataTable,
categoryColumn: String,
valueColumn: String,
title: String? = nil,
showAreaFill: Bool = true,
showPoints: Bool = true,
maxPoints: Int = 100
) {
self.dataTable = dataTable
self.categoryColumn = categoryColumn
self.valueColumn = valueColumn
self.title = title
self.showAreaFill = showAreaFill
self.showPoints = showPoints
self.maxPoints = maxPoints
}
public var body: some View {
VStack(alignment: .leading, spacing: 8) {
if let title {
Text(title)
.font(.caption.weight(.semibold))
.foregroundStyle(.secondary)
}
if chartData.isEmpty {
emptyChartView
} else {
chartContent
}
}
}
// MARK: - Chart Content
@ViewBuilder
private var chartContent: some View {
Chart(chartData, id: \.label) { item in
LineMark(
x: .value(categoryColumn, item.label),
y: .value(valueColumn, item.value)
)
.foregroundStyle(Color.accentColor)
.lineStyle(StrokeStyle(lineWidth: 2))
.interpolationMethod(.catmullRom)
if showAreaFill {
AreaMark(
x: .value(categoryColumn, item.label),
y: .value(valueColumn, item.value)
)
.foregroundStyle(
.linearGradient(
colors: [
Color.accentColor.opacity(0.2),
Color.accentColor.opacity(0.02),
],
startPoint: .top,
endPoint: .bottom
)
)
.interpolationMethod(.catmullRom)
}
if showPoints {
PointMark(
x: .value(categoryColumn, item.label),
y: .value(valueColumn, item.value)
)
.foregroundStyle(Color.accentColor)
.symbolSize(30)
}
}
.chartXAxis {
AxisMarks(values: .automatic) { _ in
AxisValueLabel()
.font(.caption2)
}
}
.chartYAxis {
AxisMarks(position: .leading) { _ in
AxisGridLine(stroke: StrokeStyle(lineWidth: 0.5, dash: [4, 4]))
.foregroundStyle(.secondary.opacity(0.3))
AxisValueLabel()
.font(.caption2)
}
}
.frame(minHeight: 200)
if isTruncated {
Text("Showing \(maxPoints) of \(dataTable.rowCount) data points")
.font(.caption2)
.foregroundStyle(.secondary)
}
}
// MARK: - Empty State
@ViewBuilder
private var emptyChartView: some View {
VStack(spacing: 8) {
Image(systemName: "chart.xyaxis.line")
.font(.title2)
.foregroundStyle(.secondary)
Text("No chartable data")
.font(.caption)
.foregroundStyle(.secondary)
}
.frame(maxWidth: .infinity, minHeight: 100)
}
// MARK: - Data Extraction
private var isTruncated: Bool {
dataTable.rowCount > maxPoints
}
private var chartData: [ChartDataPoint] {
let labels = dataTable.stringValues(forColumn: categoryColumn)
let values = dataTable.numericValues(forColumn: valueColumn)
let count = min(labels.count, values.count, maxPoints)
guard count > 0 else { return [] }
return (0..<count).map { i in
ChartDataPoint(label: labels[i], value: values[i])
}
}
}
// MARK: - Preview
#if DEBUG
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
#Preview("Line Chart") {
let columns: [DataTable.Column] = [
.init(name: "month", index: 0, inferredType: .text),
.init(name: "revenue", index: 1, inferredType: .real),
]
let months = ["Jan", "Feb", "Mar", "Apr", "May", "Jun"]
let rows: [DataTable.Row] = months.enumerated().map { i, month in
DataTable.Row(
id: i,
values: [.text(month), .real(Double(i + 1) * 15_000 + Double.random(in: -3000...3000))],
columnNames: ["month", "revenue"]
)
}
let table = DataTable(columns: columns, rows: rows)
LineChartView(
dataTable: table,
categoryColumn: "month",
valueColumn: "revenue",
title: "Monthly Revenue"
)
.padding()
.frame(height: 300)
}
#endif

View File

@@ -0,0 +1,234 @@
// PieChartView.swift
// SwiftDBAI
//
// A SwiftUI pie/donut chart that renders DataTable values using Swift Charts.
// Best for proportional breakdowns with few categories (e.g., market share).
import SwiftUI
import Charts
/// A pie chart view that renders a `DataTable` column pair using Swift Charts.
///
/// Displays proportional slices with category labels. Each slice is
/// automatically colored from a curated palette and sized relative to
/// its proportion of the total. Best suited for data with few categories
/// ( 8) where all values are positive.
///
/// Usage:
/// ```swift
/// PieChartView(
/// dataTable: table,
/// categoryColumn: "status",
/// valueColumn: "count"
/// )
/// ```
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
public struct PieChartView: View {
/// The data to chart.
public let dataTable: DataTable
/// Column name for category labels (slice labels).
public let categoryColumn: String
/// Column name for numeric values (slice sizes).
public let valueColumn: String
/// Optional chart title.
public var title: String?
/// Inner radius ratio for donut style (0 = full pie, >0 = donut).
public var innerRadiusRatio: CGFloat
/// Maximum number of slices before grouping remaining into "Other".
public var maxSlices: Int
public init(
dataTable: DataTable,
categoryColumn: String,
valueColumn: String,
title: String? = nil,
innerRadiusRatio: CGFloat = 0.4,
maxSlices: Int = 8
) {
self.dataTable = dataTable
self.categoryColumn = categoryColumn
self.valueColumn = valueColumn
self.title = title
self.innerRadiusRatio = innerRadiusRatio
self.maxSlices = maxSlices
}
public var body: some View {
VStack(alignment: .leading, spacing: 8) {
if let title {
Text(title)
.font(.caption.weight(.semibold))
.foregroundStyle(.secondary)
}
if chartData.isEmpty {
emptyChartView
} else {
HStack(alignment: .center, spacing: 16) {
chartContent
legendView
}
}
}
}
// MARK: - Chart Content
@ViewBuilder
private var chartContent: some View {
Chart(chartData, id: \.label) { item in
SectorMark(
angle: .value(valueColumn, item.value),
innerRadius: .ratio(innerRadiusRatio),
angularInset: 1.5
)
.foregroundStyle(by: .value(categoryColumn, item.label))
.cornerRadius(3)
}
.chartForegroundStyleScale(
domain: chartData.map(\.label),
range: sliceColors
)
.chartLegend(.hidden)
.frame(minWidth: 150, minHeight: 150)
.aspectRatio(1, contentMode: .fit)
}
// MARK: - Legend
@ViewBuilder
private var legendView: some View {
VStack(alignment: .leading, spacing: 6) {
ForEach(Array(chartData.enumerated()), id: \.element.label) { index, item in
HStack(spacing: 8) {
Circle()
.fill(sliceColors[index % sliceColors.count])
.frame(width: 8, height: 8)
Text(item.label)
.font(.caption)
.foregroundStyle(.primary)
.lineLimit(1)
Spacer()
Text(percentageText(for: item.value))
.font(.caption)
.foregroundStyle(.secondary)
.monospacedDigit()
}
}
}
}
// MARK: - Empty State
@ViewBuilder
private var emptyChartView: some View {
VStack(spacing: 8) {
Image(systemName: "chart.pie")
.font(.title2)
.foregroundStyle(.secondary)
Text("No chartable data")
.font(.caption)
.foregroundStyle(.secondary)
}
.frame(maxWidth: .infinity, minHeight: 100)
}
// MARK: - Colors
/// Curated color palette for pie slices.
private var sliceColors: [Color] {
[
.blue,
.green,
.orange,
.purple,
.pink,
.cyan,
.yellow,
.indigo,
.mint,
.teal,
]
}
// MARK: - Helpers
private var total: Double {
chartData.reduce(0) { $0 + $1.value }
}
private func percentageText(for value: Double) -> String {
guard total > 0 else { return "0%" }
let pct = (value / total) * 100
if pct >= 10 {
return String(format: "%.0f%%", pct)
}
return String(format: "%.1f%%", pct)
}
// MARK: - Data Extraction
private var chartData: [ChartDataPoint] {
let labels = dataTable.stringValues(forColumn: categoryColumn)
let values = dataTable.numericValues(forColumn: valueColumn)
let count = min(labels.count, values.count)
guard count > 0 else { return [] }
// Build all points, sorted by value descending
var points = (0..<count).map { i in
ChartDataPoint(label: labels[i], value: values[i])
}
.filter { $0.value > 0 }
.sorted { $0.value > $1.value }
// Group excess slices into "Other"
if points.count > maxSlices {
let kept = Array(points.prefix(maxSlices - 1))
let otherValue = points.dropFirst(maxSlices - 1).reduce(0) { $0 + $1.value }
points = kept + [ChartDataPoint(label: "Other", value: otherValue)]
}
return points
}
}
// MARK: - Preview
#if DEBUG
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
#Preview("Pie Chart") {
let columns: [DataTable.Column] = [
.init(name: "status", index: 0, inferredType: .text),
.init(name: "count", index: 1, inferredType: .integer),
]
let statuses = ["Active", "Inactive", "Pending", "Archived"]
let counts: [Int64] = [45, 20, 15, 10]
let rows: [DataTable.Row] = statuses.enumerated().map { i, status in
DataTable.Row(
id: i,
values: [.text(status), .integer(counts[i])],
columnNames: ["status", "count"]
)
}
let table = DataTable(columns: columns, rows: rows)
PieChartView(
dataTable: table,
categoryColumn: "status",
valueColumn: "count",
title: "Users by Status"
)
.padding()
.frame(height: 250)
}
#endif

View File

@@ -0,0 +1,214 @@
// ChatView.swift
// SwiftDBAI
//
// Drop-in SwiftUI view for chatting with a SQLite database.
// Renders messages with automatic data table display for query results.
import SwiftUI
/// A drop-in SwiftUI chat interface for querying SQLite databases
/// with natural language.
///
/// `ChatView` renders the full conversation including:
/// - User messages (right-aligned, accent-colored)
/// - Assistant responses with text summaries
/// - **Automatic data tables** via `ScrollableDataTableView` when query results
/// contain tabular data (rows + columns)
/// - SQL query disclosure for transparency
/// - Error messages with red styling
/// - A loading indicator while the engine is processing
///
/// Usage:
/// ```swift
/// let engine = ChatEngine(database: myPool, model: myModel)
/// let viewModel = ChatViewModel(engine: engine)
///
/// ChatView(viewModel: viewModel)
/// ```
///
/// Or use the convenience initializer:
/// ```swift
/// ChatView(engine: myEngine)
/// ```
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
public struct ChatView: View {
@Bindable private var viewModel: ChatViewModel
@State private var inputText: String = ""
@FocusState private var isInputFocused: Bool
/// Creates a ChatView with an existing view model.
///
/// - Parameter viewModel: The `ChatViewModel` driving this view.
public init(viewModel: ChatViewModel) {
self.viewModel = viewModel
}
/// Creates a ChatView with a `ChatEngine`, automatically creating
/// a `ChatViewModel`.
///
/// - Parameter engine: The `ChatEngine` to power the chat.
public init(engine: ChatEngine) {
self.viewModel = ChatViewModel(engine: engine)
}
public var body: some View {
VStack(spacing: 0) {
messageList
Divider()
inputBar
}
}
// MARK: - Message List
@ViewBuilder
private var messageList: some View {
ScrollViewReader { proxy in
ScrollView {
LazyVStack(spacing: 12) {
if viewModel.messages.isEmpty {
emptyState
}
ForEach(viewModel.messages) { message in
messageBubble(for: message)
.id(message.id)
}
if viewModel.isLoading {
loadingIndicator
}
}
.padding(.horizontal, 16)
.padding(.vertical, 12)
}
.onChange(of: viewModel.messages.count) { _, _ in
if let lastMessage = viewModel.messages.last {
withAnimation(.easeOut(duration: 0.3)) {
proxy.scrollTo(lastMessage.id, anchor: .bottom)
}
}
}
}
}
// MARK: - Empty State
@ViewBuilder
private var emptyState: some View {
VStack(spacing: 12) {
Image(systemName: "bubble.left.and.text.bubble.right")
.font(.system(size: 40))
.foregroundStyle(.tertiary)
Text("Ask a question about your data")
.font(.headline)
.foregroundStyle(.secondary)
Text("Try something like \"How many records are in the database?\"")
.font(.subheadline)
.foregroundStyle(.tertiary)
.multilineTextAlignment(.center)
}
.frame(maxWidth: .infinity)
.padding(.vertical, 60)
}
// MARK: - Loading Indicator
@ViewBuilder
private var loadingIndicator: some View {
HStack(alignment: .top) {
HStack(spacing: 8) {
ProgressView()
.controlSize(.small)
Text("Querying…")
.font(.callout)
.foregroundStyle(.secondary)
}
.padding(.horizontal, 14)
.padding(.vertical, 10)
.background(
Self.assistantBackgroundColor,
in: RoundedRectangle(cornerRadius: 16, style: .continuous)
)
Spacer(minLength: 48)
}
.id("loading-indicator")
.transition(.opacity.combined(with: .move(edge: .bottom)))
}
private static var assistantBackgroundColor: Color {
#if os(macOS)
Color(nsColor: .controlBackgroundColor)
#else
Color(uiColor: .secondarySystemGroupedBackground)
#endif
}
// MARK: - Input Bar
@ViewBuilder
private var inputBar: some View {
HStack(spacing: 8) {
TextField("Ask about your data…", text: $inputText, axis: .vertical)
.textFieldStyle(.plain)
.lineLimit(1...5)
.focused($isInputFocused)
.onSubmit { sendMessage() }
.submitLabel(.send)
Button(action: sendMessage) {
Image(systemName: "arrow.up.circle.fill")
.font(.title2)
.foregroundStyle(canSend ? Color.accentColor : Color.secondary)
}
.disabled(!canSend)
.keyboardShortcut(.return, modifiers: .command)
}
.padding(.horizontal, 16)
.padding(.vertical, 10)
}
// MARK: - Message Bubble
@ViewBuilder
private func messageBubble(for message: ChatMessage) -> some View {
if message.role == .error {
MessageBubbleView(
message: message,
onRetry: makeRetryAction(for: message)
)
} else {
MessageBubbleView(message: message)
}
}
private func makeRetryAction(for errorMessage: ChatMessage) -> @Sendable () async -> Void {
let vm = viewModel
let messageId = errorMessage.id
return { @MainActor [vm] in
let allMessages = await MainActor.run { vm.messages }
if let lastUserMessage = allMessages
.prefix(while: { $0.id != messageId })
.last(where: { $0.role == .user }) {
await vm.send(lastUserMessage.content)
}
}
}
// MARK: - Helpers
private var canSend: Bool {
!inputText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty && !viewModel.isLoading
}
private func sendMessage() {
guard canSend else { return }
let text = inputText
inputText = ""
Task {
await viewModel.send(text)
}
}
}

View File

@@ -0,0 +1,137 @@
// ChatViewModel.swift
// SwiftDBAI
//
// Observable view model that bridges ChatEngine with the SwiftUI ChatView.
import Foundation
import Observation
/// The readiness state of the schema introspection.
public enum SchemaReadiness: Sendable, Equatable {
/// Schema has not been loaded yet.
case idle
/// Schema introspection is in progress.
case loading
/// Schema is ready with the given number of tables.
case ready(tableCount: Int)
/// Schema introspection failed.
case failed(String)
public var isReady: Bool {
if case .ready = self { return true }
return false
}
}
/// Observable view model that drives the `ChatView`.
///
/// Wraps `ChatEngine` to provide reactive state updates for the SwiftUI layer.
/// Manages the message list, loading state, error presentation, and schema
/// readiness. Call ``prepare()`` at view-appear time to eagerly introspect the
/// database schema.
///
/// Usage:
/// ```swift
/// let viewModel = ChatViewModel(engine: myChatEngine)
/// ChatView(viewModel: viewModel)
/// ```
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
@Observable
@MainActor
public final class ChatViewModel {
// MARK: - Public State
/// All messages in the conversation, in chronological order.
public private(set) var messages: [ChatMessage] = []
/// Whether the engine is currently processing a request.
public private(set) var isLoading: Bool = false
/// The most recent error message, if any. Cleared on next send.
public private(set) var errorMessage: String?
/// Current schema readiness state.
public private(set) var schemaReadiness: SchemaReadiness = .idle
// MARK: - Dependencies
private let engine: ChatEngine
// MARK: - Initialization
/// Creates a new ChatViewModel.
///
/// - Parameter engine: The `ChatEngine` to use for processing messages.
public init(engine: ChatEngine) {
self.engine = engine
}
// MARK: - Schema Preparation
/// Eagerly introspects the database schema so it's ready before the first query.
///
/// This should be called from a `.task` modifier on the view. It transitions
/// `schemaReadiness` through `.loading` `.ready` (or `.failed`).
/// If the schema is already cached, this completes immediately.
public func prepare() async {
// Don't re-prepare if already ready
if schemaReadiness.isReady { return }
schemaReadiness = .loading
do {
let schema = try await engine.prepareSchema()
schemaReadiness = .ready(tableCount: schema.tableNames.count)
} catch {
schemaReadiness = .failed(error.localizedDescription)
}
}
// MARK: - Public API
/// Sends a user message and appends the response to the conversation.
///
/// - Parameter text: The natural language message from the user.
public func send(_ text: String) async {
let trimmed = text.trimmingCharacters(in: .whitespacesAndNewlines)
guard !trimmed.isEmpty else { return }
errorMessage = nil
// Add user message immediately
let userMessage = ChatMessage(role: .user, content: trimmed)
messages.append(userMessage)
isLoading = true
defer { isLoading = false }
do {
let response = try await engine.send(trimmed)
let assistantMessage = ChatMessage(
role: .assistant,
content: response.summary,
queryResult: response.queryResult,
sql: response.sql
)
messages.append(assistantMessage)
} catch {
let typedError = (error as? SwiftDBAIError)
let errorMsg = ChatMessage(
role: .error,
content: error.localizedDescription,
error: typedError
)
messages.append(errorMsg)
errorMessage = error.localizedDescription
}
}
/// Clears the conversation and resets the engine state.
public func reset() {
messages.removeAll()
errorMessage = nil
engine.reset()
}
}

View File

@@ -0,0 +1,220 @@
// DataChatView.swift
// SwiftDBAI
//
// Zero-config SwiftUI view: provide a database path and a model, get a chat UI.
import AnyLanguageModel
import GRDB
import SwiftUI
/// A convenience SwiftUI view that wraps the full chat-with-database stack.
///
/// `DataChatView` is the simplest entry point into SwiftDBAI. It requires only
/// a database file path and a language model no schema files, no annotations,
/// no manual setup. The view creates a GRDB connection, a `ChatEngine`,
/// a `ChatViewModel`, and renders a fully functional `ChatView`.
///
/// Usage with just a path and model:
/// ```swift
/// DataChatView(
/// databasePath: "/path/to/mydata.sqlite",
/// model: OllamaLanguageModel(model: "llama3")
/// )
/// ```
///
/// Usage with additional configuration:
/// ```swift
/// DataChatView(
/// databasePath: documentsURL.appendingPathComponent("app.db").path,
/// model: OpenAILanguageModel(apiKey: key),
/// allowlist: .standard,
/// additionalContext: "This database stores a recipe app's data."
/// )
/// ```
///
/// If you already have a GRDB `DatabasePool` or `DatabaseQueue`, use
/// `ChatView` with a `ChatEngine` directly for full control.
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
public struct DataChatView: View {
@State private var viewModel: ChatViewModel
@State private var loadError: DataChatError?
/// Creates a DataChatView from a database file path and language model.
///
/// This is the zero-config convenience initializer. It opens a GRDB
/// `DatabasePool` at the given path, creates a `ChatEngine` with
/// read-only defaults, and wires up the full chat UI.
///
/// - Parameters:
/// - databasePath: Absolute path to a SQLite database file.
/// - model: Any `AnyLanguageModel`-compatible language model instance.
/// - allowlist: SQL operations the LLM may generate. Defaults to `.readOnly` (SELECT only).
/// - additionalContext: Optional extra context about the database for the LLM system prompt
/// (e.g., "This database stores e-commerce orders and products.").
/// - maxSummaryRows: Maximum rows to include when summarizing results (default: 50).
public init(
databasePath: String,
model: any LanguageModel,
allowlist: OperationAllowlist = .readOnly,
additionalContext: String? = nil,
maxSummaryRows: Int = 50
) {
do {
let pool = try DatabasePool(path: databasePath)
let engine = ChatEngine(
database: pool,
model: model,
allowlist: allowlist,
additionalContext: additionalContext,
maxSummaryRows: maxSummaryRows
)
self._viewModel = State(initialValue: ChatViewModel(engine: engine))
self._loadError = State(initialValue: nil)
} catch {
// If the database can't be opened, create a placeholder engine
// and store the error to display in the UI.
let queue = try! DatabaseQueue()
let engine = ChatEngine(
database: queue,
model: model,
allowlist: allowlist,
additionalContext: additionalContext,
maxSummaryRows: maxSummaryRows
)
self._viewModel = State(initialValue: ChatViewModel(engine: engine))
self._loadError = State(initialValue: DataChatError.databaseOpenFailed(
path: databasePath,
underlying: error
))
}
}
/// Creates a DataChatView from an existing GRDB database connection and language model.
///
/// Use this initializer when you already have a configured `DatabasePool` or
/// `DatabaseQueue` and want the convenience of `DataChatView` without
/// creating a `ChatEngine` yourself.
///
/// - Parameters:
/// - database: A GRDB `DatabaseWriter` (`DatabasePool` or `DatabaseQueue`).
/// - model: Any `AnyLanguageModel`-compatible language model instance.
/// - allowlist: SQL operations the LLM may generate. Defaults to `.readOnly`.
/// - additionalContext: Optional extra context about the database for the LLM.
/// - maxSummaryRows: Maximum rows to include when summarizing results (default: 50).
public init(
database: any DatabaseWriter,
model: any LanguageModel,
allowlist: OperationAllowlist = .readOnly,
additionalContext: String? = nil,
maxSummaryRows: Int = 50
) {
let engine = ChatEngine(
database: database,
model: model,
allowlist: allowlist,
additionalContext: additionalContext,
maxSummaryRows: maxSummaryRows
)
self._viewModel = State(initialValue: ChatViewModel(engine: engine))
self._loadError = State(initialValue: nil)
}
public var body: some View {
if let error = loadError {
errorView(error)
} else {
ChatView(viewModel: viewModel)
.task {
await viewModel.prepare()
}
.overlay {
if case .loading = viewModel.schemaReadiness {
schemaLoadingView
}
if case .failed(let reason) = viewModel.schemaReadiness {
schemaErrorView(reason)
}
}
}
}
// MARK: - Schema Loading View
@ViewBuilder
private var schemaLoadingView: some View {
VStack(spacing: 16) {
ProgressView()
.controlSize(.large)
Text("Introspecting database schema…")
.font(.subheadline)
.foregroundStyle(.secondary)
}
.frame(maxWidth: .infinity, maxHeight: .infinity)
.background(.ultraThinMaterial)
}
// MARK: - Schema Error View
@ViewBuilder
private func schemaErrorView(_ reason: String) -> some View {
VStack(spacing: 16) {
Image(systemName: "exclamationmark.triangle.fill")
.font(.system(size: 40))
.foregroundStyle(.orange)
Text("Schema Introspection Failed")
.font(.headline)
Text(reason)
.font(.subheadline)
.foregroundStyle(.secondary)
.multilineTextAlignment(.center)
.padding(.horizontal, 32)
Button("Retry") {
Task {
await viewModel.prepare()
}
}
.buttonStyle(.borderedProminent)
}
.frame(maxWidth: .infinity, maxHeight: .infinity)
.background(.ultraThinMaterial)
}
// MARK: - Database Open Error View
@ViewBuilder
private func errorView(_ error: DataChatError) -> some View {
VStack(spacing: 16) {
Image(systemName: "exclamationmark.triangle.fill")
.font(.system(size: 40))
.foregroundStyle(.red)
Text("Unable to Open Database")
.font(.headline)
Text(error.localizedDescription)
.font(.subheadline)
.foregroundStyle(.secondary)
.multilineTextAlignment(.center)
.padding(.horizontal, 32)
}
.frame(maxWidth: .infinity, maxHeight: .infinity)
}
}
// MARK: - Errors
/// Errors specific to `DataChatView` initialization.
public enum DataChatError: Error, LocalizedError, Sendable {
/// The database file could not be opened at the given path.
case databaseOpenFailed(path: String, underlying: any Error)
public var errorDescription: String? {
switch self {
case .databaseOpenFailed(let path, let underlying):
return "Could not open database at \"\(path)\": \(underlying.localizedDescription)"
}
}
}

View File

@@ -0,0 +1,360 @@
// ErrorMessageView.swift
// SwiftDBAI
//
// Reusable SwiftUI component that renders error messages with contextual
// icons, descriptions, and optional retry actions based on the error type.
import SwiftUI
/// A reusable SwiftUI component that renders a ``SwiftDBAIError`` with an
/// appropriate icon, human-readable message, and optional retry action.
///
/// The view automatically selects a visual treatment based on the error
/// category:
///
/// | Category | Icon | Color | Retry? |
/// |-------------------|-------------------------------|---------|--------|
/// | Safety / blocked | `shield.trianglebadge.excl` | Orange | No |
/// | Confirmation | `hand.raised.fill` | Yellow | Yes* |
/// | LLM failure | `brain` | Purple | Yes |
/// | Schema / DB | `cylinder.split.1x2` | Red | No |
/// | Recoverable SQL | `arrow.clockwise` | Blue | Yes |
/// | Generic | `exclamationmark.triangle` | Red | No |
///
/// *Confirmation retry triggers the confirm callback, not a standard retry.
///
/// Usage:
/// ```swift
/// ErrorMessageView(
/// error: .llmTimeout(seconds: 30),
/// onRetry: { /* resend the message */ }
/// )
/// ```
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
public struct ErrorMessageView: View {
/// The error to display. When `nil`, the view falls back to the raw message.
private let error: SwiftDBAIError?
/// The raw error message string (used as fallback when error is nil).
private let message: String
/// Called when the user taps the retry button. `nil` hides the button.
private let onRetry: (@Sendable () async -> Void)?
/// Called when the user confirms a destructive operation.
private let onConfirm: (@Sendable () async -> Void)?
@State private var isRetrying = false
// MARK: - Initializers
/// Creates an ErrorMessageView from a typed ``SwiftDBAIError``.
///
/// - Parameters:
/// - error: The ``SwiftDBAIError`` to display.
/// - onRetry: An optional async closure invoked when the user taps retry.
/// - onConfirm: An optional async closure invoked when the user confirms
/// a destructive operation (only relevant for `.confirmationRequired`).
public init(
error: SwiftDBAIError,
onRetry: (@Sendable () async -> Void)? = nil,
onConfirm: (@Sendable () async -> Void)? = nil
) {
self.error = error
self.message = error.localizedDescription
self.onRetry = onRetry
self.onConfirm = onConfirm
}
/// Creates an ErrorMessageView from a ``ChatMessage``.
///
/// Extracts the typed error if available, otherwise falls back to the
/// message content string.
///
/// - Parameters:
/// - message: The chat message with role `.error`.
/// - onRetry: An optional async closure invoked when the user taps retry.
/// - onConfirm: An optional async closure invoked when the user confirms
/// a destructive operation.
public init(
chatMessage: ChatMessage,
onRetry: (@Sendable () async -> Void)? = nil,
onConfirm: (@Sendable () async -> Void)? = nil
) {
self.error = chatMessage.error
self.message = chatMessage.content
self.onRetry = onRetry
self.onConfirm = onConfirm
}
/// Creates an ErrorMessageView from a plain string (untyped fallback).
///
/// - Parameters:
/// - message: The error message string.
/// - onRetry: An optional async closure invoked when the user taps retry.
public init(
message: String,
onRetry: (@Sendable () async -> Void)? = nil
) {
self.error = nil
self.message = message
self.onRetry = onRetry
self.onConfirm = nil
}
// MARK: - Body
public var body: some View {
VStack(alignment: .leading, spacing: 10) {
// Icon + message row
HStack(alignment: .firstTextBaseline, spacing: 8) {
Image(systemName: iconName)
.foregroundStyle(iconColor)
.font(.callout)
.accessibilityHidden(true)
VStack(alignment: .leading, spacing: 4) {
if let title = errorTitle {
Text(title)
.font(.callout.weight(.semibold))
.foregroundStyle(iconColor)
}
Text(message)
.font(.body)
.foregroundStyle(.primary)
.textSelection(.enabled)
.fixedSize(horizontal: false, vertical: true)
if let hint = recoveryHint {
Text(hint)
.font(.caption)
.foregroundStyle(.secondary)
.fixedSize(horizontal: false, vertical: true)
}
}
}
// Action buttons
if showRetryButton || showConfirmButton {
HStack(spacing: 12) {
if showConfirmButton {
confirmButton
}
if showRetryButton {
retryButton
}
}
.padding(.leading, 26) // Align with text (icon width + spacing)
}
}
.accessibilityElement(children: .combine)
.accessibilityLabel(accessibilityDescription)
}
// MARK: - Action Buttons
@ViewBuilder
private var retryButton: some View {
Button {
guard !isRetrying else { return }
isRetrying = true
Task {
await onRetry?()
isRetrying = false
}
} label: {
HStack(spacing: 4) {
if isRetrying {
ProgressView()
.controlSize(.mini)
} else {
Image(systemName: "arrow.clockwise")
.font(.caption)
}
Text(retryButtonLabel)
.font(.caption.weight(.medium))
}
.padding(.horizontal, 10)
.padding(.vertical, 6)
.background(iconColor.opacity(0.12))
.foregroundStyle(iconColor)
.clipShape(Capsule())
}
.buttonStyle(.plain)
.disabled(isRetrying)
}
@ViewBuilder
private var confirmButton: some View {
Button {
Task {
await onConfirm?()
}
} label: {
HStack(spacing: 4) {
Image(systemName: "checkmark.circle")
.font(.caption)
Text("Confirm")
.font(.caption.weight(.medium))
}
.padding(.horizontal, 10)
.padding(.vertical, 6)
.background(Color.orange.opacity(0.12))
.foregroundStyle(.orange)
.clipShape(Capsule())
}
.buttonStyle(.plain)
}
// MARK: - Error Classification
private var errorCategory: ErrorCategory {
guard let error else { return .generic }
if error.requiresUserAction {
return .confirmation
}
if error.isSafetyError {
return .safety
}
if error.isRecoverable {
return .recoverable
}
switch error {
case .llmFailure, .llmResponseUnparseable, .llmTimeout:
return .llm
case .schemaIntrospectionFailed, .emptySchema, .databaseError, .queryTimedOut:
return .database
case .configurationError:
return .configuration
default:
return .generic
}
}
private enum ErrorCategory {
case safety
case confirmation
case llm
case database
case recoverable
case configuration
case generic
}
// MARK: - Visual Properties
private var iconName: String {
switch errorCategory {
case .safety:
return "shield.trianglebadge.exclamationmark.fill"
case .confirmation:
return "hand.raised.fill"
case .llm:
return "brain"
case .database:
return "cylinder.split.1x2"
case .recoverable:
return "arrow.clockwise"
case .configuration:
return "gearshape.triangle.fill"
case .generic:
return "exclamationmark.triangle.fill"
}
}
private var iconColor: Color {
switch errorCategory {
case .safety:
return .orange
case .confirmation:
return .yellow
case .llm:
return .purple
case .database:
return .red
case .recoverable:
return .blue
case .configuration:
return .gray
case .generic:
return .red
}
}
private var errorTitle: String? {
switch errorCategory {
case .safety:
return "Operation Blocked"
case .confirmation:
return "Confirmation Required"
case .llm:
return "AI Provider Error"
case .database:
return "Database Error"
case .recoverable:
return "Query Issue"
case .configuration:
return "Configuration Error"
case .generic:
return nil
}
}
private var recoveryHint: String? {
guard let error else { return nil }
switch error {
case .noSQLGenerated, .llmResponseUnparseable:
return "Try rephrasing your question."
case .tableNotFound:
return "Check that you're referring to an existing table."
case .columnNotFound:
return "Verify the column name matches your schema."
case .invalidSQL:
return "The AI generated an invalid query. Try asking differently."
case .llmTimeout:
return "The AI took too long. Try a simpler question."
case .llmFailure:
return "The AI service may be temporarily unavailable."
case .emptySchema:
return "Add some tables to your database first."
case .queryTimedOut:
return "Try a simpler query or add database indexes."
default:
return nil
}
}
// MARK: - Button Visibility
private var showRetryButton: Bool {
guard onRetry != nil else { return false }
return errorCategory == .recoverable || errorCategory == .llm
}
private var showConfirmButton: Bool {
guard onConfirm != nil else { return false }
return errorCategory == .confirmation
}
private var retryButtonLabel: String {
switch errorCategory {
case .llm:
return "Retry"
case .recoverable:
return "Try Again"
default:
return "Retry"
}
}
// MARK: - Accessibility
private var accessibilityDescription: String {
let prefix = errorTitle.map { "\($0): " } ?? "Error: "
return prefix + message
}
}

View File

@@ -0,0 +1,205 @@
// MessageBubbleView.swift
// SwiftDBAI
//
// Renders a single ChatMessage as a styled bubble with optional
// data table and SQL disclosure for query results.
import SwiftUI
import Charts
/// Renders a single `ChatMessage` in the chat conversation.
///
/// - **User messages** display right-aligned with an accent-colored background
/// and white text, using a continuous rounded rectangle shape.
/// - **Assistant messages** display left-aligned with a secondary background.
/// The natural language text summary is the primary content, rendered with
/// full `.body` font and `.primary` foreground for readability.
/// If the message contains a `queryResult` with tabular data, a
/// `ScrollableDataTableView` is automatically embedded below the summary.
/// An optional SQL disclosure group shows the generated query.
/// - **Error messages** display left-aligned with a red-tinted background
/// and an exclamation mark icon.
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
struct MessageBubbleView: View {
let message: ChatMessage
/// Whether to show the SQL query in a disclosure group.
var showSQL: Bool = true
/// Maximum height for the data table before it scrolls.
var maxTableHeight: CGFloat = 300
/// Called when the user taps "Retry" on a recoverable error.
var onRetry: (@Sendable () async -> Void)?
/// Called when the user confirms a destructive operation.
var onConfirm: (@Sendable () async -> Void)?
var body: some View {
HStack(alignment: .top) {
if message.role == .user { Spacer(minLength: 48) }
bubbleContent
.padding(.horizontal, 14)
.padding(.vertical, 10)
.background(bubbleBackground)
.clipShape(bubbleShape)
if message.role != .user { Spacer(minLength: 48) }
}
}
// MARK: - Bubble Content
@ViewBuilder
private var bubbleContent: some View {
switch message.role {
case .user:
userContent
case .assistant:
assistantContent
case .error:
errorContent
}
}
// MARK: - User Content
@ViewBuilder
private var userContent: some View {
Text(message.content)
.font(.body)
.foregroundStyle(.white)
.textSelection(.enabled)
}
// MARK: - Assistant Content (Text Summary + Data Table + SQL)
@ViewBuilder
private var assistantContent: some View {
VStack(alignment: .leading, spacing: 10) {
// Natural language text summary primary content
Text(message.content)
.font(.body)
.foregroundStyle(.primary)
.textSelection(.enabled)
.fixedSize(horizontal: false, vertical: true)
// Data table automatically shown when queryResult has tabular data
if let queryResult = message.queryResult,
!queryResult.columns.isEmpty,
!queryResult.rows.isEmpty {
dataTableSection(for: queryResult)
}
// SQL disclosure collapsed by default for transparency
if showSQL, let sql = message.sql {
sqlDisclosure(sql: sql)
}
}
}
// MARK: - Error Content
@ViewBuilder
private var errorContent: some View {
ErrorMessageView(
chatMessage: message,
onRetry: onRetry,
onConfirm: onConfirm
)
}
/// Maximum height for the chart section.
var maxChartHeight: CGFloat = 250
/// Whether to show auto-detected charts. Defaults to `true`.
var showCharts: Bool = true
// MARK: - Chart Detection
/// The shared detector used for chart eligibility checks.
private static let chartDetector = ChartDataDetector()
// MARK: - Data Table Section
@ViewBuilder
private func dataTableSection(for queryResult: QueryResult) -> some View {
let dataTable = DataTable(queryResult)
VStack(alignment: .leading, spacing: 8) {
// Chart automatically shown when ChartDataDetector finds eligible data
if showCharts {
chartSection(for: dataTable)
}
Divider()
ScrollableDataTableView(
dataTable: dataTable,
showAlternatingRows: true,
showFooter: true
)
.frame(maxHeight: maxTableHeight)
}
}
// MARK: - Chart Section
@ViewBuilder
private func chartSection(for dataTable: DataTable) -> some View {
let detector = Self.chartDetector
if detector.detect(dataTable) != nil {
VStack(alignment: .leading, spacing: 4) {
ChartResultView(dataTable: dataTable, detector: detector)
.frame(maxHeight: maxChartHeight)
}
}
}
// MARK: - SQL Disclosure
@ViewBuilder
private func sqlDisclosure(sql: String) -> some View {
DisclosureGroup {
Text(sql)
.font(.system(.caption, design: .monospaced))
.foregroundStyle(.secondary)
.textSelection(.enabled)
.padding(8)
.frame(maxWidth: .infinity, alignment: .leading)
.background(Color.primary.opacity(0.04))
.clipShape(RoundedRectangle(cornerRadius: 6))
} label: {
Label("SQL Query", systemImage: "chevron.left.forwardslash.chevron.right")
.font(.caption)
.foregroundStyle(.secondary)
}
}
// MARK: - Styling Helpers
private var bubbleShape: RoundedRectangle {
RoundedRectangle(cornerRadius: 16, style: .continuous)
}
@ViewBuilder
private var bubbleBackground: some View {
switch message.role {
case .user:
Color.accentColor
case .assistant:
Self.assistantBackgroundColor
case .error:
Color.red.opacity(0.1)
}
}
private static var assistantBackgroundColor: Color {
#if os(macOS)
Color(nsColor: .controlBackgroundColor)
#else
Color(uiColor: .secondarySystemGroupedBackground)
#endif
}
}

View File

@@ -0,0 +1,267 @@
// ScrollableDataTableView.swift
// SwiftDBAI
//
// A SwiftUI view that renders a DataTable with horizontal and vertical
// scrolling, styled column headers, and row cells.
import SwiftUI
/// A scrollable table view that renders a `DataTable` with column headers
/// and row cells, supporting both horizontal and vertical scrolling.
///
/// Usage:
/// ```swift
/// ScrollableDataTableView(dataTable: myDataTable)
/// ```
///
/// The view automatically sizes columns based on content, highlights
/// alternating rows for readability, and right-aligns numeric columns.
public struct ScrollableDataTableView: View {
/// The data table to render.
public let dataTable: DataTable
/// Minimum width for each column in points.
public var minimumColumnWidth: CGFloat
/// Maximum width for each column in points.
public var maximumColumnWidth: CGFloat
/// Whether to show alternating row backgrounds.
public var showAlternatingRows: Bool
/// Whether to show the row count footer.
public var showFooter: Bool
public init(
dataTable: DataTable,
minimumColumnWidth: CGFloat = 80,
maximumColumnWidth: CGFloat = 250,
showAlternatingRows: Bool = true,
showFooter: Bool = true
) {
self.dataTable = dataTable
self.minimumColumnWidth = minimumColumnWidth
self.maximumColumnWidth = maximumColumnWidth
self.showAlternatingRows = showAlternatingRows
self.showFooter = showFooter
}
public var body: some View {
if dataTable.isEmpty {
emptyView
} else {
tableContent
}
}
// MARK: - Empty State
@ViewBuilder
private var emptyView: some View {
VStack(spacing: 8) {
Image(systemName: "tablecells")
.font(.largeTitle)
.foregroundStyle(.secondary)
Text("No results")
.font(.headline)
.foregroundStyle(.secondary)
}
.frame(maxWidth: .infinity, minHeight: 100)
}
// MARK: - Table Content
@ViewBuilder
private var tableContent: some View {
VStack(alignment: .leading, spacing: 0) {
ScrollView([.horizontal, .vertical]) {
LazyVStack(alignment: .leading, spacing: 0, pinnedViews: [.sectionHeaders]) {
Section {
ForEach(dataTable.rows) { row in
rowView(row)
}
} header: {
headerRow
}
}
}
if showFooter {
footerView
}
}
}
// MARK: - Header
@ViewBuilder
private var headerRow: some View {
HStack(spacing: 0) {
ForEach(dataTable.columns) { column in
Text(column.name)
.font(.caption.weight(.semibold))
.foregroundStyle(.primary)
.lineLimit(1)
.frame(
width: columnWidth(for: column),
alignment: alignment(for: column)
)
.padding(.horizontal, 8)
.padding(.vertical, 6)
if column.index < dataTable.columnCount - 1 {
Divider()
}
}
}
.background(.bar)
.overlay(alignment: .bottom) {
Divider()
}
}
// MARK: - Row
@ViewBuilder
private func rowView(_ row: DataTable.Row) -> some View {
HStack(spacing: 0) {
ForEach(dataTable.columns) { column in
cellView(value: row[column.index], column: column)
if column.index < dataTable.columnCount - 1 {
Divider()
}
}
}
.background(rowBackground(for: row))
.overlay(alignment: .bottom) {
Divider()
}
}
// MARK: - Cell
@ViewBuilder
private func cellView(value: QueryResult.Value, column: DataTable.Column) -> some View {
Group {
switch value {
case .null:
Text("NULL")
.foregroundStyle(.tertiary)
.italic()
case .blob(let data):
Text("<\(data.count) bytes>")
.foregroundStyle(.secondary)
default:
Text(value.stringValue)
.foregroundStyle(.primary)
}
}
.font(.caption)
.lineLimit(2)
.frame(
width: columnWidth(for: column),
alignment: alignment(for: column)
)
.padding(.horizontal, 8)
.padding(.vertical, 4)
}
// MARK: - Footer
@ViewBuilder
private var footerView: some View {
HStack {
Text("\(dataTable.rowCount) row\(dataTable.rowCount == 1 ? "" : "s")")
.font(.caption2)
.foregroundStyle(.secondary)
Spacer()
if dataTable.executionTime > 0 {
Text(String(format: "%.1f ms", dataTable.executionTime * 1000))
.font(.caption2)
.foregroundStyle(.secondary)
}
}
.padding(.horizontal, 8)
.padding(.vertical, 4)
.background(.bar)
}
// MARK: - Layout Helpers
/// Determines column width based on the column name length and type.
private func columnWidth(for column: DataTable.Column) -> CGFloat {
// Estimate based on header text length
let headerWidth = CGFloat(column.name.count) * 8 + 16
// Sample some row values to estimate content width
let sampleRows = dataTable.rows.prefix(20)
let maxContentWidth = sampleRows.reduce(CGFloat(0)) { maxWidth, row in
let value = row[column.index]
let textLength = CGFloat(value.stringValue.count) * 7
return max(maxWidth, textLength)
}
let estimatedWidth = max(headerWidth, maxContentWidth) + 16
return min(max(estimatedWidth, minimumColumnWidth), maximumColumnWidth)
}
/// Returns the alignment for a column based on its inferred type.
private func alignment(for column: DataTable.Column) -> Alignment {
switch column.inferredType {
case .integer, .real:
return .trailing
default:
return .leading
}
}
/// Returns the background color for alternating rows.
@ViewBuilder
private func rowBackground(for row: DataTable.Row) -> some View {
if showAlternatingRows && row.id.isMultiple(of: 2) {
Color.clear
} else if showAlternatingRows {
Color.primary.opacity(0.03)
} else {
Color.clear
}
}
}
// MARK: - Preview Support
#if DEBUG
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
#Preview("Data Table") {
let columns: [DataTable.Column] = [
.init(name: "id", index: 0, inferredType: .integer),
.init(name: "name", index: 1, inferredType: .text),
.init(name: "score", index: 2, inferredType: .real),
]
let rows: [DataTable.Row] = (0..<25).map { i in
DataTable.Row(
id: i,
values: [
.integer(Int64(i + 1)),
.text("Item \(i + 1)"),
.real(Double.random(in: 1.0...100.0)),
],
columnNames: ["id", "name", "score"]
)
}
let table = DataTable(columns: columns, rows: rows, sql: "SELECT * FROM items", executionTime: 0.023)
ScrollableDataTableView(dataTable: table)
.frame(height: 400)
.padding()
}
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
#Preview("Empty Table") {
let table = DataTable(columns: [], rows: [], sql: "", executionTime: 0)
ScrollableDataTableView(dataTable: table)
.frame(height: 200)
.padding()
}
#endif