Files
SwiftDBAI/Tests/SwiftDBAITests/DatabaseToolTests.swift
Krishna Kumar fcd752466a SwiftDBAI: natural language queries for any SQLite database
Drop-in SwiftUI chat view, headless ChatEngine, LLM-agnostic via
AnyLanguageModel. Read-only by default with configurable allowlists.
Robust SQL parser with 63 tests. Includes demo app with GitHub stars dataset.
2026-04-05 17:11:12 -05:00

318 lines
11 KiB
Swift

// DatabaseToolTests.swift
// SwiftDBAI
import Testing
import Foundation
import GRDB
@testable import SwiftDBAI
@Suite("DatabaseTool")
struct DatabaseToolTests {
// MARK: - Helper
/// Creates an in-memory database with sample data for testing.
private func makeTestDatabase() throws -> DatabaseQueue {
let db = try DatabaseQueue(configuration: {
var config = Configuration()
config.foreignKeysEnabled = true
return config
}())
try db.write { db in
try db.execute(sql: """
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
email TEXT UNIQUE
);
""")
try db.execute(sql: """
CREATE TABLE posts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id),
title TEXT NOT NULL,
body TEXT,
created_at TEXT DEFAULT CURRENT_TIMESTAMP
);
""")
try db.execute(sql: """
CREATE INDEX idx_posts_user ON posts(user_id);
""")
// Insert sample data
try db.execute(sql: "INSERT INTO users (name, email) VALUES ('Alice', 'alice@example.com')")
try db.execute(sql: "INSERT INTO users (name, email) VALUES ('Bob', 'bob@example.com')")
try db.execute(sql: "INSERT INTO posts (user_id, title, body) VALUES (1, 'Hello World', 'First post')")
try db.execute(sql: "INSERT INTO posts (user_id, title, body) VALUES (1, 'Second Post', 'More content')")
try db.execute(sql: "INSERT INTO posts (user_id, title, body) VALUES (2, 'Bob Post', 'Bob writes')")
}
return db
}
// MARK: - Creation
@Test("Creates tool from database connection")
func testCreationFromDatabase() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
#expect(tool.name == "execute_sql")
#expect(!tool.description.isEmpty)
}
@Test("Creates tool from database path")
func testCreationFromPath() async throws {
let tmpDir = FileManager.default.temporaryDirectory
let dbPath = tmpDir.appendingPathComponent("test_\(UUID().uuidString).sqlite").path
defer { try? FileManager.default.removeItem(atPath: dbPath) }
// Create a database at the path
let dbQueue = try DatabaseQueue(path: dbPath)
try await dbQueue.write { db in
try db.execute(sql: "CREATE TABLE test (id INTEGER PRIMARY KEY)")
}
let tool = try await DatabaseTool(databasePath: dbPath)
#expect(tool.name == "execute_sql")
}
// MARK: - Schema Context
@Test("Schema context contains table info")
func testSchemaContext() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let context = tool.schemaContext
#expect(context.contains("users"))
#expect(context.contains("posts"))
#expect(context.contains("name"))
#expect(context.contains("email"))
}
@Test("System prompt snippet contains schema")
func testSystemPromptSnippet() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let snippet = tool.systemPromptSnippet
#expect(snippet.contains("users"))
#expect(snippet.contains("posts"))
#expect(snippet.contains("execute_sql"))
#expect(snippet.contains("SELECT"))
}
// MARK: - SQL Execution
@Test("Executes valid SELECT query")
func testExecuteSelect() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let result = try tool.execute(sql: "SELECT name, email FROM users ORDER BY name")
#expect(result.rowCount == 2)
#expect(result.columns == ["name", "email"])
#expect(result.rows[0]["name"] == "Alice")
#expect(result.rows[1]["name"] == "Bob")
}
@Test("Executes query with JOIN")
func testExecuteJoin() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let result = try tool.execute(sql: """
SELECT u.name, COUNT(p.id) as post_count
FROM users u
JOIN posts p ON p.user_id = u.id
GROUP BY u.name
ORDER BY u.name
""")
#expect(result.rowCount == 2)
#expect(result.rows[0]["name"] == "Alice")
#expect(result.rows[0]["post_count"] == "2")
}
@Test("Rejects INSERT with read-only allowlist")
func testRejectInsert() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db, allowlist: .readOnly)
#expect(throws: SQLParsingError.self) {
try tool.execute(sql: "INSERT INTO users (name, email) VALUES ('Eve', 'eve@example.com')")
}
}
@Test("Rejects DELETE with read-only allowlist")
func testRejectDelete() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db, allowlist: .readOnly)
#expect(throws: SQLParsingError.self) {
try tool.execute(sql: "DELETE FROM users WHERE id = 1")
}
}
@Test("Rejects DROP as dangerous operation")
func testRejectDrop() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db, allowlist: .unrestricted)
#expect(throws: SQLParsingError.self) {
try tool.execute(sql: "DROP TABLE users")
}
}
@Test("Executes raw query returning QueryResult")
func testExecuteRaw() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let result = try tool.executeRaw(sql: "SELECT COUNT(*) as cnt FROM users")
#expect(result.rowCount == 1)
#expect(result.columns == ["cnt"])
if case .integer(let count) = result.rows[0]["cnt"] {
#expect(count == 2)
} else {
Issue.record("Expected integer value")
}
}
// MARK: - ToolResult Formatting
@Test("ToolResult JSON serialization")
func testToolResultJSON() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let result = try tool.execute(sql: "SELECT name FROM users ORDER BY name LIMIT 1")
let json = result.jsonString
#expect(json.contains("\"columns\""))
#expect(json.contains("\"rows\""))
#expect(json.contains("\"row_count\""))
#expect(json.contains("Alice"))
// Verify it is valid JSON
let data = json.data(using: .utf8)!
let parsed = try JSONSerialization.jsonObject(with: data) as! [String: Any]
#expect(parsed["row_count"] as? Int == 1)
}
@Test("ToolResult markdown table formatting")
func testToolResultMarkdown() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let result = try tool.execute(sql: "SELECT name, email FROM users ORDER BY name")
let md = result.markdownTable
#expect(md.contains("| name | email |"))
#expect(md.contains("| --- | --- |"))
#expect(md.contains("| Alice | alice@example.com |"))
#expect(md.contains("| Bob | bob@example.com |"))
}
@Test("ToolResult markdown table with empty result")
func testToolResultMarkdownEmpty() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let result = try tool.execute(sql: "SELECT name FROM users WHERE name = 'Nobody'")
#expect(result.markdownTable == "_No results._")
}
@Test("ToolResult text summary")
func testToolResultTextSummary() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let result = try tool.execute(sql: "SELECT name FROM users")
let summary = result.textSummary
#expect(summary.contains("2 rows"))
#expect(summary.contains("name"))
}
@Test("ToolResult text summary with empty result")
func testToolResultTextSummaryEmpty() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let result = try tool.execute(sql: "SELECT name FROM users WHERE 1 = 0")
#expect(result.textSummary.contains("no results"))
}
// MARK: - Parameters Schema
@Test("Parameters schema has correct structure")
func testParametersSchema() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let schema = tool.parametersSchema
#expect(schema["type"] as? String == "object")
let properties = schema["properties"] as? [String: Any]
#expect(properties != nil)
let sqlProp = properties?["sql"] as? [String: Any]
#expect(sqlProp?["type"] as? String == "string")
let required = schema["required"] as? [String]
#expect(required == ["sql"])
}
// MARK: - OpenAI Function Definition
@Test("OpenAI function definition has correct format")
func testOpenAIFunctionDefinition() async throws {
let db = try makeTestDatabase()
let tool = try await DatabaseTool(database: db)
let def = tool.openAIFunctionDefinition
#expect(def["type"] as? String == "function")
let function = def["function"] as? [String: Any]
#expect(function?["name"] as? String == "execute_sql")
#expect(function?["description"] as? String != nil)
#expect(function?["parameters"] as? [String: Any] != nil)
// Verify it can be serialized to JSON
let data = try JSONSerialization.data(withJSONObject: def)
#expect(data.count > 0)
}
// MARK: - ToolResult Codable
@Test("ToolResult is Codable")
func testToolResultCodable() throws {
let result = ToolResult(
columns: ["name", "age"],
rows: [["name": "Alice", "age": "30"]],
rowCount: 1,
executionTime: 0.005,
sql: "SELECT name, age FROM users"
)
let encoder = JSONEncoder()
let data = try encoder.encode(result)
let decoder = JSONDecoder()
let decoded = try decoder.decode(ToolResult.self, from: data)
#expect(decoded.columns == result.columns)
#expect(decoded.rows == result.rows)
#expect(decoded.rowCount == result.rowCount)
#expect(decoded.sql == result.sql)
}
}