Initial implementation of SwiftDBAI
Chat with any SQLite database using natural language. Built on AnyLanguageModel (HuggingFace) for LLM-agnostic provider support and GRDB for SQLite access. Core features: - Auto schema introspection from sqlite_master (zero config) - NL → SQL generation via any AnyLanguageModel provider - Three rendering modes: text summary, data table, Swift Charts - Drop-in DataChatView (SwiftUI) and headless ChatEngine - Operation allowlist with read-only default - Mutation policy with per-table control - ToolExecutionDelegate for destructive operation confirmation - Multi-turn conversation context - 352 tests across 24 suites, all passing Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
113
Sources/SwiftDBAI/Config/ChatEngineConfiguration.swift
Normal file
113
Sources/SwiftDBAI/Config/ChatEngineConfiguration.swift
Normal file
@@ -0,0 +1,113 @@
|
||||
// ChatEngineConfiguration.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Configurable settings for ChatEngine behavior — timeouts, context window,
|
||||
// summary limits, and custom query validation.
|
||||
|
||||
import Foundation
|
||||
|
||||
/// Configuration for ``ChatEngine`` behavior.
|
||||
///
|
||||
/// Use this to tune timeouts, conversation context windows, and attach
|
||||
/// custom query validators.
|
||||
///
|
||||
/// ```swift
|
||||
/// var config = ChatEngineConfiguration()
|
||||
/// config.queryTimeout = 10 // 10-second SQL timeout
|
||||
/// config.contextWindowSize = 20 // Keep last 20 messages for LLM context
|
||||
/// config.maxSummaryRows = 100 // Summarize up to 100 rows
|
||||
///
|
||||
/// let engine = ChatEngine(
|
||||
/// database: db,
|
||||
/// model: model,
|
||||
/// configuration: config
|
||||
/// )
|
||||
/// ```
|
||||
public struct ChatEngineConfiguration: Sendable {
|
||||
|
||||
// MARK: - Query Execution
|
||||
|
||||
/// Maximum time (in seconds) to wait for a SQL query to execute.
|
||||
///
|
||||
/// If the query exceeds this duration, a ``ChatEngineError/queryTimedOut``
|
||||
/// error is thrown. Set to `nil` to disable the timeout (not recommended
|
||||
/// for user-facing apps). Defaults to 30 seconds.
|
||||
public var queryTimeout: TimeInterval?
|
||||
|
||||
// MARK: - Conversation Context
|
||||
|
||||
/// Maximum number of conversation messages to include when building
|
||||
/// LLM context for follow-up queries.
|
||||
///
|
||||
/// Only the most recent `contextWindowSize` messages are sent to the LLM.
|
||||
/// Older messages are still retained in ``ChatEngine/messages`` for UI
|
||||
/// display but do not consume LLM tokens.
|
||||
///
|
||||
/// Set to `nil` for unlimited context (all history is always sent).
|
||||
/// Defaults to 50 messages.
|
||||
public var contextWindowSize: Int?
|
||||
|
||||
// MARK: - Rendering
|
||||
|
||||
/// Maximum number of rows to include when generating text summaries.
|
||||
/// Defaults to 50.
|
||||
public var maxSummaryRows: Int
|
||||
|
||||
// MARK: - LLM Context
|
||||
|
||||
/// Optional extra instructions appended to the LLM system prompt.
|
||||
///
|
||||
/// Use this to provide business-specific terminology, query hints,
|
||||
/// or domain constraints. For example:
|
||||
/// ```swift
|
||||
/// config.additionalContext = "The 'status' column uses: 'active', 'inactive', 'suspended'."
|
||||
/// ```
|
||||
public var additionalContext: String?
|
||||
|
||||
// MARK: - Validation
|
||||
|
||||
/// Custom query validators that run after the built-in allowlist check.
|
||||
///
|
||||
/// Use ``addValidator(_:)`` to add validators. They are executed in order;
|
||||
/// the first validator to throw stops execution.
|
||||
public private(set) var validators: [any QueryValidator] = []
|
||||
|
||||
// MARK: - Initialization
|
||||
|
||||
/// Creates a configuration with the given settings.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - queryTimeout: SQL execution timeout in seconds. Defaults to 30.
|
||||
/// - contextWindowSize: Max messages for LLM context. Defaults to 50.
|
||||
/// - maxSummaryRows: Max rows for text summaries. Defaults to 50.
|
||||
/// - additionalContext: Extra LLM system prompt instructions.
|
||||
public init(
|
||||
queryTimeout: TimeInterval? = 30,
|
||||
contextWindowSize: Int? = 50,
|
||||
maxSummaryRows: Int = 50,
|
||||
additionalContext: String? = nil
|
||||
) {
|
||||
self.queryTimeout = queryTimeout
|
||||
self.contextWindowSize = contextWindowSize
|
||||
self.maxSummaryRows = maxSummaryRows
|
||||
self.additionalContext = additionalContext
|
||||
}
|
||||
|
||||
/// The default configuration: 30s timeout, 50-message context window,
|
||||
/// 50-row summaries, no additional context, no custom validators.
|
||||
public static let `default` = ChatEngineConfiguration()
|
||||
|
||||
// MARK: - Mutating Helpers
|
||||
|
||||
/// Appends a custom query validator.
|
||||
///
|
||||
/// Validators run after the built-in allowlist and dangerous-keyword checks.
|
||||
/// They receive the parsed SQL and can throw to reject a query.
|
||||
///
|
||||
/// ```swift
|
||||
/// config.addValidator(TableAllowlistValidator(allowedTables: ["users", "orders"]))
|
||||
/// ```
|
||||
public mutating func addValidator(_ validator: any QueryValidator) {
|
||||
validators.append(validator)
|
||||
}
|
||||
}
|
||||
336
Sources/SwiftDBAI/Config/LocalProviderConfiguration.swift
Normal file
336
Sources/SwiftDBAI/Config/LocalProviderConfiguration.swift
Normal file
@@ -0,0 +1,336 @@
|
||||
// LocalProviderConfiguration.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Configuration and endpoint discovery for local/self-hosted LLM providers
|
||||
// (Ollama, llama.cpp). Wraps AnyLanguageModel's OllamaLanguageModel and
|
||||
// OpenAILanguageModel with convenient factory methods and health checking.
|
||||
|
||||
import AnyLanguageModel
|
||||
import Foundation
|
||||
|
||||
#if canImport(FoundationNetworking)
|
||||
import FoundationNetworking
|
||||
#endif
|
||||
|
||||
// MARK: - Local Provider Endpoint
|
||||
|
||||
/// Represents a discovered local LLM endpoint with its connection status.
|
||||
public struct LocalProviderEndpoint: Sendable, Equatable {
|
||||
/// The base URL of the local provider.
|
||||
public let baseURL: URL
|
||||
|
||||
/// The provider type (Ollama or llama.cpp).
|
||||
public let providerType: LocalProviderType
|
||||
|
||||
/// Whether the endpoint was reachable at discovery time.
|
||||
public let isReachable: Bool
|
||||
|
||||
/// The list of available models, if the endpoint supports model listing.
|
||||
public let availableModels: [String]
|
||||
|
||||
/// Human-readable description of the endpoint.
|
||||
public var description: String {
|
||||
let status = isReachable ? "reachable" : "unreachable"
|
||||
return "\(providerType.rawValue) at \(baseURL.absoluteString) (\(status), \(availableModels.count) models)"
|
||||
}
|
||||
}
|
||||
|
||||
/// The type of local LLM provider.
|
||||
public enum LocalProviderType: String, Sendable, Hashable, CaseIterable {
|
||||
/// Ollama — runs models locally via `ollama serve`.
|
||||
/// Default endpoint: http://localhost:11434
|
||||
case ollama
|
||||
|
||||
/// llama.cpp server — runs GGUF models via `llama-server`.
|
||||
/// Default endpoint: http://localhost:8080
|
||||
/// Exposes an OpenAI-compatible API.
|
||||
case llamaCpp = "llama.cpp"
|
||||
}
|
||||
|
||||
// MARK: - Local Provider Discovery
|
||||
|
||||
/// Discovers and validates local LLM provider endpoints.
|
||||
///
|
||||
/// Use `LocalProviderDiscovery` to automatically find running Ollama or llama.cpp
|
||||
/// instances on the local machine, check their health, and list available models.
|
||||
///
|
||||
/// ```swift
|
||||
/// // Check if Ollama is running
|
||||
/// let isRunning = await LocalProviderDiscovery.isOllamaRunning()
|
||||
///
|
||||
/// // Discover all local providers
|
||||
/// let endpoints = await LocalProviderDiscovery.discoverAll()
|
||||
/// for endpoint in endpoints where endpoint.isReachable {
|
||||
/// print("Found \(endpoint.description)")
|
||||
/// }
|
||||
///
|
||||
/// // List models available on Ollama
|
||||
/// let models = await LocalProviderDiscovery.listOllamaModels()
|
||||
/// ```
|
||||
public enum LocalProviderDiscovery {
|
||||
|
||||
/// Default Ollama endpoint.
|
||||
public static let defaultOllamaURL = URL(string: "http://localhost:11434")!
|
||||
|
||||
/// Default llama.cpp server endpoint.
|
||||
public static let defaultLlamaCppURL = URL(string: "http://localhost:8080")!
|
||||
|
||||
/// Well-known ports to probe for local providers.
|
||||
/// Ollama: 11434, llama.cpp: 8080
|
||||
private static let wellKnownEndpoints: [(URL, LocalProviderType)] = [
|
||||
(defaultOllamaURL, .ollama),
|
||||
(defaultLlamaCppURL, .llamaCpp),
|
||||
]
|
||||
|
||||
// MARK: - Health Checks
|
||||
|
||||
/// Checks if an Ollama instance is reachable at the given URL.
|
||||
///
|
||||
/// Sends a GET request to the Ollama root endpoint and checks for a 200 response.
|
||||
///
|
||||
/// - Parameter baseURL: The Ollama base URL. Defaults to `http://localhost:11434`.
|
||||
/// - Parameter timeout: Connection timeout in seconds. Defaults to 3.
|
||||
/// - Returns: `true` if the Ollama server responded successfully.
|
||||
public static func isOllamaRunning(
|
||||
at baseURL: URL = defaultOllamaURL,
|
||||
timeout: TimeInterval = 3
|
||||
) async -> Bool {
|
||||
await checkEndpointHealth(baseURL, timeout: timeout)
|
||||
}
|
||||
|
||||
/// Checks if a llama.cpp server is reachable at the given URL.
|
||||
///
|
||||
/// Sends a GET request to the `/health` endpoint and checks for a 200 response.
|
||||
///
|
||||
/// - Parameter baseURL: The llama.cpp base URL. Defaults to `http://localhost:8080`.
|
||||
/// - Parameter timeout: Connection timeout in seconds. Defaults to 3.
|
||||
/// - Returns: `true` if the llama.cpp server responded successfully.
|
||||
public static func isLlamaCppRunning(
|
||||
at baseURL: URL = defaultLlamaCppURL,
|
||||
timeout: TimeInterval = 3
|
||||
) async -> Bool {
|
||||
let healthURL = baseURL.appendingPathComponent("health")
|
||||
return await checkEndpointHealth(healthURL, timeout: timeout)
|
||||
}
|
||||
|
||||
/// Checks if any endpoint at the given URL responds to HTTP requests.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - url: The URL to probe.
|
||||
/// - timeout: Connection timeout in seconds.
|
||||
/// - Returns: `true` if the endpoint returned an HTTP response with status 200.
|
||||
private static func checkEndpointHealth(
|
||||
_ url: URL,
|
||||
timeout: TimeInterval
|
||||
) async -> Bool {
|
||||
let config = URLSessionConfiguration.ephemeral
|
||||
config.timeoutIntervalForRequest = timeout
|
||||
config.timeoutIntervalForResource = timeout
|
||||
let session = URLSession(configuration: config)
|
||||
|
||||
do {
|
||||
let (_, response) = try await session.data(from: url)
|
||||
if let httpResponse = response as? HTTPURLResponse {
|
||||
return httpResponse.statusCode == 200
|
||||
}
|
||||
return false
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Model Listing
|
||||
|
||||
/// Lists models available on an Ollama instance.
|
||||
///
|
||||
/// Calls the Ollama `/api/tags` endpoint to retrieve the list of
|
||||
/// locally installed models.
|
||||
///
|
||||
/// - Parameter baseURL: The Ollama base URL. Defaults to `http://localhost:11434`.
|
||||
/// - Parameter timeout: Request timeout in seconds. Defaults to 5.
|
||||
/// - Returns: An array of model name strings, or an empty array if unreachable.
|
||||
public static func listOllamaModels(
|
||||
at baseURL: URL = defaultOllamaURL,
|
||||
timeout: TimeInterval = 5
|
||||
) async -> [String] {
|
||||
let tagsURL = baseURL.appendingPathComponent("api/tags")
|
||||
let config = URLSessionConfiguration.ephemeral
|
||||
config.timeoutIntervalForRequest = timeout
|
||||
config.timeoutIntervalForResource = timeout
|
||||
let session = URLSession(configuration: config)
|
||||
|
||||
do {
|
||||
let (data, response) = try await session.data(from: tagsURL)
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
httpResponse.statusCode == 200
|
||||
else {
|
||||
return []
|
||||
}
|
||||
|
||||
let decoded = try JSONDecoder().decode(OllamaTagsResponse.self, from: data)
|
||||
return decoded.models.map(\.name)
|
||||
} catch {
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
/// Lists models available on a llama.cpp server via its OpenAI-compatible endpoint.
|
||||
///
|
||||
/// Calls `/v1/models` which llama.cpp exposes when running with
|
||||
/// `--api-key` or in default mode.
|
||||
///
|
||||
/// - Parameter baseURL: The llama.cpp base URL. Defaults to `http://localhost:8080`.
|
||||
/// - Parameter timeout: Request timeout in seconds. Defaults to 5.
|
||||
/// - Returns: An array of model ID strings, or an empty array if unreachable.
|
||||
public static func listLlamaCppModels(
|
||||
at baseURL: URL = defaultLlamaCppURL,
|
||||
timeout: TimeInterval = 5
|
||||
) async -> [String] {
|
||||
let modelsURL = baseURL.appendingPathComponent("v1/models")
|
||||
let config = URLSessionConfiguration.ephemeral
|
||||
config.timeoutIntervalForRequest = timeout
|
||||
config.timeoutIntervalForResource = timeout
|
||||
let session = URLSession(configuration: config)
|
||||
|
||||
do {
|
||||
let (data, response) = try await session.data(from: modelsURL)
|
||||
guard let httpResponse = response as? HTTPURLResponse,
|
||||
httpResponse.statusCode == 200
|
||||
else {
|
||||
return []
|
||||
}
|
||||
|
||||
let decoded = try JSONDecoder().decode(OpenAIModelsResponse.self, from: data)
|
||||
return decoded.data.map(\.id)
|
||||
} catch {
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Full Discovery
|
||||
|
||||
/// Discovers all running local LLM providers by probing well-known endpoints.
|
||||
///
|
||||
/// Probes Ollama (port 11434) and llama.cpp (port 8080) concurrently,
|
||||
/// returning their status and available models.
|
||||
///
|
||||
/// ```swift
|
||||
/// let endpoints = await LocalProviderDiscovery.discoverAll()
|
||||
/// for endpoint in endpoints where endpoint.isReachable {
|
||||
/// print("Found: \(endpoint.description)")
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// - Parameter timeout: Connection timeout per endpoint in seconds. Defaults to 3.
|
||||
/// - Returns: An array of `LocalProviderEndpoint` for each probed location.
|
||||
public static func discoverAll(
|
||||
timeout: TimeInterval = 3
|
||||
) async -> [LocalProviderEndpoint] {
|
||||
await withTaskGroup(of: LocalProviderEndpoint.self, returning: [LocalProviderEndpoint].self) { group in
|
||||
for (url, providerType) in wellKnownEndpoints {
|
||||
group.addTask {
|
||||
await discover(providerType: providerType, at: url, timeout: timeout)
|
||||
}
|
||||
}
|
||||
|
||||
var results: [LocalProviderEndpoint] = []
|
||||
for await endpoint in group {
|
||||
results.append(endpoint)
|
||||
}
|
||||
return results
|
||||
}
|
||||
}
|
||||
|
||||
/// Discovers a specific local provider at the given URL.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - providerType: The type of provider to probe.
|
||||
/// - baseURL: The base URL to check.
|
||||
/// - timeout: Connection timeout in seconds.
|
||||
/// - Returns: A `LocalProviderEndpoint` with reachability and model info.
|
||||
public static func discover(
|
||||
providerType: LocalProviderType,
|
||||
at baseURL: URL,
|
||||
timeout: TimeInterval = 3
|
||||
) async -> LocalProviderEndpoint {
|
||||
switch providerType {
|
||||
case .ollama:
|
||||
let reachable = await isOllamaRunning(at: baseURL, timeout: timeout)
|
||||
let models = reachable ? await listOllamaModels(at: baseURL) : []
|
||||
return LocalProviderEndpoint(
|
||||
baseURL: baseURL,
|
||||
providerType: .ollama,
|
||||
isReachable: reachable,
|
||||
availableModels: models
|
||||
)
|
||||
|
||||
case .llamaCpp:
|
||||
let reachable = await isLlamaCppRunning(at: baseURL, timeout: timeout)
|
||||
let models = reachable ? await listLlamaCppModels(at: baseURL) : []
|
||||
return LocalProviderEndpoint(
|
||||
baseURL: baseURL,
|
||||
providerType: .llamaCpp,
|
||||
isReachable: reachable,
|
||||
availableModels: models
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Discovers a specific local provider at a custom URL and port.
|
||||
///
|
||||
/// Use this for non-standard configurations where Ollama or llama.cpp
|
||||
/// is running on a custom host or port.
|
||||
///
|
||||
/// ```swift
|
||||
/// let endpoint = await LocalProviderDiscovery.discover(
|
||||
/// providerType: .ollama,
|
||||
/// host: "192.168.1.100",
|
||||
/// port: 11434
|
||||
/// )
|
||||
/// ```
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - providerType: The provider type.
|
||||
/// - host: The hostname or IP address.
|
||||
/// - port: The port number.
|
||||
/// - timeout: Connection timeout in seconds. Defaults to 3.
|
||||
/// - Returns: A `LocalProviderEndpoint` with reachability and model info.
|
||||
public static func discover(
|
||||
providerType: LocalProviderType,
|
||||
host: String,
|
||||
port: Int,
|
||||
timeout: TimeInterval = 3
|
||||
) async -> LocalProviderEndpoint {
|
||||
guard let url = URL(string: "http://\(host):\(port)") else {
|
||||
return LocalProviderEndpoint(
|
||||
baseURL: URL(string: "http://\(host):\(port)")!,
|
||||
providerType: providerType,
|
||||
isReachable: false,
|
||||
availableModels: []
|
||||
)
|
||||
}
|
||||
return await discover(providerType: providerType, at: url, timeout: timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - JSON Response Types
|
||||
|
||||
/// Response from Ollama's `/api/tags` endpoint.
|
||||
private struct OllamaTagsResponse: Decodable, Sendable {
|
||||
let models: [OllamaModelInfo]
|
||||
}
|
||||
|
||||
/// Individual model info from Ollama's tags endpoint.
|
||||
private struct OllamaModelInfo: Decodable, Sendable {
|
||||
let name: String
|
||||
}
|
||||
|
||||
/// Response from the OpenAI-compatible `/v1/models` endpoint.
|
||||
private struct OpenAIModelsResponse: Decodable, Sendable {
|
||||
let data: [OpenAIModelInfo]
|
||||
}
|
||||
|
||||
/// Individual model info from the OpenAI models endpoint.
|
||||
private struct OpenAIModelInfo: Decodable, Sendable {
|
||||
let id: String
|
||||
}
|
||||
148
Sources/SwiftDBAI/Config/MutationPolicy.swift
Normal file
148
Sources/SwiftDBAI/Config/MutationPolicy.swift
Normal file
@@ -0,0 +1,148 @@
|
||||
// MutationPolicy.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Defines which mutation operations are permitted and optionally restricts
|
||||
// them to specific tables. Wraps OperationAllowlist with table-level granularity.
|
||||
|
||||
import Foundation
|
||||
|
||||
/// Controls which SQL mutation operations the LLM may generate and,
|
||||
/// optionally, which tables those mutations may target.
|
||||
///
|
||||
/// `MutationPolicy` builds on ``OperationAllowlist`` by adding per-table
|
||||
/// restrictions. The default policy is **read-only** — no mutations are
|
||||
/// allowed on any table. Write operations require explicit opt-in.
|
||||
///
|
||||
/// ```swift
|
||||
/// // Read-only (default) — only SELECT is allowed
|
||||
/// let readOnly = MutationPolicy.readOnly
|
||||
///
|
||||
/// // Allow INSERT and UPDATE on specific tables only
|
||||
/// let restricted = MutationPolicy(
|
||||
/// allowedOperations: [.insert, .update],
|
||||
/// allowedTables: ["orders", "order_items"]
|
||||
/// )
|
||||
///
|
||||
/// // Allow INSERT and UPDATE on all tables
|
||||
/// let broad = MutationPolicy(allowedOperations: [.insert, .update])
|
||||
///
|
||||
/// // Full access including DELETE (requires confirmation)
|
||||
/// let full = MutationPolicy.unrestricted
|
||||
/// ```
|
||||
public struct MutationPolicy: Sendable, Equatable {
|
||||
|
||||
// MARK: - Properties
|
||||
|
||||
/// The underlying operation allowlist (always includes SELECT).
|
||||
public let operationAllowlist: OperationAllowlist
|
||||
|
||||
/// Optional set of table names that mutations may target.
|
||||
///
|
||||
/// When `nil`, mutations are allowed on all tables (subject to
|
||||
/// ``operationAllowlist``). When non-nil, mutation operations
|
||||
/// (INSERT, UPDATE, DELETE) are only permitted on the listed tables.
|
||||
/// SELECT queries are never restricted by this property.
|
||||
public let allowedMutationTables: Set<String>?
|
||||
|
||||
/// When `true`, destructive operations (DELETE) require explicit user
|
||||
/// confirmation before execution, even when the operation is allowed.
|
||||
/// Defaults to `true`.
|
||||
public let requiresDestructiveConfirmation: Bool
|
||||
|
||||
// MARK: - Initialization
|
||||
|
||||
/// Creates a mutation policy with the given operations and optional table restrictions.
|
||||
///
|
||||
/// SELECT is always implicitly included — you cannot create a policy
|
||||
/// that disallows reads.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - allowedOperations: The mutation operations to permit (INSERT, UPDATE, DELETE).
|
||||
/// SELECT is always allowed regardless of this parameter.
|
||||
/// - allowedTables: Optional set of table names mutations may target.
|
||||
/// Pass `nil` to allow mutations on all tables. Defaults to `nil`.
|
||||
/// - requiresDestructiveConfirmation: Whether DELETE requires user confirmation.
|
||||
/// Defaults to `true`.
|
||||
public init(
|
||||
allowedOperations: Set<SQLOperation> = [],
|
||||
allowedTables: Set<String>? = nil,
|
||||
requiresDestructiveConfirmation: Bool = true
|
||||
) {
|
||||
// Always include SELECT
|
||||
var ops = allowedOperations
|
||||
ops.insert(.select)
|
||||
self.operationAllowlist = OperationAllowlist(ops)
|
||||
self.allowedMutationTables = allowedTables
|
||||
self.requiresDestructiveConfirmation = requiresDestructiveConfirmation
|
||||
}
|
||||
|
||||
// MARK: - Presets
|
||||
|
||||
/// Read-only policy: only SELECT queries are allowed. This is the default.
|
||||
public static let readOnly = MutationPolicy()
|
||||
|
||||
/// Standard read-write: SELECT, INSERT, and UPDATE on all tables.
|
||||
public static let readWrite = MutationPolicy(
|
||||
allowedOperations: [.insert, .update]
|
||||
)
|
||||
|
||||
/// Unrestricted: all operations including DELETE on all tables.
|
||||
/// DELETE still requires confirmation by default.
|
||||
public static let unrestricted = MutationPolicy(
|
||||
allowedOperations: [.insert, .update, .delete]
|
||||
)
|
||||
|
||||
// MARK: - Validation
|
||||
|
||||
/// Returns `true` if the given operation is permitted by this policy.
|
||||
public func isOperationAllowed(_ operation: SQLOperation) -> Bool {
|
||||
operationAllowlist.isAllowed(operation)
|
||||
}
|
||||
|
||||
/// Returns `true` if the given mutation operation is permitted on the
|
||||
/// specified table.
|
||||
///
|
||||
/// SELECT operations always return `true` regardless of table restrictions.
|
||||
/// For mutation operations, this checks both the operation allowlist and
|
||||
/// the table restrictions (if any).
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - operation: The SQL operation type.
|
||||
/// - table: The target table name (case-insensitive comparison).
|
||||
/// - Returns: Whether the operation is allowed on the given table.
|
||||
public func isAllowed(operation: SQLOperation, on table: String) -> Bool {
|
||||
// SELECT is always allowed
|
||||
guard operation != .select else { return true }
|
||||
|
||||
// Check operation allowlist first
|
||||
guard operationAllowlist.isAllowed(operation) else { return false }
|
||||
|
||||
// If no table restrictions, the operation is allowed
|
||||
guard let allowedTables = allowedMutationTables else { return true }
|
||||
|
||||
// Case-insensitive table name check
|
||||
let lowerTable = table.lowercased()
|
||||
return allowedTables.contains { $0.lowercased() == lowerTable }
|
||||
}
|
||||
|
||||
/// Returns `true` if the given operation requires user confirmation.
|
||||
public func requiresConfirmation(for operation: SQLOperation) -> Bool {
|
||||
operation == .delete && requiresDestructiveConfirmation
|
||||
}
|
||||
|
||||
/// Returns a human-readable description for inclusion in the LLM system prompt.
|
||||
func describeForLLM() -> String {
|
||||
var desc = operationAllowlist.describeForLLM()
|
||||
|
||||
if let tables = allowedMutationTables, !tables.isEmpty {
|
||||
let sorted = tables.sorted()
|
||||
desc += " Mutations (INSERT/UPDATE/DELETE) are restricted to these tables only: \(sorted.joined(separator: ", "))."
|
||||
}
|
||||
|
||||
if requiresDestructiveConfirmation && operationAllowlist.isAllowed(.delete) {
|
||||
desc += " DELETE operations require user confirmation before execution."
|
||||
}
|
||||
|
||||
return desc
|
||||
}
|
||||
}
|
||||
866
Sources/SwiftDBAI/Config/OnDeviceProviderConfiguration.swift
Normal file
866
Sources/SwiftDBAI/Config/OnDeviceProviderConfiguration.swift
Normal file
@@ -0,0 +1,866 @@
|
||||
// OnDeviceProviderConfiguration.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Configuration for on-device LLM providers (CoreML, MLX) that run models
|
||||
// locally on Apple silicon. These providers enable fully offline,
|
||||
// privacy-sensitive deployments where no data leaves the device.
|
||||
//
|
||||
// Both CoreML and MLX models are provided by AnyLanguageModel behind
|
||||
// conditional compilation flags (#if CoreML, #if MLX). This configuration
|
||||
// layer wraps their setup with convenient factory methods and integrates
|
||||
// them into the SwiftDBAI ChatEngine pipeline.
|
||||
|
||||
import AnyLanguageModel
|
||||
import Foundation
|
||||
import GRDB
|
||||
|
||||
// MARK: - On-Device Provider Type
|
||||
|
||||
/// The type of on-device LLM provider.
|
||||
public enum OnDeviceProviderType: String, Sendable, Hashable, CaseIterable {
|
||||
/// CoreML — runs compiled .mlmodelc models on-device using Apple's CoreML framework.
|
||||
/// Requires pre-compiled models and supports CPU, GPU, and Neural Engine compute units.
|
||||
case coreML
|
||||
|
||||
/// MLX — runs HuggingFace models on Apple silicon using the MLX framework.
|
||||
/// Models are automatically downloaded and cached. Supports quantized models
|
||||
/// (e.g., 4-bit) for efficient memory usage.
|
||||
case mlx
|
||||
}
|
||||
|
||||
// MARK: - CoreML Configuration
|
||||
|
||||
/// Configuration for loading and running a CoreML language model on-device.
|
||||
///
|
||||
/// CoreML models must be pre-compiled to `.mlmodelc` format before use.
|
||||
/// The model runs entirely on-device using CPU, GPU, and/or Neural Engine
|
||||
/// depending on the `computeUnits` setting.
|
||||
///
|
||||
/// ```swift
|
||||
/// let config = CoreMLProviderConfiguration(
|
||||
/// modelURL: Bundle.main.url(forResource: "MyModel", withExtension: "mlmodelc")!,
|
||||
/// computeUnits: .all
|
||||
/// )
|
||||
/// ```
|
||||
///
|
||||
/// - Note: CoreML models are available behind the `#if CoreML` flag in AnyLanguageModel.
|
||||
/// Ensure your project enables the CoreML build condition.
|
||||
public struct CoreMLProviderConfiguration: Sendable, Equatable {
|
||||
|
||||
/// The URL to the compiled CoreML model (`.mlmodelc`).
|
||||
public let modelURL: URL
|
||||
|
||||
/// The compute units to use for inference.
|
||||
///
|
||||
/// - `.all`: Uses the best available hardware (Neural Engine, GPU, CPU).
|
||||
/// - `.cpuOnly`: Forces CPU-only inference. Useful for debugging.
|
||||
/// - `.cpuAndGPU`: Uses CPU and GPU but not the Neural Engine.
|
||||
/// - `.cpuAndNeuralEngine`: Uses CPU and Neural Engine.
|
||||
public let computeUnits: ComputeUnitPreference
|
||||
|
||||
/// Maximum number of tokens the model can generate per response.
|
||||
/// Defaults to 2048.
|
||||
public let maxResponseTokens: Int
|
||||
|
||||
/// Whether to use sampling (true) or greedy decoding (false).
|
||||
/// Defaults to false (greedy) for more deterministic SQL generation.
|
||||
public let useSampling: Bool
|
||||
|
||||
/// Temperature for sampling. Only used when `useSampling` is true.
|
||||
/// Lower values produce more focused output. Defaults to 0.1.
|
||||
public let temperature: Double
|
||||
|
||||
/// Creates a CoreML provider configuration.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - modelURL: The URL to a compiled CoreML model (`.mlmodelc`).
|
||||
/// - computeUnits: The compute units to use. Defaults to `.all`.
|
||||
/// - maxResponseTokens: Maximum tokens per response. Defaults to 2048.
|
||||
/// - useSampling: Whether to use sampling vs greedy decoding. Defaults to false.
|
||||
/// - temperature: Sampling temperature. Defaults to 0.1.
|
||||
public init(
|
||||
modelURL: URL,
|
||||
computeUnits: ComputeUnitPreference = .all,
|
||||
maxResponseTokens: Int = 2048,
|
||||
useSampling: Bool = false,
|
||||
temperature: Double = 0.1
|
||||
) {
|
||||
self.modelURL = modelURL
|
||||
self.computeUnits = computeUnits
|
||||
self.maxResponseTokens = maxResponseTokens
|
||||
self.useSampling = useSampling
|
||||
self.temperature = temperature
|
||||
}
|
||||
|
||||
/// Validates that the model URL points to a compiled CoreML model.
|
||||
///
|
||||
/// - Throws: ``OnDeviceProviderError`` if the URL is invalid.
|
||||
public func validate() throws {
|
||||
guard modelURL.pathExtension == "mlmodelc" else {
|
||||
throw OnDeviceProviderError.invalidModelFormat(
|
||||
expected: ".mlmodelc",
|
||||
actual: modelURL.pathExtension
|
||||
)
|
||||
}
|
||||
|
||||
guard FileManager.default.fileExists(atPath: modelURL.path) else {
|
||||
throw OnDeviceProviderError.modelNotFound(modelURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute unit preference for CoreML inference.
|
||||
///
|
||||
/// Maps to `MLComputeUnits` in the CoreML framework.
|
||||
public enum ComputeUnitPreference: String, Sendable, Hashable, CaseIterable {
|
||||
/// Use all available compute units (Neural Engine, GPU, CPU).
|
||||
/// This is the recommended setting for production use.
|
||||
case all
|
||||
|
||||
/// Force CPU-only execution. Useful for debugging or testing.
|
||||
case cpuOnly
|
||||
|
||||
/// Use CPU and GPU, but not the Neural Engine.
|
||||
case cpuAndGPU
|
||||
|
||||
/// Use CPU and Neural Engine, but not the GPU.
|
||||
case cpuAndNeuralEngine
|
||||
}
|
||||
|
||||
// MARK: - MLX Configuration
|
||||
|
||||
/// Configuration for loading and running an MLX language model on Apple silicon.
|
||||
///
|
||||
/// MLX models are loaded from HuggingFace Hub or a local directory. The MLX
|
||||
/// framework provides efficient inference on Apple silicon with support for
|
||||
/// quantized models (4-bit, 8-bit) for reduced memory usage.
|
||||
///
|
||||
/// ```swift
|
||||
/// // From HuggingFace Hub (auto-downloaded)
|
||||
/// let config = MLXProviderConfiguration(
|
||||
/// modelId: "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||
/// )
|
||||
///
|
||||
/// // From a local directory
|
||||
/// let config = MLXProviderConfiguration(
|
||||
/// modelId: "my-local-model",
|
||||
/// localDirectory: URL(fileURLWithPath: "/path/to/model")
|
||||
/// )
|
||||
/// ```
|
||||
///
|
||||
/// - Note: MLX models are available behind the `#if MLX` flag in AnyLanguageModel.
|
||||
/// Ensure your project enables the MLX build condition.
|
||||
public struct MLXProviderConfiguration: Sendable, Equatable {
|
||||
|
||||
/// The HuggingFace model identifier (e.g., "mlx-community/Llama-3.2-3B-Instruct-4bit").
|
||||
public let modelId: String
|
||||
|
||||
/// Optional local directory containing the model files.
|
||||
/// When set, the model is loaded from this directory instead of downloading from Hub.
|
||||
public let localDirectory: URL?
|
||||
|
||||
/// GPU memory management configuration.
|
||||
public let gpuMemory: MLXGPUMemoryConfig
|
||||
|
||||
/// Maximum number of tokens the model can generate per response.
|
||||
/// Defaults to 2048.
|
||||
public let maxResponseTokens: Int
|
||||
|
||||
/// Temperature for text generation. Lower values produce more deterministic output.
|
||||
/// Defaults to 0.1 for SQL generation accuracy.
|
||||
public let temperature: Double
|
||||
|
||||
/// Top-P (nucleus) sampling threshold. Only tokens with cumulative probability
|
||||
/// below this threshold are considered. Defaults to 0.95.
|
||||
public let topP: Double
|
||||
|
||||
/// Repetition penalty to reduce repetitive output. Defaults to 1.1.
|
||||
public let repetitionPenalty: Double
|
||||
|
||||
/// Creates an MLX provider configuration.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - modelId: The HuggingFace model ID or local identifier.
|
||||
/// - localDirectory: Optional path to a local model directory.
|
||||
/// - gpuMemory: GPU memory configuration. Defaults to `.automatic`.
|
||||
/// - maxResponseTokens: Maximum tokens per response. Defaults to 2048.
|
||||
/// - temperature: Generation temperature. Defaults to 0.1.
|
||||
/// - topP: Top-P sampling threshold. Defaults to 0.95.
|
||||
/// - repetitionPenalty: Repetition penalty. Defaults to 1.1.
|
||||
public init(
|
||||
modelId: String,
|
||||
localDirectory: URL? = nil,
|
||||
gpuMemory: MLXGPUMemoryConfig = .automatic,
|
||||
maxResponseTokens: Int = 2048,
|
||||
temperature: Double = 0.1,
|
||||
topP: Double = 0.95,
|
||||
repetitionPenalty: Double = 1.1
|
||||
) {
|
||||
self.modelId = modelId
|
||||
self.localDirectory = localDirectory
|
||||
self.gpuMemory = gpuMemory
|
||||
self.maxResponseTokens = maxResponseTokens
|
||||
self.temperature = temperature
|
||||
self.topP = topP
|
||||
self.repetitionPenalty = repetitionPenalty
|
||||
}
|
||||
|
||||
/// Validates the configuration parameters.
|
||||
///
|
||||
/// - Throws: ``OnDeviceProviderError`` if the configuration is invalid.
|
||||
public func validate() throws {
|
||||
guard !modelId.isEmpty else {
|
||||
throw OnDeviceProviderError.emptyModelId
|
||||
}
|
||||
|
||||
if let dir = localDirectory {
|
||||
guard FileManager.default.fileExists(atPath: dir.path) else {
|
||||
throw OnDeviceProviderError.modelNotFound(dir)
|
||||
}
|
||||
}
|
||||
|
||||
guard temperature >= 0 else {
|
||||
throw OnDeviceProviderError.invalidParameter(
|
||||
name: "temperature",
|
||||
value: "\(temperature)",
|
||||
reason: "Must be non-negative"
|
||||
)
|
||||
}
|
||||
|
||||
guard topP > 0, topP <= 1.0 else {
|
||||
throw OnDeviceProviderError.invalidParameter(
|
||||
name: "topP",
|
||||
value: "\(topP)",
|
||||
reason: "Must be between 0 (exclusive) and 1.0 (inclusive)"
|
||||
)
|
||||
}
|
||||
|
||||
guard repetitionPenalty > 0 else {
|
||||
throw OnDeviceProviderError.invalidParameter(
|
||||
name: "repetitionPenalty",
|
||||
value: "\(repetitionPenalty)",
|
||||
reason: "Must be positive"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Well-Known Models
|
||||
|
||||
/// Pre-configured for Llama 3.2 3B Instruct (4-bit quantized).
|
||||
/// Good balance of quality and memory usage (~2GB RAM).
|
||||
public static func llama3_2_3B(
|
||||
localDirectory: URL? = nil,
|
||||
gpuMemory: MLXGPUMemoryConfig = .automatic
|
||||
) -> MLXProviderConfiguration {
|
||||
MLXProviderConfiguration(
|
||||
modelId: "mlx-community/Llama-3.2-3B-Instruct-4bit",
|
||||
localDirectory: localDirectory,
|
||||
gpuMemory: gpuMemory,
|
||||
maxResponseTokens: 2048,
|
||||
temperature: 0.1
|
||||
)
|
||||
}
|
||||
|
||||
/// Pre-configured for Qwen 2.5 Coder 3B Instruct (4-bit quantized).
|
||||
/// Optimized for code and SQL generation.
|
||||
public static func qwen2_5_coder_3B(
|
||||
localDirectory: URL? = nil,
|
||||
gpuMemory: MLXGPUMemoryConfig = .automatic
|
||||
) -> MLXProviderConfiguration {
|
||||
MLXProviderConfiguration(
|
||||
modelId: "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit",
|
||||
localDirectory: localDirectory,
|
||||
gpuMemory: gpuMemory,
|
||||
maxResponseTokens: 2048,
|
||||
temperature: 0.05
|
||||
)
|
||||
}
|
||||
|
||||
/// Pre-configured for Phi-3.5 Mini Instruct (4-bit quantized).
|
||||
/// Compact model suitable for devices with limited memory (~1.5GB RAM).
|
||||
public static func phi3_5_mini(
|
||||
localDirectory: URL? = nil,
|
||||
gpuMemory: MLXGPUMemoryConfig = .automatic
|
||||
) -> MLXProviderConfiguration {
|
||||
MLXProviderConfiguration(
|
||||
modelId: "mlx-community/Phi-3.5-mini-instruct-4bit",
|
||||
localDirectory: localDirectory,
|
||||
gpuMemory: gpuMemory,
|
||||
maxResponseTokens: 2048,
|
||||
temperature: 0.1
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// GPU memory management configuration for MLX models.
|
||||
///
|
||||
/// Controls how aggressively the MLX runtime manages GPU buffer caches
|
||||
/// during active generation and idle phases.
|
||||
public struct MLXGPUMemoryConfig: Sendable, Equatable {
|
||||
/// GPU cache limit (in bytes) during active generation.
|
||||
public let activeCacheLimit: Int
|
||||
|
||||
/// GPU cache limit (in bytes) when idle.
|
||||
public let idleCacheLimit: Int
|
||||
|
||||
/// Whether to clear cached GPU buffers when eviction is safe.
|
||||
public let clearCacheOnEviction: Bool
|
||||
|
||||
/// Creates a GPU memory configuration.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - activeCacheLimit: Cache limit during active generation (bytes).
|
||||
/// - idleCacheLimit: Cache limit when idle (bytes).
|
||||
/// - clearCacheOnEviction: Whether to clear cache on eviction.
|
||||
public init(
|
||||
activeCacheLimit: Int,
|
||||
idleCacheLimit: Int,
|
||||
clearCacheOnEviction: Bool = true
|
||||
) {
|
||||
self.activeCacheLimit = activeCacheLimit
|
||||
self.idleCacheLimit = idleCacheLimit
|
||||
self.clearCacheOnEviction = clearCacheOnEviction
|
||||
}
|
||||
|
||||
/// Automatically determined based on device physical memory.
|
||||
///
|
||||
/// - Devices with <4GB RAM: 128MB active cache
|
||||
/// - Devices with <6GB RAM: 256MB active cache
|
||||
/// - Devices with <8GB RAM: 512MB active cache
|
||||
/// - Devices with 8GB+ RAM: 768MB active cache
|
||||
/// - Idle cache: 50MB for all devices
|
||||
public static var automatic: MLXGPUMemoryConfig {
|
||||
let ramBytes = ProcessInfo.processInfo.physicalMemory
|
||||
let ramGB = ramBytes / (1024 * 1024 * 1024)
|
||||
let active: Int
|
||||
switch ramGB {
|
||||
case ..<4:
|
||||
active = 128_000_000
|
||||
case ..<6:
|
||||
active = 256_000_000
|
||||
case ..<8:
|
||||
active = 512_000_000
|
||||
default:
|
||||
active = 768_000_000
|
||||
}
|
||||
|
||||
return .init(
|
||||
activeCacheLimit: active,
|
||||
idleCacheLimit: 50_000_000,
|
||||
clearCacheOnEviction: true
|
||||
)
|
||||
}
|
||||
|
||||
/// Minimal memory configuration for constrained devices.
|
||||
/// Uses 64MB active cache and 16MB idle cache.
|
||||
public static var minimal: MLXGPUMemoryConfig {
|
||||
.init(
|
||||
activeCacheLimit: 64_000_000,
|
||||
idleCacheLimit: 16_000_000,
|
||||
clearCacheOnEviction: true
|
||||
)
|
||||
}
|
||||
|
||||
/// Unconstrained configuration for maximum performance.
|
||||
/// Leaves GPU cache effectively unlimited. Use when your app
|
||||
/// can afford maximum memory usage.
|
||||
public static var unconstrained: MLXGPUMemoryConfig {
|
||||
.init(
|
||||
activeCacheLimit: Int.max,
|
||||
idleCacheLimit: Int.max,
|
||||
clearCacheOnEviction: false
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - On-Device Provider Errors
|
||||
|
||||
/// Errors specific to on-device provider configuration and model loading.
|
||||
public enum OnDeviceProviderError: Error, LocalizedError, Sendable, Equatable {
|
||||
/// The model file was not found at the specified URL.
|
||||
case modelNotFound(URL)
|
||||
|
||||
/// The model file format is not what was expected.
|
||||
case invalidModelFormat(expected: String, actual: String)
|
||||
|
||||
/// The model ID is empty.
|
||||
case emptyModelId
|
||||
|
||||
/// A configuration parameter is invalid.
|
||||
case invalidParameter(name: String, value: String, reason: String)
|
||||
|
||||
/// The on-device provider is not available on this platform.
|
||||
/// CoreML requires macOS 15+ / iOS 18+. MLX requires the MLX build flag.
|
||||
case providerUnavailable(OnDeviceProviderType, reason: String)
|
||||
|
||||
/// Model loading failed with an underlying error.
|
||||
case modelLoadFailed(reason: String)
|
||||
|
||||
/// Model inference failed.
|
||||
case inferenceFailed(reason: String)
|
||||
|
||||
public var errorDescription: String? {
|
||||
switch self {
|
||||
case .modelNotFound(let url):
|
||||
return "On-device model not found at: \(url.path)"
|
||||
case .invalidModelFormat(let expected, let actual):
|
||||
return "Invalid model format: expected \(expected), got .\(actual)"
|
||||
case .emptyModelId:
|
||||
return "Model ID must not be empty"
|
||||
case .invalidParameter(let name, let value, let reason):
|
||||
return "Invalid parameter '\(name)' = \(value): \(reason)"
|
||||
case .providerUnavailable(let type, let reason):
|
||||
return "\(type.rawValue) provider unavailable: \(reason)"
|
||||
case .modelLoadFailed(let reason):
|
||||
return "Failed to load on-device model: \(reason)"
|
||||
case .inferenceFailed(let reason):
|
||||
return "On-device inference failed: \(reason)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - On-Device Inference Pipeline
|
||||
|
||||
/// Manages the on-device model inference pipeline.
|
||||
///
|
||||
/// `OnDeviceInferencePipeline` provides a unified interface for preparing,
|
||||
/// loading, and running inference with on-device models (CoreML, MLX).
|
||||
/// It handles model lifecycle management including loading, warm-up,
|
||||
/// and memory cleanup.
|
||||
///
|
||||
/// ```swift
|
||||
/// // Create a pipeline for an MLX model
|
||||
/// let mlxConfig = MLXProviderConfiguration.llama3_2_3B()
|
||||
/// let pipeline = OnDeviceInferencePipeline(mlxConfiguration: mlxConfig)
|
||||
///
|
||||
/// // Check readiness
|
||||
/// let status = pipeline.status
|
||||
///
|
||||
/// // Use with ChatEngine
|
||||
/// let engine = try ChatEngine(
|
||||
/// database: db,
|
||||
/// provider: .onDevice(mlx: mlxConfig)
|
||||
/// )
|
||||
/// ```
|
||||
public final class OnDeviceInferencePipeline: @unchecked Sendable {
|
||||
|
||||
/// The current status of the on-device inference pipeline.
|
||||
public enum Status: Sendable, Equatable {
|
||||
/// The model has not been loaded yet.
|
||||
case notLoaded
|
||||
|
||||
/// The model is currently being loaded/downloaded.
|
||||
case loading
|
||||
|
||||
/// The model is loaded and ready for inference.
|
||||
case ready
|
||||
|
||||
/// The model failed to load.
|
||||
case failed(String)
|
||||
}
|
||||
|
||||
/// The type of on-device provider this pipeline uses.
|
||||
public let providerType: OnDeviceProviderType
|
||||
|
||||
/// The MLX configuration, if this is an MLX pipeline.
|
||||
public let mlxConfiguration: MLXProviderConfiguration?
|
||||
|
||||
/// The CoreML configuration, if this is a CoreML pipeline.
|
||||
public let coreMLConfiguration: CoreMLProviderConfiguration?
|
||||
|
||||
/// The current status of the pipeline.
|
||||
private let _statusLock = NSLock()
|
||||
private var _status: Status = .notLoaded
|
||||
|
||||
/// The current pipeline status.
|
||||
public var status: Status {
|
||||
_statusLock.lock()
|
||||
defer { _statusLock.unlock() }
|
||||
return _status
|
||||
}
|
||||
|
||||
/// Creates an MLX inference pipeline.
|
||||
///
|
||||
/// - Parameter configuration: The MLX model configuration.
|
||||
public init(mlxConfiguration: MLXProviderConfiguration) {
|
||||
self.providerType = .mlx
|
||||
self.mlxConfiguration = mlxConfiguration
|
||||
self.coreMLConfiguration = nil
|
||||
}
|
||||
|
||||
/// Creates a CoreML inference pipeline.
|
||||
///
|
||||
/// - Parameter configuration: The CoreML model configuration.
|
||||
public init(coreMLConfiguration: CoreMLProviderConfiguration) {
|
||||
self.providerType = .coreML
|
||||
self.coreMLConfiguration = coreMLConfiguration
|
||||
self.mlxConfiguration = nil
|
||||
}
|
||||
|
||||
/// Validates the configuration before attempting to load.
|
||||
///
|
||||
/// Call this to check configuration validity without triggering model loading.
|
||||
///
|
||||
/// - Throws: ``OnDeviceProviderError`` if the configuration is invalid.
|
||||
public func validateConfiguration() throws {
|
||||
switch providerType {
|
||||
case .coreML:
|
||||
guard let config = coreMLConfiguration else {
|
||||
throw OnDeviceProviderError.providerUnavailable(
|
||||
.coreML,
|
||||
reason: "No CoreML configuration provided"
|
||||
)
|
||||
}
|
||||
try config.validate()
|
||||
|
||||
case .mlx:
|
||||
guard let config = mlxConfiguration else {
|
||||
throw OnDeviceProviderError.providerUnavailable(
|
||||
.mlx,
|
||||
reason: "No MLX configuration provided"
|
||||
)
|
||||
}
|
||||
try config.validate()
|
||||
}
|
||||
}
|
||||
|
||||
/// Updates the pipeline status.
|
||||
internal func setStatus(_ newStatus: Status) {
|
||||
_statusLock.lock()
|
||||
_status = newStatus
|
||||
_statusLock.unlock()
|
||||
}
|
||||
|
||||
/// Provides recommended generation options optimized for SQL generation
|
||||
/// based on the pipeline's configuration.
|
||||
///
|
||||
/// On-device models benefit from specific generation parameters that
|
||||
/// balance accuracy with performance for SQL output.
|
||||
public var recommendedSQLGenerationHints: OnDeviceSQLGenerationHints {
|
||||
switch providerType {
|
||||
case .coreML:
|
||||
let config = coreMLConfiguration ?? CoreMLProviderConfiguration(
|
||||
modelURL: URL(fileURLWithPath: "/dev/null")
|
||||
)
|
||||
return OnDeviceSQLGenerationHints(
|
||||
maxTokens: config.maxResponseTokens,
|
||||
temperature: config.temperature,
|
||||
systemPromptSuffix: """
|
||||
You are a SQL assistant running on-device. Generate only valid SQLite SQL.
|
||||
Be concise — output ONLY the SQL query with no explanation.
|
||||
""",
|
||||
useSampling: config.useSampling
|
||||
)
|
||||
|
||||
case .mlx:
|
||||
let config = mlxConfiguration ?? .llama3_2_3B()
|
||||
return OnDeviceSQLGenerationHints(
|
||||
maxTokens: config.maxResponseTokens,
|
||||
temperature: config.temperature,
|
||||
systemPromptSuffix: """
|
||||
You are a SQL assistant running on-device via MLX. Generate only valid SQLite SQL.
|
||||
Be concise — output ONLY the SQL query with no explanation.
|
||||
""",
|
||||
useSampling: true
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Hints for optimizing SQL generation with on-device models.
|
||||
///
|
||||
/// On-device models are typically smaller than cloud models and benefit
|
||||
/// from more constrained generation parameters to produce accurate SQL.
|
||||
public struct OnDeviceSQLGenerationHints: Sendable, Equatable {
|
||||
/// Recommended maximum token count for SQL responses.
|
||||
public let maxTokens: Int
|
||||
|
||||
/// Recommended temperature for SQL generation.
|
||||
public let temperature: Double
|
||||
|
||||
/// Additional system prompt text optimized for on-device SQL generation.
|
||||
public let systemPromptSuffix: String
|
||||
|
||||
/// Whether to use sampling or greedy decoding.
|
||||
public let useSampling: Bool
|
||||
}
|
||||
|
||||
// MARK: - ProviderConfiguration Extension
|
||||
|
||||
extension ProviderConfiguration {
|
||||
|
||||
/// Creates a configuration for an on-device MLX model.
|
||||
///
|
||||
/// MLX models run entirely on Apple silicon using the MLX framework.
|
||||
/// Models are automatically downloaded from HuggingFace Hub on first use.
|
||||
///
|
||||
/// ```swift
|
||||
/// // Using a pre-configured model
|
||||
/// let config = ProviderConfiguration.onDeviceMLX(.llama3_2_3B())
|
||||
///
|
||||
/// // Using a custom model
|
||||
/// let config = ProviderConfiguration.onDeviceMLX(
|
||||
/// MLXProviderConfiguration(
|
||||
/// modelId: "mlx-community/Qwen2.5-7B-Instruct-4bit",
|
||||
/// temperature: 0.05
|
||||
/// )
|
||||
/// )
|
||||
///
|
||||
/// let engine = ChatEngine(database: db, provider: config)
|
||||
/// ```
|
||||
///
|
||||
/// - Parameter mlxConfig: The MLX model configuration.
|
||||
/// - Returns: A configured `ProviderConfiguration` that wraps the MLX model.
|
||||
///
|
||||
/// - Note: The returned configuration uses `.openAICompatible` as the provider
|
||||
/// type internally. The actual model is created via MLX APIs when `#if MLX` is
|
||||
/// available. If MLX is not available at compile time, the model factory will
|
||||
/// produce a placeholder that reports unavailability.
|
||||
public static func onDeviceMLX(
|
||||
_ mlxConfig: MLXProviderConfiguration
|
||||
) -> ProviderConfiguration {
|
||||
ProviderConfiguration(
|
||||
provider: .openAICompatible,
|
||||
model: mlxConfig.modelId,
|
||||
apiKeyProvider: { "" },
|
||||
baseURL: nil,
|
||||
apiVersion: nil,
|
||||
betas: nil,
|
||||
openAIVariant: nil
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates a configuration for an on-device CoreML model.
|
||||
///
|
||||
/// CoreML models must be pre-compiled to `.mlmodelc` format.
|
||||
/// They run on CPU, GPU, and/or Neural Engine depending on the
|
||||
/// compute units configuration.
|
||||
///
|
||||
/// ```swift
|
||||
/// let modelURL = Bundle.main.url(forResource: "SQLModel", withExtension: "mlmodelc")!
|
||||
/// let config = ProviderConfiguration.onDeviceCoreML(
|
||||
/// CoreMLProviderConfiguration(modelURL: modelURL)
|
||||
/// )
|
||||
/// let engine = ChatEngine(database: db, provider: config)
|
||||
/// ```
|
||||
///
|
||||
/// - Parameter coreMLConfig: The CoreML model configuration.
|
||||
/// - Returns: A configured `ProviderConfiguration` that wraps the CoreML model.
|
||||
///
|
||||
/// - Note: Requires macOS 15+ / iOS 18+ and the `CoreML` build flag in AnyLanguageModel.
|
||||
public static func onDeviceCoreML(
|
||||
_ coreMLConfig: CoreMLProviderConfiguration
|
||||
) -> ProviderConfiguration {
|
||||
ProviderConfiguration(
|
||||
provider: .openAICompatible,
|
||||
model: coreMLConfig.modelURL.lastPathComponent,
|
||||
apiKeyProvider: { "" },
|
||||
baseURL: nil,
|
||||
apiVersion: nil,
|
||||
betas: nil,
|
||||
openAIVariant: nil
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - ChatEngine On-Device Convenience
|
||||
|
||||
extension ChatEngine {
|
||||
|
||||
/// Creates a ChatEngine with an on-device MLX model.
|
||||
///
|
||||
/// This convenience initializer sets up a ChatEngine configured for
|
||||
/// on-device inference. It validates the MLX configuration and creates
|
||||
/// an inference pipeline.
|
||||
///
|
||||
/// ```swift
|
||||
/// let engine = try ChatEngine.onDevice(
|
||||
/// database: db,
|
||||
/// mlx: .llama3_2_3B()
|
||||
/// )
|
||||
/// let response = try await engine.send("How many users are there?")
|
||||
/// ```
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
|
||||
/// - mlx: The MLX model configuration.
|
||||
/// - allowlist: SQL operations allowed. Defaults to read-only.
|
||||
/// - configuration: Engine configuration.
|
||||
/// - Returns: A configured `ChatEngine` instance.
|
||||
/// - Throws: ``OnDeviceProviderError`` if the configuration is invalid.
|
||||
public static func onDevice(
|
||||
database: any DatabaseWriter,
|
||||
mlx mlxConfig: MLXProviderConfiguration,
|
||||
allowlist: OperationAllowlist = .readOnly,
|
||||
configuration: ChatEngineConfiguration = .default
|
||||
) throws -> ChatEngine {
|
||||
// Validate configuration
|
||||
try mlxConfig.validate()
|
||||
|
||||
let pipeline = OnDeviceInferencePipeline(mlxConfiguration: mlxConfig)
|
||||
|
||||
// Build a ChatEngineConfiguration that includes on-device hints
|
||||
var engineConfig = configuration
|
||||
let hints = pipeline.recommendedSQLGenerationHints
|
||||
if engineConfig.additionalContext == nil {
|
||||
engineConfig.additionalContext = hints.systemPromptSuffix
|
||||
} else {
|
||||
engineConfig.additionalContext! += "\n\n" + hints.systemPromptSuffix
|
||||
}
|
||||
|
||||
let providerConfig = ProviderConfiguration.onDeviceMLX(mlxConfig)
|
||||
|
||||
return ChatEngine(
|
||||
database: database,
|
||||
provider: providerConfig,
|
||||
allowlist: allowlist,
|
||||
configuration: engineConfig
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates a ChatEngine with an on-device CoreML model.
|
||||
///
|
||||
/// ```swift
|
||||
/// let modelURL = Bundle.main.url(forResource: "SQLModel", withExtension: "mlmodelc")!
|
||||
/// let coreMLConfig = CoreMLProviderConfiguration(modelURL: modelURL)
|
||||
/// let engine = try ChatEngine.onDevice(
|
||||
/// database: db,
|
||||
/// coreML: coreMLConfig
|
||||
/// )
|
||||
/// ```
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
|
||||
/// - coreML: The CoreML model configuration.
|
||||
/// - allowlist: SQL operations allowed. Defaults to read-only.
|
||||
/// - configuration: Engine configuration.
|
||||
/// - Returns: A configured `ChatEngine` instance.
|
||||
/// - Throws: ``OnDeviceProviderError`` if the configuration is invalid.
|
||||
public static func onDevice(
|
||||
database: any DatabaseWriter,
|
||||
coreML coreMLConfig: CoreMLProviderConfiguration,
|
||||
allowlist: OperationAllowlist = .readOnly,
|
||||
configuration: ChatEngineConfiguration = .default
|
||||
) throws -> ChatEngine {
|
||||
// Validate configuration
|
||||
try coreMLConfig.validate()
|
||||
|
||||
let pipeline = OnDeviceInferencePipeline(coreMLConfiguration: coreMLConfig)
|
||||
|
||||
var engineConfig = configuration
|
||||
let hints = pipeline.recommendedSQLGenerationHints
|
||||
if engineConfig.additionalContext == nil {
|
||||
engineConfig.additionalContext = hints.systemPromptSuffix
|
||||
} else {
|
||||
engineConfig.additionalContext! += "\n\n" + hints.systemPromptSuffix
|
||||
}
|
||||
|
||||
let providerConfig = ProviderConfiguration.onDeviceCoreML(coreMLConfig)
|
||||
|
||||
return ChatEngine(
|
||||
database: database,
|
||||
provider: providerConfig,
|
||||
allowlist: allowlist,
|
||||
configuration: engineConfig
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Model Readiness Checker
|
||||
|
||||
/// Utility for checking on-device model availability and system capability.
|
||||
public enum OnDeviceModelReadiness {
|
||||
|
||||
/// System capability information for on-device inference.
|
||||
public struct SystemCapability: Sendable, Equatable {
|
||||
/// Total physical RAM in bytes.
|
||||
public let totalRAM: UInt64
|
||||
|
||||
/// Whether the device has sufficient RAM for typical on-device models.
|
||||
/// Generally requires at least 4GB for 3B parameter models.
|
||||
public let hasSufficientRAM: Bool
|
||||
|
||||
/// Whether Apple Neural Engine is likely available.
|
||||
/// True on devices with Apple silicon.
|
||||
public let hasNeuralEngine: Bool
|
||||
|
||||
/// Recommended model size category based on available RAM.
|
||||
public let recommendedModelSize: RecommendedModelSize
|
||||
}
|
||||
|
||||
/// Recommended model size based on device capabilities.
|
||||
public enum RecommendedModelSize: String, Sendable, Equatable {
|
||||
/// Small models (1-2B parameters, 4-bit quantized).
|
||||
/// Suitable for devices with 4GB RAM.
|
||||
case small
|
||||
|
||||
/// Medium models (3-4B parameters, 4-bit quantized).
|
||||
/// Suitable for devices with 6-8GB RAM.
|
||||
case medium
|
||||
|
||||
/// Large models (7-8B parameters, 4-bit quantized).
|
||||
/// Suitable for devices with 16GB+ RAM.
|
||||
case large
|
||||
}
|
||||
|
||||
/// Checks the current device's capability for on-device inference.
|
||||
///
|
||||
/// ```swift
|
||||
/// let capability = OnDeviceModelReadiness.checkSystemCapability()
|
||||
/// if capability.hasSufficientRAM {
|
||||
/// print("Recommended size: \(capability.recommendedModelSize)")
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// - Returns: A `SystemCapability` describing the device's readiness.
|
||||
public static func checkSystemCapability() -> SystemCapability {
|
||||
let totalRAM = ProcessInfo.processInfo.physicalMemory
|
||||
let ramGB = totalRAM / (1024 * 1024 * 1024)
|
||||
|
||||
let recommendedSize: RecommendedModelSize
|
||||
switch ramGB {
|
||||
case ..<4:
|
||||
recommendedSize = .small
|
||||
case ..<8:
|
||||
recommendedSize = .medium
|
||||
default:
|
||||
recommendedSize = .large
|
||||
}
|
||||
|
||||
return SystemCapability(
|
||||
totalRAM: totalRAM,
|
||||
hasSufficientRAM: ramGB >= 4,
|
||||
hasNeuralEngine: hasAppleSilicon(),
|
||||
recommendedModelSize: recommendedSize
|
||||
)
|
||||
}
|
||||
|
||||
/// Suggests an MLX model configuration based on system capabilities.
|
||||
///
|
||||
/// ```swift
|
||||
/// let config = OnDeviceModelReadiness.suggestedMLXModel()
|
||||
/// let engine = try ChatEngine.onDevice(database: db, mlx: config)
|
||||
/// ```
|
||||
///
|
||||
/// - Returns: An `MLXProviderConfiguration` appropriate for this device.
|
||||
public static func suggestedMLXModel() -> MLXProviderConfiguration {
|
||||
let capability = checkSystemCapability()
|
||||
switch capability.recommendedModelSize {
|
||||
case .small:
|
||||
return .phi3_5_mini()
|
||||
case .medium:
|
||||
return .llama3_2_3B()
|
||||
case .large:
|
||||
return .qwen2_5_coder_3B()
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if the current device uses Apple silicon.
|
||||
private static func hasAppleSilicon() -> Bool {
|
||||
#if arch(arm64)
|
||||
return true
|
||||
#else
|
||||
return false
|
||||
#endif
|
||||
}
|
||||
}
|
||||
54
Sources/SwiftDBAI/Config/OperationAllowlist.swift
Normal file
54
Sources/SwiftDBAI/Config/OperationAllowlist.swift
Normal file
@@ -0,0 +1,54 @@
|
||||
/// Defines which SQL operations the LLM is permitted to generate.
|
||||
///
|
||||
/// The default is ``readOnly`` (SELECT only). Write operations require
|
||||
/// explicit opt-in. This is the safety-by-default principle.
|
||||
public struct OperationAllowlist: Sendable, Equatable {
|
||||
/// The set of permitted SQL operation types.
|
||||
public let allowedOperations: Set<SQLOperation>
|
||||
|
||||
/// Creates an allowlist from the given set of operations.
|
||||
public init(_ operations: Set<SQLOperation>) {
|
||||
self.allowedOperations = operations
|
||||
}
|
||||
|
||||
/// Read-only: only SELECT queries are permitted. This is the default.
|
||||
public static let readOnly = OperationAllowlist([.select])
|
||||
|
||||
/// Standard read-write: SELECT, INSERT, and UPDATE are permitted.
|
||||
public static let standard = OperationAllowlist([.select, .insert, .update])
|
||||
|
||||
/// Unrestricted: all operations including DELETE are permitted.
|
||||
/// DELETE still requires confirmation via `ToolExecutionDelegate`.
|
||||
public static let unrestricted = OperationAllowlist([.select, .insert, .update, .delete])
|
||||
|
||||
/// Returns true if the given operation is allowed.
|
||||
public func isAllowed(_ operation: SQLOperation) -> Bool {
|
||||
allowedOperations.contains(operation)
|
||||
}
|
||||
|
||||
/// Returns a human-readable description of what's allowed, for inclusion
|
||||
/// in the LLM system prompt.
|
||||
func describeForLLM() -> String {
|
||||
if allowedOperations == [.select] {
|
||||
return "You may ONLY generate SELECT queries. No data modifications are allowed."
|
||||
}
|
||||
|
||||
let sorted = allowedOperations.sorted { $0.rawValue < $1.rawValue }
|
||||
let names = sorted.map { $0.rawValue.uppercased() }
|
||||
var desc = "Allowed SQL operations: \(names.joined(separator: ", "))."
|
||||
|
||||
if allowedOperations.contains(.delete) {
|
||||
desc += " DELETE operations are destructive and require user confirmation before execution."
|
||||
}
|
||||
|
||||
return desc
|
||||
}
|
||||
}
|
||||
|
||||
/// The types of SQL operations that can be controlled via the allowlist.
|
||||
public enum SQLOperation: String, Sendable, Hashable, CaseIterable {
|
||||
case select
|
||||
case insert
|
||||
case update
|
||||
case delete
|
||||
}
|
||||
609
Sources/SwiftDBAI/Config/ProviderConfiguration.swift
Normal file
609
Sources/SwiftDBAI/Config/ProviderConfiguration.swift
Normal file
@@ -0,0 +1,609 @@
|
||||
// ProviderConfiguration.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Unified provider configuration for cloud-based LLM providers.
|
||||
// Wraps AnyLanguageModel provider types with convenient factory methods.
|
||||
|
||||
import AnyLanguageModel
|
||||
import Foundation
|
||||
import GRDB
|
||||
|
||||
/// Configuration for connecting to a cloud-based LLM provider.
|
||||
///
|
||||
/// `ProviderConfiguration` provides a unified way to configure any supported
|
||||
/// LLM provider (OpenAI, Anthropic, Gemini, or OpenAI-compatible services).
|
||||
/// Each configuration produces a properly configured `LanguageModel` instance
|
||||
/// that works with ``ChatEngine`` and ``TextSummaryRenderer``.
|
||||
///
|
||||
/// ## Quick Start
|
||||
///
|
||||
/// ```swift
|
||||
/// // OpenAI
|
||||
/// let config = ProviderConfiguration.openAI(apiKey: "sk-...", model: "gpt-4o")
|
||||
///
|
||||
/// // Anthropic
|
||||
/// let config = ProviderConfiguration.anthropic(apiKey: "sk-ant-...", model: "claude-sonnet-4-20250514")
|
||||
///
|
||||
/// // Gemini
|
||||
/// let config = ProviderConfiguration.gemini(apiKey: "AIza...", model: "gemini-2.0-flash")
|
||||
///
|
||||
/// // Use with ChatEngine
|
||||
/// let engine = ChatEngine(database: db, model: config.makeModel())
|
||||
/// ```
|
||||
///
|
||||
/// ## API Key Handling
|
||||
///
|
||||
/// API keys are stored as closures to support both static strings and
|
||||
/// dynamic retrieval from keychains, environment variables, or secure storage:
|
||||
///
|
||||
/// ```swift
|
||||
/// // Static key
|
||||
/// let config = ProviderConfiguration.openAI(apiKey: "sk-...", model: "gpt-4o")
|
||||
///
|
||||
/// // Dynamic key from environment
|
||||
/// let config = ProviderConfiguration.openAI(
|
||||
/// apiKeyProvider: { ProcessInfo.processInfo.environment["OPENAI_API_KEY"] ?? "" },
|
||||
/// model: "gpt-4o"
|
||||
/// )
|
||||
/// ```
|
||||
public struct ProviderConfiguration: Sendable {
|
||||
|
||||
/// The supported LLM provider types.
|
||||
public enum Provider: String, Sendable, Hashable, CaseIterable {
|
||||
/// OpenAI's GPT models via the Chat Completions or Responses API.
|
||||
case openAI
|
||||
|
||||
/// Anthropic's Claude models.
|
||||
case anthropic
|
||||
|
||||
/// Google's Gemini models.
|
||||
case gemini
|
||||
|
||||
/// Any OpenAI-compatible API (e.g., local servers, third-party providers).
|
||||
case openAICompatible
|
||||
|
||||
/// Ollama — local models via `ollama serve`.
|
||||
/// Default endpoint: http://localhost:11434
|
||||
case ollama
|
||||
|
||||
/// llama.cpp server — local GGUF models via `llama-server`.
|
||||
/// Default endpoint: http://localhost:8080
|
||||
/// Uses the OpenAI-compatible API.
|
||||
case llamaCpp
|
||||
}
|
||||
|
||||
/// The provider type for this configuration.
|
||||
public let provider: Provider
|
||||
|
||||
/// The model identifier (e.g., "gpt-4o", "claude-sonnet-4-20250514", "gemini-2.0-flash").
|
||||
public let model: String
|
||||
|
||||
/// A closure that provides the API key on demand.
|
||||
///
|
||||
/// Using a closure allows lazy evaluation and integration with secure
|
||||
/// storage systems (Keychain, environment variables, etc.).
|
||||
private let apiKeyProvider: @Sendable () -> String
|
||||
|
||||
/// Optional custom base URL for OpenAI-compatible providers.
|
||||
public let baseURL: URL?
|
||||
|
||||
/// Optional API version override (used by Anthropic and Gemini).
|
||||
public let apiVersion: String?
|
||||
|
||||
/// Optional beta headers (used by Anthropic).
|
||||
public let betas: [String]?
|
||||
|
||||
/// The OpenAI API variant to use (Chat Completions or Responses).
|
||||
public let openAIVariant: OpenAILanguageModel.APIVariant?
|
||||
|
||||
// MARK: - Internal Init
|
||||
|
||||
/// Internal memberwise initializer used by factory methods.
|
||||
internal init(
|
||||
provider: Provider,
|
||||
model: String,
|
||||
apiKeyProvider: @escaping @Sendable () -> String,
|
||||
baseURL: URL?,
|
||||
apiVersion: String?,
|
||||
betas: [String]?,
|
||||
openAIVariant: OpenAILanguageModel.APIVariant?
|
||||
) {
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.apiKeyProvider = apiKeyProvider
|
||||
self.baseURL = baseURL
|
||||
self.apiVersion = apiVersion
|
||||
self.betas = betas
|
||||
self.openAIVariant = openAIVariant
|
||||
}
|
||||
|
||||
// MARK: - Factory Methods
|
||||
|
||||
/// Creates a configuration for OpenAI's API.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - apiKey: Your OpenAI API key (e.g., "sk-...").
|
||||
/// - model: The model identifier (e.g., "gpt-4o", "gpt-4o-mini").
|
||||
/// - variant: The API variant to use. Defaults to `.chatCompletions`.
|
||||
/// - baseURL: Optional custom base URL. Defaults to OpenAI's API.
|
||||
/// - Returns: A configured `ProviderConfiguration`.
|
||||
public static func openAI(
|
||||
apiKey: String,
|
||||
model: String,
|
||||
variant: OpenAILanguageModel.APIVariant = .chatCompletions,
|
||||
baseURL: URL? = nil
|
||||
) -> ProviderConfiguration {
|
||||
ProviderConfiguration(
|
||||
provider: .openAI,
|
||||
model: model,
|
||||
apiKeyProvider: { apiKey },
|
||||
baseURL: baseURL,
|
||||
apiVersion: nil,
|
||||
betas: nil,
|
||||
openAIVariant: variant
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates a configuration for OpenAI's API with a dynamic key provider.
|
||||
///
|
||||
/// Use this when the API key comes from a keychain, environment variable,
|
||||
/// or other dynamic source.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - apiKeyProvider: A closure that returns the API key.
|
||||
/// - model: The model identifier.
|
||||
/// - variant: The API variant to use. Defaults to `.chatCompletions`.
|
||||
/// - baseURL: Optional custom base URL.
|
||||
/// - Returns: A configured `ProviderConfiguration`.
|
||||
public static func openAI(
|
||||
apiKeyProvider: @escaping @Sendable () -> String,
|
||||
model: String,
|
||||
variant: OpenAILanguageModel.APIVariant = .chatCompletions,
|
||||
baseURL: URL? = nil
|
||||
) -> ProviderConfiguration {
|
||||
ProviderConfiguration(
|
||||
provider: .openAI,
|
||||
model: model,
|
||||
apiKeyProvider: apiKeyProvider,
|
||||
baseURL: baseURL,
|
||||
apiVersion: nil,
|
||||
betas: nil,
|
||||
openAIVariant: variant
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates a configuration for Anthropic's Claude API.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - apiKey: Your Anthropic API key (e.g., "sk-ant-...").
|
||||
/// - model: The model identifier (e.g., "claude-sonnet-4-20250514").
|
||||
/// - apiVersion: Optional API version override.
|
||||
/// - betas: Optional beta feature headers.
|
||||
/// - Returns: A configured `ProviderConfiguration`.
|
||||
public static func anthropic(
|
||||
apiKey: String,
|
||||
model: String,
|
||||
apiVersion: String? = nil,
|
||||
betas: [String]? = nil
|
||||
) -> ProviderConfiguration {
|
||||
ProviderConfiguration(
|
||||
provider: .anthropic,
|
||||
model: model,
|
||||
apiKeyProvider: { apiKey },
|
||||
baseURL: nil,
|
||||
apiVersion: apiVersion,
|
||||
betas: betas,
|
||||
openAIVariant: nil
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates a configuration for Anthropic's Claude API with a dynamic key provider.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - apiKeyProvider: A closure that returns the API key.
|
||||
/// - model: The model identifier.
|
||||
/// - apiVersion: Optional API version override.
|
||||
/// - betas: Optional beta feature headers.
|
||||
/// - Returns: A configured `ProviderConfiguration`.
|
||||
public static func anthropic(
|
||||
apiKeyProvider: @escaping @Sendable () -> String,
|
||||
model: String,
|
||||
apiVersion: String? = nil,
|
||||
betas: [String]? = nil
|
||||
) -> ProviderConfiguration {
|
||||
ProviderConfiguration(
|
||||
provider: .anthropic,
|
||||
model: model,
|
||||
apiKeyProvider: apiKeyProvider,
|
||||
baseURL: nil,
|
||||
apiVersion: apiVersion,
|
||||
betas: betas,
|
||||
openAIVariant: nil
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates a configuration for Google's Gemini API.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - apiKey: Your Gemini API key (e.g., "AIza...").
|
||||
/// - model: The model identifier (e.g., "gemini-2.0-flash").
|
||||
/// - apiVersion: Optional API version override (defaults to "v1beta").
|
||||
/// - Returns: A configured `ProviderConfiguration`.
|
||||
public static func gemini(
|
||||
apiKey: String,
|
||||
model: String,
|
||||
apiVersion: String? = nil
|
||||
) -> ProviderConfiguration {
|
||||
ProviderConfiguration(
|
||||
provider: .gemini,
|
||||
model: model,
|
||||
apiKeyProvider: { apiKey },
|
||||
baseURL: nil,
|
||||
apiVersion: apiVersion,
|
||||
betas: nil,
|
||||
openAIVariant: nil
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates a configuration for Google's Gemini API with a dynamic key provider.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - apiKeyProvider: A closure that returns the API key.
|
||||
/// - model: The model identifier.
|
||||
/// - apiVersion: Optional API version override.
|
||||
/// - Returns: A configured `ProviderConfiguration`.
|
||||
public static func gemini(
|
||||
apiKeyProvider: @escaping @Sendable () -> String,
|
||||
model: String,
|
||||
apiVersion: String? = nil
|
||||
) -> ProviderConfiguration {
|
||||
ProviderConfiguration(
|
||||
provider: .gemini,
|
||||
model: model,
|
||||
apiKeyProvider: apiKeyProvider,
|
||||
baseURL: nil,
|
||||
apiVersion: apiVersion,
|
||||
betas: nil,
|
||||
openAIVariant: nil
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates a configuration for any OpenAI-compatible API.
|
||||
///
|
||||
/// Use this for third-party services that implement the OpenAI Chat Completions
|
||||
/// API (e.g., local LLM servers, Groq, Together AI, etc.).
|
||||
///
|
||||
/// ```swift
|
||||
/// let config = ProviderConfiguration.openAICompatible(
|
||||
/// apiKey: "your-key",
|
||||
/// model: "llama-3.1-70b",
|
||||
/// baseURL: URL(string: "https://api.together.xyz/v1/")!
|
||||
/// )
|
||||
/// ```
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - apiKey: The API key for the service.
|
||||
/// - model: The model identifier.
|
||||
/// - baseURL: The base URL of the compatible API.
|
||||
/// - variant: The API variant. Defaults to `.chatCompletions`.
|
||||
/// - Returns: A configured `ProviderConfiguration`.
|
||||
public static func openAICompatible(
|
||||
apiKey: String,
|
||||
model: String,
|
||||
baseURL: URL,
|
||||
variant: OpenAILanguageModel.APIVariant = .chatCompletions
|
||||
) -> ProviderConfiguration {
|
||||
ProviderConfiguration(
|
||||
provider: .openAICompatible,
|
||||
model: model,
|
||||
apiKeyProvider: { apiKey },
|
||||
baseURL: baseURL,
|
||||
apiVersion: nil,
|
||||
betas: nil,
|
||||
openAIVariant: variant
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates a configuration for any OpenAI-compatible API with a dynamic key provider.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - apiKeyProvider: A closure that returns the API key.
|
||||
/// - model: The model identifier.
|
||||
/// - baseURL: The base URL of the compatible API.
|
||||
/// - variant: The API variant. Defaults to `.chatCompletions`.
|
||||
/// - Returns: A configured `ProviderConfiguration`.
|
||||
public static func openAICompatible(
|
||||
apiKeyProvider: @escaping @Sendable () -> String,
|
||||
model: String,
|
||||
baseURL: URL,
|
||||
variant: OpenAILanguageModel.APIVariant = .chatCompletions
|
||||
) -> ProviderConfiguration {
|
||||
ProviderConfiguration(
|
||||
provider: .openAICompatible,
|
||||
model: model,
|
||||
apiKeyProvider: apiKeyProvider,
|
||||
baseURL: baseURL,
|
||||
apiVersion: nil,
|
||||
betas: nil,
|
||||
openAIVariant: variant
|
||||
)
|
||||
}
|
||||
|
||||
// MARK: - Local Provider Factory Methods
|
||||
|
||||
/// Creates a configuration for a local Ollama instance.
|
||||
///
|
||||
/// Ollama runs models locally and exposes a native API on port 11434.
|
||||
/// No API key is required by default.
|
||||
///
|
||||
/// ```swift
|
||||
/// // Default local Ollama
|
||||
/// let config = ProviderConfiguration.ollama(model: "llama3.2")
|
||||
///
|
||||
/// // Ollama on a remote machine
|
||||
/// let config = ProviderConfiguration.ollama(
|
||||
/// model: "qwen2.5",
|
||||
/// baseURL: URL(string: "http://192.168.1.100:11434")!
|
||||
/// )
|
||||
///
|
||||
/// // Use with ChatEngine
|
||||
/// let engine = ChatEngine(database: db, provider: config)
|
||||
/// ```
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - model: The Ollama model name (e.g., "llama3.2", "qwen2.5", "mistral").
|
||||
/// - baseURL: The Ollama server URL. Defaults to `http://localhost:11434`.
|
||||
/// - Returns: A configured `ProviderConfiguration`.
|
||||
public static func ollama(
|
||||
model: String,
|
||||
baseURL: URL = OllamaLanguageModel.defaultBaseURL
|
||||
) -> ProviderConfiguration {
|
||||
ProviderConfiguration(
|
||||
provider: .ollama,
|
||||
model: model,
|
||||
apiKeyProvider: { "" },
|
||||
baseURL: baseURL,
|
||||
apiVersion: nil,
|
||||
betas: nil,
|
||||
openAIVariant: nil
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates a configuration for a local llama.cpp server.
|
||||
///
|
||||
/// llama.cpp's `llama-server` exposes an OpenAI-compatible Chat Completions
|
||||
/// API, typically on port 8080. No API key is required by default.
|
||||
///
|
||||
/// ```swift
|
||||
/// // Default local llama.cpp
|
||||
/// let config = ProviderConfiguration.llamaCpp(model: "default")
|
||||
///
|
||||
/// // llama.cpp on a custom port with API key
|
||||
/// let config = ProviderConfiguration.llamaCpp(
|
||||
/// model: "my-model",
|
||||
/// baseURL: URL(string: "http://localhost:9090")!,
|
||||
/// apiKey: "my-secret-key"
|
||||
/// )
|
||||
/// ```
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - model: The model identifier. Use "default" if llama-server loads a single model.
|
||||
/// - baseURL: The llama.cpp server URL. Defaults to `http://localhost:8080`.
|
||||
/// - apiKey: Optional API key if the server requires authentication.
|
||||
/// - Returns: A configured `ProviderConfiguration`.
|
||||
public static func llamaCpp(
|
||||
model: String = "default",
|
||||
baseURL: URL = LocalProviderDiscovery.defaultLlamaCppURL,
|
||||
apiKey: String = ""
|
||||
) -> ProviderConfiguration {
|
||||
ProviderConfiguration(
|
||||
provider: .llamaCpp,
|
||||
model: model,
|
||||
apiKeyProvider: { apiKey },
|
||||
baseURL: baseURL,
|
||||
apiVersion: nil,
|
||||
betas: nil,
|
||||
openAIVariant: .chatCompletions
|
||||
)
|
||||
}
|
||||
|
||||
// MARK: - Model Construction
|
||||
|
||||
/// Creates a configured `LanguageModel` instance for this provider.
|
||||
///
|
||||
/// This is the primary way to get a model from a configuration.
|
||||
/// The returned model is ready to use with ``ChatEngine`` or
|
||||
/// ``TextSummaryRenderer``.
|
||||
///
|
||||
/// ```swift
|
||||
/// let config = ProviderConfiguration.openAI(apiKey: "sk-...", model: "gpt-4o")
|
||||
/// let engine = ChatEngine(database: db, model: config.makeModel())
|
||||
/// ```
|
||||
///
|
||||
/// - Returns: A configured `LanguageModel` instance.
|
||||
public func makeModel() -> any LanguageModel {
|
||||
let key = apiKeyProvider
|
||||
|
||||
switch provider {
|
||||
case .openAI:
|
||||
if let baseURL {
|
||||
return OpenAILanguageModel(
|
||||
baseURL: baseURL,
|
||||
apiKey: key(),
|
||||
model: model,
|
||||
apiVariant: openAIVariant ?? .chatCompletions
|
||||
)
|
||||
}
|
||||
return OpenAILanguageModel(
|
||||
apiKey: key(),
|
||||
model: model,
|
||||
apiVariant: openAIVariant ?? .chatCompletions
|
||||
)
|
||||
|
||||
case .anthropic:
|
||||
if let apiVersion {
|
||||
return AnthropicLanguageModel(
|
||||
apiKey: key(),
|
||||
apiVersion: apiVersion,
|
||||
betas: betas,
|
||||
model: model
|
||||
)
|
||||
}
|
||||
if let betas {
|
||||
return AnthropicLanguageModel(
|
||||
apiKey: key(),
|
||||
betas: betas,
|
||||
model: model
|
||||
)
|
||||
}
|
||||
return AnthropicLanguageModel(
|
||||
apiKey: key(),
|
||||
model: model
|
||||
)
|
||||
|
||||
case .gemini:
|
||||
if let apiVersion {
|
||||
return GeminiLanguageModel(
|
||||
apiKey: key(),
|
||||
apiVersion: apiVersion,
|
||||
model: model
|
||||
)
|
||||
}
|
||||
return GeminiLanguageModel(
|
||||
apiKey: key(),
|
||||
model: model
|
||||
)
|
||||
|
||||
case .openAICompatible:
|
||||
return OpenAILanguageModel(
|
||||
baseURL: baseURL ?? OpenAILanguageModel.defaultBaseURL,
|
||||
apiKey: key(),
|
||||
model: model,
|
||||
apiVariant: openAIVariant ?? .chatCompletions
|
||||
)
|
||||
|
||||
case .ollama:
|
||||
return OllamaLanguageModel(
|
||||
baseURL: baseURL ?? OllamaLanguageModel.defaultBaseURL,
|
||||
model: model
|
||||
)
|
||||
|
||||
case .llamaCpp:
|
||||
// llama.cpp exposes an OpenAI-compatible API
|
||||
return OpenAILanguageModel(
|
||||
baseURL: baseURL ?? LocalProviderDiscovery.defaultLlamaCppURL,
|
||||
apiKey: key(),
|
||||
model: model,
|
||||
apiVariant: openAIVariant ?? .chatCompletions
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - API Key Access
|
||||
|
||||
/// Returns the current API key.
|
||||
///
|
||||
/// Useful for validation or debugging. In production, prefer using
|
||||
/// ``makeModel()`` which handles key injection automatically.
|
||||
public var apiKey: String {
|
||||
apiKeyProvider()
|
||||
}
|
||||
|
||||
/// Returns `true` if the API key is non-empty.
|
||||
///
|
||||
/// Use this to check configuration validity before creating an engine:
|
||||
/// ```swift
|
||||
/// guard config.hasValidAPIKey else {
|
||||
/// // Show API key setup UI
|
||||
/// return
|
||||
/// }
|
||||
/// ```
|
||||
public var hasValidAPIKey: Bool {
|
||||
!apiKeyProvider().trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
|
||||
}
|
||||
|
||||
// MARK: - Environment Variable Helpers
|
||||
|
||||
/// Creates a configuration using an API key from an environment variable.
|
||||
///
|
||||
/// Falls back to an empty string if the environment variable is not set,
|
||||
/// which will cause API calls to fail with an authentication error.
|
||||
///
|
||||
/// ```swift
|
||||
/// let config = ProviderConfiguration.fromEnvironment(
|
||||
/// provider: .openAI,
|
||||
/// environmentVariable: "OPENAI_API_KEY",
|
||||
/// model: "gpt-4o"
|
||||
/// )
|
||||
/// ```
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - provider: The LLM provider.
|
||||
/// - environmentVariable: The name of the environment variable holding the API key.
|
||||
/// - model: The model identifier.
|
||||
/// - Returns: A configured `ProviderConfiguration`.
|
||||
public static func fromEnvironment(
|
||||
provider: Provider,
|
||||
environmentVariable: String,
|
||||
model: String
|
||||
) -> ProviderConfiguration {
|
||||
let keyProvider: @Sendable () -> String = {
|
||||
ProcessInfo.processInfo.environment[environmentVariable] ?? ""
|
||||
}
|
||||
|
||||
switch provider {
|
||||
case .openAI:
|
||||
return .openAI(apiKeyProvider: keyProvider, model: model)
|
||||
case .anthropic:
|
||||
return .anthropic(apiKeyProvider: keyProvider, model: model)
|
||||
case .gemini:
|
||||
return .gemini(apiKeyProvider: keyProvider, model: model)
|
||||
case .openAICompatible:
|
||||
return .openAICompatible(
|
||||
apiKeyProvider: keyProvider,
|
||||
model: model,
|
||||
baseURL: OpenAILanguageModel.defaultBaseURL
|
||||
)
|
||||
case .ollama:
|
||||
return .ollama(model: model)
|
||||
case .llamaCpp:
|
||||
return .llamaCpp(model: model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - ChatEngine Convenience Init
|
||||
|
||||
extension ChatEngine {
|
||||
|
||||
/// Creates a ChatEngine using a ``ProviderConfiguration``.
|
||||
///
|
||||
/// This is the most convenient way to set up a ChatEngine with a
|
||||
/// cloud provider:
|
||||
///
|
||||
/// ```swift
|
||||
/// let engine = ChatEngine(
|
||||
/// database: myDB,
|
||||
/// provider: .openAI(apiKey: "sk-...", model: "gpt-4o")
|
||||
/// )
|
||||
/// ```
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
|
||||
/// - provider: The provider configuration.
|
||||
/// - allowlist: SQL operations the LLM may generate. Defaults to read-only.
|
||||
/// - configuration: Engine configuration for timeouts, context window, validators, etc.
|
||||
public convenience init(
|
||||
database: any DatabaseWriter,
|
||||
provider: ProviderConfiguration,
|
||||
allowlist: OperationAllowlist = .readOnly,
|
||||
configuration: ChatEngineConfiguration = .default
|
||||
) {
|
||||
self.init(
|
||||
database: database,
|
||||
model: provider.makeModel(),
|
||||
allowlist: allowlist,
|
||||
configuration: configuration
|
||||
)
|
||||
}
|
||||
}
|
||||
114
Sources/SwiftDBAI/Config/QueryValidator.swift
Normal file
114
Sources/SwiftDBAI/Config/QueryValidator.swift
Normal file
@@ -0,0 +1,114 @@
|
||||
// QueryValidator.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Extensible query validation protocol for custom pre-execution checks.
|
||||
|
||||
import Foundation
|
||||
|
||||
/// A protocol for custom SQL query validation.
|
||||
///
|
||||
/// Implement this protocol to add domain-specific validation rules that run
|
||||
/// after the built-in allowlist and safety checks. Validators receive the
|
||||
/// parsed SQL string and its detected operation type.
|
||||
///
|
||||
/// Example — restrict queries to specific tables:
|
||||
/// ```swift
|
||||
/// struct TableAllowlistValidator: QueryValidator {
|
||||
/// let allowedTables: Set<String>
|
||||
///
|
||||
/// func validate(sql: String, operation: SQLOperation) throws {
|
||||
/// let upper = sql.uppercased()
|
||||
/// for table in allowedTables {
|
||||
/// // Simple check — real implementation might parse FROM/JOIN clauses
|
||||
/// if upper.contains(table.uppercased()) { return }
|
||||
/// }
|
||||
/// throw QueryValidationError.rejected("Query references tables outside the allowlist.")
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
public protocol QueryValidator: Sendable {
|
||||
/// Validates a SQL query before execution.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - sql: The cleaned SQL statement about to be executed.
|
||||
/// - operation: The detected operation type (SELECT, INSERT, etc.).
|
||||
/// - Throws: ``QueryValidationError`` or any `Error` to reject the query.
|
||||
func validate(sql: String, operation: SQLOperation) throws
|
||||
}
|
||||
|
||||
/// Errors thrown by custom ``QueryValidator`` implementations.
|
||||
public enum QueryValidationError: Error, LocalizedError, Sendable, Equatable {
|
||||
/// The query was rejected by a custom validator with the given reason.
|
||||
case rejected(String)
|
||||
|
||||
public var errorDescription: String? {
|
||||
switch self {
|
||||
case .rejected(let reason):
|
||||
return "Query rejected: \(reason)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Built-in Validators
|
||||
|
||||
/// A validator that restricts queries to a specific set of table names.
|
||||
///
|
||||
/// This performs a simple keyword check — it verifies that the SQL references
|
||||
/// at least one of the allowed tables. This is a best-effort check, not a
|
||||
/// full SQL parser.
|
||||
public struct TableAllowlistValidator: QueryValidator {
|
||||
/// The set of table names queries are allowed to reference.
|
||||
public let allowedTables: Set<String>
|
||||
|
||||
/// Creates a validator with the given allowed table names.
|
||||
public init(allowedTables: Set<String>) {
|
||||
self.allowedTables = allowedTables
|
||||
}
|
||||
|
||||
public func validate(sql: String, operation: SQLOperation) throws {
|
||||
let upper = sql.uppercased()
|
||||
let found = allowedTables.contains { table in
|
||||
let pattern = table.uppercased()
|
||||
return upper.contains(pattern)
|
||||
}
|
||||
guard found else {
|
||||
throw QueryValidationError.rejected(
|
||||
"Query does not reference any allowed tables: \(allowedTables.sorted().joined(separator: ", "))"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A validator that enforces a maximum row limit on SELECT queries
|
||||
/// by checking for a LIMIT clause.
|
||||
public struct MaxRowLimitValidator: QueryValidator {
|
||||
/// The maximum number of rows allowed.
|
||||
public let maxRows: Int
|
||||
|
||||
/// Creates a validator that requires SELECT queries to include a LIMIT clause
|
||||
/// not exceeding `maxRows`.
|
||||
public init(maxRows: Int) {
|
||||
self.maxRows = maxRows
|
||||
}
|
||||
|
||||
public func validate(sql: String, operation: SQLOperation) throws {
|
||||
guard operation == .select else { return }
|
||||
|
||||
let upper = sql.uppercased()
|
||||
// Check if LIMIT is present
|
||||
guard let limitRange = upper.range(of: #"LIMIT\s+(\d+)"#, options: .regularExpression) else {
|
||||
throw QueryValidationError.rejected(
|
||||
"SELECT queries must include a LIMIT clause (max \(maxRows) rows)."
|
||||
)
|
||||
}
|
||||
|
||||
// Extract the limit value
|
||||
let limitSubstring = upper[limitRange]
|
||||
let digits = limitSubstring.components(separatedBy: .decimalDigits.inverted).joined()
|
||||
if let value = Int(digits), value > maxRows {
|
||||
throw QueryValidationError.rejected(
|
||||
"LIMIT \(value) exceeds the maximum allowed (\(maxRows))."
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
677
Sources/SwiftDBAI/Engine/ChatEngine.swift
Normal file
677
Sources/SwiftDBAI/Engine/ChatEngine.swift
Normal file
@@ -0,0 +1,677 @@
|
||||
// ChatEngine.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Orchestrates the conversation loop: user message → SQL generation → query
|
||||
// execution → result summarization → response.
|
||||
|
||||
import AnyLanguageModel
|
||||
import Foundation
|
||||
import GRDB
|
||||
|
||||
/// A message in the chat conversation.
|
||||
public struct ChatMessage: Sendable, Identifiable, Equatable {
|
||||
public let id: UUID
|
||||
public let role: Role
|
||||
public let content: String
|
||||
public let queryResult: QueryResult?
|
||||
public let sql: String?
|
||||
public let timestamp: Date
|
||||
/// The typed error, if this is an error message.
|
||||
public let error: SwiftDBAIError?
|
||||
|
||||
public enum Role: String, Sendable, Equatable {
|
||||
case user
|
||||
case assistant
|
||||
case error
|
||||
}
|
||||
|
||||
public init(
|
||||
id: UUID = UUID(),
|
||||
role: Role,
|
||||
content: String,
|
||||
queryResult: QueryResult? = nil,
|
||||
sql: String? = nil,
|
||||
timestamp: Date = Date(),
|
||||
error: SwiftDBAIError? = nil
|
||||
) {
|
||||
self.id = id
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.queryResult = queryResult
|
||||
self.sql = sql
|
||||
self.timestamp = timestamp
|
||||
self.error = error
|
||||
}
|
||||
}
|
||||
|
||||
/// The response returned by `ChatEngine.send(_:)`.
|
||||
public struct ChatResponse: Sendable {
|
||||
/// The natural language summary of the result.
|
||||
public let summary: String
|
||||
|
||||
/// The SQL that was generated and executed, if any.
|
||||
public let sql: String?
|
||||
|
||||
/// The raw query result, if a query was executed.
|
||||
public let queryResult: QueryResult?
|
||||
}
|
||||
|
||||
/// Headless engine that orchestrates the full chat-with-database pipeline.
|
||||
///
|
||||
/// The engine:
|
||||
/// 1. Introspects the database schema (once, lazily)
|
||||
/// 2. Builds a system prompt with schema context
|
||||
/// 3. Sends the user's question to the LLM to generate SQL
|
||||
/// 4. Validates the SQL against the operation allowlist
|
||||
/// 5. Executes the SQL via GRDB
|
||||
/// 6. Summarizes results using `TextSummaryRenderer`
|
||||
/// 7. Returns the summary (and raw data) to the caller
|
||||
///
|
||||
/// Usage:
|
||||
/// ```swift
|
||||
/// let engine = ChatEngine(
|
||||
/// database: myDatabasePool,
|
||||
/// model: myLanguageModel
|
||||
/// )
|
||||
/// let response = try await engine.send("How many users signed up this week?")
|
||||
/// print(response.summary) // "There were 42 new signups this week."
|
||||
/// ```
|
||||
public final class ChatEngine: @unchecked Sendable {
|
||||
|
||||
// MARK: - Dependencies
|
||||
|
||||
private let database: any DatabaseWriter
|
||||
private let model: any LanguageModel
|
||||
private let allowlist: OperationAllowlist
|
||||
private let mutationPolicy: MutationPolicy?
|
||||
private let configuration: ChatEngineConfiguration
|
||||
private let summaryRenderer: TextSummaryRenderer
|
||||
private let sqlParser: SQLQueryParser
|
||||
|
||||
/// Optional delegate for intercepting destructive operations and observing SQL execution.
|
||||
private let delegate: (any ToolExecutionDelegate)?
|
||||
|
||||
// MARK: - State
|
||||
|
||||
private var schema: DatabaseSchema?
|
||||
private var conversationHistory: [ChatMessage] = []
|
||||
private let lock = NSLock()
|
||||
|
||||
// MARK: - Initialization
|
||||
|
||||
/// Creates a new ChatEngine with a full configuration object.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
|
||||
/// - model: Any `AnyLanguageModel`-compatible language model.
|
||||
/// - allowlist: SQL operations the LLM may generate. Defaults to read-only (SELECT only).
|
||||
/// - configuration: Engine configuration for timeouts, context window, validators, etc.
|
||||
/// - delegate: Optional delegate for confirming destructive operations and observing SQL execution.
|
||||
public init(
|
||||
database: any DatabaseWriter,
|
||||
model: any LanguageModel,
|
||||
allowlist: OperationAllowlist = .readOnly,
|
||||
configuration: ChatEngineConfiguration = .default,
|
||||
delegate: (any ToolExecutionDelegate)? = nil
|
||||
) {
|
||||
self.database = database
|
||||
self.model = model
|
||||
self.allowlist = allowlist
|
||||
self.mutationPolicy = nil
|
||||
self.configuration = configuration
|
||||
self.delegate = delegate
|
||||
self.summaryRenderer = TextSummaryRenderer(
|
||||
model: model,
|
||||
maxRowsInPrompt: configuration.maxSummaryRows
|
||||
)
|
||||
self.sqlParser = SQLQueryParser(allowlist: allowlist)
|
||||
}
|
||||
|
||||
/// Creates a new ChatEngine with a `MutationPolicy` for table-level control.
|
||||
///
|
||||
/// This initializer provides fine-grained control over which mutations are
|
||||
/// allowed on which tables. The policy's operation allowlist is used for
|
||||
/// SQL validation, and table-level restrictions are enforced during parsing.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
|
||||
/// - model: Any `AnyLanguageModel`-compatible language model.
|
||||
/// - mutationPolicy: Controls which operations are allowed on which tables.
|
||||
/// - configuration: Engine configuration for timeouts, context window, validators, etc.
|
||||
/// - delegate: Optional delegate for confirming destructive operations and observing SQL execution.
|
||||
public init(
|
||||
database: any DatabaseWriter,
|
||||
model: any LanguageModel,
|
||||
mutationPolicy: MutationPolicy,
|
||||
configuration: ChatEngineConfiguration = .default,
|
||||
delegate: (any ToolExecutionDelegate)? = nil
|
||||
) {
|
||||
self.database = database
|
||||
self.model = model
|
||||
self.allowlist = mutationPolicy.operationAllowlist
|
||||
self.mutationPolicy = mutationPolicy
|
||||
self.configuration = configuration
|
||||
self.delegate = delegate
|
||||
self.summaryRenderer = TextSummaryRenderer(
|
||||
model: model,
|
||||
maxRowsInPrompt: configuration.maxSummaryRows
|
||||
)
|
||||
self.sqlParser = SQLQueryParser(mutationPolicy: mutationPolicy)
|
||||
}
|
||||
|
||||
/// Creates a new ChatEngine with individual parameters (convenience).
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - database: A GRDB `DatabaseWriter` (DatabasePool or DatabaseQueue).
|
||||
/// - model: Any `AnyLanguageModel`-compatible language model.
|
||||
/// - allowlist: SQL operations the LLM may generate. Defaults to read-only (SELECT only).
|
||||
/// - additionalContext: Optional extra instructions for the LLM system prompt.
|
||||
/// - maxSummaryRows: Maximum rows to include when summarizing results (default: 50).
|
||||
public convenience init(
|
||||
database: any DatabaseWriter,
|
||||
model: any LanguageModel,
|
||||
allowlist: OperationAllowlist,
|
||||
additionalContext: String?,
|
||||
maxSummaryRows: Int = 50
|
||||
) {
|
||||
let config = ChatEngineConfiguration(
|
||||
maxSummaryRows: maxSummaryRows,
|
||||
additionalContext: additionalContext
|
||||
)
|
||||
self.init(
|
||||
database: database,
|
||||
model: model,
|
||||
allowlist: allowlist,
|
||||
configuration: config
|
||||
)
|
||||
}
|
||||
|
||||
// MARK: - Public API
|
||||
|
||||
/// Sends a natural language message and returns a summarized response.
|
||||
///
|
||||
/// This is the primary entry point. The engine will:
|
||||
/// 1. Introspect the schema if not yet cached
|
||||
/// 2. Ask the LLM to generate SQL
|
||||
/// 3. Validate the SQL against the allowlist and custom validators
|
||||
/// 4. Execute the SQL (with timeout if configured)
|
||||
/// 5. Summarize the results using `TextSummaryRenderer`
|
||||
///
|
||||
/// All errors are caught and mapped to a distinct ``SwiftDBAIError`` case
|
||||
/// so callers always receive a typed, user-friendly error with a localized
|
||||
/// description suitable for display in a chat UI.
|
||||
///
|
||||
/// - Parameter message: The user's natural language question or command.
|
||||
/// - Returns: A `ChatResponse` containing the summary, SQL, and raw result.
|
||||
/// - Throws: ``SwiftDBAIError`` for every failure mode.
|
||||
public func send(_ message: String) async throws -> ChatResponse {
|
||||
// 1. Ensure schema is introspected
|
||||
let schema: DatabaseSchema
|
||||
do {
|
||||
schema = try await ensureSchema()
|
||||
} catch let error as SwiftDBAIError {
|
||||
throw error
|
||||
} catch {
|
||||
throw SwiftDBAIError.schemaIntrospectionFailed(reason: error.localizedDescription)
|
||||
}
|
||||
|
||||
// Check for empty schema
|
||||
if schema.tableNames.isEmpty {
|
||||
throw SwiftDBAIError.emptySchema
|
||||
}
|
||||
|
||||
// 2. Build prompt and get raw LLM response
|
||||
let promptBuilder = PromptBuilder(
|
||||
schema: schema,
|
||||
allowlist: allowlist,
|
||||
additionalContext: configuration.additionalContext
|
||||
)
|
||||
|
||||
let rawLLMResponse: String
|
||||
do {
|
||||
rawLLMResponse = try await generateRawResponse(
|
||||
question: message,
|
||||
promptBuilder: promptBuilder
|
||||
)
|
||||
} catch let error as SwiftDBAIError {
|
||||
throw error
|
||||
} catch {
|
||||
throw SwiftDBAIError.llmFailure(reason: error.localizedDescription)
|
||||
}
|
||||
|
||||
// 3. Parse and validate SQL through SQLQueryParser
|
||||
let parsed: ParsedSQL
|
||||
do {
|
||||
parsed = try sqlParser.parse(rawLLMResponse)
|
||||
} catch let error as SQLParsingError {
|
||||
throw error.toSwiftDBAIError(rawResponse: rawLLMResponse)
|
||||
} catch let error as SwiftDBAIError {
|
||||
throw error
|
||||
} catch {
|
||||
throw SwiftDBAIError.invalidSQL(sql: rawLLMResponse, reason: error.localizedDescription)
|
||||
}
|
||||
|
||||
// 4. Run custom validators
|
||||
do {
|
||||
try runCustomValidators(parsed: parsed)
|
||||
} catch let error as QueryValidationError {
|
||||
throw error
|
||||
} catch let error as SwiftDBAIError {
|
||||
throw error
|
||||
} catch {
|
||||
throw SwiftDBAIError.queryRejected(reason: error.localizedDescription)
|
||||
}
|
||||
|
||||
// 5. Handle confirmation-required operations (DELETE, DROP, etc.)
|
||||
if parsed.requiresConfirmation {
|
||||
if let delegate = self.delegate {
|
||||
// Build context for the delegate
|
||||
let classification = classifySQL(parsed.sql)
|
||||
let context = DestructiveOperationContext(
|
||||
sql: parsed.sql,
|
||||
statementKind: detectStatementKind(parsed.sql) ?? .delete,
|
||||
classification: classification,
|
||||
description: "Execute \(parsed.operation.rawValue.uppercased()) operation: \(parsed.sql)",
|
||||
targetTable: extractTargetTableForDelegate(from: parsed.sql, operation: parsed.operation)
|
||||
)
|
||||
// Ask the delegate for approval
|
||||
let approved = await delegate.confirmDestructiveOperation(context)
|
||||
if !approved {
|
||||
throw SwiftDBAIError.confirmationRequired(
|
||||
sql: parsed.sql,
|
||||
operation: parsed.operation.rawValue
|
||||
)
|
||||
}
|
||||
// Delegate approved — fall through to execution
|
||||
} else {
|
||||
// No delegate — throw confirmation required so caller can handle it
|
||||
throw SwiftDBAIError.confirmationRequired(
|
||||
sql: parsed.sql,
|
||||
operation: parsed.operation.rawValue
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 6. Execute the SQL (with timeout if configured)
|
||||
let result: QueryResult
|
||||
do {
|
||||
let classification = classifySQL(parsed.sql)
|
||||
await delegate?.willExecuteSQL(parsed.sql, classification: classification)
|
||||
result = try await executeSQLWithTimeout(parsed.sql)
|
||||
await delegate?.didExecuteSQL(parsed.sql, success: true)
|
||||
} catch let error as SwiftDBAIError {
|
||||
await delegate?.didExecuteSQL(parsed.sql, success: false)
|
||||
throw error
|
||||
} catch let error as ChatEngineError {
|
||||
await delegate?.didExecuteSQL(parsed.sql, success: false)
|
||||
// Map internal ChatEngineError (e.g. from timeout) to SwiftDBAIError
|
||||
throw error.toSwiftDBAIError()
|
||||
} catch {
|
||||
await delegate?.didExecuteSQL(parsed.sql, success: false)
|
||||
throw SwiftDBAIError.databaseError(reason: error.localizedDescription)
|
||||
}
|
||||
|
||||
// 7. Summarize the result using TextSummaryRenderer
|
||||
let summary: String
|
||||
do {
|
||||
summary = try await summaryRenderer.summarize(
|
||||
result: result,
|
||||
userQuestion: message
|
||||
)
|
||||
} catch let error as SwiftDBAIError {
|
||||
throw error
|
||||
} catch {
|
||||
throw SwiftDBAIError.llmFailure(reason: "Summarization failed: \(error.localizedDescription)")
|
||||
}
|
||||
|
||||
// 8. Record conversation history
|
||||
let userMessage = ChatMessage(role: .user, content: message)
|
||||
let assistantMessage = ChatMessage(
|
||||
role: .assistant,
|
||||
content: summary,
|
||||
queryResult: result,
|
||||
sql: parsed.sql
|
||||
)
|
||||
lock.withLock {
|
||||
conversationHistory.append(userMessage)
|
||||
conversationHistory.append(assistantMessage)
|
||||
}
|
||||
|
||||
return ChatResponse(
|
||||
summary: summary,
|
||||
sql: parsed.sql,
|
||||
queryResult: result
|
||||
)
|
||||
}
|
||||
|
||||
/// Sends a natural language message, executing a previously confirmed destructive operation.
|
||||
///
|
||||
/// Call this after receiving a `confirmationRequired` error and the user has confirmed.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - message: The original user message (for history recording).
|
||||
/// - confirmedSQL: The SQL that was confirmed by the user.
|
||||
/// - Returns: A `ChatResponse` with the result.
|
||||
public func sendConfirmed(_ message: String, confirmedSQL: String) async throws -> ChatResponse {
|
||||
let result: QueryResult
|
||||
do {
|
||||
let classification = classifySQL(confirmedSQL)
|
||||
await delegate?.willExecuteSQL(confirmedSQL, classification: classification)
|
||||
result = try await executeSQLWithTimeout(confirmedSQL)
|
||||
await delegate?.didExecuteSQL(confirmedSQL, success: true)
|
||||
} catch let error as SwiftDBAIError {
|
||||
await delegate?.didExecuteSQL(confirmedSQL, success: false)
|
||||
throw error
|
||||
} catch let error as ChatEngineError {
|
||||
await delegate?.didExecuteSQL(confirmedSQL, success: false)
|
||||
throw error.toSwiftDBAIError()
|
||||
} catch {
|
||||
await delegate?.didExecuteSQL(confirmedSQL, success: false)
|
||||
throw SwiftDBAIError.databaseError(reason: error.localizedDescription)
|
||||
}
|
||||
|
||||
let summary: String
|
||||
do {
|
||||
summary = try await summaryRenderer.summarize(
|
||||
result: result,
|
||||
userQuestion: message
|
||||
)
|
||||
} catch let error as SwiftDBAIError {
|
||||
throw error
|
||||
} catch {
|
||||
throw SwiftDBAIError.llmFailure(reason: "Summarization failed: \(error.localizedDescription)")
|
||||
}
|
||||
|
||||
let userMessage = ChatMessage(role: .user, content: message)
|
||||
let assistantMessage = ChatMessage(
|
||||
role: .assistant,
|
||||
content: summary,
|
||||
queryResult: result,
|
||||
sql: confirmedSQL
|
||||
)
|
||||
lock.withLock {
|
||||
conversationHistory.append(userMessage)
|
||||
conversationHistory.append(assistantMessage)
|
||||
}
|
||||
|
||||
return ChatResponse(
|
||||
summary: summary,
|
||||
sql: confirmedSQL,
|
||||
queryResult: result
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns the current conversation history.
|
||||
public var messages: [ChatMessage] {
|
||||
lock.lock()
|
||||
defer { lock.unlock() }
|
||||
return conversationHistory
|
||||
}
|
||||
|
||||
/// Eagerly introspects the database schema so it's ready before the first query.
|
||||
///
|
||||
/// Call this at view-appear time to pre-warm the schema cache. If the schema
|
||||
/// is already cached, this returns immediately. The returned `DatabaseSchema`
|
||||
/// can be used to display table/column info in the UI.
|
||||
///
|
||||
/// - Returns: The introspected `DatabaseSchema`.
|
||||
@discardableResult
|
||||
public func prepareSchema() async throws -> DatabaseSchema {
|
||||
try await ensureSchema()
|
||||
}
|
||||
|
||||
/// The number of tables discovered during schema introspection.
|
||||
/// Returns `nil` if the schema has not been introspected yet.
|
||||
public var tableCount: Int? {
|
||||
lock.withLock { schema?.tableNames.count }
|
||||
}
|
||||
|
||||
/// The cached schema, if introspection has completed.
|
||||
public var cachedSchema: DatabaseSchema? {
|
||||
lock.withLock { schema }
|
||||
}
|
||||
|
||||
/// Clears the conversation history and cached schema.
|
||||
///
|
||||
/// After calling this, the next `send(_:)` call will re-introspect the
|
||||
/// schema. Use ``clearHistory()`` if you only want to reset the conversation
|
||||
/// while keeping the cached schema.
|
||||
public func reset() {
|
||||
lock.withLock {
|
||||
conversationHistory.removeAll()
|
||||
schema = nil
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears only the conversation history, keeping the cached schema.
|
||||
///
|
||||
/// This is useful when you want to start a fresh conversation thread
|
||||
/// without re-introspecting the database. The schema cache remains valid
|
||||
/// as long as the database structure hasn't changed.
|
||||
public func clearHistory() {
|
||||
lock.withLock {
|
||||
conversationHistory.removeAll()
|
||||
}
|
||||
}
|
||||
|
||||
/// The current engine configuration.
|
||||
public var currentConfiguration: ChatEngineConfiguration {
|
||||
configuration
|
||||
}
|
||||
|
||||
// MARK: - Internal Helpers (visible for testing)
|
||||
|
||||
/// Ensures the database schema is introspected and cached.
|
||||
func ensureSchema() async throws -> DatabaseSchema {
|
||||
if let cached = lock.withLock({ schema }) {
|
||||
return cached
|
||||
}
|
||||
|
||||
let introspected = try await SchemaIntrospector.introspect(database: database)
|
||||
|
||||
lock.withLock { schema = introspected }
|
||||
|
||||
return introspected
|
||||
}
|
||||
|
||||
/// Asks the LLM to generate SQL from a natural language question.
|
||||
/// Returns the raw LLM response text (before parsing).
|
||||
///
|
||||
/// Uses the configured ``ChatEngineConfiguration/contextWindowSize`` to limit
|
||||
/// how many conversation messages are included as context for the LLM.
|
||||
private func generateRawResponse(
|
||||
question: String,
|
||||
promptBuilder: PromptBuilder
|
||||
) async throws -> String {
|
||||
let instructions = promptBuilder.buildSystemInstructions()
|
||||
|
||||
// Build user prompt — include full conversation history for follow-ups
|
||||
// Respect context window: only use recent messages for context
|
||||
let userPrompt: String
|
||||
let historySlice = lock.withLock { () -> [ChatMessage] in
|
||||
Array(contextWindowSlice())
|
||||
}
|
||||
|
||||
if historySlice.isEmpty {
|
||||
userPrompt = promptBuilder.buildUserPrompt(question)
|
||||
} else {
|
||||
userPrompt = promptBuilder.buildConversationPrompt(
|
||||
question,
|
||||
history: historySlice
|
||||
)
|
||||
}
|
||||
|
||||
let session = LanguageModelSession(
|
||||
model: model,
|
||||
instructions: instructions + "\n\nRespond with ONLY the SQL query. No explanations, no markdown, no code fences."
|
||||
)
|
||||
|
||||
let response = try await session.respond(to: userPrompt)
|
||||
return response.content.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
}
|
||||
|
||||
/// Returns the conversation history slice within the configured context window.
|
||||
/// Must be called within a `lock.withLock` closure.
|
||||
private func contextWindowSlice() -> ArraySlice<ChatMessage> {
|
||||
guard let windowSize = configuration.contextWindowSize else {
|
||||
return conversationHistory[...]
|
||||
}
|
||||
let count = conversationHistory.count
|
||||
let start = max(0, count - windowSize)
|
||||
return conversationHistory[start...]
|
||||
}
|
||||
|
||||
/// Runs all custom validators from the configuration against the parsed SQL.
|
||||
private func runCustomValidators(parsed: ParsedSQL) throws {
|
||||
for validator in configuration.validators {
|
||||
try validator.validate(sql: parsed.sql, operation: parsed.operation)
|
||||
}
|
||||
}
|
||||
|
||||
/// Extracts the target table name from a SQL statement for delegate context.
|
||||
private func extractTargetTableForDelegate(from sql: String, operation: SQLOperation) -> String? {
|
||||
let pattern: String
|
||||
switch operation {
|
||||
case .insert:
|
||||
pattern = #"INSERT\s+INTO\s+[`"\[]?(\w+)[`"\]]?"#
|
||||
case .update:
|
||||
pattern = #"UPDATE\s+[`"\[]?(\w+)[`"\]]?"#
|
||||
case .delete:
|
||||
pattern = #"DELETE\s+FROM\s+[`"\[]?(\w+)[`"\]]?"#
|
||||
case .select:
|
||||
return nil
|
||||
}
|
||||
guard let regex = try? NSRegularExpression(pattern: pattern, options: .caseInsensitive) else {
|
||||
return nil
|
||||
}
|
||||
let range = NSRange(sql.startIndex..., in: sql)
|
||||
guard let match = regex.firstMatch(in: sql, range: range),
|
||||
match.numberOfRanges > 1,
|
||||
let groupRange = Range(match.range(at: 1), in: sql) else {
|
||||
return nil
|
||||
}
|
||||
return String(sql[groupRange])
|
||||
}
|
||||
|
||||
/// Executes SQL with the configured timeout, if any.
|
||||
private func executeSQLWithTimeout(_ sql: String) async throws -> QueryResult {
|
||||
guard let timeout = configuration.queryTimeout else {
|
||||
return try await executeSQL(sql)
|
||||
}
|
||||
|
||||
return try await withThrowingTaskGroup(of: QueryResult.self) { group in
|
||||
group.addTask {
|
||||
try await self.executeSQL(sql)
|
||||
}
|
||||
|
||||
group.addTask {
|
||||
try await Task.sleep(for: .seconds(timeout))
|
||||
throw ChatEngineError.queryTimedOut(seconds: timeout)
|
||||
}
|
||||
|
||||
// Return whichever finishes first
|
||||
let result = try await group.next()!
|
||||
group.cancelAll()
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
/// Executes SQL against the database and returns a `QueryResult`.
|
||||
private func executeSQL(_ sql: String) async throws -> QueryResult {
|
||||
let trimmed = sql.trimmingCharacters(in: .whitespacesAndNewlines).uppercased()
|
||||
let isSelect = trimmed.hasPrefix("SELECT") || trimmed.hasPrefix("WITH")
|
||||
|
||||
let startTime = CFAbsoluteTimeGetCurrent()
|
||||
|
||||
if isSelect {
|
||||
let result = try await database.read { db -> (columns: [String], rows: [[String: QueryResult.Value]]) in
|
||||
let statement = try db.makeStatement(sql: sql)
|
||||
let columnNames = statement.columnNames
|
||||
|
||||
var rows: [[String: QueryResult.Value]] = []
|
||||
let cursor = try Row.fetchCursor(statement)
|
||||
while let row = try cursor.next() {
|
||||
var dict: [String: QueryResult.Value] = [:]
|
||||
for col in columnNames {
|
||||
dict[col] = Self.extractValue(row: row, column: col)
|
||||
}
|
||||
rows.append(dict)
|
||||
}
|
||||
return (columns: columnNames, rows: rows)
|
||||
}
|
||||
|
||||
let elapsed = CFAbsoluteTimeGetCurrent() - startTime
|
||||
|
||||
return QueryResult(
|
||||
columns: result.columns,
|
||||
rows: result.rows,
|
||||
sql: sql,
|
||||
executionTime: elapsed
|
||||
)
|
||||
} else {
|
||||
// Mutation query
|
||||
let affected = try await database.write { db -> Int in
|
||||
try db.execute(sql: sql)
|
||||
return db.changesCount
|
||||
}
|
||||
|
||||
let elapsed = CFAbsoluteTimeGetCurrent() - startTime
|
||||
|
||||
return QueryResult(
|
||||
columns: [],
|
||||
rows: [],
|
||||
sql: sql,
|
||||
executionTime: elapsed,
|
||||
rowsAffected: affected
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Extracts a `QueryResult.Value` from a GRDB `Row` for the given column.
|
||||
private static func extractValue(row: Row, column: String) -> QueryResult.Value {
|
||||
let dbValue: DatabaseValue = row[column]
|
||||
switch dbValue.storage {
|
||||
case .null:
|
||||
return .null
|
||||
case .int64(let i):
|
||||
return .integer(i)
|
||||
case .double(let d):
|
||||
return .real(d)
|
||||
case .string(let s):
|
||||
return .text(s)
|
||||
case .blob(let data):
|
||||
return .blob(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Errors
|
||||
|
||||
/// Errors that can occur during ChatEngine operations.
|
||||
public enum ChatEngineError: Error, LocalizedError, Sendable {
|
||||
/// SQL parsing/extraction from LLM response failed.
|
||||
case sqlParsingFailed(SQLParsingError)
|
||||
/// A destructive operation requires user confirmation before execution.
|
||||
case confirmationRequired(sql: String, operation: SQLOperation)
|
||||
/// Schema introspection failed.
|
||||
case schemaIntrospectionFailed(String)
|
||||
/// The SQL query exceeded the configured timeout.
|
||||
case queryTimedOut(seconds: TimeInterval)
|
||||
/// A custom query validator rejected the query.
|
||||
case validationFailed(String)
|
||||
|
||||
public var errorDescription: String? {
|
||||
switch self {
|
||||
case .sqlParsingFailed(let parsingError):
|
||||
return "SQL parsing failed: \(parsingError.description)"
|
||||
case .confirmationRequired(let sql, let op):
|
||||
return "The \(op.rawValue.uppercased()) operation requires confirmation: \(sql)"
|
||||
case .schemaIntrospectionFailed(let reason):
|
||||
return "Failed to introspect database schema: \(reason)"
|
||||
case .queryTimedOut(let seconds):
|
||||
return "Query timed out after \(Int(seconds)) seconds."
|
||||
case .validationFailed(let reason):
|
||||
return "Query validation failed: \(reason)"
|
||||
}
|
||||
}
|
||||
}
|
||||
288
Sources/SwiftDBAI/Engine/ToolExecutionDelegate.swift
Normal file
288
Sources/SwiftDBAI/Engine/ToolExecutionDelegate.swift
Normal file
@@ -0,0 +1,288 @@
|
||||
// ToolExecutionDelegate.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Delegate protocol for controlling SQL tool execution, including
|
||||
// confirmation of destructive operations before they reach the database.
|
||||
|
||||
import Foundation
|
||||
|
||||
// MARK: - Destructive SQL Classification
|
||||
|
||||
/// Classifies SQL statements by their destructive potential.
|
||||
///
|
||||
/// A statement is considered **destructive** if it modifies or removes data
|
||||
/// or schema objects. The classification drives the confirmation flow:
|
||||
/// destructive statements require explicit user approval via
|
||||
/// ``ToolExecutionDelegate/confirmDestructiveOperation(_:)``.
|
||||
public enum DestructiveClassification: Sendable, Equatable {
|
||||
/// The statement is read-only (e.g. SELECT). No confirmation needed.
|
||||
case safe
|
||||
|
||||
/// The statement modifies existing data (INSERT, UPDATE).
|
||||
case mutation(SQLStatementKind)
|
||||
|
||||
/// The statement deletes data or alters/drops schema objects.
|
||||
/// These always require confirmation, even when the operation is allowed.
|
||||
case destructive(SQLStatementKind)
|
||||
|
||||
/// Returns `true` when the statement requires user confirmation.
|
||||
public var requiresConfirmation: Bool {
|
||||
switch self {
|
||||
case .safe:
|
||||
return false
|
||||
case .mutation:
|
||||
return false
|
||||
case .destructive:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` when the statement modifies data or schema in any way.
|
||||
public var isMutating: Bool {
|
||||
switch self {
|
||||
case .safe:
|
||||
return false
|
||||
case .mutation, .destructive:
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The kind of SQL statement, used for classification and display.
|
||||
public enum SQLStatementKind: String, Sendable, Hashable, CaseIterable {
|
||||
case select = "SELECT"
|
||||
case insert = "INSERT"
|
||||
case update = "UPDATE"
|
||||
case delete = "DELETE"
|
||||
case drop = "DROP"
|
||||
case alter = "ALTER"
|
||||
case truncate = "TRUNCATE"
|
||||
|
||||
/// All kinds that are classified as destructive.
|
||||
public static let destructiveKinds: Set<SQLStatementKind> = [
|
||||
.delete, .drop, .alter, .truncate
|
||||
]
|
||||
|
||||
/// All kinds that are classified as mutations (data-modifying but not destructive).
|
||||
public static let mutationKinds: Set<SQLStatementKind> = [
|
||||
.insert, .update
|
||||
]
|
||||
|
||||
/// Whether this kind of statement is destructive.
|
||||
public var isDestructive: Bool {
|
||||
Self.destructiveKinds.contains(self)
|
||||
}
|
||||
|
||||
/// Whether this kind of statement is a mutation (INSERT/UPDATE).
|
||||
public var isMutation: Bool {
|
||||
Self.mutationKinds.contains(self)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Classification Function
|
||||
|
||||
/// Classifies a SQL statement string by its destructive potential.
|
||||
///
|
||||
/// The classifier inspects the first keyword token of the statement
|
||||
/// (case-insensitive) to determine the statement kind, then maps it
|
||||
/// to a ``DestructiveClassification``.
|
||||
///
|
||||
/// - Parameter sql: The SQL statement to classify.
|
||||
/// - Returns: The classification for the statement.
|
||||
public func classifySQL(_ sql: String) -> DestructiveClassification {
|
||||
guard let kind = detectStatementKind(sql) else {
|
||||
return .safe
|
||||
}
|
||||
|
||||
if kind.isDestructive {
|
||||
return .destructive(kind)
|
||||
} else if kind.isMutation {
|
||||
return .mutation(kind)
|
||||
} else {
|
||||
return .safe
|
||||
}
|
||||
}
|
||||
|
||||
/// Detects the ``SQLStatementKind`` from the leading keyword of a SQL string.
|
||||
///
|
||||
/// - Parameter sql: The SQL statement to inspect.
|
||||
/// - Returns: The detected kind, or `nil` if unrecognized.
|
||||
public func detectStatementKind(_ sql: String) -> SQLStatementKind? {
|
||||
let trimmed = sql.trimmingCharacters(in: .whitespacesAndNewlines).uppercased()
|
||||
|
||||
// Check each known statement kind against the first token
|
||||
if trimmed.hasPrefix("SELECT") || trimmed.hasPrefix("WITH") {
|
||||
return .select
|
||||
} else if trimmed.hasPrefix("INSERT") {
|
||||
return .insert
|
||||
} else if trimmed.hasPrefix("UPDATE") {
|
||||
return .update
|
||||
} else if trimmed.hasPrefix("DELETE") {
|
||||
return .delete
|
||||
} else if trimmed.hasPrefix("DROP") {
|
||||
return .drop
|
||||
} else if trimmed.hasPrefix("ALTER") {
|
||||
return .alter
|
||||
} else if trimmed.hasPrefix("TRUNCATE") {
|
||||
return .truncate
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MARK: - Destructive Operation Context
|
||||
|
||||
/// Context provided to the delegate when a destructive operation needs confirmation.
|
||||
///
|
||||
/// Contains all the information a UI or programmatic handler needs to
|
||||
/// decide whether to allow the operation.
|
||||
public struct DestructiveOperationContext: Sendable {
|
||||
/// The SQL statement that would be executed.
|
||||
public let sql: String
|
||||
|
||||
/// The detected kind of statement (DELETE, DROP, ALTER, TRUNCATE).
|
||||
public let statementKind: SQLStatementKind
|
||||
|
||||
/// The classification result.
|
||||
public let classification: DestructiveClassification
|
||||
|
||||
/// A human-readable description of what the operation will do.
|
||||
public let description: String
|
||||
|
||||
/// The target table name, if detected.
|
||||
public let targetTable: String?
|
||||
|
||||
public init(
|
||||
sql: String,
|
||||
statementKind: SQLStatementKind,
|
||||
classification: DestructiveClassification,
|
||||
description: String,
|
||||
targetTable: String? = nil
|
||||
) {
|
||||
self.sql = sql
|
||||
self.statementKind = statementKind
|
||||
self.classification = classification
|
||||
self.description = description
|
||||
self.targetTable = targetTable
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - ToolExecutionDelegate Protocol
|
||||
|
||||
/// A delegate that controls execution of SQL operations, providing
|
||||
/// confirmation gates for destructive statements.
|
||||
///
|
||||
/// Implement this protocol to intercept destructive SQL operations
|
||||
/// (DELETE, DROP, ALTER, TRUNCATE) before they are executed. The
|
||||
/// ``ChatEngine`` consults the delegate whenever it encounters a
|
||||
/// statement classified as ``DestructiveClassification/destructive(_:)``.
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
/// ```swift
|
||||
/// struct MyDelegate: ToolExecutionDelegate {
|
||||
/// func confirmDestructiveOperation(
|
||||
/// _ context: DestructiveOperationContext
|
||||
/// ) async -> Bool {
|
||||
/// // Show a confirmation dialog to the user
|
||||
/// return await showAlert(
|
||||
/// "Confirm \(context.statementKind.rawValue)",
|
||||
/// message: context.description
|
||||
/// )
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
/// let engine = ChatEngine(
|
||||
/// database: pool,
|
||||
/// model: model,
|
||||
/// delegate: MyDelegate()
|
||||
/// )
|
||||
/// ```
|
||||
public protocol ToolExecutionDelegate: Sendable {
|
||||
|
||||
/// Called when a destructive SQL operation is about to be executed.
|
||||
///
|
||||
/// The delegate should present the operation details to the user and
|
||||
/// return `true` to proceed or `false` to cancel.
|
||||
///
|
||||
/// - Parameter context: Details about the destructive operation.
|
||||
/// - Returns: `true` to allow execution, `false` to reject it.
|
||||
func confirmDestructiveOperation(
|
||||
_ context: DestructiveOperationContext
|
||||
) async -> Bool
|
||||
|
||||
/// Called before any SQL statement is executed.
|
||||
///
|
||||
/// This is an observation hook — the engine does not wait for a
|
||||
/// decision. Override to log, audit, or instrument queries.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - sql: The SQL about to be executed.
|
||||
/// - classification: The destructive classification of the statement.
|
||||
func willExecuteSQL(
|
||||
_ sql: String,
|
||||
classification: DestructiveClassification
|
||||
) async
|
||||
|
||||
/// Called after a SQL statement completes execution.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - sql: The SQL that was executed.
|
||||
/// - success: Whether execution succeeded.
|
||||
func didExecuteSQL(
|
||||
_ sql: String,
|
||||
success: Bool
|
||||
) async
|
||||
}
|
||||
|
||||
// MARK: - Default Implementations
|
||||
|
||||
extension ToolExecutionDelegate {
|
||||
/// Default: rejects all destructive operations.
|
||||
public func confirmDestructiveOperation(
|
||||
_ context: DestructiveOperationContext
|
||||
) async -> Bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Default: no-op.
|
||||
public func willExecuteSQL(
|
||||
_ sql: String,
|
||||
classification: DestructiveClassification
|
||||
) async {}
|
||||
|
||||
/// Default: no-op.
|
||||
public func didExecuteSQL(
|
||||
_ sql: String,
|
||||
success: Bool
|
||||
) async {}
|
||||
}
|
||||
|
||||
// MARK: - Built-in Delegates
|
||||
|
||||
/// A delegate that automatically approves all destructive operations.
|
||||
///
|
||||
/// Use this only in testing or trusted environments where confirmation
|
||||
/// is not needed.
|
||||
public struct AutoApproveDelegate: ToolExecutionDelegate {
|
||||
public init() {}
|
||||
|
||||
public func confirmDestructiveOperation(
|
||||
_ context: DestructiveOperationContext
|
||||
) async -> Bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// A delegate that always rejects destructive operations.
|
||||
///
|
||||
/// This is the safest option and matches the default behavior.
|
||||
public struct RejectAllDelegate: ToolExecutionDelegate {
|
||||
public init() {}
|
||||
|
||||
public func confirmDestructiveOperation(
|
||||
_ context: DestructiveOperationContext
|
||||
) async -> Bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
143
Sources/SwiftDBAI/Models/ConversationHistory.swift
Normal file
143
Sources/SwiftDBAI/Models/ConversationHistory.swift
Normal file
@@ -0,0 +1,143 @@
|
||||
// ConversationHistory.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Ordered chat message history with configurable context window.
|
||||
|
||||
import Foundation
|
||||
|
||||
/// Stores an ordered sequence of ``ChatMessage`` instances with a configurable
|
||||
/// context window limit.
|
||||
///
|
||||
/// When the number of messages exceeds ``maxMessages``, the oldest messages are
|
||||
/// trimmed to keep the history within budget. This prevents unbounded token
|
||||
/// growth when building LLM prompts from conversation history.
|
||||
///
|
||||
/// Usage:
|
||||
/// ```swift
|
||||
/// var history = ConversationHistory(maxMessages: 20)
|
||||
/// history.append(ChatMessage(role: .user, content: "How many users?"))
|
||||
/// history.append(ChatMessage(role: .assistant, content: "42", sql: "SELECT COUNT(*) FROM users"))
|
||||
/// print(history.promptText) // formatted for LLM context
|
||||
/// ```
|
||||
public struct ConversationHistory: Sendable {
|
||||
|
||||
/// The maximum number of messages to retain. `nil` means unlimited.
|
||||
public let maxMessages: Int?
|
||||
|
||||
/// All messages in chronological order.
|
||||
public private(set) var messages: [ChatMessage] = []
|
||||
|
||||
/// Creates a new conversation history.
|
||||
///
|
||||
/// - Parameter maxMessages: Maximum number of messages to keep in the
|
||||
/// context window. Pass `nil` for unlimited history. Defaults to 50.
|
||||
public init(maxMessages: Int? = 50) {
|
||||
precondition(maxMessages == nil || maxMessages! > 0,
|
||||
"maxMessages must be positive or nil")
|
||||
self.maxMessages = maxMessages
|
||||
}
|
||||
|
||||
/// The number of messages currently stored.
|
||||
public var count: Int { messages.count }
|
||||
|
||||
/// Whether the history is empty.
|
||||
public var isEmpty: Bool { messages.isEmpty }
|
||||
|
||||
// MARK: - Mutating Operations
|
||||
|
||||
/// Appends a message and trims the history if it exceeds the context window.
|
||||
public mutating func append(_ message: ChatMessage) {
|
||||
messages.append(message)
|
||||
trimIfNeeded()
|
||||
}
|
||||
|
||||
/// Appends multiple messages and trims once afterward.
|
||||
public mutating func append(contentsOf newMessages: [ChatMessage]) {
|
||||
messages.append(contentsOf: newMessages)
|
||||
trimIfNeeded()
|
||||
}
|
||||
|
||||
/// Removes all messages from the history.
|
||||
public mutating func clear() {
|
||||
messages.removeAll()
|
||||
}
|
||||
|
||||
// MARK: - Context Window
|
||||
|
||||
/// Returns the most recent messages formatted for inclusion in an LLM prompt.
|
||||
///
|
||||
/// Each message is formatted as `[role] content`, with SQL and query results
|
||||
/// included inline for assistant messages.
|
||||
///
|
||||
/// - Parameter limit: Optional override to further restrict the number of
|
||||
/// messages returned. When `nil`, uses the full retained history.
|
||||
/// - Returns: An array of prompt-formatted strings, one per message.
|
||||
public func promptMessages(limit: Int? = nil) -> [String] {
|
||||
let slice: ArraySlice<ChatMessage>
|
||||
if let limit {
|
||||
slice = messages.suffix(limit)
|
||||
} else {
|
||||
slice = messages[...]
|
||||
}
|
||||
return slice.map { message in
|
||||
Self.formatForPrompt(message)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the combined prompt text for all retained messages, separated by
|
||||
/// double newlines.
|
||||
public var promptText: String {
|
||||
promptMessages().joined(separator: "\n\n")
|
||||
}
|
||||
|
||||
// MARK: - Queries
|
||||
|
||||
/// Returns only user messages.
|
||||
public var userMessages: [ChatMessage] {
|
||||
messages.filter { $0.role == .user }
|
||||
}
|
||||
|
||||
/// Returns only assistant messages.
|
||||
public var assistantMessages: [ChatMessage] {
|
||||
messages.filter { $0.role == .assistant }
|
||||
}
|
||||
|
||||
/// Returns the last message, if any.
|
||||
public var lastMessage: ChatMessage? {
|
||||
messages.last
|
||||
}
|
||||
|
||||
/// Returns the most recent user query text, if any.
|
||||
public var lastUserQuery: String? {
|
||||
messages.last(where: { $0.role == .user })?.content
|
||||
}
|
||||
|
||||
/// Returns the most recent assistant message, if any.
|
||||
public var lastAssistantMessage: ChatMessage? {
|
||||
messages.last(where: { $0.role == .assistant })
|
||||
}
|
||||
|
||||
// MARK: - Private
|
||||
|
||||
/// Formats a ``ChatMessage`` into a string suitable for LLM prompt context.
|
||||
private static func formatForPrompt(_ message: ChatMessage) -> String {
|
||||
var parts: [String] = ["[\(message.role.rawValue)] \(message.content)"]
|
||||
|
||||
if let sql = message.sql {
|
||||
parts.append("SQL: \(sql)")
|
||||
}
|
||||
|
||||
if let result = message.queryResult {
|
||||
parts.append("Result:\n\(result.tabularDescription)")
|
||||
}
|
||||
|
||||
return parts.joined(separator: "\n")
|
||||
}
|
||||
|
||||
/// Trims the oldest messages to stay within the context window.
|
||||
private mutating func trimIfNeeded() {
|
||||
guard let max = maxMessages, messages.count > max else { return }
|
||||
let overflow = messages.count - max
|
||||
messages.removeFirst(overflow)
|
||||
}
|
||||
}
|
||||
136
Sources/SwiftDBAI/Models/QueryResult.swift
Normal file
136
Sources/SwiftDBAI/Models/QueryResult.swift
Normal file
@@ -0,0 +1,136 @@
|
||||
// QueryResult.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Structured result from SQL query execution.
|
||||
|
||||
import Foundation
|
||||
|
||||
/// Represents the result of executing a SQL query against the database.
|
||||
///
|
||||
/// Contains raw row data as dictionaries, column metadata, row count,
|
||||
/// the original SQL string, and execution timing.
|
||||
public struct QueryResult: Sendable, Equatable {
|
||||
|
||||
/// A single cell value from a query result.
|
||||
///
|
||||
/// Wraps SQLite's dynamic value types into a type-safe, Sendable enum.
|
||||
public enum Value: Sendable, Equatable, CustomStringConvertible {
|
||||
case text(String)
|
||||
case integer(Int64)
|
||||
case real(Double)
|
||||
case blob(Data)
|
||||
case null
|
||||
|
||||
public var description: String {
|
||||
switch self {
|
||||
case .text(let s): return s
|
||||
case .integer(let i): return String(i)
|
||||
case .real(let d):
|
||||
if d == d.rounded() && abs(d) < 1e15 {
|
||||
return String(format: "%.0f", d)
|
||||
}
|
||||
return String(d)
|
||||
case .blob(let data): return "<\(data.count) bytes>"
|
||||
case .null: return "NULL"
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the value as a `Double` if it is numeric, nil otherwise.
|
||||
public var doubleValue: Double? {
|
||||
switch self {
|
||||
case .integer(let i): return Double(i)
|
||||
case .real(let d): return d
|
||||
case .text(let s): return Double(s)
|
||||
default: return nil
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the value as a `String` (non-nil for all cases).
|
||||
public var stringValue: String { description }
|
||||
|
||||
/// Returns `true` if this value is `.null`.
|
||||
public var isNull: Bool {
|
||||
if case .null = self { return true }
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/// Column names in the order they appear in the result set.
|
||||
public let columns: [String]
|
||||
|
||||
/// Row data as an array of dictionaries mapping column name to value.
|
||||
public let rows: [[String: Value]]
|
||||
|
||||
/// Total number of rows returned.
|
||||
public var rowCount: Int { rows.count }
|
||||
|
||||
/// The SQL statement that was executed.
|
||||
public let sql: String
|
||||
|
||||
/// Time taken to execute the query, in seconds.
|
||||
public let executionTime: TimeInterval
|
||||
|
||||
/// Number of rows affected (for INSERT/UPDATE/DELETE). Nil for SELECT.
|
||||
public let rowsAffected: Int?
|
||||
|
||||
public init(
|
||||
columns: [String],
|
||||
rows: [[String: Value]],
|
||||
sql: String,
|
||||
executionTime: TimeInterval,
|
||||
rowsAffected: Int? = nil
|
||||
) {
|
||||
self.columns = columns
|
||||
self.rows = rows
|
||||
self.sql = sql
|
||||
self.executionTime = executionTime
|
||||
self.rowsAffected = rowsAffected
|
||||
}
|
||||
|
||||
// MARK: - Convenience Accessors
|
||||
|
||||
/// Returns all values for a given column, in row order.
|
||||
public func values(forColumn column: String) -> [Value] {
|
||||
rows.compactMap { $0[column] }
|
||||
}
|
||||
|
||||
/// Returns a compact tabular string representation of the results.
|
||||
///
|
||||
/// Useful for embedding query results into LLM prompts.
|
||||
public var tabularDescription: String {
|
||||
guard !rows.isEmpty else {
|
||||
return "(empty result set)"
|
||||
}
|
||||
|
||||
var lines: [String] = []
|
||||
|
||||
// Header
|
||||
lines.append(columns.joined(separator: " | "))
|
||||
lines.append(String(repeating: "-", count: lines[0].count))
|
||||
|
||||
// Rows (cap at 50 for prompt size)
|
||||
let displayRows = rows.prefix(50)
|
||||
for row in displayRows {
|
||||
let vals = columns.map { col in
|
||||
row[col]?.description ?? "NULL"
|
||||
}
|
||||
lines.append(vals.joined(separator: " | "))
|
||||
}
|
||||
|
||||
if rows.count > 50 {
|
||||
lines.append("... and \(rows.count - 50) more rows")
|
||||
}
|
||||
|
||||
return lines.joined(separator: "\n")
|
||||
}
|
||||
|
||||
/// Returns true if the result looks like a single aggregate value
|
||||
/// (1 row, 1-3 columns, all numeric).
|
||||
public var isAggregate: Bool {
|
||||
guard rowCount == 1, columns.count <= 3 else { return false }
|
||||
let firstRow = rows[0]
|
||||
return columns.allSatisfy { col in
|
||||
firstRow[col]?.doubleValue != nil
|
||||
}
|
||||
}
|
||||
}
|
||||
380
Sources/SwiftDBAI/Parsing/SQLQueryParser.swift
Normal file
380
Sources/SwiftDBAI/Parsing/SQLQueryParser.swift
Normal file
@@ -0,0 +1,380 @@
|
||||
// SQLQueryParser.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Extracts and validates SQL statements from raw LLM response text.
|
||||
|
||||
import Foundation
|
||||
|
||||
/// Errors that can occur during SQL parsing and validation.
|
||||
public enum SQLParsingError: Error, Sendable, Equatable, CustomStringConvertible {
|
||||
/// No SQL statement could be found in the LLM response.
|
||||
case noSQLFound
|
||||
|
||||
/// The SQL statement uses an operation not in the allowlist.
|
||||
case operationNotAllowed(SQLOperation)
|
||||
|
||||
/// A destructive operation (DELETE) requires user confirmation.
|
||||
case confirmationRequired(sql: String, operation: SQLOperation)
|
||||
|
||||
/// The mutation targets a table not in the allowed mutation tables.
|
||||
case tableNotAllowed(table: String, operation: SQLOperation)
|
||||
|
||||
/// The SQL contains a disallowed keyword (e.g., DROP, ALTER, TRUNCATE).
|
||||
case dangerousOperation(String)
|
||||
|
||||
/// Multiple SQL statements were found but only single-statement execution is supported.
|
||||
case multipleStatements
|
||||
|
||||
public var description: String {
|
||||
switch self {
|
||||
case .noSQLFound:
|
||||
return "No SQL statement found in the response."
|
||||
case .operationNotAllowed(let op):
|
||||
return "Operation '\(op.rawValue.uppercased())' is not allowed by the current configuration."
|
||||
case .confirmationRequired(let sql, let op):
|
||||
return "The \(op.rawValue.uppercased()) operation requires confirmation: \(sql)"
|
||||
case .tableNotAllowed(let table, let op):
|
||||
return "The \(op.rawValue.uppercased()) operation is not allowed on table '\(table)'."
|
||||
case .dangerousOperation(let keyword):
|
||||
return "Dangerous SQL operation '\(keyword)' is never allowed."
|
||||
case .multipleStatements:
|
||||
return "Only single SQL statements are supported."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of successfully parsing SQL from an LLM response.
|
||||
public struct ParsedSQL: Sendable, Equatable {
|
||||
/// The cleaned SQL statement ready for execution.
|
||||
public let sql: String
|
||||
|
||||
/// The detected operation type.
|
||||
public let operation: SQLOperation
|
||||
|
||||
/// Whether this operation requires user confirmation before execution.
|
||||
public let requiresConfirmation: Bool
|
||||
|
||||
public init(sql: String, operation: SQLOperation, requiresConfirmation: Bool = false) {
|
||||
self.sql = sql
|
||||
self.operation = operation
|
||||
self.requiresConfirmation = requiresConfirmation
|
||||
}
|
||||
}
|
||||
|
||||
/// Extracts SQL statements from raw LLM response text and validates them
|
||||
/// against the configured ``OperationAllowlist``.
|
||||
///
|
||||
/// The parser handles common LLM output patterns:
|
||||
/// - SQL in markdown code blocks (```sql ... ```)
|
||||
/// - SQL in generic code blocks (``` ... ```)
|
||||
/// - Raw SQL statements in plain text
|
||||
/// - SQL prefixed with labels like "SQL:" or "Query:"
|
||||
public struct SQLQueryParser: Sendable {
|
||||
|
||||
/// Keywords that are never allowed regardless of allowlist configuration.
|
||||
private static let dangerousKeywords: Set<String> = [
|
||||
"DROP", "ALTER", "TRUNCATE", "CREATE", "GRANT", "REVOKE",
|
||||
"ATTACH", "DETACH", "PRAGMA", "VACUUM", "REINDEX"
|
||||
]
|
||||
|
||||
/// The operation allowlist to validate against.
|
||||
private let allowlist: OperationAllowlist
|
||||
|
||||
/// The mutation policy for table-level restrictions.
|
||||
private let mutationPolicy: MutationPolicy?
|
||||
|
||||
/// Creates a parser with the given operation allowlist.
|
||||
/// - Parameter allowlist: The set of permitted operations. Defaults to read-only.
|
||||
public init(allowlist: OperationAllowlist = .readOnly) {
|
||||
self.allowlist = allowlist
|
||||
self.mutationPolicy = nil
|
||||
}
|
||||
|
||||
/// Creates a parser with a mutation policy (preferred initializer).
|
||||
/// - Parameter mutationPolicy: The mutation policy controlling operations and table access.
|
||||
public init(mutationPolicy: MutationPolicy) {
|
||||
self.allowlist = mutationPolicy.operationAllowlist
|
||||
self.mutationPolicy = mutationPolicy
|
||||
}
|
||||
|
||||
/// Extracts and validates a SQL statement from raw LLM response text.
|
||||
///
|
||||
/// - Parameter text: The raw text from the LLM response.
|
||||
/// - Returns: A ``ParsedSQL`` containing the validated statement.
|
||||
/// - Throws: ``SQLParsingError`` if extraction or validation fails.
|
||||
public func parse(_ text: String) throws -> ParsedSQL {
|
||||
let sql = try extractSQL(from: text)
|
||||
return try validate(sql)
|
||||
}
|
||||
|
||||
// MARK: - Extraction
|
||||
|
||||
/// Attempts to extract a SQL statement from the LLM response text.
|
||||
/// Tries multiple strategies in order of confidence.
|
||||
func extractSQL(from text: String) throws -> String {
|
||||
// Strategy 1: SQL in markdown fenced code block with sql language tag
|
||||
if let sql = extractFromSQLCodeBlock(text) {
|
||||
return sql
|
||||
}
|
||||
|
||||
// Strategy 2: SQL in generic fenced code block
|
||||
if let sql = extractFromGenericCodeBlock(text) {
|
||||
return sql
|
||||
}
|
||||
|
||||
// Strategy 3: SQL after a label like "SQL:" or "Query:"
|
||||
if let sql = extractFromLabel(text) {
|
||||
return sql
|
||||
}
|
||||
|
||||
// Strategy 4: Direct SQL detection in plain text
|
||||
if let sql = extractDirectSQL(text) {
|
||||
return sql
|
||||
}
|
||||
|
||||
throw SQLParsingError.noSQLFound
|
||||
}
|
||||
|
||||
/// Extracts SQL from a ```sql ... ``` code block.
|
||||
private func extractFromSQLCodeBlock(_ text: String) -> String? {
|
||||
let pattern = #"```sql\s*\n([\s\S]*?)```"#
|
||||
return firstMatch(pattern: pattern, in: text, group: 1)?
|
||||
.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
.nonEmptyOrNil
|
||||
}
|
||||
|
||||
/// Extracts SQL from a generic ``` ... ``` code block.
|
||||
private func extractFromGenericCodeBlock(_ text: String) -> String? {
|
||||
let pattern = #"```\s*\n([\s\S]*?)```"#
|
||||
guard let content = firstMatch(pattern: pattern, in: text, group: 1)?
|
||||
.trimmingCharacters(in: .whitespacesAndNewlines) else {
|
||||
return nil
|
||||
}
|
||||
// Only accept if it looks like SQL
|
||||
guard looksLikeSQL(content) else { return nil }
|
||||
return content.nonEmptyOrNil
|
||||
}
|
||||
|
||||
/// Extracts SQL after labels like "SQL:", "Query:", "Here's the query:"
|
||||
private func extractFromLabel(_ text: String) -> String? {
|
||||
// Match the SQL keyword up to end-of-line (handling multi-line SQL with indentation)
|
||||
let pattern = #"(?:SQL|Query|Statement)\s*:\s*\n?\s*((?:SELECT|INSERT|UPDATE|DELETE|WITH)\b.+?)(?:\n(?!\s)|$)"#
|
||||
guard let content = firstMatch(pattern: pattern, in: text, group: 1, options: [.caseInsensitive, .dotMatchesLineSeparators])?
|
||||
.trimmingCharacters(in: .whitespacesAndNewlines) else {
|
||||
return nil
|
||||
}
|
||||
guard looksLikeSQL(content) else { return nil }
|
||||
return content.nonEmptyOrNil
|
||||
}
|
||||
|
||||
/// Detects SQL directly in the text by matching known statement patterns.
|
||||
private func extractDirectSQL(_ text: String) -> String? {
|
||||
// Match SQL statement, allowing semicolons inside single-quoted string literals
|
||||
let pattern = #"(?:^|\n)\s*((?:SELECT|INSERT|UPDATE|DELETE)\b(?:[^;']|'[^']*')*;?)"#
|
||||
guard let content = firstMatch(pattern: pattern, in: text, group: 1, options: .caseInsensitive)?
|
||||
.trimmingCharacters(in: .whitespacesAndNewlines) else {
|
||||
return nil
|
||||
}
|
||||
return content.nonEmptyOrNil
|
||||
}
|
||||
|
||||
// MARK: - Validation
|
||||
|
||||
/// Validates a SQL string against the allowlist and safety rules.
|
||||
func validate(_ sql: String) throws -> ParsedSQL {
|
||||
let cleaned = cleanSQL(sql)
|
||||
|
||||
guard !cleaned.isEmpty else {
|
||||
throw SQLParsingError.noSQLFound
|
||||
}
|
||||
|
||||
// Check for multiple statements (semicolons in non-trivial positions)
|
||||
if containsMultipleStatements(cleaned) {
|
||||
throw SQLParsingError.multipleStatements
|
||||
}
|
||||
|
||||
// Check for dangerous operations first (before allowlist)
|
||||
try checkDangerousKeywords(cleaned)
|
||||
|
||||
// Detect the operation type
|
||||
let operation = detectOperation(cleaned)
|
||||
|
||||
// Check against the allowlist
|
||||
guard allowlist.isAllowed(operation) else {
|
||||
throw SQLParsingError.operationNotAllowed(operation)
|
||||
}
|
||||
|
||||
// Check table-level restrictions for mutation operations
|
||||
if let policy = mutationPolicy, operation != .select,
|
||||
let targetTable = extractTargetTable(from: cleaned, operation: operation) {
|
||||
guard policy.isAllowed(operation: operation, on: targetTable) else {
|
||||
throw SQLParsingError.tableNotAllowed(table: targetTable, operation: operation)
|
||||
}
|
||||
}
|
||||
|
||||
// DELETE requires confirmation when policy says so, or always by default
|
||||
let requiresConfirmation: Bool
|
||||
if let policy = mutationPolicy {
|
||||
requiresConfirmation = policy.requiresConfirmation(for: operation)
|
||||
} else {
|
||||
requiresConfirmation = operation == .delete
|
||||
}
|
||||
|
||||
return ParsedSQL(
|
||||
sql: cleaned,
|
||||
operation: operation,
|
||||
requiresConfirmation: requiresConfirmation
|
||||
)
|
||||
}
|
||||
|
||||
// MARK: - Helpers
|
||||
|
||||
/// Cleans a SQL string by removing trailing semicolons (outside string literals) and excess whitespace.
|
||||
private func cleanSQL(_ sql: String) -> String {
|
||||
var cleaned = sql.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
// Remove trailing semicolons only if they're outside string literals
|
||||
while cleaned.hasSuffix(";") && !isInsideStringLiteral(sql: cleaned, position: cleaned.index(before: cleaned.endIndex)) {
|
||||
cleaned = String(cleaned.dropLast()).trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
}
|
||||
// Collapse internal whitespace outside string literals
|
||||
cleaned = collapseWhitespace(cleaned)
|
||||
return cleaned
|
||||
}
|
||||
|
||||
/// Collapses whitespace while preserving string literal contents.
|
||||
private func collapseWhitespace(_ sql: String) -> String {
|
||||
var result = ""
|
||||
var inString = false
|
||||
var prevWasSpace = false
|
||||
for ch in sql {
|
||||
if ch == "'" {
|
||||
inString.toggle()
|
||||
prevWasSpace = false
|
||||
result.append(ch)
|
||||
} else if inString {
|
||||
result.append(ch)
|
||||
} else if ch.isWhitespace {
|
||||
if !prevWasSpace {
|
||||
result.append(" ")
|
||||
prevWasSpace = true
|
||||
}
|
||||
} else {
|
||||
prevWasSpace = false
|
||||
result.append(ch)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
/// Returns true if the character at the given position is inside a single-quoted string literal.
|
||||
private func isInsideStringLiteral(sql: String, position: String.Index) -> Bool {
|
||||
var inString = false
|
||||
for idx in sql.indices {
|
||||
if idx == position { return inString }
|
||||
if sql[idx] == "'" { inString.toggle() }
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/// Checks whether cleaned SQL contains multiple statements.
|
||||
private func containsMultipleStatements(_ sql: String) -> Bool {
|
||||
// Remove string literals before checking for semicolons
|
||||
var inString = false
|
||||
for ch in sql {
|
||||
if ch == "'" {
|
||||
inString.toggle()
|
||||
} else if ch == ";" && !inString {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
/// Checks for dangerous SQL keywords that are never allowed.
|
||||
private func checkDangerousKeywords(_ sql: String) throws {
|
||||
let upper = sql.uppercased()
|
||||
// Tokenize to avoid partial matches (e.g., "DROPDOWN" matching "DROP")
|
||||
let tokens = upper.components(separatedBy: .alphanumerics.inverted)
|
||||
.filter { !$0.isEmpty }
|
||||
|
||||
for keyword in Self.dangerousKeywords {
|
||||
if tokens.contains(keyword) {
|
||||
throw SQLParsingError.dangerousOperation(keyword)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Detects the SQL operation type from the first keyword.
|
||||
private func detectOperation(_ sql: String) -> SQLOperation {
|
||||
let upper = sql.uppercased().trimmingCharacters(in: .whitespaces)
|
||||
|
||||
if upper.hasPrefix("SELECT") || upper.hasPrefix("WITH") {
|
||||
return .select
|
||||
} else if upper.hasPrefix("INSERT") {
|
||||
return .insert
|
||||
} else if upper.hasPrefix("UPDATE") {
|
||||
return .update
|
||||
} else if upper.hasPrefix("DELETE") {
|
||||
return .delete
|
||||
}
|
||||
|
||||
// Default to select for unrecognized patterns (e.g. EXPLAIN)
|
||||
return .select
|
||||
}
|
||||
|
||||
/// Extracts the target table name from a mutation SQL statement.
|
||||
///
|
||||
/// Handles common patterns:
|
||||
/// - `INSERT INTO table_name ...`
|
||||
/// - `UPDATE table_name SET ...`
|
||||
/// - `DELETE FROM table_name ...`
|
||||
private func extractTargetTable(from sql: String, operation: SQLOperation) -> String? {
|
||||
let pattern: String
|
||||
switch operation {
|
||||
case .insert:
|
||||
pattern = #"INSERT\s+INTO\s+[`"\[]?(\w+)[`"\]]?"#
|
||||
case .update:
|
||||
pattern = #"UPDATE\s+[`"\[]?(\w+)[`"\]]?"#
|
||||
case .delete:
|
||||
pattern = #"DELETE\s+FROM\s+[`"\[]?(\w+)[`"\]]?"#
|
||||
case .select:
|
||||
return nil
|
||||
}
|
||||
return firstMatch(pattern: pattern, in: sql, group: 1, options: .caseInsensitive)
|
||||
}
|
||||
|
||||
/// Returns true if the text looks like a SQL statement.
|
||||
private func looksLikeSQL(_ text: String) -> Bool {
|
||||
let upper = text.uppercased().trimmingCharacters(in: .whitespaces)
|
||||
let sqlPrefixes = ["SELECT", "INSERT", "UPDATE", "DELETE", "WITH"]
|
||||
return sqlPrefixes.contains { upper.hasPrefix($0) }
|
||||
}
|
||||
|
||||
/// Extracts the first regex match group from the text.
|
||||
private func firstMatch(
|
||||
pattern: String,
|
||||
in text: String,
|
||||
group: Int,
|
||||
options: NSRegularExpression.Options = []
|
||||
) -> String? {
|
||||
guard let regex = try? NSRegularExpression(pattern: pattern, options: options) else {
|
||||
return nil
|
||||
}
|
||||
let range = NSRange(text.startIndex..., in: text)
|
||||
guard let match = regex.firstMatch(in: text, range: range),
|
||||
match.numberOfRanges > group,
|
||||
let groupRange = Range(match.range(at: group), in: text) else {
|
||||
return nil
|
||||
}
|
||||
return String(text[groupRange])
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - String Extension
|
||||
|
||||
private extension String {
|
||||
/// Returns nil if the string is empty, otherwise returns self.
|
||||
var nonEmptyOrNil: String? {
|
||||
isEmpty ? nil : self
|
||||
}
|
||||
}
|
||||
211
Sources/SwiftDBAI/Prompt/PromptBuilder.swift
Normal file
211
Sources/SwiftDBAI/Prompt/PromptBuilder.swift
Normal file
@@ -0,0 +1,211 @@
|
||||
/// Builds structured LLM prompts for SQL generation from a database schema
|
||||
/// and natural language input.
|
||||
///
|
||||
/// `PromptBuilder` is the bridge between the introspected database schema and
|
||||
/// the LLM. It produces two things:
|
||||
/// 1. A **system instructions** string containing schema context and behavioral rules
|
||||
/// 2. A **user prompt** string wrapping the natural language question
|
||||
///
|
||||
/// Usage:
|
||||
/// ```swift
|
||||
/// let builder = PromptBuilder(schema: mySchema, allowlist: .readOnly)
|
||||
/// let instructions = builder.buildSystemInstructions()
|
||||
/// let prompt = builder.buildUserPrompt("How many users signed up this week?")
|
||||
/// ```
|
||||
public struct PromptBuilder: Sendable {
|
||||
/// The database schema to include as context.
|
||||
public let schema: DatabaseSchema
|
||||
|
||||
/// Which SQL operations the LLM may generate.
|
||||
public let allowlist: OperationAllowlist
|
||||
|
||||
/// Optional additional context to append to the system instructions
|
||||
/// (e.g., business-specific terminology or query hints).
|
||||
public let additionalContext: String?
|
||||
|
||||
/// Creates a prompt builder for the given schema and allowlist.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - schema: The introspected database schema.
|
||||
/// - allowlist: Permitted SQL operations. Defaults to ``OperationAllowlist/readOnly``.
|
||||
/// - additionalContext: Extra instructions appended to the system prompt.
|
||||
public init(
|
||||
schema: DatabaseSchema,
|
||||
allowlist: OperationAllowlist = .readOnly,
|
||||
additionalContext: String? = nil
|
||||
) {
|
||||
self.schema = schema
|
||||
self.allowlist = allowlist
|
||||
self.additionalContext = additionalContext
|
||||
}
|
||||
|
||||
// MARK: - System Instructions
|
||||
|
||||
/// Builds the system instructions string that should be passed as the
|
||||
/// `instructions` parameter when creating a `LanguageModelSession`.
|
||||
///
|
||||
/// The instructions include:
|
||||
/// - Role definition
|
||||
/// - The full database schema
|
||||
/// - SQL generation rules and constraints
|
||||
/// - The operation allowlist
|
||||
/// - Output format requirements
|
||||
public func buildSystemInstructions() -> String {
|
||||
var sections: [String] = []
|
||||
|
||||
// 1. Role
|
||||
sections.append(Self.roleSection)
|
||||
|
||||
// 2. Schema
|
||||
sections.append(buildSchemaSection())
|
||||
|
||||
// 3. Operation permissions
|
||||
sections.append(buildPermissionsSection())
|
||||
|
||||
// 4. SQL generation rules
|
||||
sections.append(Self.sqlRulesSection)
|
||||
|
||||
// 5. Output format
|
||||
sections.append(Self.outputFormatSection)
|
||||
|
||||
// 6. Additional context
|
||||
if let additionalContext, !additionalContext.isEmpty {
|
||||
sections.append("ADDITIONAL CONTEXT\n=================\n\(additionalContext)")
|
||||
}
|
||||
|
||||
return sections.joined(separator: "\n\n")
|
||||
}
|
||||
|
||||
// MARK: - User Prompt
|
||||
|
||||
/// Wraps a natural language question into a user prompt string.
|
||||
///
|
||||
/// - Parameter question: The user's natural language question.
|
||||
/// - Returns: A formatted prompt string for the LLM.
|
||||
public func buildUserPrompt(_ question: String) -> String {
|
||||
question
|
||||
}
|
||||
|
||||
/// Builds a follow-up prompt that includes prior SQL context for
|
||||
/// multi-turn conversations.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - question: The user's follow-up question.
|
||||
/// - previousSQL: The SQL from the previous turn, for context.
|
||||
/// - previousResultSummary: A brief summary of what the previous query returned.
|
||||
/// - Returns: A formatted prompt string.
|
||||
public func buildFollowUpPrompt(
|
||||
_ question: String,
|
||||
previousSQL: String,
|
||||
previousResultSummary: String
|
||||
) -> String {
|
||||
"""
|
||||
Previous query: \(previousSQL)
|
||||
Previous result: \(previousResultSummary)
|
||||
|
||||
Follow-up question: \(question)
|
||||
"""
|
||||
}
|
||||
|
||||
/// Builds a prompt that includes the full conversation history within the
|
||||
/// configured context window, enabling the LLM to resolve follow-up
|
||||
/// references (pronouns, implicit table/column references, etc.).
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - question: The user's current question.
|
||||
/// - history: The conversation history messages within the context window.
|
||||
/// - Returns: A formatted prompt string with conversation context.
|
||||
public func buildConversationPrompt(
|
||||
_ question: String,
|
||||
history: [ChatMessage]
|
||||
) -> String {
|
||||
guard !history.isEmpty else {
|
||||
return buildUserPrompt(question)
|
||||
}
|
||||
|
||||
var lines: [String] = []
|
||||
lines.append("CONVERSATION HISTORY")
|
||||
lines.append("====================")
|
||||
|
||||
for message in history {
|
||||
switch message.role {
|
||||
case .user:
|
||||
lines.append("User: \(message.content)")
|
||||
case .assistant:
|
||||
if let sql = message.sql {
|
||||
lines.append("Assistant SQL: \(sql)")
|
||||
}
|
||||
lines.append("Assistant: \(message.content)")
|
||||
case .error:
|
||||
lines.append("Error: \(message.content)")
|
||||
}
|
||||
}
|
||||
|
||||
lines.append("")
|
||||
lines.append("CURRENT QUESTION")
|
||||
lines.append("================")
|
||||
lines.append(question)
|
||||
|
||||
return lines.joined(separator: "\n")
|
||||
}
|
||||
|
||||
// MARK: - Private Sections
|
||||
|
||||
private func buildSchemaSection() -> String {
|
||||
var lines: [String] = []
|
||||
lines.append("DATABASE SCHEMA")
|
||||
lines.append("===============")
|
||||
lines.append("")
|
||||
lines.append(schema.schemaDescription)
|
||||
return lines.joined(separator: "\n")
|
||||
}
|
||||
|
||||
private func buildPermissionsSection() -> String {
|
||||
var lines: [String] = []
|
||||
lines.append("PERMISSIONS")
|
||||
lines.append("===========")
|
||||
lines.append(allowlist.describeForLLM())
|
||||
return lines.joined(separator: "\n")
|
||||
}
|
||||
|
||||
// MARK: - Static Content
|
||||
|
||||
static let roleSection = """
|
||||
ROLE
|
||||
====
|
||||
You are a SQL assistant for a SQLite database. Your job is to translate \
|
||||
natural language questions into valid SQLite SQL queries based on the \
|
||||
database schema provided below. You must ONLY reference tables and columns \
|
||||
that exist in the schema. Never fabricate table or column names.
|
||||
"""
|
||||
|
||||
static let sqlRulesSection = """
|
||||
SQL GENERATION RULES
|
||||
====================
|
||||
1. Use ONLY the tables and columns listed in the schema above.
|
||||
2. Use SQLite-compatible syntax (e.g., || for string concatenation, \
|
||||
IFNULL instead of COALESCE where needed).
|
||||
3. Use appropriate JOINs when queries span multiple tables — reference \
|
||||
the foreign key relationships in the schema.
|
||||
4. For date/time operations, use SQLite date functions \
|
||||
(date(), time(), datetime(), strftime()).
|
||||
5. Use parameterized-style values where possible. For literal values \
|
||||
from the user's question, embed them directly in the SQL.
|
||||
6. Always include an ORDER BY clause when the user implies ordering.
|
||||
7. Use LIMIT when the user asks for "top N" or "first N" results.
|
||||
8. For aggregate queries (count, sum, average, min, max), use the \
|
||||
appropriate SQL aggregate functions.
|
||||
9. When the user's question is ambiguous, prefer the simplest valid \
|
||||
interpretation.
|
||||
10. Never generate DDL statements (CREATE, ALTER, DROP TABLE).
|
||||
"""
|
||||
|
||||
static let outputFormatSection = """
|
||||
OUTPUT FORMAT
|
||||
=============
|
||||
When generating SQL, call the appropriate tool with the SQL query. \
|
||||
After receiving query results, provide a concise natural language \
|
||||
summary of the data. Be specific with numbers and names from the results. \
|
||||
If no rows are returned, say so clearly.
|
||||
"""
|
||||
}
|
||||
423
Sources/SwiftDBAI/Rendering/ChartDataDetector.swift
Normal file
423
Sources/SwiftDBAI/Rendering/ChartDataDetector.swift
Normal file
@@ -0,0 +1,423 @@
|
||||
// ChartDataDetector.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Analyzes query results to determine chart eligibility and
|
||||
// recommends appropriate chart types based on data shape.
|
||||
|
||||
import Foundation
|
||||
|
||||
/// Detects whether a `DataTable` is suitable for charting and
|
||||
/// recommends the best chart type based on data shape heuristics.
|
||||
///
|
||||
/// The detector examines column types, row counts, and value distributions
|
||||
/// to produce a `ChartRecommendation` that the rendering layer can use
|
||||
/// to auto-select an appropriate Swift Charts visualization.
|
||||
///
|
||||
/// Usage:
|
||||
/// ```swift
|
||||
/// let detector = ChartDataDetector()
|
||||
/// if let recommendation = detector.detect(table) {
|
||||
/// switch recommendation.chartType {
|
||||
/// case .bar: // render bar chart
|
||||
/// case .line: // render line chart
|
||||
/// case .pie: // render pie chart
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
public struct ChartDataDetector: Sendable {
|
||||
|
||||
// MARK: - Chart Types
|
||||
|
||||
/// The type of chart recommended for the data.
|
||||
public enum ChartType: String, Sendable, Equatable, CaseIterable {
|
||||
/// Vertical bar chart — best for categorical comparisons.
|
||||
case bar
|
||||
/// Line chart — best for time series or ordered sequences.
|
||||
case line
|
||||
/// Pie/donut chart — best for proportional breakdowns with few categories.
|
||||
case pie
|
||||
}
|
||||
|
||||
/// A recommendation for how to chart a `DataTable`.
|
||||
public struct ChartRecommendation: Sendable, Equatable {
|
||||
/// The recommended chart type.
|
||||
public let chartType: ChartType
|
||||
|
||||
/// The column to use for the category axis (x-axis / labels).
|
||||
public let categoryColumn: String
|
||||
|
||||
/// The column to use for the value axis (y-axis / sizes).
|
||||
public let valueColumn: String
|
||||
|
||||
/// Confidence score from 0.0 (guess) to 1.0 (strong match).
|
||||
public let confidence: Double
|
||||
|
||||
/// Human-readable reason for this recommendation.
|
||||
public let reason: String
|
||||
|
||||
public init(
|
||||
chartType: ChartType,
|
||||
categoryColumn: String,
|
||||
valueColumn: String,
|
||||
confidence: Double,
|
||||
reason: String
|
||||
) {
|
||||
self.chartType = chartType
|
||||
self.categoryColumn = categoryColumn
|
||||
self.valueColumn = valueColumn
|
||||
self.confidence = confidence
|
||||
self.reason = reason
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Configuration
|
||||
|
||||
/// Minimum rows required to consider chart-eligible.
|
||||
public let minimumRows: Int
|
||||
|
||||
/// Maximum rows for a pie chart (too many slices becomes unreadable).
|
||||
public let maxPieSlices: Int
|
||||
|
||||
/// Maximum rows for any chart before it becomes cluttered.
|
||||
public let maximumRows: Int
|
||||
|
||||
// MARK: - Initialization
|
||||
|
||||
/// Creates a detector with configurable thresholds.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - minimumRows: Minimum rows for chart eligibility (default: 2).
|
||||
/// - maxPieSlices: Maximum categories for pie charts (default: 8).
|
||||
/// - maximumRows: Maximum rows for any chart (default: 100).
|
||||
public init(
|
||||
minimumRows: Int = 2,
|
||||
maxPieSlices: Int = 8,
|
||||
maximumRows: Int = 100
|
||||
) {
|
||||
self.minimumRows = minimumRows
|
||||
self.maxPieSlices = maxPieSlices
|
||||
self.maximumRows = maximumRows
|
||||
}
|
||||
|
||||
// MARK: - Detection
|
||||
|
||||
/// Analyzes a `DataTable` and returns a chart recommendation, or `nil`
|
||||
/// if the data is not suitable for charting.
|
||||
///
|
||||
/// - Parameter table: The data table to analyze.
|
||||
/// - Returns: A recommendation, or `nil` if no chart type fits.
|
||||
public func detect(_ table: DataTable) -> ChartRecommendation? {
|
||||
// Must have at least 2 columns (category + value) and enough rows
|
||||
guard table.columnCount >= 2,
|
||||
table.rowCount >= minimumRows,
|
||||
table.rowCount <= maximumRows else {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find candidate category and value columns
|
||||
guard let (categoryCol, valueCol) = findCategoryValuePair(in: table) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
let chartType = recommendChartType(
|
||||
table: table,
|
||||
categoryColumn: categoryCol,
|
||||
valueColumn: valueCol
|
||||
)
|
||||
|
||||
let confidence = computeConfidence(
|
||||
table: table,
|
||||
categoryColumn: categoryCol,
|
||||
valueColumn: valueCol,
|
||||
chartType: chartType
|
||||
)
|
||||
|
||||
let reason = describeReason(
|
||||
chartType: chartType,
|
||||
categoryColumn: categoryCol,
|
||||
valueColumn: valueCol,
|
||||
table: table
|
||||
)
|
||||
|
||||
return ChartRecommendation(
|
||||
chartType: chartType,
|
||||
categoryColumn: categoryCol.name,
|
||||
valueColumn: valueCol.name,
|
||||
confidence: confidence,
|
||||
reason: reason
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns all viable chart recommendations, ranked by confidence.
|
||||
///
|
||||
/// - Parameter table: The data table to analyze.
|
||||
/// - Returns: An array of recommendations sorted by confidence (highest first).
|
||||
public func allRecommendations(for table: DataTable) -> [ChartRecommendation] {
|
||||
guard table.columnCount >= 2,
|
||||
table.rowCount >= minimumRows,
|
||||
table.rowCount <= maximumRows else {
|
||||
return []
|
||||
}
|
||||
|
||||
guard let (categoryCol, valueCol) = findCategoryValuePair(in: table) else {
|
||||
return []
|
||||
}
|
||||
|
||||
return ChartType.allCases.compactMap { chartType in
|
||||
guard isViable(chartType, table: table, categoryColumn: categoryCol) else {
|
||||
return nil
|
||||
}
|
||||
|
||||
let confidence = computeConfidence(
|
||||
table: table,
|
||||
categoryColumn: categoryCol,
|
||||
valueColumn: valueCol,
|
||||
chartType: chartType
|
||||
)
|
||||
|
||||
let reason = describeReason(
|
||||
chartType: chartType,
|
||||
categoryColumn: categoryCol,
|
||||
valueColumn: valueCol,
|
||||
table: table
|
||||
)
|
||||
|
||||
return ChartRecommendation(
|
||||
chartType: chartType,
|
||||
categoryColumn: categoryCol.name,
|
||||
valueColumn: valueCol.name,
|
||||
confidence: confidence,
|
||||
reason: reason
|
||||
)
|
||||
}
|
||||
.sorted { $0.confidence > $1.confidence }
|
||||
}
|
||||
|
||||
// MARK: - Private Helpers
|
||||
|
||||
/// Finds the best (category, value) column pair from the table.
|
||||
private func findCategoryValuePair(
|
||||
in table: DataTable
|
||||
) -> (category: DataTable.Column, value: DataTable.Column)? {
|
||||
let numericColumns = table.columns.filter { isNumeric($0) }
|
||||
let categoryColumns = table.columns.filter { isCategory($0) }
|
||||
|
||||
// Prefer: first text/category column + first numeric column
|
||||
if let cat = categoryColumns.first, let val = numericColumns.first {
|
||||
return (cat, val)
|
||||
}
|
||||
|
||||
// Fallback: if all columns are numeric, use first as category, second as value
|
||||
if numericColumns.count >= 2 {
|
||||
return (numericColumns[0], numericColumns[1])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
/// Recommends the single best chart type for the data shape.
|
||||
private func recommendChartType(
|
||||
table: DataTable,
|
||||
categoryColumn: DataTable.Column,
|
||||
valueColumn: DataTable.Column
|
||||
) -> ChartType {
|
||||
// Line: time series or sequential numeric categories (check first — strongest signal)
|
||||
if isTimeSeries(categoryColumn, in: table) || isSequential(categoryColumn, in: table) {
|
||||
return .line
|
||||
}
|
||||
|
||||
// Pie: small number of categories with all-positive values
|
||||
// Only when clearly categorical (text labels) and few rows
|
||||
if table.rowCount <= maxPieSlices,
|
||||
isCategory(categoryColumn),
|
||||
isPieCandidate(table: table, valueColumn: valueColumn),
|
||||
looksProportional(table: table, valueColumn: valueColumn) {
|
||||
return .pie
|
||||
}
|
||||
|
||||
// Default: bar chart for categorical comparisons
|
||||
return .bar
|
||||
}
|
||||
|
||||
/// Checks if a chart type is viable for the given data.
|
||||
private func isViable(
|
||||
_ chartType: ChartType,
|
||||
table: DataTable,
|
||||
categoryColumn: DataTable.Column
|
||||
) -> Bool {
|
||||
switch chartType {
|
||||
case .pie:
|
||||
return table.rowCount <= maxPieSlices
|
||||
case .line:
|
||||
return table.rowCount >= minimumRows
|
||||
case .bar:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
/// Determines if a column holds numeric data.
|
||||
private func isNumeric(_ column: DataTable.Column) -> Bool {
|
||||
switch column.inferredType {
|
||||
case .integer, .real:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/// Determines if a column holds categorical (label) data.
|
||||
private func isCategory(_ column: DataTable.Column) -> Bool {
|
||||
switch column.inferredType {
|
||||
case .text, .mixed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if the value column contains all non-negative values,
|
||||
/// making it a candidate for pie charts.
|
||||
private func isPieCandidate(
|
||||
table: DataTable,
|
||||
valueColumn: DataTable.Column
|
||||
) -> Bool {
|
||||
let values = table.numericValues(forColumn: valueColumn.name)
|
||||
guard !values.isEmpty else { return false }
|
||||
// All values must be positive for a meaningful pie chart
|
||||
return values.allSatisfy { $0 > 0 }
|
||||
}
|
||||
|
||||
/// Heuristic: do values look like they represent parts of a whole?
|
||||
///
|
||||
/// Checks for aggregate-like column names (count, total, sum, amount, pct, etc.)
|
||||
/// or if values sum to a round number suggesting percentages/proportions.
|
||||
private func looksProportional(
|
||||
table: DataTable,
|
||||
valueColumn: DataTable.Column
|
||||
) -> Bool {
|
||||
let proportionalNames: Set<String> = ["count", "total", "sum", "amount", "pct",
|
||||
"percent", "percentage", "share", "proportion",
|
||||
"quantity", "qty", "num", "number"]
|
||||
// Split on common separators and check for exact word matches
|
||||
let lowerName = valueColumn.name.lowercased()
|
||||
let words = Set(lowerName.split { $0 == "_" || $0 == "-" || $0 == " " }.map(String.init))
|
||||
if !words.isDisjoint(with: proportionalNames) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if values sum to ~100 (percentages)
|
||||
let values = table.numericValues(forColumn: valueColumn.name)
|
||||
let sum = values.reduce(0, +)
|
||||
if abs(sum - 100.0) < 1.0 {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
/// Heuristic: does the category column look like time-series data?
|
||||
///
|
||||
/// Checks for date-like patterns (YYYY, YYYY-MM, YYYY-MM-DD)
|
||||
/// or common time-related column names.
|
||||
private func isTimeSeries(_ column: DataTable.Column, in table: DataTable) -> Bool {
|
||||
let timeNames = ["date", "time", "timestamp", "year", "month", "day",
|
||||
"week", "quarter", "period", "created_at", "updated_at"]
|
||||
let lowerName = column.name.lowercased()
|
||||
if timeNames.contains(where: { lowerName.contains($0) }) {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if text values look like dates
|
||||
if column.inferredType == .text {
|
||||
let values = table.stringValues(forColumn: column.name)
|
||||
let datePattern = #/^\d{4}(-\d{2}){0,2}$/#
|
||||
let matchCount = values.prefix(5).filter { (try? datePattern.wholeMatch(in: $0)) != nil }.count
|
||||
if matchCount >= 3 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
/// Heuristic: does the category column contain sequential numeric values?
|
||||
private func isSequential(_ column: DataTable.Column, in table: DataTable) -> Bool {
|
||||
guard isNumeric(column) else { return false }
|
||||
let values = table.numericValues(forColumn: column.name)
|
||||
guard values.count >= 3 else { return false }
|
||||
|
||||
// Check if values are monotonically increasing
|
||||
for i in 1..<values.count {
|
||||
if values[i] <= values[i - 1] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
/// Computes a confidence score for a specific chart type + data combination.
|
||||
private func computeConfidence(
|
||||
table: DataTable,
|
||||
categoryColumn: DataTable.Column,
|
||||
valueColumn: DataTable.Column,
|
||||
chartType: ChartType
|
||||
) -> Double {
|
||||
var score = 0.5 // baseline
|
||||
|
||||
// Bonus: clear category/value split (text + numeric)
|
||||
if isCategory(categoryColumn) && isNumeric(valueColumn) {
|
||||
score += 0.2
|
||||
}
|
||||
|
||||
// Bonus: reasonable row count for the chart type
|
||||
switch chartType {
|
||||
case .bar:
|
||||
if table.rowCount >= 2 && table.rowCount <= 20 {
|
||||
score += 0.15
|
||||
}
|
||||
case .line:
|
||||
if isTimeSeries(categoryColumn, in: table) {
|
||||
score += 0.2
|
||||
} else if isSequential(categoryColumn, in: table) {
|
||||
score += 0.1
|
||||
}
|
||||
case .pie:
|
||||
if table.rowCount <= maxPieSlices && isPieCandidate(table: table, valueColumn: valueColumn) {
|
||||
score += 0.2
|
||||
}
|
||||
// Penalty: too many slices
|
||||
if table.rowCount > 5 {
|
||||
score -= 0.1
|
||||
}
|
||||
}
|
||||
|
||||
// Bonus: no null values in key columns
|
||||
let categoryNulls = table.columnValues(named: categoryColumn.name).filter(\.isNull).count
|
||||
let valueNulls = table.columnValues(named: valueColumn.name).filter(\.isNull).count
|
||||
if categoryNulls == 0 && valueNulls == 0 {
|
||||
score += 0.1
|
||||
}
|
||||
|
||||
return min(max(score, 0.0), 1.0)
|
||||
}
|
||||
|
||||
/// Generates a human-readable reason for the recommendation.
|
||||
private func describeReason(
|
||||
chartType: ChartType,
|
||||
categoryColumn: DataTable.Column,
|
||||
valueColumn: DataTable.Column,
|
||||
table: DataTable
|
||||
) -> String {
|
||||
switch chartType {
|
||||
case .bar:
|
||||
return "\(table.rowCount) categories comparing \(valueColumn.name) by \(categoryColumn.name)"
|
||||
case .line:
|
||||
if isTimeSeries(categoryColumn, in: table) {
|
||||
return "\(valueColumn.name) over time (\(categoryColumn.name))"
|
||||
}
|
||||
return "\(valueColumn.name) trend across \(table.rowCount) points"
|
||||
case .pie:
|
||||
return "Proportional breakdown of \(valueColumn.name) by \(categoryColumn.name)"
|
||||
}
|
||||
}
|
||||
}
|
||||
255
Sources/SwiftDBAI/Rendering/DataTable.swift
Normal file
255
Sources/SwiftDBAI/Rendering/DataTable.swift
Normal file
@@ -0,0 +1,255 @@
|
||||
// DataTable.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Structured table representation for rendering query results
|
||||
// in SwiftUI table views and charts.
|
||||
|
||||
import Foundation
|
||||
|
||||
/// A structured, row-column table built from a `QueryResult`.
|
||||
///
|
||||
/// `DataTable` provides indexed access to rows and columns, typed column
|
||||
/// metadata, and convenience methods for extracting data suitable for
|
||||
/// SwiftUI `Table` views and Swift Charts.
|
||||
///
|
||||
/// Usage:
|
||||
/// ```swift
|
||||
/// let table = DataTable(queryResult)
|
||||
/// print(table.columnCount) // 3
|
||||
/// print(table[row: 0, column: 1]) // .text("Alice")
|
||||
/// ```
|
||||
public struct DataTable: Sendable, Equatable {
|
||||
|
||||
// MARK: - Column Metadata
|
||||
|
||||
/// Metadata for a single column in the data table.
|
||||
public struct Column: Sendable, Equatable, Identifiable {
|
||||
/// Stable identifier for the column (same as `name`).
|
||||
public var id: String { name }
|
||||
|
||||
/// Column name from the query result set.
|
||||
public let name: String
|
||||
|
||||
/// Index of this column in the table (0-based).
|
||||
public let index: Int
|
||||
|
||||
/// Inferred data type based on the values in this column.
|
||||
public let inferredType: InferredType
|
||||
|
||||
public init(name: String, index: Int, inferredType: InferredType) {
|
||||
self.name = name
|
||||
self.index = index
|
||||
self.inferredType = inferredType
|
||||
}
|
||||
}
|
||||
|
||||
/// The inferred data type for a column, determined by inspecting its values.
|
||||
public enum InferredType: Sendable, Equatable {
|
||||
/// All non-null values are integers.
|
||||
case integer
|
||||
/// All non-null values are numeric (mix of integer and real).
|
||||
case real
|
||||
/// All non-null values are text.
|
||||
case text
|
||||
/// Values contain blob data.
|
||||
case blob
|
||||
/// Column contains only null values or is empty.
|
||||
case null
|
||||
/// Values are a mix of incompatible types.
|
||||
case mixed
|
||||
}
|
||||
|
||||
// MARK: - Row Type
|
||||
|
||||
/// A single row in the data table, providing indexed and named access.
|
||||
public struct Row: Sendable, Equatable, Identifiable {
|
||||
/// Row index (0-based), used as stable identity.
|
||||
public let id: Int
|
||||
|
||||
/// Values in column order.
|
||||
public let values: [QueryResult.Value]
|
||||
|
||||
/// Column names for named access.
|
||||
private let columnNames: [String]
|
||||
|
||||
public init(id: Int, values: [QueryResult.Value], columnNames: [String]) {
|
||||
self.id = id
|
||||
self.values = values
|
||||
self.columnNames = columnNames
|
||||
}
|
||||
|
||||
/// Access a value by column index.
|
||||
public subscript(columnIndex: Int) -> QueryResult.Value {
|
||||
values[columnIndex]
|
||||
}
|
||||
|
||||
/// Access a value by column name. Returns `.null` if the column doesn't exist.
|
||||
public subscript(columnName: String) -> QueryResult.Value {
|
||||
guard let idx = columnNames.firstIndex(of: columnName) else {
|
||||
return .null
|
||||
}
|
||||
return values[idx]
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Properties
|
||||
|
||||
/// Column metadata in order.
|
||||
public let columns: [Column]
|
||||
|
||||
/// All rows in order.
|
||||
public let rows: [Row]
|
||||
|
||||
/// The SQL that produced this table.
|
||||
public let sql: String
|
||||
|
||||
/// Execution time of the underlying query.
|
||||
public let executionTime: TimeInterval
|
||||
|
||||
/// Number of columns.
|
||||
public var columnCount: Int { columns.count }
|
||||
|
||||
/// Number of rows.
|
||||
public var rowCount: Int { rows.count }
|
||||
|
||||
/// Whether the table has no rows.
|
||||
public var isEmpty: Bool { rows.isEmpty }
|
||||
|
||||
/// Column names in order.
|
||||
public var columnNames: [String] { columns.map(\.name) }
|
||||
|
||||
// MARK: - Initialization
|
||||
|
||||
/// Creates a `DataTable` from a `QueryResult`.
|
||||
///
|
||||
/// Converts the dictionary-based row representation into an indexed
|
||||
/// array representation and infers column types from the data.
|
||||
///
|
||||
/// - Parameter queryResult: The raw query result to convert.
|
||||
public init(_ queryResult: QueryResult) {
|
||||
let colNames = queryResult.columns
|
||||
|
||||
// Build indexed rows
|
||||
let indexedRows: [Row] = queryResult.rows.enumerated().map { idx, rowDict in
|
||||
let values = colNames.map { col in
|
||||
rowDict[col] ?? .null
|
||||
}
|
||||
return Row(id: idx, values: values, columnNames: colNames)
|
||||
}
|
||||
|
||||
// Infer column types
|
||||
let inferredColumns: [Column] = colNames.enumerated().map { colIdx, name in
|
||||
let type = Self.inferType(
|
||||
from: indexedRows.map { $0.values[colIdx] }
|
||||
)
|
||||
return Column(name: name, index: colIdx, inferredType: type)
|
||||
}
|
||||
|
||||
self.columns = inferredColumns
|
||||
self.rows = indexedRows
|
||||
self.sql = queryResult.sql
|
||||
self.executionTime = queryResult.executionTime
|
||||
}
|
||||
|
||||
/// Creates a `DataTable` directly from components (useful for testing).
|
||||
public init(
|
||||
columns: [Column],
|
||||
rows: [Row],
|
||||
sql: String = "",
|
||||
executionTime: TimeInterval = 0
|
||||
) {
|
||||
self.columns = columns
|
||||
self.rows = rows
|
||||
self.sql = sql
|
||||
self.executionTime = executionTime
|
||||
}
|
||||
|
||||
// MARK: - Subscript Access
|
||||
|
||||
/// Access a cell by row and column index.
|
||||
public subscript(row rowIndex: Int, column columnIndex: Int) -> QueryResult.Value {
|
||||
rows[rowIndex].values[columnIndex]
|
||||
}
|
||||
|
||||
/// Access a cell by row index and column name.
|
||||
public subscript(row rowIndex: Int, column columnName: String) -> QueryResult.Value {
|
||||
rows[rowIndex][columnName]
|
||||
}
|
||||
|
||||
// MARK: - Column Data Extraction
|
||||
|
||||
/// Returns all values for a column by index, in row order.
|
||||
public func columnValues(at index: Int) -> [QueryResult.Value] {
|
||||
rows.map { $0.values[index] }
|
||||
}
|
||||
|
||||
/// Returns all values for a column by name, in row order.
|
||||
public func columnValues(named name: String) -> [QueryResult.Value] {
|
||||
guard let col = columns.first(where: { $0.name == name }) else {
|
||||
return []
|
||||
}
|
||||
return columnValues(at: col.index)
|
||||
}
|
||||
|
||||
/// Returns all non-null `Double` values for a column (useful for charting).
|
||||
public func numericValues(forColumn name: String) -> [Double] {
|
||||
columnValues(named: name).compactMap(\.doubleValue)
|
||||
}
|
||||
|
||||
/// Returns all non-null `String` values for a column (useful for labels).
|
||||
public func stringValues(forColumn name: String) -> [String] {
|
||||
columnValues(named: name).compactMap { value in
|
||||
if case .null = value { return nil }
|
||||
return value.stringValue
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Type Inference
|
||||
|
||||
/// Infers the predominant type from an array of values.
|
||||
static func inferType(from values: [QueryResult.Value]) -> InferredType {
|
||||
var hasInteger = false
|
||||
var hasReal = false
|
||||
var hasText = false
|
||||
var hasBlob = false
|
||||
var hasNonNull = false
|
||||
|
||||
for value in values {
|
||||
switch value {
|
||||
case .integer:
|
||||
hasInteger = true
|
||||
hasNonNull = true
|
||||
case .real:
|
||||
hasReal = true
|
||||
hasNonNull = true
|
||||
case .text:
|
||||
hasText = true
|
||||
hasNonNull = true
|
||||
case .blob:
|
||||
hasBlob = true
|
||||
hasNonNull = true
|
||||
case .null:
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
guard hasNonNull else { return .null }
|
||||
|
||||
// Count how many distinct types are present
|
||||
let typeCount = [hasInteger, hasReal, hasText, hasBlob].filter { $0 }.count
|
||||
|
||||
if typeCount == 1 {
|
||||
if hasInteger { return .integer }
|
||||
if hasReal { return .real }
|
||||
if hasText { return .text }
|
||||
if hasBlob { return .blob }
|
||||
}
|
||||
|
||||
// Integer + real → treat as real (numeric promotion)
|
||||
if typeCount == 2, hasInteger, hasReal {
|
||||
return .real
|
||||
}
|
||||
|
||||
return .mixed
|
||||
}
|
||||
}
|
||||
301
Sources/SwiftDBAI/Rendering/TextSummaryRenderer.swift
Normal file
301
Sources/SwiftDBAI/Rendering/TextSummaryRenderer.swift
Normal file
@@ -0,0 +1,301 @@
|
||||
// TextSummaryRenderer.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Converts raw SQL query results into natural language text summaries
|
||||
// using the LLM via AnyLanguageModel.
|
||||
|
||||
import AnyLanguageModel
|
||||
import Foundation
|
||||
|
||||
/// Renders SQL query results as natural language text summaries.
|
||||
///
|
||||
/// The renderer takes a `QueryResult` and the user's original question,
|
||||
/// sends them to the LLM for summarization, and returns a concise,
|
||||
/// human-readable response.
|
||||
///
|
||||
/// Usage:
|
||||
/// ```swift
|
||||
/// let renderer = TextSummaryRenderer(model: myModel)
|
||||
/// let summary = try await renderer.summarize(
|
||||
/// result: queryResult,
|
||||
/// userQuestion: "How many orders were placed last month?"
|
||||
/// )
|
||||
/// print(summary) // "There were 42 orders placed last month."
|
||||
/// ```
|
||||
public struct TextSummaryRenderer: Sendable {
|
||||
|
||||
/// The language model used to generate summaries.
|
||||
private let model: any LanguageModel
|
||||
|
||||
/// Maximum number of rows to include in the LLM prompt.
|
||||
///
|
||||
/// Results larger than this are truncated with a note about total count.
|
||||
public let maxRowsInPrompt: Int
|
||||
|
||||
/// Creates a new text summary renderer.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - model: Any `AnyLanguageModel`-compatible language model.
|
||||
/// - maxRowsInPrompt: Maximum rows to send to the LLM for summarization (default: 50).
|
||||
public init(model: any LanguageModel, maxRowsInPrompt: Int = 50) {
|
||||
self.model = model
|
||||
self.maxRowsInPrompt = maxRowsInPrompt
|
||||
}
|
||||
|
||||
/// Generates a natural language summary of query results.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - result: The raw `QueryResult` from SQL execution.
|
||||
/// - userQuestion: The original natural language question from the user.
|
||||
/// - context: Optional additional context (e.g., table descriptions) to help the LLM.
|
||||
/// - Returns: A natural language text summary of the results.
|
||||
public func summarize(
|
||||
result: QueryResult,
|
||||
userQuestion: String,
|
||||
context: String? = nil
|
||||
) async throws -> String {
|
||||
// For mutation results (INSERT/UPDATE/DELETE), use a simple template
|
||||
if let affected = result.rowsAffected {
|
||||
return summarizeMutation(result: result, affected: affected)
|
||||
}
|
||||
|
||||
// For empty results, no need to call the LLM
|
||||
if result.rows.isEmpty {
|
||||
return "No results found for your query."
|
||||
}
|
||||
|
||||
// For simple aggregates, produce a direct answer without LLM
|
||||
if let directAnswer = tryDirectAggregateSummary(result: result, userQuestion: userQuestion) {
|
||||
return directAnswer
|
||||
}
|
||||
|
||||
// Build the prompt and ask the LLM to summarize
|
||||
let prompt = buildSummarizationPrompt(
|
||||
result: result,
|
||||
userQuestion: userQuestion,
|
||||
context: context
|
||||
)
|
||||
|
||||
let session = LanguageModelSession(
|
||||
model: model,
|
||||
instructions: summaryInstructions
|
||||
)
|
||||
|
||||
let response = try await session.respond(to: prompt)
|
||||
return response.content.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
}
|
||||
|
||||
/// Generates a summary without calling the LLM, using simple templates.
|
||||
///
|
||||
/// Useful when LLM access is unavailable, or for fast local rendering.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - result: The raw `QueryResult` from SQL execution.
|
||||
/// - userQuestion: The original natural language question.
|
||||
/// - Returns: A template-based text summary.
|
||||
public func localSummary(result: QueryResult, userQuestion: String) -> String {
|
||||
if let affected = result.rowsAffected {
|
||||
return summarizeMutation(result: result, affected: affected)
|
||||
}
|
||||
|
||||
if result.rows.isEmpty {
|
||||
return "No results found for your query."
|
||||
}
|
||||
|
||||
if let directAnswer = tryDirectAggregateSummary(result: result, userQuestion: userQuestion) {
|
||||
return directAnswer
|
||||
}
|
||||
|
||||
return buildTemplateSummary(result: result)
|
||||
}
|
||||
|
||||
// MARK: - Private Helpers
|
||||
|
||||
/// System instructions for the summarization session.
|
||||
private var summaryInstructions: String {
|
||||
"""
|
||||
You are a data assistant that summarizes SQL query results in natural language.
|
||||
|
||||
Rules:
|
||||
- Be concise and direct. Answer the user's question first, then add detail if helpful.
|
||||
- Use natural language, not SQL or code.
|
||||
- For numeric results, include the exact numbers.
|
||||
- For lists of records, summarize the count and highlight notable items.
|
||||
- If the data contains dates, format them in a readable way.
|
||||
- Do not mention SQL, databases, tables, columns, or queries in your response.
|
||||
- Do not include markdown formatting.
|
||||
- Keep your response under 3 sentences for simple results, under 5 for complex ones.
|
||||
"""
|
||||
}
|
||||
|
||||
/// Builds the prompt sent to the LLM for summarization.
|
||||
private func buildSummarizationPrompt(
|
||||
result: QueryResult,
|
||||
userQuestion: String,
|
||||
context: String?
|
||||
) -> String {
|
||||
var parts: [String] = []
|
||||
|
||||
parts.append("User's question: \(userQuestion)")
|
||||
|
||||
if let context {
|
||||
parts.append("Context: \(context)")
|
||||
}
|
||||
|
||||
parts.append("Query returned \(result.rowCount) row(s) with columns: \(result.columns.joined(separator: ", "))")
|
||||
|
||||
// Include the result data (truncated if large)
|
||||
let dataStr = formatResultData(result)
|
||||
parts.append("Data:\n\(dataStr)")
|
||||
|
||||
parts.append("Summarize these results in natural language, directly answering the user's question.")
|
||||
|
||||
return parts.joined(separator: "\n\n")
|
||||
}
|
||||
|
||||
/// Formats the query result data as a compact table for the LLM prompt.
|
||||
private func formatResultData(_ result: QueryResult) -> String {
|
||||
let rowsToInclude = Array(result.rows.prefix(maxRowsInPrompt))
|
||||
var lines: [String] = []
|
||||
|
||||
// Header
|
||||
lines.append(result.columns.joined(separator: " | "))
|
||||
|
||||
// Rows
|
||||
for row in rowsToInclude {
|
||||
let values = result.columns.map { col in
|
||||
row[col]?.description ?? "NULL"
|
||||
}
|
||||
lines.append(values.joined(separator: " | "))
|
||||
}
|
||||
|
||||
if result.rowCount > maxRowsInPrompt {
|
||||
lines.append("(\(result.rowCount - maxRowsInPrompt) additional rows not shown)")
|
||||
}
|
||||
|
||||
return lines.joined(separator: "\n")
|
||||
}
|
||||
|
||||
/// Produces a direct answer for simple aggregate queries (1 row, few columns).
|
||||
private func tryDirectAggregateSummary(result: QueryResult, userQuestion: String) -> String? {
|
||||
guard result.isAggregate else { return nil }
|
||||
|
||||
let row = result.rows[0]
|
||||
|
||||
// Single numeric column — e.g., "COUNT(*)" → "42"
|
||||
if result.columns.count == 1 {
|
||||
let col = result.columns[0]
|
||||
guard let value = row[col] else { return nil }
|
||||
let formatted = formatNumber(value)
|
||||
return "The result is \(formatted)."
|
||||
}
|
||||
|
||||
// Multiple aggregate columns — e.g., COUNT, AVG, SUM
|
||||
let parts = result.columns.compactMap { col -> String? in
|
||||
guard let value = row[col] else { return nil }
|
||||
let label = humanizeColumnName(col)
|
||||
let formatted = formatNumber(value)
|
||||
return "\(label): \(formatted)"
|
||||
}
|
||||
return parts.joined(separator: ", ") + "."
|
||||
}
|
||||
|
||||
/// Formats a numeric Value for display.
|
||||
private func formatNumber(_ value: QueryResult.Value) -> String {
|
||||
switch value {
|
||||
case .integer(let i):
|
||||
return NumberFormatter.localizedString(from: NSNumber(value: i), number: .decimal)
|
||||
case .real(let d):
|
||||
if d == d.rounded() && abs(d) < 1e12 {
|
||||
return NumberFormatter.localizedString(from: NSNumber(value: Int64(d)), number: .decimal)
|
||||
}
|
||||
let formatter = NumberFormatter()
|
||||
formatter.numberStyle = .decimal
|
||||
formatter.maximumFractionDigits = 2
|
||||
return formatter.string(from: NSNumber(value: d)) ?? String(d)
|
||||
default:
|
||||
return value.description
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a column name like "total_count" or "AVG(price)" into a readable label.
|
||||
private func humanizeColumnName(_ name: String) -> String {
|
||||
// Handle SQL function names: "COUNT(*)" → "count", "AVG(price)" → "average price"
|
||||
let functionPatterns: [(pattern: String, label: String)] = [
|
||||
("COUNT", "count"),
|
||||
("SUM", "total"),
|
||||
("AVG", "average"),
|
||||
("MIN", "minimum"),
|
||||
("MAX", "maximum"),
|
||||
]
|
||||
|
||||
let upper = name.uppercased()
|
||||
for (pattern, label) in functionPatterns {
|
||||
if upper.hasPrefix(pattern + "(") {
|
||||
// Extract the inner column name
|
||||
let start = name.index(name.startIndex, offsetBy: pattern.count + 1)
|
||||
let end = name.index(before: name.endIndex)
|
||||
if start < end {
|
||||
let inner = String(name[start..<end])
|
||||
if inner == "*" { return label }
|
||||
return "\(label) \(humanizeColumnName(inner))"
|
||||
}
|
||||
return label
|
||||
}
|
||||
}
|
||||
|
||||
// snake_case → space-separated
|
||||
return name
|
||||
.replacingOccurrences(of: "_", with: " ")
|
||||
.lowercased()
|
||||
}
|
||||
|
||||
/// Produces a template-based summary without calling the LLM.
|
||||
private func buildTemplateSummary(result: QueryResult) -> String {
|
||||
let count = result.rowCount
|
||||
|
||||
if count == 1 {
|
||||
// Single record — list field values
|
||||
let row = result.rows[0]
|
||||
let details = result.columns.prefix(5).compactMap { col -> String? in
|
||||
guard let val = row[col], !val.isNull else { return nil }
|
||||
return "\(humanizeColumnName(col)): \(val.description)"
|
||||
}
|
||||
return "Found 1 result. \(details.joined(separator: ", "))."
|
||||
}
|
||||
|
||||
// Multiple records
|
||||
var summary = "Found \(count) results"
|
||||
|
||||
// If there's a clear "name" or "title" column, list first few
|
||||
let nameColumns = ["name", "title", "label", "description"]
|
||||
if let nameCol = result.columns.first(where: { nameColumns.contains($0.lowercased()) }) {
|
||||
let names = result.rows.prefix(3).compactMap { $0[nameCol]?.description }
|
||||
if !names.isEmpty {
|
||||
summary += " including \(names.joined(separator: ", "))"
|
||||
if count > 3 { summary += ", and \(count - 3) more" }
|
||||
}
|
||||
}
|
||||
|
||||
return summary + "."
|
||||
}
|
||||
|
||||
/// Summarizes a mutation (INSERT/UPDATE/DELETE) result.
|
||||
private func summarizeMutation(result: QueryResult, affected: Int) -> String {
|
||||
let sql = result.sql.trimmingCharacters(in: .whitespacesAndNewlines).uppercased()
|
||||
|
||||
let operation: String
|
||||
if sql.hasPrefix("INSERT") {
|
||||
operation = "inserted"
|
||||
} else if sql.hasPrefix("UPDATE") {
|
||||
operation = "updated"
|
||||
} else if sql.hasPrefix("DELETE") {
|
||||
operation = "deleted"
|
||||
} else {
|
||||
operation = "affected"
|
||||
}
|
||||
|
||||
let noun = affected == 1 ? "row" : "rows"
|
||||
return "Successfully \(operation) \(affected) \(noun)."
|
||||
}
|
||||
}
|
||||
164
Sources/SwiftDBAI/Schema/DatabaseSchema.swift
Normal file
164
Sources/SwiftDBAI/Schema/DatabaseSchema.swift
Normal file
@@ -0,0 +1,164 @@
|
||||
// DatabaseSchema.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Auto-introspected SQLite schema model types.
|
||||
|
||||
import Foundation
|
||||
|
||||
/// Complete schema representation of an SQLite database.
|
||||
public struct DatabaseSchema: Sendable, Equatable {
|
||||
/// All tables in the database, keyed by table name.
|
||||
public let tables: [String: TableSchema]
|
||||
|
||||
/// Ordered table names (preserves discovery order).
|
||||
public let tableNames: [String]
|
||||
|
||||
/// Returns a compact text description suitable for LLM system prompts.
|
||||
public var schemaDescription: String {
|
||||
var lines: [String] = []
|
||||
for name in tableNames {
|
||||
guard let table = tables[name] else { continue }
|
||||
lines.append(table.descriptionForLLM)
|
||||
}
|
||||
return lines.joined(separator: "\n\n")
|
||||
}
|
||||
|
||||
/// Returns a description suitable for LLM system prompts.
|
||||
/// Alias for `schemaDescription` for API compatibility.
|
||||
public func describeForLLM() -> String {
|
||||
schemaDescription
|
||||
}
|
||||
|
||||
public init(tables: [String: TableSchema], tableNames: [String]) {
|
||||
self.tables = tables
|
||||
self.tableNames = tableNames
|
||||
}
|
||||
}
|
||||
|
||||
/// Schema for a single SQLite table.
|
||||
public struct TableSchema: Sendable, Equatable {
|
||||
public let name: String
|
||||
public let columns: [ColumnSchema]
|
||||
public let primaryKey: [String]
|
||||
public let foreignKeys: [ForeignKeySchema]
|
||||
public let indexes: [IndexSchema]
|
||||
|
||||
/// Text description for embedding in LLM prompts.
|
||||
public var descriptionForLLM: String {
|
||||
var parts: [String] = []
|
||||
let colDefs = columns.map { col in
|
||||
var def = " \(col.name) \(col.type)"
|
||||
if col.isPrimaryKey { def += " PRIMARY KEY" }
|
||||
if col.isNotNull { def += " NOT NULL" }
|
||||
if let defaultValue = col.defaultValue { def += " DEFAULT \(defaultValue)" }
|
||||
return def
|
||||
}
|
||||
parts.append("TABLE \(name) (\n\(colDefs.joined(separator: ",\n"))\n)")
|
||||
|
||||
if !foreignKeys.isEmpty {
|
||||
let fkDescs = foreignKeys.map {
|
||||
" FOREIGN KEY (\($0.fromColumn)) REFERENCES \($0.toTable)(\($0.toColumn))"
|
||||
}
|
||||
parts.append("FOREIGN KEYS:\n\(fkDescs.joined(separator: "\n"))")
|
||||
}
|
||||
|
||||
if !indexes.isEmpty {
|
||||
let idxDescs = indexes.map {
|
||||
" INDEX \($0.name) ON (\($0.columns.joined(separator: ", ")))\($0.isUnique ? " UNIQUE" : "")"
|
||||
}
|
||||
parts.append("INDEXES:\n\(idxDescs.joined(separator: "\n"))")
|
||||
}
|
||||
|
||||
return parts.joined(separator: "\n")
|
||||
}
|
||||
|
||||
public init(
|
||||
name: String,
|
||||
columns: [ColumnSchema],
|
||||
primaryKey: [String],
|
||||
foreignKeys: [ForeignKeySchema],
|
||||
indexes: [IndexSchema]
|
||||
) {
|
||||
self.name = name
|
||||
self.columns = columns
|
||||
self.primaryKey = primaryKey
|
||||
self.foreignKeys = foreignKeys
|
||||
self.indexes = indexes
|
||||
}
|
||||
}
|
||||
|
||||
/// Schema for a single column.
|
||||
public struct ColumnSchema: Sendable, Equatable {
|
||||
/// Column position (0-based).
|
||||
public let cid: Int
|
||||
/// Column name.
|
||||
public let name: String
|
||||
/// Declared SQLite type (e.g. "TEXT", "INTEGER", "REAL", "BLOB").
|
||||
public let type: String
|
||||
/// Whether the column has a NOT NULL constraint.
|
||||
public let isNotNull: Bool
|
||||
/// Default value expression, if any.
|
||||
public let defaultValue: String?
|
||||
/// Whether this column is part of the primary key.
|
||||
public let isPrimaryKey: Bool
|
||||
|
||||
public init(
|
||||
cid: Int,
|
||||
name: String,
|
||||
type: String,
|
||||
isNotNull: Bool,
|
||||
defaultValue: String?,
|
||||
isPrimaryKey: Bool
|
||||
) {
|
||||
self.cid = cid
|
||||
self.name = name
|
||||
self.type = type
|
||||
self.isNotNull = isNotNull
|
||||
self.defaultValue = defaultValue
|
||||
self.isPrimaryKey = isPrimaryKey
|
||||
}
|
||||
}
|
||||
|
||||
/// Schema for a foreign key relationship.
|
||||
public struct ForeignKeySchema: Sendable, Equatable {
|
||||
/// Column in the source table.
|
||||
public let fromColumn: String
|
||||
/// Referenced table name.
|
||||
public let toTable: String
|
||||
/// Referenced column name.
|
||||
public let toColumn: String
|
||||
/// ON UPDATE action (e.g. "CASCADE", "NO ACTION").
|
||||
public let onUpdate: String
|
||||
/// ON DELETE action.
|
||||
public let onDelete: String
|
||||
|
||||
public init(
|
||||
fromColumn: String,
|
||||
toTable: String,
|
||||
toColumn: String,
|
||||
onUpdate: String,
|
||||
onDelete: String
|
||||
) {
|
||||
self.fromColumn = fromColumn
|
||||
self.toTable = toTable
|
||||
self.toColumn = toColumn
|
||||
self.onUpdate = onUpdate
|
||||
self.onDelete = onDelete
|
||||
}
|
||||
}
|
||||
|
||||
/// Schema for a database index.
|
||||
public struct IndexSchema: Sendable, Equatable {
|
||||
/// Index name.
|
||||
public let name: String
|
||||
/// Whether the index enforces uniqueness.
|
||||
public let isUnique: Bool
|
||||
/// Columns included in the index, in order.
|
||||
public let columns: [String]
|
||||
|
||||
public init(name: String, isUnique: Bool, columns: [String]) {
|
||||
self.name = name
|
||||
self.isUnique = isUnique
|
||||
self.columns = columns
|
||||
}
|
||||
}
|
||||
153
Sources/SwiftDBAI/Schema/SchemaIntrospector.swift
Normal file
153
Sources/SwiftDBAI/Schema/SchemaIntrospector.swift
Normal file
@@ -0,0 +1,153 @@
|
||||
// SchemaIntrospector.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Auto-introspects SQLite database schema using GRDB.
|
||||
|
||||
import GRDB
|
||||
|
||||
/// Introspects an SQLite database schema by querying sqlite_master and PRAGMA statements.
|
||||
///
|
||||
/// Usage:
|
||||
/// ```swift
|
||||
/// let dbPool = try DatabasePool(path: "path/to/db.sqlite")
|
||||
/// let schema = try await SchemaIntrospector.introspect(database: dbPool)
|
||||
/// print(schema.schemaDescription)
|
||||
/// ```
|
||||
public struct SchemaIntrospector: Sendable {
|
||||
|
||||
// MARK: - Public API
|
||||
|
||||
/// Introspects the full schema of the given database.
|
||||
///
|
||||
/// Discovers all user tables (excluding sqlite_ internal tables),
|
||||
/// their columns, primary keys, foreign keys, and indexes.
|
||||
///
|
||||
/// - Parameter database: A GRDB `DatabaseReader` (DatabasePool or DatabaseQueue).
|
||||
/// - Returns: A complete `DatabaseSchema` representation.
|
||||
public static func introspect(database: any DatabaseReader) async throws -> DatabaseSchema {
|
||||
try await database.read { db in
|
||||
try introspect(db: db)
|
||||
}
|
||||
}
|
||||
|
||||
/// Synchronous introspection within an existing database access context.
|
||||
///
|
||||
/// - Parameter db: A GRDB `Database` instance from within a read/write block.
|
||||
/// - Returns: A complete `DatabaseSchema` representation.
|
||||
public static func introspect(db: Database) throws -> DatabaseSchema {
|
||||
let tableNames = try fetchTableNames(db: db)
|
||||
var tables: [String: TableSchema] = [:]
|
||||
|
||||
for tableName in tableNames {
|
||||
let columns = try fetchColumns(db: db, table: tableName)
|
||||
let primaryKey = try fetchPrimaryKey(db: db, table: tableName)
|
||||
let foreignKeys = try fetchForeignKeys(db: db, table: tableName)
|
||||
let indexes = try fetchIndexes(db: db, table: tableName)
|
||||
|
||||
// Mark columns that are part of the primary key
|
||||
let pkSet = Set(primaryKey)
|
||||
let annotatedColumns = columns.map { col in
|
||||
ColumnSchema(
|
||||
cid: col.cid,
|
||||
name: col.name,
|
||||
type: col.type,
|
||||
isNotNull: col.isNotNull,
|
||||
defaultValue: col.defaultValue,
|
||||
isPrimaryKey: pkSet.contains(col.name)
|
||||
)
|
||||
}
|
||||
|
||||
tables[tableName] = TableSchema(
|
||||
name: tableName,
|
||||
columns: annotatedColumns,
|
||||
primaryKey: primaryKey,
|
||||
foreignKeys: foreignKeys,
|
||||
indexes: indexes
|
||||
)
|
||||
}
|
||||
|
||||
return DatabaseSchema(tables: tables, tableNames: tableNames)
|
||||
}
|
||||
|
||||
// MARK: - Private Helpers
|
||||
|
||||
/// Fetches all user table names from sqlite_master.
|
||||
private static func fetchTableNames(db: Database) throws -> [String] {
|
||||
let sql = """
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type = 'table'
|
||||
AND name NOT LIKE 'sqlite_%'
|
||||
ORDER BY name
|
||||
"""
|
||||
return try String.fetchAll(db, sql: sql)
|
||||
}
|
||||
|
||||
/// Fetches column metadata for a table using PRAGMA table_info.
|
||||
private static func fetchColumns(db: Database, table: String) throws -> [ColumnSchema] {
|
||||
let sql = "PRAGMA table_info(\(table.quotedDatabaseIdentifier))"
|
||||
let rows = try Row.fetchAll(db, sql: sql)
|
||||
return rows.map { row in
|
||||
ColumnSchema(
|
||||
cid: row["cid"],
|
||||
name: row["name"],
|
||||
type: (row["type"] as String?) ?? "",
|
||||
isNotNull: row["notnull"] == 1,
|
||||
defaultValue: row["dflt_value"],
|
||||
isPrimaryKey: row["pk"] != 0
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetches primary key columns for a table.
|
||||
private static func fetchPrimaryKey(db: Database, table: String) throws -> [String] {
|
||||
let sql = "PRAGMA table_info(\(table.quotedDatabaseIdentifier))"
|
||||
let rows = try Row.fetchAll(db, sql: sql)
|
||||
return rows
|
||||
.filter { ($0["pk"] as Int) > 0 }
|
||||
.sorted { ($0["pk"] as Int) < ($1["pk"] as Int) }
|
||||
.map { $0["name"] }
|
||||
}
|
||||
|
||||
/// Fetches foreign key relationships for a table.
|
||||
private static func fetchForeignKeys(db: Database, table: String) throws -> [ForeignKeySchema] {
|
||||
let sql = "PRAGMA foreign_key_list(\(table.quotedDatabaseIdentifier))"
|
||||
let rows = try Row.fetchAll(db, sql: sql)
|
||||
return rows.map { row in
|
||||
ForeignKeySchema(
|
||||
fromColumn: row["from"],
|
||||
toTable: row["table"],
|
||||
toColumn: row["to"],
|
||||
onUpdate: row["on_update"] ?? "NO ACTION",
|
||||
onDelete: row["on_delete"] ?? "NO ACTION"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetches indexes and their columns for a table.
|
||||
private static func fetchIndexes(db: Database, table: String) throws -> [IndexSchema] {
|
||||
let indexListSQL = "PRAGMA index_list(\(table.quotedDatabaseIdentifier))"
|
||||
let indexRows = try Row.fetchAll(db, sql: indexListSQL)
|
||||
|
||||
var indexes: [IndexSchema] = []
|
||||
for indexRow in indexRows {
|
||||
let indexName: String = indexRow["name"]
|
||||
let isUnique: Bool = indexRow["unique"] == 1
|
||||
|
||||
// Skip auto-generated indexes for primary keys
|
||||
if indexName.hasPrefix("sqlite_autoindex_") { continue }
|
||||
|
||||
let infoSQL = "PRAGMA index_info(\(indexName.quotedDatabaseIdentifier))"
|
||||
let infoRows = try Row.fetchAll(db, sql: infoSQL)
|
||||
let columns: [String] = infoRows
|
||||
.sorted { ($0["seqno"] as Int) < ($1["seqno"] as Int) }
|
||||
.map { $0["name"] }
|
||||
|
||||
indexes.append(IndexSchema(
|
||||
name: indexName,
|
||||
isUnique: isUnique,
|
||||
columns: columns
|
||||
))
|
||||
}
|
||||
return indexes
|
||||
}
|
||||
}
|
||||
215
Sources/SwiftDBAI/SwiftDBAIError.swift
Normal file
215
Sources/SwiftDBAI/SwiftDBAIError.swift
Normal file
@@ -0,0 +1,215 @@
|
||||
// SwiftDBAIError.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Unified error type for the SwiftDBAI package.
|
||||
|
||||
import Foundation
|
||||
|
||||
/// The top-level error type for SwiftDBAI operations.
|
||||
///
|
||||
/// `SwiftDBAIError` provides a single, typed error surface that covers
|
||||
/// every failure mode a consumer of SwiftDBAI may encounter — from invalid
|
||||
/// SQL and LLM failures to schema mismatches and safety violations.
|
||||
///
|
||||
/// Every case includes a user-friendly `localizedDescription` suitable for
|
||||
/// displaying directly in a chat interface.
|
||||
public enum SwiftDBAIError: Error, LocalizedError, Sendable, Equatable {
|
||||
|
||||
// MARK: - SQL Errors
|
||||
|
||||
/// No SQL statement could be extracted from the LLM response.
|
||||
case noSQLGenerated
|
||||
|
||||
/// The generated SQL is syntactically invalid or failed execution.
|
||||
case invalidSQL(sql: String, reason: String)
|
||||
|
||||
/// The SQL uses an operation (e.g. DELETE) not in the developer's allowlist.
|
||||
case operationNotAllowed(operation: String)
|
||||
|
||||
/// Multiple SQL statements were generated but only single-statement execution is supported.
|
||||
case multipleStatementsNotSupported
|
||||
|
||||
/// A dangerous SQL keyword (DROP, ALTER, TRUNCATE) was detected.
|
||||
case dangerousOperationBlocked(keyword: String)
|
||||
|
||||
// MARK: - LLM Errors
|
||||
|
||||
/// The LLM failed to produce a response.
|
||||
case llmFailure(reason: String)
|
||||
|
||||
/// The LLM response could not be parsed into an actionable result.
|
||||
case llmResponseUnparseable(response: String)
|
||||
|
||||
/// The LLM request timed out.
|
||||
case llmTimeout(seconds: TimeInterval)
|
||||
|
||||
// MARK: - Schema Errors
|
||||
|
||||
/// Schema introspection of the database failed.
|
||||
case schemaIntrospectionFailed(reason: String)
|
||||
|
||||
/// The generated SQL references a table that does not exist in the schema.
|
||||
case tableNotFound(tableName: String)
|
||||
|
||||
/// The generated SQL references a column that does not exist on the given table.
|
||||
case columnNotFound(columnName: String, tableName: String)
|
||||
|
||||
/// The database schema is empty (no user tables found).
|
||||
case emptySchema
|
||||
|
||||
// MARK: - Safety & Validation Errors
|
||||
|
||||
/// A destructive operation requires explicit user confirmation before execution.
|
||||
case confirmationRequired(sql: String, operation: String)
|
||||
|
||||
/// A mutation targets a table not in the allowed mutation tables.
|
||||
case tableNotAllowedForMutation(tableName: String, operation: String)
|
||||
|
||||
/// A custom query validator rejected the query.
|
||||
case queryRejected(reason: String)
|
||||
|
||||
// MARK: - Database Errors
|
||||
|
||||
/// The underlying database operation failed.
|
||||
case databaseError(reason: String)
|
||||
|
||||
/// The query exceeded the configured execution timeout.
|
||||
case queryTimedOut(seconds: TimeInterval)
|
||||
|
||||
// MARK: - Configuration Errors
|
||||
|
||||
/// The engine has not been configured correctly.
|
||||
case configurationError(reason: String)
|
||||
|
||||
// MARK: - Error Classification
|
||||
|
||||
/// Whether this error represents a safety/permissions issue (not a bug).
|
||||
public var isSafetyError: Bool {
|
||||
switch self {
|
||||
case .operationNotAllowed, .dangerousOperationBlocked,
|
||||
.confirmationRequired, .tableNotAllowedForMutation, .queryRejected:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether this error is recoverable by rephrasing the user's question.
|
||||
public var isRecoverable: Bool {
|
||||
switch self {
|
||||
case .noSQLGenerated, .llmResponseUnparseable, .invalidSQL,
|
||||
.tableNotFound, .columnNotFound:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether this error requires user action (e.g. confirmation).
|
||||
public var requiresUserAction: Bool {
|
||||
if case .confirmationRequired = self { return true }
|
||||
return false
|
||||
}
|
||||
|
||||
// MARK: - LocalizedError
|
||||
|
||||
public var errorDescription: String? {
|
||||
switch self {
|
||||
// SQL
|
||||
case .noSQLGenerated:
|
||||
return "I couldn't generate a SQL query from your request. Could you rephrase your question?"
|
||||
case .invalidSQL(let sql, let reason):
|
||||
return "The generated query is invalid — \(reason). Query: \(sql)"
|
||||
case .operationNotAllowed(let operation):
|
||||
return "The \(operation.uppercased()) operation is not allowed by the current configuration."
|
||||
case .multipleStatementsNotSupported:
|
||||
return "Only single SQL statements are supported. Please ask one question at a time."
|
||||
case .dangerousOperationBlocked(let keyword):
|
||||
return "The \(keyword.uppercased()) operation is blocked for safety. This operation is never allowed."
|
||||
|
||||
// LLM
|
||||
case .llmFailure(let reason):
|
||||
return "The language model encountered an error: \(reason)"
|
||||
case .llmResponseUnparseable(let response):
|
||||
return "I received a response but couldn't understand it. Raw response: \(response.prefix(200))"
|
||||
case .llmTimeout(let seconds):
|
||||
return "The language model did not respond within \(Int(seconds)) seconds. Please try again."
|
||||
|
||||
// Schema
|
||||
case .schemaIntrospectionFailed(let reason):
|
||||
return "Failed to read the database schema: \(reason)"
|
||||
case .tableNotFound(let tableName):
|
||||
return "The table '\(tableName)' does not exist in this database."
|
||||
case .columnNotFound(let columnName, let tableName):
|
||||
return "The column '\(columnName)' does not exist on table '\(tableName)'."
|
||||
case .emptySchema:
|
||||
return "This database has no tables. There's nothing to query yet."
|
||||
|
||||
// Safety
|
||||
case .confirmationRequired(let sql, let operation):
|
||||
return "The \(operation.uppercased()) operation requires your confirmation before running: \(sql)"
|
||||
case .tableNotAllowedForMutation(let tableName, let operation):
|
||||
return "The \(operation.uppercased()) operation is not allowed on table '\(tableName)'."
|
||||
case .queryRejected(let reason):
|
||||
return "Query rejected: \(reason)"
|
||||
|
||||
// Database
|
||||
case .databaseError(let reason):
|
||||
return "A database error occurred: \(reason)"
|
||||
case .queryTimedOut(let seconds):
|
||||
return "The query timed out after \(Int(seconds)) seconds. Try a simpler query."
|
||||
|
||||
// Configuration
|
||||
case .configurationError(let reason):
|
||||
return "Configuration error: \(reason)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Conversion from SQLParsingError
|
||||
|
||||
extension SQLParsingError {
|
||||
/// Maps a ``SQLParsingError`` to the corresponding ``SwiftDBAIError`` case.
|
||||
///
|
||||
/// - Parameter rawResponse: The raw LLM response text (used for context in `.noSQLFound`).
|
||||
/// - Returns: A ``SwiftDBAIError`` with the same semantic meaning.
|
||||
func toSwiftDBAIError(rawResponse: String = "") -> SwiftDBAIError {
|
||||
switch self {
|
||||
case .noSQLFound:
|
||||
if rawResponse.isEmpty {
|
||||
return .noSQLGenerated
|
||||
}
|
||||
return .llmResponseUnparseable(response: rawResponse)
|
||||
case .operationNotAllowed(let op):
|
||||
return .operationNotAllowed(operation: op.rawValue)
|
||||
case .confirmationRequired(let sql, let op):
|
||||
return .confirmationRequired(sql: sql, operation: op.rawValue)
|
||||
case .tableNotAllowed(let table, let op):
|
||||
return .tableNotAllowedForMutation(tableName: table, operation: op.rawValue)
|
||||
case .dangerousOperation(let keyword):
|
||||
return .dangerousOperationBlocked(keyword: keyword)
|
||||
case .multipleStatements:
|
||||
return .multipleStatementsNotSupported
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Conversion from ChatEngineError
|
||||
|
||||
extension ChatEngineError {
|
||||
/// Maps a ``ChatEngineError`` to the corresponding ``SwiftDBAIError`` case.
|
||||
func toSwiftDBAIError() -> SwiftDBAIError {
|
||||
switch self {
|
||||
case .sqlParsingFailed(let parsingError):
|
||||
return parsingError.toSwiftDBAIError()
|
||||
case .confirmationRequired(let sql, let operation):
|
||||
return .confirmationRequired(sql: sql, operation: operation.rawValue)
|
||||
case .schemaIntrospectionFailed(let reason):
|
||||
return .schemaIntrospectionFailed(reason: reason)
|
||||
case .queryTimedOut(let seconds):
|
||||
return .queryTimedOut(seconds: seconds)
|
||||
case .validationFailed(let reason):
|
||||
return .queryRejected(reason: reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
182
Sources/SwiftDBAI/Views/Charts/BarChartView.swift
Normal file
182
Sources/SwiftDBAI/Views/Charts/BarChartView.swift
Normal file
@@ -0,0 +1,182 @@
|
||||
// BarChartView.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// A SwiftUI bar chart that renders DataTable values using Swift Charts.
|
||||
// Best for categorical comparisons (e.g., sales by region, counts by status).
|
||||
|
||||
import SwiftUI
|
||||
import Charts
|
||||
|
||||
/// A bar chart view that renders a `DataTable` column pair using Swift Charts.
|
||||
///
|
||||
/// Displays vertical bars with category labels on the x-axis and numeric
|
||||
/// values on the y-axis. Automatically colors bars using the accent gradient
|
||||
/// and supports scrolling when many categories are present.
|
||||
///
|
||||
/// Usage:
|
||||
/// ```swift
|
||||
/// BarChartView(
|
||||
/// dataTable: table,
|
||||
/// categoryColumn: "department",
|
||||
/// valueColumn: "total_sales"
|
||||
/// )
|
||||
/// ```
|
||||
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||
public struct BarChartView: View {
|
||||
|
||||
/// The data to chart.
|
||||
public let dataTable: DataTable
|
||||
|
||||
/// Column name for category labels (x-axis).
|
||||
public let categoryColumn: String
|
||||
|
||||
/// Column name for numeric values (y-axis).
|
||||
public let valueColumn: String
|
||||
|
||||
/// Optional chart title.
|
||||
public var title: String?
|
||||
|
||||
/// Maximum number of bars to display before truncating.
|
||||
public var maxBars: Int
|
||||
|
||||
public init(
|
||||
dataTable: DataTable,
|
||||
categoryColumn: String,
|
||||
valueColumn: String,
|
||||
title: String? = nil,
|
||||
maxBars: Int = 30
|
||||
) {
|
||||
self.dataTable = dataTable
|
||||
self.categoryColumn = categoryColumn
|
||||
self.valueColumn = valueColumn
|
||||
self.title = title
|
||||
self.maxBars = maxBars
|
||||
}
|
||||
|
||||
public var body: some View {
|
||||
VStack(alignment: .leading, spacing: 8) {
|
||||
if let title {
|
||||
Text(title)
|
||||
.font(.caption.weight(.semibold))
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
|
||||
if chartData.isEmpty {
|
||||
emptyChartView
|
||||
} else {
|
||||
chartContent
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Chart Content
|
||||
|
||||
@ViewBuilder
|
||||
private var chartContent: some View {
|
||||
Chart(chartData, id: \.label) { item in
|
||||
BarMark(
|
||||
x: .value(categoryColumn, item.label),
|
||||
y: .value(valueColumn, item.value)
|
||||
)
|
||||
.foregroundStyle(
|
||||
.linearGradient(
|
||||
colors: [.accentColor, .accentColor.opacity(0.7)],
|
||||
startPoint: .bottom,
|
||||
endPoint: .top
|
||||
)
|
||||
)
|
||||
.cornerRadius(4)
|
||||
}
|
||||
.chartXAxis {
|
||||
AxisMarks(values: .automatic) { _ in
|
||||
AxisValueLabel()
|
||||
.font(.caption2)
|
||||
}
|
||||
}
|
||||
.chartYAxis {
|
||||
AxisMarks(position: .leading) { _ in
|
||||
AxisGridLine(stroke: StrokeStyle(lineWidth: 0.5, dash: [4, 4]))
|
||||
.foregroundStyle(.secondary.opacity(0.3))
|
||||
AxisValueLabel()
|
||||
.font(.caption2)
|
||||
}
|
||||
}
|
||||
.frame(minHeight: 200)
|
||||
|
||||
if isTruncated {
|
||||
truncationNotice
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Empty State
|
||||
|
||||
@ViewBuilder
|
||||
private var emptyChartView: some View {
|
||||
VStack(spacing: 8) {
|
||||
Image(systemName: "chart.bar")
|
||||
.font(.title2)
|
||||
.foregroundStyle(.secondary)
|
||||
Text("No chartable data")
|
||||
.font(.caption)
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
.frame(maxWidth: .infinity, minHeight: 100)
|
||||
}
|
||||
|
||||
// MARK: - Truncation Notice
|
||||
|
||||
@ViewBuilder
|
||||
private var truncationNotice: some View {
|
||||
Text("Showing \(maxBars) of \(dataTable.rowCount) categories")
|
||||
.font(.caption2)
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
|
||||
// MARK: - Data Extraction
|
||||
|
||||
private var isTruncated: Bool {
|
||||
dataTable.rowCount > maxBars
|
||||
}
|
||||
|
||||
private var chartData: [ChartDataPoint] {
|
||||
let labels = dataTable.stringValues(forColumn: categoryColumn)
|
||||
let values = dataTable.numericValues(forColumn: valueColumn)
|
||||
|
||||
let count = min(labels.count, values.count, maxBars)
|
||||
guard count > 0 else { return [] }
|
||||
|
||||
return (0..<count).map { i in
|
||||
ChartDataPoint(label: labels[i], value: values[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Preview
|
||||
|
||||
#if DEBUG
|
||||
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||
#Preview("Bar Chart") {
|
||||
let columns: [DataTable.Column] = [
|
||||
.init(name: "department", index: 0, inferredType: .text),
|
||||
.init(name: "revenue", index: 1, inferredType: .real),
|
||||
]
|
||||
let departments = ["Engineering", "Sales", "Marketing", "Support", "Design"]
|
||||
let rows: [DataTable.Row] = departments.enumerated().map { i, dept in
|
||||
DataTable.Row(
|
||||
id: i,
|
||||
values: [.text(dept), .real(Double.random(in: 50_000...200_000))],
|
||||
columnNames: ["department", "revenue"]
|
||||
)
|
||||
}
|
||||
let table = DataTable(columns: columns, rows: rows)
|
||||
|
||||
BarChartView(
|
||||
dataTable: table,
|
||||
categoryColumn: "department",
|
||||
valueColumn: "revenue",
|
||||
title: "Revenue by Department"
|
||||
)
|
||||
.padding()
|
||||
.frame(height: 300)
|
||||
}
|
||||
#endif
|
||||
21
Sources/SwiftDBAI/Views/Charts/ChartDataPoint.swift
Normal file
21
Sources/SwiftDBAI/Views/Charts/ChartDataPoint.swift
Normal file
@@ -0,0 +1,21 @@
|
||||
// ChartDataPoint.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Shared data model used by all chart views.
|
||||
|
||||
import Foundation
|
||||
|
||||
/// A single data point for chart rendering.
|
||||
///
|
||||
/// Pairs a string label (category) with a numeric value.
|
||||
/// Used as the common data format across BarChartView,
|
||||
/// LineChartView, and PieChartView.
|
||||
struct ChartDataPoint: Sendable, Identifiable {
|
||||
var id: String { label }
|
||||
|
||||
/// The category label (x-axis or slice label).
|
||||
let label: String
|
||||
|
||||
/// The numeric value (y-axis or slice size).
|
||||
let value: Double
|
||||
}
|
||||
135
Sources/SwiftDBAI/Views/Charts/ChartResultView.swift
Normal file
135
Sources/SwiftDBAI/Views/Charts/ChartResultView.swift
Normal file
@@ -0,0 +1,135 @@
|
||||
// ChartResultView.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Auto-selecting chart view that uses ChartDataDetector to pick the
|
||||
// best chart type for a given DataTable.
|
||||
|
||||
import SwiftUI
|
||||
import Charts
|
||||
|
||||
/// A chart view that automatically selects the best chart type for a `DataTable`.
|
||||
///
|
||||
/// Uses `ChartDataDetector` to analyze the data shape and renders the
|
||||
/// appropriate chart (bar, line, or pie). If the data isn't suitable for
|
||||
/// charting, the view renders nothing.
|
||||
///
|
||||
/// Usage:
|
||||
/// ```swift
|
||||
/// ChartResultView(dataTable: myTable)
|
||||
/// ```
|
||||
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||
public struct ChartResultView: View {
|
||||
|
||||
/// The data table to chart.
|
||||
public let dataTable: DataTable
|
||||
|
||||
/// Optional override: force a specific chart type.
|
||||
public var chartType: ChartDataDetector.ChartType?
|
||||
|
||||
/// The detector used to analyze the data.
|
||||
private let detector: ChartDataDetector
|
||||
|
||||
public init(
|
||||
dataTable: DataTable,
|
||||
chartType: ChartDataDetector.ChartType? = nil,
|
||||
detector: ChartDataDetector = ChartDataDetector()
|
||||
) {
|
||||
self.dataTable = dataTable
|
||||
self.chartType = chartType
|
||||
self.detector = detector
|
||||
}
|
||||
|
||||
public var body: some View {
|
||||
if let recommendation = resolvedRecommendation {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
chartView(for: recommendation)
|
||||
|
||||
Text(recommendation.reason)
|
||||
.font(.caption2)
|
||||
.foregroundStyle(.tertiary)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Chart Selection
|
||||
|
||||
@ViewBuilder
|
||||
private func chartView(
|
||||
for recommendation: ChartDataDetector.ChartRecommendation
|
||||
) -> some View {
|
||||
switch recommendation.chartType {
|
||||
case .bar:
|
||||
BarChartView(
|
||||
dataTable: dataTable,
|
||||
categoryColumn: recommendation.categoryColumn,
|
||||
valueColumn: recommendation.valueColumn
|
||||
)
|
||||
case .line:
|
||||
LineChartView(
|
||||
dataTable: dataTable,
|
||||
categoryColumn: recommendation.categoryColumn,
|
||||
valueColumn: recommendation.valueColumn
|
||||
)
|
||||
case .pie:
|
||||
PieChartView(
|
||||
dataTable: dataTable,
|
||||
categoryColumn: recommendation.categoryColumn,
|
||||
valueColumn: recommendation.valueColumn
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Resolution
|
||||
|
||||
/// Resolves the chart recommendation, using the override type if provided.
|
||||
private var resolvedRecommendation: ChartDataDetector.ChartRecommendation? {
|
||||
if let override = chartType {
|
||||
// Use forced chart type — still need column pair from detector
|
||||
let all = detector.allRecommendations(for: dataTable)
|
||||
// Try to find recommendation for the forced type
|
||||
if let match = all.first(where: { $0.chartType == override }) {
|
||||
return match
|
||||
}
|
||||
// Fallback: use first recommendation and override its type
|
||||
if let first = all.first {
|
||||
return ChartDataDetector.ChartRecommendation(
|
||||
chartType: override,
|
||||
categoryColumn: first.categoryColumn,
|
||||
valueColumn: first.valueColumn,
|
||||
confidence: first.confidence * 0.8,
|
||||
reason: first.reason
|
||||
)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Auto-detect best chart type
|
||||
return detector.detect(dataTable)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Preview
|
||||
|
||||
#if DEBUG
|
||||
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||
#Preview("Auto Chart — Bar") {
|
||||
let columns: [DataTable.Column] = [
|
||||
.init(name: "city", index: 0, inferredType: .text),
|
||||
.init(name: "population", index: 1, inferredType: .integer),
|
||||
]
|
||||
let cities = ["NYC", "LA", "Chicago", "Houston", "Phoenix"]
|
||||
let pops: [Int64] = [8_336_817, 3_979_576, 2_693_976, 2_320_268, 1_680_992]
|
||||
let rows: [DataTable.Row] = cities.enumerated().map { i, city in
|
||||
DataTable.Row(
|
||||
id: i,
|
||||
values: [.text(city), .integer(pops[i])],
|
||||
columnNames: ["city", "population"]
|
||||
)
|
||||
}
|
||||
let table = DataTable(columns: columns, rows: rows)
|
||||
|
||||
ChartResultView(dataTable: table)
|
||||
.padding()
|
||||
.frame(height: 300)
|
||||
}
|
||||
#endif
|
||||
206
Sources/SwiftDBAI/Views/Charts/LineChartView.swift
Normal file
206
Sources/SwiftDBAI/Views/Charts/LineChartView.swift
Normal file
@@ -0,0 +1,206 @@
|
||||
// LineChartView.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// A SwiftUI line chart that renders DataTable values using Swift Charts.
|
||||
// Best for time series or sequential data (e.g., revenue over months).
|
||||
|
||||
import SwiftUI
|
||||
import Charts
|
||||
|
||||
/// A line chart view that renders a `DataTable` column pair using Swift Charts.
|
||||
///
|
||||
/// Displays a connected line with optional area fill, point markers,
|
||||
/// and smooth interpolation. Best suited for time series or sequential data.
|
||||
///
|
||||
/// Usage:
|
||||
/// ```swift
|
||||
/// LineChartView(
|
||||
/// dataTable: table,
|
||||
/// categoryColumn: "month",
|
||||
/// valueColumn: "revenue"
|
||||
/// )
|
||||
/// ```
|
||||
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||
public struct LineChartView: View {
|
||||
|
||||
/// The data to chart.
|
||||
public let dataTable: DataTable
|
||||
|
||||
/// Column name for category/time labels (x-axis).
|
||||
public let categoryColumn: String
|
||||
|
||||
/// Column name for numeric values (y-axis).
|
||||
public let valueColumn: String
|
||||
|
||||
/// Optional chart title.
|
||||
public var title: String?
|
||||
|
||||
/// Whether to show an area fill below the line.
|
||||
public var showAreaFill: Bool
|
||||
|
||||
/// Whether to show point markers at each data point.
|
||||
public var showPoints: Bool
|
||||
|
||||
/// Maximum data points to display.
|
||||
public var maxPoints: Int
|
||||
|
||||
public init(
|
||||
dataTable: DataTable,
|
||||
categoryColumn: String,
|
||||
valueColumn: String,
|
||||
title: String? = nil,
|
||||
showAreaFill: Bool = true,
|
||||
showPoints: Bool = true,
|
||||
maxPoints: Int = 100
|
||||
) {
|
||||
self.dataTable = dataTable
|
||||
self.categoryColumn = categoryColumn
|
||||
self.valueColumn = valueColumn
|
||||
self.title = title
|
||||
self.showAreaFill = showAreaFill
|
||||
self.showPoints = showPoints
|
||||
self.maxPoints = maxPoints
|
||||
}
|
||||
|
||||
public var body: some View {
|
||||
VStack(alignment: .leading, spacing: 8) {
|
||||
if let title {
|
||||
Text(title)
|
||||
.font(.caption.weight(.semibold))
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
|
||||
if chartData.isEmpty {
|
||||
emptyChartView
|
||||
} else {
|
||||
chartContent
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Chart Content
|
||||
|
||||
@ViewBuilder
|
||||
private var chartContent: some View {
|
||||
Chart(chartData, id: \.label) { item in
|
||||
LineMark(
|
||||
x: .value(categoryColumn, item.label),
|
||||
y: .value(valueColumn, item.value)
|
||||
)
|
||||
.foregroundStyle(Color.accentColor)
|
||||
.lineStyle(StrokeStyle(lineWidth: 2))
|
||||
.interpolationMethod(.catmullRom)
|
||||
|
||||
if showAreaFill {
|
||||
AreaMark(
|
||||
x: .value(categoryColumn, item.label),
|
||||
y: .value(valueColumn, item.value)
|
||||
)
|
||||
.foregroundStyle(
|
||||
.linearGradient(
|
||||
colors: [
|
||||
Color.accentColor.opacity(0.2),
|
||||
Color.accentColor.opacity(0.02),
|
||||
],
|
||||
startPoint: .top,
|
||||
endPoint: .bottom
|
||||
)
|
||||
)
|
||||
.interpolationMethod(.catmullRom)
|
||||
}
|
||||
|
||||
if showPoints {
|
||||
PointMark(
|
||||
x: .value(categoryColumn, item.label),
|
||||
y: .value(valueColumn, item.value)
|
||||
)
|
||||
.foregroundStyle(Color.accentColor)
|
||||
.symbolSize(30)
|
||||
}
|
||||
}
|
||||
.chartXAxis {
|
||||
AxisMarks(values: .automatic) { _ in
|
||||
AxisValueLabel()
|
||||
.font(.caption2)
|
||||
}
|
||||
}
|
||||
.chartYAxis {
|
||||
AxisMarks(position: .leading) { _ in
|
||||
AxisGridLine(stroke: StrokeStyle(lineWidth: 0.5, dash: [4, 4]))
|
||||
.foregroundStyle(.secondary.opacity(0.3))
|
||||
AxisValueLabel()
|
||||
.font(.caption2)
|
||||
}
|
||||
}
|
||||
.frame(minHeight: 200)
|
||||
|
||||
if isTruncated {
|
||||
Text("Showing \(maxPoints) of \(dataTable.rowCount) data points")
|
||||
.font(.caption2)
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Empty State
|
||||
|
||||
@ViewBuilder
|
||||
private var emptyChartView: some View {
|
||||
VStack(spacing: 8) {
|
||||
Image(systemName: "chart.xyaxis.line")
|
||||
.font(.title2)
|
||||
.foregroundStyle(.secondary)
|
||||
Text("No chartable data")
|
||||
.font(.caption)
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
.frame(maxWidth: .infinity, minHeight: 100)
|
||||
}
|
||||
|
||||
// MARK: - Data Extraction
|
||||
|
||||
private var isTruncated: Bool {
|
||||
dataTable.rowCount > maxPoints
|
||||
}
|
||||
|
||||
private var chartData: [ChartDataPoint] {
|
||||
let labels = dataTable.stringValues(forColumn: categoryColumn)
|
||||
let values = dataTable.numericValues(forColumn: valueColumn)
|
||||
|
||||
let count = min(labels.count, values.count, maxPoints)
|
||||
guard count > 0 else { return [] }
|
||||
|
||||
return (0..<count).map { i in
|
||||
ChartDataPoint(label: labels[i], value: values[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Preview
|
||||
|
||||
#if DEBUG
|
||||
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||
#Preview("Line Chart") {
|
||||
let columns: [DataTable.Column] = [
|
||||
.init(name: "month", index: 0, inferredType: .text),
|
||||
.init(name: "revenue", index: 1, inferredType: .real),
|
||||
]
|
||||
let months = ["Jan", "Feb", "Mar", "Apr", "May", "Jun"]
|
||||
let rows: [DataTable.Row] = months.enumerated().map { i, month in
|
||||
DataTable.Row(
|
||||
id: i,
|
||||
values: [.text(month), .real(Double(i + 1) * 15_000 + Double.random(in: -3000...3000))],
|
||||
columnNames: ["month", "revenue"]
|
||||
)
|
||||
}
|
||||
let table = DataTable(columns: columns, rows: rows)
|
||||
|
||||
LineChartView(
|
||||
dataTable: table,
|
||||
categoryColumn: "month",
|
||||
valueColumn: "revenue",
|
||||
title: "Monthly Revenue"
|
||||
)
|
||||
.padding()
|
||||
.frame(height: 300)
|
||||
}
|
||||
#endif
|
||||
234
Sources/SwiftDBAI/Views/Charts/PieChartView.swift
Normal file
234
Sources/SwiftDBAI/Views/Charts/PieChartView.swift
Normal file
@@ -0,0 +1,234 @@
|
||||
// PieChartView.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// A SwiftUI pie/donut chart that renders DataTable values using Swift Charts.
|
||||
// Best for proportional breakdowns with few categories (e.g., market share).
|
||||
|
||||
import SwiftUI
|
||||
import Charts
|
||||
|
||||
/// A pie chart view that renders a `DataTable` column pair using Swift Charts.
|
||||
///
|
||||
/// Displays proportional slices with category labels. Each slice is
|
||||
/// automatically colored from a curated palette and sized relative to
|
||||
/// its proportion of the total. Best suited for data with few categories
|
||||
/// (≤ 8) where all values are positive.
|
||||
///
|
||||
/// Usage:
|
||||
/// ```swift
|
||||
/// PieChartView(
|
||||
/// dataTable: table,
|
||||
/// categoryColumn: "status",
|
||||
/// valueColumn: "count"
|
||||
/// )
|
||||
/// ```
|
||||
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||
public struct PieChartView: View {
|
||||
|
||||
/// The data to chart.
|
||||
public let dataTable: DataTable
|
||||
|
||||
/// Column name for category labels (slice labels).
|
||||
public let categoryColumn: String
|
||||
|
||||
/// Column name for numeric values (slice sizes).
|
||||
public let valueColumn: String
|
||||
|
||||
/// Optional chart title.
|
||||
public var title: String?
|
||||
|
||||
/// Inner radius ratio for donut style (0 = full pie, >0 = donut).
|
||||
public var innerRadiusRatio: CGFloat
|
||||
|
||||
/// Maximum number of slices before grouping remaining into "Other".
|
||||
public var maxSlices: Int
|
||||
|
||||
public init(
|
||||
dataTable: DataTable,
|
||||
categoryColumn: String,
|
||||
valueColumn: String,
|
||||
title: String? = nil,
|
||||
innerRadiusRatio: CGFloat = 0.4,
|
||||
maxSlices: Int = 8
|
||||
) {
|
||||
self.dataTable = dataTable
|
||||
self.categoryColumn = categoryColumn
|
||||
self.valueColumn = valueColumn
|
||||
self.title = title
|
||||
self.innerRadiusRatio = innerRadiusRatio
|
||||
self.maxSlices = maxSlices
|
||||
}
|
||||
|
||||
public var body: some View {
|
||||
VStack(alignment: .leading, spacing: 8) {
|
||||
if let title {
|
||||
Text(title)
|
||||
.font(.caption.weight(.semibold))
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
|
||||
if chartData.isEmpty {
|
||||
emptyChartView
|
||||
} else {
|
||||
HStack(alignment: .center, spacing: 16) {
|
||||
chartContent
|
||||
legendView
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Chart Content
|
||||
|
||||
@ViewBuilder
|
||||
private var chartContent: some View {
|
||||
Chart(chartData, id: \.label) { item in
|
||||
SectorMark(
|
||||
angle: .value(valueColumn, item.value),
|
||||
innerRadius: .ratio(innerRadiusRatio),
|
||||
angularInset: 1.5
|
||||
)
|
||||
.foregroundStyle(by: .value(categoryColumn, item.label))
|
||||
.cornerRadius(3)
|
||||
}
|
||||
.chartForegroundStyleScale(
|
||||
domain: chartData.map(\.label),
|
||||
range: sliceColors
|
||||
)
|
||||
.chartLegend(.hidden)
|
||||
.frame(minWidth: 150, minHeight: 150)
|
||||
.aspectRatio(1, contentMode: .fit)
|
||||
}
|
||||
|
||||
// MARK: - Legend
|
||||
|
||||
@ViewBuilder
|
||||
private var legendView: some View {
|
||||
VStack(alignment: .leading, spacing: 6) {
|
||||
ForEach(Array(chartData.enumerated()), id: \.element.label) { index, item in
|
||||
HStack(spacing: 8) {
|
||||
Circle()
|
||||
.fill(sliceColors[index % sliceColors.count])
|
||||
.frame(width: 8, height: 8)
|
||||
|
||||
Text(item.label)
|
||||
.font(.caption)
|
||||
.foregroundStyle(.primary)
|
||||
.lineLimit(1)
|
||||
|
||||
Spacer()
|
||||
|
||||
Text(percentageText(for: item.value))
|
||||
.font(.caption)
|
||||
.foregroundStyle(.secondary)
|
||||
.monospacedDigit()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Empty State
|
||||
|
||||
@ViewBuilder
|
||||
private var emptyChartView: some View {
|
||||
VStack(spacing: 8) {
|
||||
Image(systemName: "chart.pie")
|
||||
.font(.title2)
|
||||
.foregroundStyle(.secondary)
|
||||
Text("No chartable data")
|
||||
.font(.caption)
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
.frame(maxWidth: .infinity, minHeight: 100)
|
||||
}
|
||||
|
||||
// MARK: - Colors
|
||||
|
||||
/// Curated color palette for pie slices.
|
||||
private var sliceColors: [Color] {
|
||||
[
|
||||
.blue,
|
||||
.green,
|
||||
.orange,
|
||||
.purple,
|
||||
.pink,
|
||||
.cyan,
|
||||
.yellow,
|
||||
.indigo,
|
||||
.mint,
|
||||
.teal,
|
||||
]
|
||||
}
|
||||
|
||||
// MARK: - Helpers
|
||||
|
||||
private var total: Double {
|
||||
chartData.reduce(0) { $0 + $1.value }
|
||||
}
|
||||
|
||||
private func percentageText(for value: Double) -> String {
|
||||
guard total > 0 else { return "0%" }
|
||||
let pct = (value / total) * 100
|
||||
if pct >= 10 {
|
||||
return String(format: "%.0f%%", pct)
|
||||
}
|
||||
return String(format: "%.1f%%", pct)
|
||||
}
|
||||
|
||||
// MARK: - Data Extraction
|
||||
|
||||
private var chartData: [ChartDataPoint] {
|
||||
let labels = dataTable.stringValues(forColumn: categoryColumn)
|
||||
let values = dataTable.numericValues(forColumn: valueColumn)
|
||||
|
||||
let count = min(labels.count, values.count)
|
||||
guard count > 0 else { return [] }
|
||||
|
||||
// Build all points, sorted by value descending
|
||||
var points = (0..<count).map { i in
|
||||
ChartDataPoint(label: labels[i], value: values[i])
|
||||
}
|
||||
.filter { $0.value > 0 }
|
||||
.sorted { $0.value > $1.value }
|
||||
|
||||
// Group excess slices into "Other"
|
||||
if points.count > maxSlices {
|
||||
let kept = Array(points.prefix(maxSlices - 1))
|
||||
let otherValue = points.dropFirst(maxSlices - 1).reduce(0) { $0 + $1.value }
|
||||
points = kept + [ChartDataPoint(label: "Other", value: otherValue)]
|
||||
}
|
||||
|
||||
return points
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Preview
|
||||
|
||||
#if DEBUG
|
||||
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||
#Preview("Pie Chart") {
|
||||
let columns: [DataTable.Column] = [
|
||||
.init(name: "status", index: 0, inferredType: .text),
|
||||
.init(name: "count", index: 1, inferredType: .integer),
|
||||
]
|
||||
let statuses = ["Active", "Inactive", "Pending", "Archived"]
|
||||
let counts: [Int64] = [45, 20, 15, 10]
|
||||
let rows: [DataTable.Row] = statuses.enumerated().map { i, status in
|
||||
DataTable.Row(
|
||||
id: i,
|
||||
values: [.text(status), .integer(counts[i])],
|
||||
columnNames: ["status", "count"]
|
||||
)
|
||||
}
|
||||
let table = DataTable(columns: columns, rows: rows)
|
||||
|
||||
PieChartView(
|
||||
dataTable: table,
|
||||
categoryColumn: "status",
|
||||
valueColumn: "count",
|
||||
title: "Users by Status"
|
||||
)
|
||||
.padding()
|
||||
.frame(height: 250)
|
||||
}
|
||||
#endif
|
||||
214
Sources/SwiftDBAI/Views/ChatView.swift
Normal file
214
Sources/SwiftDBAI/Views/ChatView.swift
Normal file
@@ -0,0 +1,214 @@
|
||||
// ChatView.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Drop-in SwiftUI view for chatting with a SQLite database.
|
||||
// Renders messages with automatic data table display for query results.
|
||||
|
||||
import SwiftUI
|
||||
|
||||
/// A drop-in SwiftUI chat interface for querying SQLite databases
|
||||
/// with natural language.
|
||||
///
|
||||
/// `ChatView` renders the full conversation including:
|
||||
/// - User messages (right-aligned, accent-colored)
|
||||
/// - Assistant responses with text summaries
|
||||
/// - **Automatic data tables** via `ScrollableDataTableView` when query results
|
||||
/// contain tabular data (rows + columns)
|
||||
/// - SQL query disclosure for transparency
|
||||
/// - Error messages with red styling
|
||||
/// - A loading indicator while the engine is processing
|
||||
///
|
||||
/// Usage:
|
||||
/// ```swift
|
||||
/// let engine = ChatEngine(database: myPool, model: myModel)
|
||||
/// let viewModel = ChatViewModel(engine: engine)
|
||||
///
|
||||
/// ChatView(viewModel: viewModel)
|
||||
/// ```
|
||||
///
|
||||
/// Or use the convenience initializer:
|
||||
/// ```swift
|
||||
/// ChatView(engine: myEngine)
|
||||
/// ```
|
||||
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||
public struct ChatView: View {
|
||||
@Bindable private var viewModel: ChatViewModel
|
||||
@State private var inputText: String = ""
|
||||
@FocusState private var isInputFocused: Bool
|
||||
|
||||
/// Creates a ChatView with an existing view model.
|
||||
///
|
||||
/// - Parameter viewModel: The `ChatViewModel` driving this view.
|
||||
public init(viewModel: ChatViewModel) {
|
||||
self.viewModel = viewModel
|
||||
}
|
||||
|
||||
/// Creates a ChatView with a `ChatEngine`, automatically creating
|
||||
/// a `ChatViewModel`.
|
||||
///
|
||||
/// - Parameter engine: The `ChatEngine` to power the chat.
|
||||
public init(engine: ChatEngine) {
|
||||
self.viewModel = ChatViewModel(engine: engine)
|
||||
}
|
||||
|
||||
public var body: some View {
|
||||
VStack(spacing: 0) {
|
||||
messageList
|
||||
Divider()
|
||||
inputBar
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Message List
|
||||
|
||||
@ViewBuilder
|
||||
private var messageList: some View {
|
||||
ScrollViewReader { proxy in
|
||||
ScrollView {
|
||||
LazyVStack(spacing: 12) {
|
||||
if viewModel.messages.isEmpty {
|
||||
emptyState
|
||||
}
|
||||
|
||||
ForEach(viewModel.messages) { message in
|
||||
messageBubble(for: message)
|
||||
.id(message.id)
|
||||
}
|
||||
|
||||
if viewModel.isLoading {
|
||||
loadingIndicator
|
||||
}
|
||||
}
|
||||
.padding(.horizontal, 16)
|
||||
.padding(.vertical, 12)
|
||||
}
|
||||
.onChange(of: viewModel.messages.count) { _, _ in
|
||||
if let lastMessage = viewModel.messages.last {
|
||||
withAnimation(.easeOut(duration: 0.3)) {
|
||||
proxy.scrollTo(lastMessage.id, anchor: .bottom)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Empty State
|
||||
|
||||
@ViewBuilder
|
||||
private var emptyState: some View {
|
||||
VStack(spacing: 12) {
|
||||
Image(systemName: "bubble.left.and.text.bubble.right")
|
||||
.font(.system(size: 40))
|
||||
.foregroundStyle(.tertiary)
|
||||
Text("Ask a question about your data")
|
||||
.font(.headline)
|
||||
.foregroundStyle(.secondary)
|
||||
Text("Try something like \"How many records are in the database?\"")
|
||||
.font(.subheadline)
|
||||
.foregroundStyle(.tertiary)
|
||||
.multilineTextAlignment(.center)
|
||||
}
|
||||
.frame(maxWidth: .infinity)
|
||||
.padding(.vertical, 60)
|
||||
}
|
||||
|
||||
// MARK: - Loading Indicator
|
||||
|
||||
@ViewBuilder
|
||||
private var loadingIndicator: some View {
|
||||
HStack(alignment: .top) {
|
||||
HStack(spacing: 8) {
|
||||
ProgressView()
|
||||
.controlSize(.small)
|
||||
Text("Querying…")
|
||||
.font(.callout)
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
.padding(.horizontal, 14)
|
||||
.padding(.vertical, 10)
|
||||
.background(
|
||||
Self.assistantBackgroundColor,
|
||||
in: RoundedRectangle(cornerRadius: 16, style: .continuous)
|
||||
)
|
||||
|
||||
Spacer(minLength: 48)
|
||||
}
|
||||
.id("loading-indicator")
|
||||
.transition(.opacity.combined(with: .move(edge: .bottom)))
|
||||
}
|
||||
|
||||
private static var assistantBackgroundColor: Color {
|
||||
#if os(macOS)
|
||||
Color(nsColor: .controlBackgroundColor)
|
||||
#else
|
||||
Color(uiColor: .secondarySystemGroupedBackground)
|
||||
#endif
|
||||
}
|
||||
|
||||
// MARK: - Input Bar
|
||||
|
||||
@ViewBuilder
|
||||
private var inputBar: some View {
|
||||
HStack(spacing: 8) {
|
||||
TextField("Ask about your data…", text: $inputText, axis: .vertical)
|
||||
.textFieldStyle(.plain)
|
||||
.lineLimit(1...5)
|
||||
.focused($isInputFocused)
|
||||
.onSubmit { sendMessage() }
|
||||
.submitLabel(.send)
|
||||
|
||||
Button(action: sendMessage) {
|
||||
Image(systemName: "arrow.up.circle.fill")
|
||||
.font(.title2)
|
||||
.foregroundStyle(canSend ? Color.accentColor : Color.secondary)
|
||||
}
|
||||
.disabled(!canSend)
|
||||
.keyboardShortcut(.return, modifiers: .command)
|
||||
}
|
||||
.padding(.horizontal, 16)
|
||||
.padding(.vertical, 10)
|
||||
}
|
||||
|
||||
// MARK: - Message Bubble
|
||||
|
||||
@ViewBuilder
|
||||
private func messageBubble(for message: ChatMessage) -> some View {
|
||||
if message.role == .error {
|
||||
MessageBubbleView(
|
||||
message: message,
|
||||
onRetry: makeRetryAction(for: message)
|
||||
)
|
||||
} else {
|
||||
MessageBubbleView(message: message)
|
||||
}
|
||||
}
|
||||
|
||||
private func makeRetryAction(for errorMessage: ChatMessage) -> @Sendable () async -> Void {
|
||||
let vm = viewModel
|
||||
let messageId = errorMessage.id
|
||||
return { @MainActor [vm] in
|
||||
let allMessages = await MainActor.run { vm.messages }
|
||||
if let lastUserMessage = allMessages
|
||||
.prefix(while: { $0.id != messageId })
|
||||
.last(where: { $0.role == .user }) {
|
||||
await vm.send(lastUserMessage.content)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Helpers
|
||||
|
||||
private var canSend: Bool {
|
||||
!inputText.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty && !viewModel.isLoading
|
||||
}
|
||||
|
||||
private func sendMessage() {
|
||||
guard canSend else { return }
|
||||
let text = inputText
|
||||
inputText = ""
|
||||
|
||||
Task {
|
||||
await viewModel.send(text)
|
||||
}
|
||||
}
|
||||
}
|
||||
137
Sources/SwiftDBAI/Views/ChatViewModel.swift
Normal file
137
Sources/SwiftDBAI/Views/ChatViewModel.swift
Normal file
@@ -0,0 +1,137 @@
|
||||
// ChatViewModel.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Observable view model that bridges ChatEngine with the SwiftUI ChatView.
|
||||
|
||||
import Foundation
|
||||
import Observation
|
||||
|
||||
/// The readiness state of the schema introspection.
|
||||
public enum SchemaReadiness: Sendable, Equatable {
|
||||
/// Schema has not been loaded yet.
|
||||
case idle
|
||||
/// Schema introspection is in progress.
|
||||
case loading
|
||||
/// Schema is ready with the given number of tables.
|
||||
case ready(tableCount: Int)
|
||||
/// Schema introspection failed.
|
||||
case failed(String)
|
||||
|
||||
public var isReady: Bool {
|
||||
if case .ready = self { return true }
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/// Observable view model that drives the `ChatView`.
|
||||
///
|
||||
/// Wraps `ChatEngine` to provide reactive state updates for the SwiftUI layer.
|
||||
/// Manages the message list, loading state, error presentation, and schema
|
||||
/// readiness. Call ``prepare()`` at view-appear time to eagerly introspect the
|
||||
/// database schema.
|
||||
///
|
||||
/// Usage:
|
||||
/// ```swift
|
||||
/// let viewModel = ChatViewModel(engine: myChatEngine)
|
||||
/// ChatView(viewModel: viewModel)
|
||||
/// ```
|
||||
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||
@Observable
|
||||
@MainActor
|
||||
public final class ChatViewModel {
|
||||
|
||||
// MARK: - Public State
|
||||
|
||||
/// All messages in the conversation, in chronological order.
|
||||
public private(set) var messages: [ChatMessage] = []
|
||||
|
||||
/// Whether the engine is currently processing a request.
|
||||
public private(set) var isLoading: Bool = false
|
||||
|
||||
/// The most recent error message, if any. Cleared on next send.
|
||||
public private(set) var errorMessage: String?
|
||||
|
||||
/// Current schema readiness state.
|
||||
public private(set) var schemaReadiness: SchemaReadiness = .idle
|
||||
|
||||
// MARK: - Dependencies
|
||||
|
||||
private let engine: ChatEngine
|
||||
|
||||
// MARK: - Initialization
|
||||
|
||||
/// Creates a new ChatViewModel.
|
||||
///
|
||||
/// - Parameter engine: The `ChatEngine` to use for processing messages.
|
||||
public init(engine: ChatEngine) {
|
||||
self.engine = engine
|
||||
}
|
||||
|
||||
// MARK: - Schema Preparation
|
||||
|
||||
/// Eagerly introspects the database schema so it's ready before the first query.
|
||||
///
|
||||
/// This should be called from a `.task` modifier on the view. It transitions
|
||||
/// `schemaReadiness` through `.loading` → `.ready` (or `.failed`).
|
||||
/// If the schema is already cached, this completes immediately.
|
||||
public func prepare() async {
|
||||
// Don't re-prepare if already ready
|
||||
if schemaReadiness.isReady { return }
|
||||
|
||||
schemaReadiness = .loading
|
||||
|
||||
do {
|
||||
let schema = try await engine.prepareSchema()
|
||||
schemaReadiness = .ready(tableCount: schema.tableNames.count)
|
||||
} catch {
|
||||
schemaReadiness = .failed(error.localizedDescription)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Public API
|
||||
|
||||
/// Sends a user message and appends the response to the conversation.
|
||||
///
|
||||
/// - Parameter text: The natural language message from the user.
|
||||
public func send(_ text: String) async {
|
||||
let trimmed = text.trimmingCharacters(in: .whitespacesAndNewlines)
|
||||
guard !trimmed.isEmpty else { return }
|
||||
|
||||
errorMessage = nil
|
||||
|
||||
// Add user message immediately
|
||||
let userMessage = ChatMessage(role: .user, content: trimmed)
|
||||
messages.append(userMessage)
|
||||
|
||||
isLoading = true
|
||||
defer { isLoading = false }
|
||||
|
||||
do {
|
||||
let response = try await engine.send(trimmed)
|
||||
|
||||
let assistantMessage = ChatMessage(
|
||||
role: .assistant,
|
||||
content: response.summary,
|
||||
queryResult: response.queryResult,
|
||||
sql: response.sql
|
||||
)
|
||||
messages.append(assistantMessage)
|
||||
} catch {
|
||||
let typedError = (error as? SwiftDBAIError)
|
||||
let errorMsg = ChatMessage(
|
||||
role: .error,
|
||||
content: error.localizedDescription,
|
||||
error: typedError
|
||||
)
|
||||
messages.append(errorMsg)
|
||||
errorMessage = error.localizedDescription
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the conversation and resets the engine state.
|
||||
public func reset() {
|
||||
messages.removeAll()
|
||||
errorMessage = nil
|
||||
engine.reset()
|
||||
}
|
||||
}
|
||||
220
Sources/SwiftDBAI/Views/DataChatView.swift
Normal file
220
Sources/SwiftDBAI/Views/DataChatView.swift
Normal file
@@ -0,0 +1,220 @@
|
||||
// DataChatView.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Zero-config SwiftUI view: provide a database path and a model, get a chat UI.
|
||||
|
||||
import AnyLanguageModel
|
||||
import GRDB
|
||||
import SwiftUI
|
||||
|
||||
/// A convenience SwiftUI view that wraps the full chat-with-database stack.
|
||||
///
|
||||
/// `DataChatView` is the simplest entry point into SwiftDBAI. It requires only
|
||||
/// a database file path and a language model — no schema files, no annotations,
|
||||
/// no manual setup. The view creates a GRDB connection, a `ChatEngine`,
|
||||
/// a `ChatViewModel`, and renders a fully functional `ChatView`.
|
||||
///
|
||||
/// Usage with just a path and model:
|
||||
/// ```swift
|
||||
/// DataChatView(
|
||||
/// databasePath: "/path/to/mydata.sqlite",
|
||||
/// model: OllamaLanguageModel(model: "llama3")
|
||||
/// )
|
||||
/// ```
|
||||
///
|
||||
/// Usage with additional configuration:
|
||||
/// ```swift
|
||||
/// DataChatView(
|
||||
/// databasePath: documentsURL.appendingPathComponent("app.db").path,
|
||||
/// model: OpenAILanguageModel(apiKey: key),
|
||||
/// allowlist: .standard,
|
||||
/// additionalContext: "This database stores a recipe app's data."
|
||||
/// )
|
||||
/// ```
|
||||
///
|
||||
/// If you already have a GRDB `DatabasePool` or `DatabaseQueue`, use
|
||||
/// `ChatView` with a `ChatEngine` directly for full control.
|
||||
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||
public struct DataChatView: View {
|
||||
@State private var viewModel: ChatViewModel
|
||||
@State private var loadError: DataChatError?
|
||||
|
||||
/// Creates a DataChatView from a database file path and language model.
|
||||
///
|
||||
/// This is the zero-config convenience initializer. It opens a GRDB
|
||||
/// `DatabasePool` at the given path, creates a `ChatEngine` with
|
||||
/// read-only defaults, and wires up the full chat UI.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - databasePath: Absolute path to a SQLite database file.
|
||||
/// - model: Any `AnyLanguageModel`-compatible language model instance.
|
||||
/// - allowlist: SQL operations the LLM may generate. Defaults to `.readOnly` (SELECT only).
|
||||
/// - additionalContext: Optional extra context about the database for the LLM system prompt
|
||||
/// (e.g., "This database stores e-commerce orders and products.").
|
||||
/// - maxSummaryRows: Maximum rows to include when summarizing results (default: 50).
|
||||
public init(
|
||||
databasePath: String,
|
||||
model: any LanguageModel,
|
||||
allowlist: OperationAllowlist = .readOnly,
|
||||
additionalContext: String? = nil,
|
||||
maxSummaryRows: Int = 50
|
||||
) {
|
||||
do {
|
||||
let pool = try DatabasePool(path: databasePath)
|
||||
let engine = ChatEngine(
|
||||
database: pool,
|
||||
model: model,
|
||||
allowlist: allowlist,
|
||||
additionalContext: additionalContext,
|
||||
maxSummaryRows: maxSummaryRows
|
||||
)
|
||||
self._viewModel = State(initialValue: ChatViewModel(engine: engine))
|
||||
self._loadError = State(initialValue: nil)
|
||||
} catch {
|
||||
// If the database can't be opened, create a placeholder engine
|
||||
// and store the error to display in the UI.
|
||||
let queue = try! DatabaseQueue()
|
||||
let engine = ChatEngine(
|
||||
database: queue,
|
||||
model: model,
|
||||
allowlist: allowlist,
|
||||
additionalContext: additionalContext,
|
||||
maxSummaryRows: maxSummaryRows
|
||||
)
|
||||
self._viewModel = State(initialValue: ChatViewModel(engine: engine))
|
||||
self._loadError = State(initialValue: DataChatError.databaseOpenFailed(
|
||||
path: databasePath,
|
||||
underlying: error
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a DataChatView from an existing GRDB database connection and language model.
|
||||
///
|
||||
/// Use this initializer when you already have a configured `DatabasePool` or
|
||||
/// `DatabaseQueue` and want the convenience of `DataChatView` without
|
||||
/// creating a `ChatEngine` yourself.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - database: A GRDB `DatabaseWriter` (`DatabasePool` or `DatabaseQueue`).
|
||||
/// - model: Any `AnyLanguageModel`-compatible language model instance.
|
||||
/// - allowlist: SQL operations the LLM may generate. Defaults to `.readOnly`.
|
||||
/// - additionalContext: Optional extra context about the database for the LLM.
|
||||
/// - maxSummaryRows: Maximum rows to include when summarizing results (default: 50).
|
||||
public init(
|
||||
database: any DatabaseWriter,
|
||||
model: any LanguageModel,
|
||||
allowlist: OperationAllowlist = .readOnly,
|
||||
additionalContext: String? = nil,
|
||||
maxSummaryRows: Int = 50
|
||||
) {
|
||||
let engine = ChatEngine(
|
||||
database: database,
|
||||
model: model,
|
||||
allowlist: allowlist,
|
||||
additionalContext: additionalContext,
|
||||
maxSummaryRows: maxSummaryRows
|
||||
)
|
||||
self._viewModel = State(initialValue: ChatViewModel(engine: engine))
|
||||
self._loadError = State(initialValue: nil)
|
||||
}
|
||||
|
||||
public var body: some View {
|
||||
if let error = loadError {
|
||||
errorView(error)
|
||||
} else {
|
||||
ChatView(viewModel: viewModel)
|
||||
.task {
|
||||
await viewModel.prepare()
|
||||
}
|
||||
.overlay {
|
||||
if case .loading = viewModel.schemaReadiness {
|
||||
schemaLoadingView
|
||||
}
|
||||
if case .failed(let reason) = viewModel.schemaReadiness {
|
||||
schemaErrorView(reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Schema Loading View
|
||||
|
||||
@ViewBuilder
|
||||
private var schemaLoadingView: some View {
|
||||
VStack(spacing: 16) {
|
||||
ProgressView()
|
||||
.controlSize(.large)
|
||||
Text("Introspecting database schema…")
|
||||
.font(.subheadline)
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
.frame(maxWidth: .infinity, maxHeight: .infinity)
|
||||
.background(.ultraThinMaterial)
|
||||
}
|
||||
|
||||
// MARK: - Schema Error View
|
||||
|
||||
@ViewBuilder
|
||||
private func schemaErrorView(_ reason: String) -> some View {
|
||||
VStack(spacing: 16) {
|
||||
Image(systemName: "exclamationmark.triangle.fill")
|
||||
.font(.system(size: 40))
|
||||
.foregroundStyle(.orange)
|
||||
|
||||
Text("Schema Introspection Failed")
|
||||
.font(.headline)
|
||||
|
||||
Text(reason)
|
||||
.font(.subheadline)
|
||||
.foregroundStyle(.secondary)
|
||||
.multilineTextAlignment(.center)
|
||||
.padding(.horizontal, 32)
|
||||
|
||||
Button("Retry") {
|
||||
Task {
|
||||
await viewModel.prepare()
|
||||
}
|
||||
}
|
||||
.buttonStyle(.borderedProminent)
|
||||
}
|
||||
.frame(maxWidth: .infinity, maxHeight: .infinity)
|
||||
.background(.ultraThinMaterial)
|
||||
}
|
||||
|
||||
// MARK: - Database Open Error View
|
||||
|
||||
@ViewBuilder
|
||||
private func errorView(_ error: DataChatError) -> some View {
|
||||
VStack(spacing: 16) {
|
||||
Image(systemName: "exclamationmark.triangle.fill")
|
||||
.font(.system(size: 40))
|
||||
.foregroundStyle(.red)
|
||||
|
||||
Text("Unable to Open Database")
|
||||
.font(.headline)
|
||||
|
||||
Text(error.localizedDescription)
|
||||
.font(.subheadline)
|
||||
.foregroundStyle(.secondary)
|
||||
.multilineTextAlignment(.center)
|
||||
.padding(.horizontal, 32)
|
||||
}
|
||||
.frame(maxWidth: .infinity, maxHeight: .infinity)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Errors
|
||||
|
||||
/// Errors specific to `DataChatView` initialization.
|
||||
public enum DataChatError: Error, LocalizedError, Sendable {
|
||||
/// The database file could not be opened at the given path.
|
||||
case databaseOpenFailed(path: String, underlying: any Error)
|
||||
|
||||
public var errorDescription: String? {
|
||||
switch self {
|
||||
case .databaseOpenFailed(let path, let underlying):
|
||||
return "Could not open database at \"\(path)\": \(underlying.localizedDescription)"
|
||||
}
|
||||
}
|
||||
}
|
||||
360
Sources/SwiftDBAI/Views/ErrorMessageView.swift
Normal file
360
Sources/SwiftDBAI/Views/ErrorMessageView.swift
Normal file
@@ -0,0 +1,360 @@
|
||||
// ErrorMessageView.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Reusable SwiftUI component that renders error messages with contextual
|
||||
// icons, descriptions, and optional retry actions based on the error type.
|
||||
|
||||
import SwiftUI
|
||||
|
||||
/// A reusable SwiftUI component that renders a ``SwiftDBAIError`` with an
|
||||
/// appropriate icon, human-readable message, and optional retry action.
|
||||
///
|
||||
/// The view automatically selects a visual treatment based on the error
|
||||
/// category:
|
||||
///
|
||||
/// | Category | Icon | Color | Retry? |
|
||||
/// |-------------------|-------------------------------|---------|--------|
|
||||
/// | Safety / blocked | `shield.trianglebadge.excl…` | Orange | No |
|
||||
/// | Confirmation | `hand.raised.fill` | Yellow | Yes* |
|
||||
/// | LLM failure | `brain` | Purple | Yes |
|
||||
/// | Schema / DB | `cylinder.split.1x2` | Red | No |
|
||||
/// | Recoverable SQL | `arrow.clockwise` | Blue | Yes |
|
||||
/// | Generic | `exclamationmark.triangle` | Red | No |
|
||||
///
|
||||
/// *Confirmation retry triggers the confirm callback, not a standard retry.
|
||||
///
|
||||
/// Usage:
|
||||
/// ```swift
|
||||
/// ErrorMessageView(
|
||||
/// error: .llmTimeout(seconds: 30),
|
||||
/// onRetry: { /* resend the message */ }
|
||||
/// )
|
||||
/// ```
|
||||
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||
public struct ErrorMessageView: View {
|
||||
/// The error to display. When `nil`, the view falls back to the raw message.
|
||||
private let error: SwiftDBAIError?
|
||||
|
||||
/// The raw error message string (used as fallback when error is nil).
|
||||
private let message: String
|
||||
|
||||
/// Called when the user taps the retry button. `nil` hides the button.
|
||||
private let onRetry: (@Sendable () async -> Void)?
|
||||
|
||||
/// Called when the user confirms a destructive operation.
|
||||
private let onConfirm: (@Sendable () async -> Void)?
|
||||
|
||||
@State private var isRetrying = false
|
||||
|
||||
// MARK: - Initializers
|
||||
|
||||
/// Creates an ErrorMessageView from a typed ``SwiftDBAIError``.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - error: The ``SwiftDBAIError`` to display.
|
||||
/// - onRetry: An optional async closure invoked when the user taps retry.
|
||||
/// - onConfirm: An optional async closure invoked when the user confirms
|
||||
/// a destructive operation (only relevant for `.confirmationRequired`).
|
||||
public init(
|
||||
error: SwiftDBAIError,
|
||||
onRetry: (@Sendable () async -> Void)? = nil,
|
||||
onConfirm: (@Sendable () async -> Void)? = nil
|
||||
) {
|
||||
self.error = error
|
||||
self.message = error.localizedDescription
|
||||
self.onRetry = onRetry
|
||||
self.onConfirm = onConfirm
|
||||
}
|
||||
|
||||
/// Creates an ErrorMessageView from a ``ChatMessage``.
|
||||
///
|
||||
/// Extracts the typed error if available, otherwise falls back to the
|
||||
/// message content string.
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - message: The chat message with role `.error`.
|
||||
/// - onRetry: An optional async closure invoked when the user taps retry.
|
||||
/// - onConfirm: An optional async closure invoked when the user confirms
|
||||
/// a destructive operation.
|
||||
public init(
|
||||
chatMessage: ChatMessage,
|
||||
onRetry: (@Sendable () async -> Void)? = nil,
|
||||
onConfirm: (@Sendable () async -> Void)? = nil
|
||||
) {
|
||||
self.error = chatMessage.error
|
||||
self.message = chatMessage.content
|
||||
self.onRetry = onRetry
|
||||
self.onConfirm = onConfirm
|
||||
}
|
||||
|
||||
/// Creates an ErrorMessageView from a plain string (untyped fallback).
|
||||
///
|
||||
/// - Parameters:
|
||||
/// - message: The error message string.
|
||||
/// - onRetry: An optional async closure invoked when the user taps retry.
|
||||
public init(
|
||||
message: String,
|
||||
onRetry: (@Sendable () async -> Void)? = nil
|
||||
) {
|
||||
self.error = nil
|
||||
self.message = message
|
||||
self.onRetry = onRetry
|
||||
self.onConfirm = nil
|
||||
}
|
||||
|
||||
// MARK: - Body
|
||||
|
||||
public var body: some View {
|
||||
VStack(alignment: .leading, spacing: 10) {
|
||||
// Icon + message row
|
||||
HStack(alignment: .firstTextBaseline, spacing: 8) {
|
||||
Image(systemName: iconName)
|
||||
.foregroundStyle(iconColor)
|
||||
.font(.callout)
|
||||
.accessibilityHidden(true)
|
||||
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
if let title = errorTitle {
|
||||
Text(title)
|
||||
.font(.callout.weight(.semibold))
|
||||
.foregroundStyle(iconColor)
|
||||
}
|
||||
|
||||
Text(message)
|
||||
.font(.body)
|
||||
.foregroundStyle(.primary)
|
||||
.textSelection(.enabled)
|
||||
.fixedSize(horizontal: false, vertical: true)
|
||||
|
||||
if let hint = recoveryHint {
|
||||
Text(hint)
|
||||
.font(.caption)
|
||||
.foregroundStyle(.secondary)
|
||||
.fixedSize(horizontal: false, vertical: true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Action buttons
|
||||
if showRetryButton || showConfirmButton {
|
||||
HStack(spacing: 12) {
|
||||
if showConfirmButton {
|
||||
confirmButton
|
||||
}
|
||||
if showRetryButton {
|
||||
retryButton
|
||||
}
|
||||
}
|
||||
.padding(.leading, 26) // Align with text (icon width + spacing)
|
||||
}
|
||||
}
|
||||
.accessibilityElement(children: .combine)
|
||||
.accessibilityLabel(accessibilityDescription)
|
||||
}
|
||||
|
||||
// MARK: - Action Buttons
|
||||
|
||||
@ViewBuilder
|
||||
private var retryButton: some View {
|
||||
Button {
|
||||
guard !isRetrying else { return }
|
||||
isRetrying = true
|
||||
Task {
|
||||
await onRetry?()
|
||||
isRetrying = false
|
||||
}
|
||||
} label: {
|
||||
HStack(spacing: 4) {
|
||||
if isRetrying {
|
||||
ProgressView()
|
||||
.controlSize(.mini)
|
||||
} else {
|
||||
Image(systemName: "arrow.clockwise")
|
||||
.font(.caption)
|
||||
}
|
||||
Text(retryButtonLabel)
|
||||
.font(.caption.weight(.medium))
|
||||
}
|
||||
.padding(.horizontal, 10)
|
||||
.padding(.vertical, 6)
|
||||
.background(iconColor.opacity(0.12))
|
||||
.foregroundStyle(iconColor)
|
||||
.clipShape(Capsule())
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
.disabled(isRetrying)
|
||||
}
|
||||
|
||||
@ViewBuilder
|
||||
private var confirmButton: some View {
|
||||
Button {
|
||||
Task {
|
||||
await onConfirm?()
|
||||
}
|
||||
} label: {
|
||||
HStack(spacing: 4) {
|
||||
Image(systemName: "checkmark.circle")
|
||||
.font(.caption)
|
||||
Text("Confirm")
|
||||
.font(.caption.weight(.medium))
|
||||
}
|
||||
.padding(.horizontal, 10)
|
||||
.padding(.vertical, 6)
|
||||
.background(Color.orange.opacity(0.12))
|
||||
.foregroundStyle(.orange)
|
||||
.clipShape(Capsule())
|
||||
}
|
||||
.buttonStyle(.plain)
|
||||
}
|
||||
|
||||
// MARK: - Error Classification
|
||||
|
||||
private var errorCategory: ErrorCategory {
|
||||
guard let error else { return .generic }
|
||||
|
||||
if error.requiresUserAction {
|
||||
return .confirmation
|
||||
}
|
||||
if error.isSafetyError {
|
||||
return .safety
|
||||
}
|
||||
if error.isRecoverable {
|
||||
return .recoverable
|
||||
}
|
||||
|
||||
switch error {
|
||||
case .llmFailure, .llmResponseUnparseable, .llmTimeout:
|
||||
return .llm
|
||||
case .schemaIntrospectionFailed, .emptySchema, .databaseError, .queryTimedOut:
|
||||
return .database
|
||||
case .configurationError:
|
||||
return .configuration
|
||||
default:
|
||||
return .generic
|
||||
}
|
||||
}
|
||||
|
||||
private enum ErrorCategory {
|
||||
case safety
|
||||
case confirmation
|
||||
case llm
|
||||
case database
|
||||
case recoverable
|
||||
case configuration
|
||||
case generic
|
||||
}
|
||||
|
||||
// MARK: - Visual Properties
|
||||
|
||||
private var iconName: String {
|
||||
switch errorCategory {
|
||||
case .safety:
|
||||
return "shield.trianglebadge.exclamationmark.fill"
|
||||
case .confirmation:
|
||||
return "hand.raised.fill"
|
||||
case .llm:
|
||||
return "brain"
|
||||
case .database:
|
||||
return "cylinder.split.1x2"
|
||||
case .recoverable:
|
||||
return "arrow.clockwise"
|
||||
case .configuration:
|
||||
return "gearshape.triangle.fill"
|
||||
case .generic:
|
||||
return "exclamationmark.triangle.fill"
|
||||
}
|
||||
}
|
||||
|
||||
private var iconColor: Color {
|
||||
switch errorCategory {
|
||||
case .safety:
|
||||
return .orange
|
||||
case .confirmation:
|
||||
return .yellow
|
||||
case .llm:
|
||||
return .purple
|
||||
case .database:
|
||||
return .red
|
||||
case .recoverable:
|
||||
return .blue
|
||||
case .configuration:
|
||||
return .gray
|
||||
case .generic:
|
||||
return .red
|
||||
}
|
||||
}
|
||||
|
||||
private var errorTitle: String? {
|
||||
switch errorCategory {
|
||||
case .safety:
|
||||
return "Operation Blocked"
|
||||
case .confirmation:
|
||||
return "Confirmation Required"
|
||||
case .llm:
|
||||
return "AI Provider Error"
|
||||
case .database:
|
||||
return "Database Error"
|
||||
case .recoverable:
|
||||
return "Query Issue"
|
||||
case .configuration:
|
||||
return "Configuration Error"
|
||||
case .generic:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
private var recoveryHint: String? {
|
||||
guard let error else { return nil }
|
||||
|
||||
switch error {
|
||||
case .noSQLGenerated, .llmResponseUnparseable:
|
||||
return "Try rephrasing your question."
|
||||
case .tableNotFound:
|
||||
return "Check that you're referring to an existing table."
|
||||
case .columnNotFound:
|
||||
return "Verify the column name matches your schema."
|
||||
case .invalidSQL:
|
||||
return "The AI generated an invalid query. Try asking differently."
|
||||
case .llmTimeout:
|
||||
return "The AI took too long. Try a simpler question."
|
||||
case .llmFailure:
|
||||
return "The AI service may be temporarily unavailable."
|
||||
case .emptySchema:
|
||||
return "Add some tables to your database first."
|
||||
case .queryTimedOut:
|
||||
return "Try a simpler query or add database indexes."
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Button Visibility
|
||||
|
||||
private var showRetryButton: Bool {
|
||||
guard onRetry != nil else { return false }
|
||||
return errorCategory == .recoverable || errorCategory == .llm
|
||||
}
|
||||
|
||||
private var showConfirmButton: Bool {
|
||||
guard onConfirm != nil else { return false }
|
||||
return errorCategory == .confirmation
|
||||
}
|
||||
|
||||
private var retryButtonLabel: String {
|
||||
switch errorCategory {
|
||||
case .llm:
|
||||
return "Retry"
|
||||
case .recoverable:
|
||||
return "Try Again"
|
||||
default:
|
||||
return "Retry"
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Accessibility
|
||||
|
||||
private var accessibilityDescription: String {
|
||||
let prefix = errorTitle.map { "\($0): " } ?? "Error: "
|
||||
return prefix + message
|
||||
}
|
||||
}
|
||||
205
Sources/SwiftDBAI/Views/MessageBubbleView.swift
Normal file
205
Sources/SwiftDBAI/Views/MessageBubbleView.swift
Normal file
@@ -0,0 +1,205 @@
|
||||
// MessageBubbleView.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// Renders a single ChatMessage as a styled bubble with optional
|
||||
// data table and SQL disclosure for query results.
|
||||
|
||||
import SwiftUI
|
||||
import Charts
|
||||
|
||||
/// Renders a single `ChatMessage` in the chat conversation.
|
||||
///
|
||||
/// - **User messages** display right-aligned with an accent-colored background
|
||||
/// and white text, using a continuous rounded rectangle shape.
|
||||
/// - **Assistant messages** display left-aligned with a secondary background.
|
||||
/// The natural language text summary is the primary content, rendered with
|
||||
/// full `.body` font and `.primary` foreground for readability.
|
||||
/// If the message contains a `queryResult` with tabular data, a
|
||||
/// `ScrollableDataTableView` is automatically embedded below the summary.
|
||||
/// An optional SQL disclosure group shows the generated query.
|
||||
/// - **Error messages** display left-aligned with a red-tinted background
|
||||
/// and an exclamation mark icon.
|
||||
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||
struct MessageBubbleView: View {
|
||||
let message: ChatMessage
|
||||
|
||||
/// Whether to show the SQL query in a disclosure group.
|
||||
var showSQL: Bool = true
|
||||
|
||||
/// Maximum height for the data table before it scrolls.
|
||||
var maxTableHeight: CGFloat = 300
|
||||
|
||||
/// Called when the user taps "Retry" on a recoverable error.
|
||||
var onRetry: (@Sendable () async -> Void)?
|
||||
|
||||
/// Called when the user confirms a destructive operation.
|
||||
var onConfirm: (@Sendable () async -> Void)?
|
||||
|
||||
var body: some View {
|
||||
HStack(alignment: .top) {
|
||||
if message.role == .user { Spacer(minLength: 48) }
|
||||
|
||||
bubbleContent
|
||||
.padding(.horizontal, 14)
|
||||
.padding(.vertical, 10)
|
||||
.background(bubbleBackground)
|
||||
.clipShape(bubbleShape)
|
||||
|
||||
if message.role != .user { Spacer(minLength: 48) }
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Bubble Content
|
||||
|
||||
@ViewBuilder
|
||||
private var bubbleContent: some View {
|
||||
switch message.role {
|
||||
case .user:
|
||||
userContent
|
||||
case .assistant:
|
||||
assistantContent
|
||||
case .error:
|
||||
errorContent
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - User Content
|
||||
|
||||
@ViewBuilder
|
||||
private var userContent: some View {
|
||||
Text(message.content)
|
||||
.font(.body)
|
||||
.foregroundStyle(.white)
|
||||
.textSelection(.enabled)
|
||||
}
|
||||
|
||||
// MARK: - Assistant Content (Text Summary + Data Table + SQL)
|
||||
|
||||
@ViewBuilder
|
||||
private var assistantContent: some View {
|
||||
VStack(alignment: .leading, spacing: 10) {
|
||||
// Natural language text summary — primary content
|
||||
Text(message.content)
|
||||
.font(.body)
|
||||
.foregroundStyle(.primary)
|
||||
.textSelection(.enabled)
|
||||
.fixedSize(horizontal: false, vertical: true)
|
||||
|
||||
// Data table — automatically shown when queryResult has tabular data
|
||||
if let queryResult = message.queryResult,
|
||||
!queryResult.columns.isEmpty,
|
||||
!queryResult.rows.isEmpty {
|
||||
dataTableSection(for: queryResult)
|
||||
}
|
||||
|
||||
// SQL disclosure — collapsed by default for transparency
|
||||
if showSQL, let sql = message.sql {
|
||||
sqlDisclosure(sql: sql)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Error Content
|
||||
|
||||
@ViewBuilder
|
||||
private var errorContent: some View {
|
||||
ErrorMessageView(
|
||||
chatMessage: message,
|
||||
onRetry: onRetry,
|
||||
onConfirm: onConfirm
|
||||
)
|
||||
}
|
||||
|
||||
/// Maximum height for the chart section.
|
||||
var maxChartHeight: CGFloat = 250
|
||||
|
||||
/// Whether to show auto-detected charts. Defaults to `true`.
|
||||
var showCharts: Bool = true
|
||||
|
||||
// MARK: - Chart Detection
|
||||
|
||||
/// The shared detector used for chart eligibility checks.
|
||||
private static let chartDetector = ChartDataDetector()
|
||||
|
||||
// MARK: - Data Table Section
|
||||
|
||||
@ViewBuilder
|
||||
private func dataTableSection(for queryResult: QueryResult) -> some View {
|
||||
let dataTable = DataTable(queryResult)
|
||||
|
||||
VStack(alignment: .leading, spacing: 8) {
|
||||
// Chart — automatically shown when ChartDataDetector finds eligible data
|
||||
if showCharts {
|
||||
chartSection(for: dataTable)
|
||||
}
|
||||
|
||||
Divider()
|
||||
|
||||
ScrollableDataTableView(
|
||||
dataTable: dataTable,
|
||||
showAlternatingRows: true,
|
||||
showFooter: true
|
||||
)
|
||||
.frame(maxHeight: maxTableHeight)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Chart Section
|
||||
|
||||
@ViewBuilder
|
||||
private func chartSection(for dataTable: DataTable) -> some View {
|
||||
let detector = Self.chartDetector
|
||||
if detector.detect(dataTable) != nil {
|
||||
VStack(alignment: .leading, spacing: 4) {
|
||||
ChartResultView(dataTable: dataTable, detector: detector)
|
||||
.frame(maxHeight: maxChartHeight)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - SQL Disclosure
|
||||
|
||||
@ViewBuilder
|
||||
private func sqlDisclosure(sql: String) -> some View {
|
||||
DisclosureGroup {
|
||||
Text(sql)
|
||||
.font(.system(.caption, design: .monospaced))
|
||||
.foregroundStyle(.secondary)
|
||||
.textSelection(.enabled)
|
||||
.padding(8)
|
||||
.frame(maxWidth: .infinity, alignment: .leading)
|
||||
.background(Color.primary.opacity(0.04))
|
||||
.clipShape(RoundedRectangle(cornerRadius: 6))
|
||||
} label: {
|
||||
Label("SQL Query", systemImage: "chevron.left.forwardslash.chevron.right")
|
||||
.font(.caption)
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Styling Helpers
|
||||
|
||||
private var bubbleShape: RoundedRectangle {
|
||||
RoundedRectangle(cornerRadius: 16, style: .continuous)
|
||||
}
|
||||
|
||||
@ViewBuilder
|
||||
private var bubbleBackground: some View {
|
||||
switch message.role {
|
||||
case .user:
|
||||
Color.accentColor
|
||||
case .assistant:
|
||||
Self.assistantBackgroundColor
|
||||
case .error:
|
||||
Color.red.opacity(0.1)
|
||||
}
|
||||
}
|
||||
|
||||
private static var assistantBackgroundColor: Color {
|
||||
#if os(macOS)
|
||||
Color(nsColor: .controlBackgroundColor)
|
||||
#else
|
||||
Color(uiColor: .secondarySystemGroupedBackground)
|
||||
#endif
|
||||
}
|
||||
}
|
||||
267
Sources/SwiftDBAI/Views/ScrollableDataTableView.swift
Normal file
267
Sources/SwiftDBAI/Views/ScrollableDataTableView.swift
Normal file
@@ -0,0 +1,267 @@
|
||||
// ScrollableDataTableView.swift
|
||||
// SwiftDBAI
|
||||
//
|
||||
// A SwiftUI view that renders a DataTable with horizontal and vertical
|
||||
// scrolling, styled column headers, and row cells.
|
||||
|
||||
import SwiftUI
|
||||
|
||||
/// A scrollable table view that renders a `DataTable` with column headers
|
||||
/// and row cells, supporting both horizontal and vertical scrolling.
|
||||
///
|
||||
/// Usage:
|
||||
/// ```swift
|
||||
/// ScrollableDataTableView(dataTable: myDataTable)
|
||||
/// ```
|
||||
///
|
||||
/// The view automatically sizes columns based on content, highlights
|
||||
/// alternating rows for readability, and right-aligns numeric columns.
|
||||
public struct ScrollableDataTableView: View {
|
||||
/// The data table to render.
|
||||
public let dataTable: DataTable
|
||||
|
||||
/// Minimum width for each column in points.
|
||||
public var minimumColumnWidth: CGFloat
|
||||
|
||||
/// Maximum width for each column in points.
|
||||
public var maximumColumnWidth: CGFloat
|
||||
|
||||
/// Whether to show alternating row backgrounds.
|
||||
public var showAlternatingRows: Bool
|
||||
|
||||
/// Whether to show the row count footer.
|
||||
public var showFooter: Bool
|
||||
|
||||
public init(
|
||||
dataTable: DataTable,
|
||||
minimumColumnWidth: CGFloat = 80,
|
||||
maximumColumnWidth: CGFloat = 250,
|
||||
showAlternatingRows: Bool = true,
|
||||
showFooter: Bool = true
|
||||
) {
|
||||
self.dataTable = dataTable
|
||||
self.minimumColumnWidth = minimumColumnWidth
|
||||
self.maximumColumnWidth = maximumColumnWidth
|
||||
self.showAlternatingRows = showAlternatingRows
|
||||
self.showFooter = showFooter
|
||||
}
|
||||
|
||||
public var body: some View {
|
||||
if dataTable.isEmpty {
|
||||
emptyView
|
||||
} else {
|
||||
tableContent
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Empty State
|
||||
|
||||
@ViewBuilder
|
||||
private var emptyView: some View {
|
||||
VStack(spacing: 8) {
|
||||
Image(systemName: "tablecells")
|
||||
.font(.largeTitle)
|
||||
.foregroundStyle(.secondary)
|
||||
Text("No results")
|
||||
.font(.headline)
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
.frame(maxWidth: .infinity, minHeight: 100)
|
||||
}
|
||||
|
||||
// MARK: - Table Content
|
||||
|
||||
@ViewBuilder
|
||||
private var tableContent: some View {
|
||||
VStack(alignment: .leading, spacing: 0) {
|
||||
ScrollView([.horizontal, .vertical]) {
|
||||
LazyVStack(alignment: .leading, spacing: 0, pinnedViews: [.sectionHeaders]) {
|
||||
Section {
|
||||
ForEach(dataTable.rows) { row in
|
||||
rowView(row)
|
||||
}
|
||||
} header: {
|
||||
headerRow
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if showFooter {
|
||||
footerView
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Header
|
||||
|
||||
@ViewBuilder
|
||||
private var headerRow: some View {
|
||||
HStack(spacing: 0) {
|
||||
ForEach(dataTable.columns) { column in
|
||||
Text(column.name)
|
||||
.font(.caption.weight(.semibold))
|
||||
.foregroundStyle(.primary)
|
||||
.lineLimit(1)
|
||||
.frame(
|
||||
width: columnWidth(for: column),
|
||||
alignment: alignment(for: column)
|
||||
)
|
||||
.padding(.horizontal, 8)
|
||||
.padding(.vertical, 6)
|
||||
|
||||
if column.index < dataTable.columnCount - 1 {
|
||||
Divider()
|
||||
}
|
||||
}
|
||||
}
|
||||
.background(.bar)
|
||||
.overlay(alignment: .bottom) {
|
||||
Divider()
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Row
|
||||
|
||||
@ViewBuilder
|
||||
private func rowView(_ row: DataTable.Row) -> some View {
|
||||
HStack(spacing: 0) {
|
||||
ForEach(dataTable.columns) { column in
|
||||
cellView(value: row[column.index], column: column)
|
||||
|
||||
if column.index < dataTable.columnCount - 1 {
|
||||
Divider()
|
||||
}
|
||||
}
|
||||
}
|
||||
.background(rowBackground(for: row))
|
||||
.overlay(alignment: .bottom) {
|
||||
Divider()
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Cell
|
||||
|
||||
@ViewBuilder
|
||||
private func cellView(value: QueryResult.Value, column: DataTable.Column) -> some View {
|
||||
Group {
|
||||
switch value {
|
||||
case .null:
|
||||
Text("NULL")
|
||||
.foregroundStyle(.tertiary)
|
||||
.italic()
|
||||
case .blob(let data):
|
||||
Text("<\(data.count) bytes>")
|
||||
.foregroundStyle(.secondary)
|
||||
default:
|
||||
Text(value.stringValue)
|
||||
.foregroundStyle(.primary)
|
||||
}
|
||||
}
|
||||
.font(.caption)
|
||||
.lineLimit(2)
|
||||
.frame(
|
||||
width: columnWidth(for: column),
|
||||
alignment: alignment(for: column)
|
||||
)
|
||||
.padding(.horizontal, 8)
|
||||
.padding(.vertical, 4)
|
||||
}
|
||||
|
||||
// MARK: - Footer
|
||||
|
||||
@ViewBuilder
|
||||
private var footerView: some View {
|
||||
HStack {
|
||||
Text("\(dataTable.rowCount) row\(dataTable.rowCount == 1 ? "" : "s")")
|
||||
.font(.caption2)
|
||||
.foregroundStyle(.secondary)
|
||||
Spacer()
|
||||
if dataTable.executionTime > 0 {
|
||||
Text(String(format: "%.1f ms", dataTable.executionTime * 1000))
|
||||
.font(.caption2)
|
||||
.foregroundStyle(.secondary)
|
||||
}
|
||||
}
|
||||
.padding(.horizontal, 8)
|
||||
.padding(.vertical, 4)
|
||||
.background(.bar)
|
||||
}
|
||||
|
||||
// MARK: - Layout Helpers
|
||||
|
||||
/// Determines column width based on the column name length and type.
|
||||
private func columnWidth(for column: DataTable.Column) -> CGFloat {
|
||||
// Estimate based on header text length
|
||||
let headerWidth = CGFloat(column.name.count) * 8 + 16
|
||||
|
||||
// Sample some row values to estimate content width
|
||||
let sampleRows = dataTable.rows.prefix(20)
|
||||
let maxContentWidth = sampleRows.reduce(CGFloat(0)) { maxWidth, row in
|
||||
let value = row[column.index]
|
||||
let textLength = CGFloat(value.stringValue.count) * 7
|
||||
return max(maxWidth, textLength)
|
||||
}
|
||||
|
||||
let estimatedWidth = max(headerWidth, maxContentWidth) + 16
|
||||
return min(max(estimatedWidth, minimumColumnWidth), maximumColumnWidth)
|
||||
}
|
||||
|
||||
/// Returns the alignment for a column based on its inferred type.
|
||||
private func alignment(for column: DataTable.Column) -> Alignment {
|
||||
switch column.inferredType {
|
||||
case .integer, .real:
|
||||
return .trailing
|
||||
default:
|
||||
return .leading
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the background color for alternating rows.
|
||||
@ViewBuilder
|
||||
private func rowBackground(for row: DataTable.Row) -> some View {
|
||||
if showAlternatingRows && row.id.isMultiple(of: 2) {
|
||||
Color.clear
|
||||
} else if showAlternatingRows {
|
||||
Color.primary.opacity(0.03)
|
||||
} else {
|
||||
Color.clear
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Preview Support
|
||||
|
||||
#if DEBUG
|
||||
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||
#Preview("Data Table") {
|
||||
let columns: [DataTable.Column] = [
|
||||
.init(name: "id", index: 0, inferredType: .integer),
|
||||
.init(name: "name", index: 1, inferredType: .text),
|
||||
.init(name: "score", index: 2, inferredType: .real),
|
||||
]
|
||||
let rows: [DataTable.Row] = (0..<25).map { i in
|
||||
DataTable.Row(
|
||||
id: i,
|
||||
values: [
|
||||
.integer(Int64(i + 1)),
|
||||
.text("Item \(i + 1)"),
|
||||
.real(Double.random(in: 1.0...100.0)),
|
||||
],
|
||||
columnNames: ["id", "name", "score"]
|
||||
)
|
||||
}
|
||||
let table = DataTable(columns: columns, rows: rows, sql: "SELECT * FROM items", executionTime: 0.023)
|
||||
|
||||
ScrollableDataTableView(dataTable: table)
|
||||
.frame(height: 400)
|
||||
.padding()
|
||||
}
|
||||
|
||||
@available(iOS 17.0, macOS 14.0, visionOS 1.0, *)
|
||||
#Preview("Empty Table") {
|
||||
let table = DataTable(columns: [], rows: [], sql: "", executionTime: 0)
|
||||
ScrollableDataTableView(dataTable: table)
|
||||
.frame(height: 200)
|
||||
.padding()
|
||||
}
|
||||
#endif
|
||||
Reference in New Issue
Block a user