diff --git a/Cargo.lock b/Cargo.lock index f8c6ac5..91882a1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3894,6 +3894,7 @@ dependencies = [ name = "tauri-plugin-sqlite" version = "0.1.0" dependencies = [ + "base64 0.22.1", "futures-core", "indexmap 2.12.1", "log", diff --git a/Cargo.toml b/Cargo.toml index 7fd6279..c0c826c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ futures-core = "0.3.31" time = "0.3.44" tokio = { version = "1.48.0", features = ["sync"] } indexmap = { version = "2.12.1", features = ["serde"] } +base64 = "0.22.1" # SQLx for types and queries (time feature enables datetime type decoding) sqlx = { version = "0.8.6", features = ["sqlite", "json", "time", "runtime-tokio"] } diff --git a/README.md b/README.md index 2d9596b..582e0fc 100644 --- a/README.md +++ b/README.md @@ -185,12 +185,18 @@ type SqlValue = string | number | boolean | null | Uint8Array Supported SQLite types: * **TEXT** - `string` values (also used for DATE, TIME, DATETIME) - * **INTEGER** - `number` values (integers) + * **INTEGER** - `number` values (integers, preserved up to i64 range) * **REAL** - `number` values (floating point) * **BOOLEAN** - `boolean` values * **NULL** - `null` value * **BLOB** - `Uint8Array` for binary data +> **Note:** JavaScript's `number` type can safely represent integers up to +> ±2^53 - 1 (±9,007,199,254,740,991). The plugin preserves integer precision by +> binding integers as SQLite's INTEGER type (i64). For values within the i64 +> range (-9,223,372,036,854,775,808 to 9,223,372,036,854,775,807), full precision +> is maintained. Values outside this range may lose precision. + ```typescript // Example with different types await db.execute( @@ -265,7 +271,6 @@ Common error codes include: * `INVALID_PATH` - Invalid database path * `IO_ERROR` - File system error * `MIGRATION_ERROR` - Migration failed - * `READ_ONLY_QUERY_IN_EXECUTE` - Attempted to use execute() for a read-only query * `MULTIPLE_ROWS_RETURNED` - `fetchOne()` query returned multiple rows ### Executing SELECT Queries @@ -304,45 +309,35 @@ if (user) { ### Using Transactions -Transactions ensure that multiple operations either all succeed or all fail together, -maintaining data consistency: +Execute multiple database operations atomically using `executeTransaction()`. All +statements either succeed together or fail together, maintaining data consistency: ```typescript -// Begin a transaction -await db.beginTransaction(); - -try { - // Execute multiple operations atomically - await db.execute( - 'INSERT INTO users (name, email) VALUES ($1, $2)', - ['Alice', 'alice@example.com'] - ); - - await db.execute( - 'INSERT INTO audit_log (action, user) VALUES ($1, $2)', - ['user_created', 'Alice'] - ); - - // Commit if all operations succeed - await db.commitTransaction(); - console.log('Transaction completed successfully'); - -} catch (error) { - // Rollback if any operation fails - await db.rollbackTransaction(); - console.error('Transaction failed, rolled back:', error); - throw error; -} -``` - -**Important Notes:** - - * All operations between `beginTransaction()` and - `commitTransaction()`/`rollbackTransaction()` are executed as a single atomic unit - * If an error occurs, call `rollbackTransaction()` to discard all changes - * Nested transactions are not supported - * Always ensure transactions are either committed or rolled back to avoid locking - issues +// Execute multiple inserts atomically +const results = await db.executeTransaction([ + ['INSERT INTO users (name, email) VALUES ($1, $2)', ['Alice', 'alice@example.com']], + ['INSERT INTO audit_log (action, user) VALUES ($1, $2)', ['user_created', 'Alice']] +]); +console.log(`User ID: ${results[0].lastInsertId}`); +console.log(`Log rows affected: ${results[1].rowsAffected}`); + +// Bank transfer example - all operations succeed or all fail +const results = await db.executeTransaction([ + ['UPDATE accounts SET balance = balance - $1 WHERE id = $2', [100, 1]], + ['UPDATE accounts SET balance = balance + $1 WHERE id = $2', [100, 2]], + ['INSERT INTO transfers (from_id, to_id, amount) VALUES ($1, $2, $3)', [1, 2, 100]] +]); +console.log(`Transfer ID: ${results[2].lastInsertId}`); +``` + +**How it works:** + + * Automatically executes `BEGIN` before running statements + * Executes all statements in order + * Commits with `COMMIT` if all statements succeed + * Rolls back with `ROLLBACK` if any statement fails + * The write connection is held for the entire transaction, ensuring atomicity + * Errors are thrown after rollback, preserving the original error message ### Closing Connections @@ -465,8 +460,8 @@ const filtered = await db.fetchAll( ) ``` -> **Important:** Do NOT use `execute()` for read-only queries. It will return -> an error. Always use `fetchAll()` or `fetchOne()` for reads. +> **Note:** Use `execute()` and `executeTransaction()` for write operations. +> For SELECT queries, use `fetchAll()` or `fetchOne()`. ## Configuration @@ -541,7 +536,7 @@ await Database.closeAll() #### Instance Methods -##### `execute(query: string, bindValues?: unknown[]): Promise` +##### `execute(query: string, bindValues?: unknown[]): Promise` Execute a write query (INSERT, UPDATE, DELETE, CREATE, etc.). @@ -602,9 +597,9 @@ await db.remove() ### TypeScript Interfaces ```typescript -interface QueryResult { +interface WriteQueryResult { rowsAffected: number // Number of rows modified - lastInsertId: number // ROWID of last inserted row + lastInsertId: number // ROWID of last inserted row (not set for WITHOUT ROWID tables, returns 0) } interface CustomConfig { diff --git a/api-iife.js b/api-iife.js index b82593c..3ae8c51 100644 --- a/api-iife.js +++ b/api-iife.js @@ -1 +1 @@ -if("__TAURI__"in window){var __TAURI_PLUGIN_SQLITE__=function(){"use strict";async function t(t,a={},e){return window.__TAURI_INTERNALS__.invoke(t,a,e)}"function"==typeof SuppressedError&&SuppressedError;class a{constructor(t){this.path=t}static async load(e,n){const i=await t("plugin:sqlite|load",{db:e,customConfig:n});return new a(i)}static get(t){return new a(t)}async execute(a,e){const[n,i]=await t("plugin:sqlite|execute",{db:this.path,query:a,values:e??[]});return{lastInsertId:i,rowsAffected:n}}async fetchAll(a,e){return await t("plugin:sqlite|fetch_all",{db:this.path,query:a,values:e??[]})}async fetchOne(a,e){return await t("plugin:sqlite|fetch_one",{db:this.path,query:a,values:e??[]})}async beginTransaction(){await t("plugin:sqlite|begin_transaction",{db:this.path})}async commitTransaction(){await t("plugin:sqlite|commit_transaction",{db:this.path})}async rollbackTransaction(){await t("plugin:sqlite|rollback_transaction",{db:this.path})}async close(){return await t("plugin:sqlite|close",{db:this.path})}static async closeAll(){return await t("plugin:sqlite|close_all")}async remove(){return await t("plugin:sqlite|remove",{db:this.path})}}return a}();Object.defineProperty(window.__TAURI__,"sqlite",{value:__TAURI_PLUGIN_SQLITE__})} +if("__TAURI__"in window){var __TAURI_PLUGIN_SQLITE__=function(){"use strict";async function t(t,e={},s){return window.__TAURI_INTERNALS__.invoke(t,e,s)}"function"==typeof SuppressedError&&SuppressedError;class e{constructor(t){this.path=t}static async load(s,n){const a=await t("plugin:sqlite|load",{db:s,customConfig:n});return new e(a)}static get(t){return new e(t)}async execute(e,s){const[n,a]=await t("plugin:sqlite|execute",{db:this.path,query:e,values:s??[]});return{lastInsertId:a,rowsAffected:n}}async executeTransaction(e){return await t("plugin:sqlite|execute_transaction",{db:this.path,statements:e.map(([t,e])=>({query:t,values:e??[]}))})}async fetchAll(e,s){return await t("plugin:sqlite|fetch_all",{db:this.path,query:e,values:s??[]})}async fetchOne(e,s){return await t("plugin:sqlite|fetch_one",{db:this.path,query:e,values:s??[]})}async close(){return await t("plugin:sqlite|close",{db:this.path})}static async closeAll(){return await t("plugin:sqlite|close_all")}async remove(){return await t("plugin:sqlite|remove",{db:this.path})}}return e}();Object.defineProperty(window.__TAURI__,"sqlite",{value:__TAURI_PLUGIN_SQLITE__})} diff --git a/build.rs b/build.rs index 503d8fe..1d694b8 100644 --- a/build.rs +++ b/build.rs @@ -1,3 +1,4 @@ fn main() { + // TODO: Add commands to the plugin tauri_plugin::Builder::new(&["hello"]).build(); } diff --git a/guest-js/index.ts b/guest-js/index.ts index a31642f..e2d176a 100644 --- a/guest-js/index.ts +++ b/guest-js/index.ts @@ -12,10 +12,17 @@ import { invoke } from '@tauri-apps/api/core' */ export type SqlValue = string | number | boolean | null | Uint8Array -export interface QueryResult { - /** The number of rows affected by the query. */ +/** + * Result returned from write operations (INSERT, UPDATE, DELETE, etc.). + */ +export interface WriteQueryResult { + /** The number of rows affected by the write operation. */ rowsAffected: number - /** The last inserted row ID (SQLite ROWID). */ + /** + * The last inserted row ID (SQLite ROWID). + * Only set for INSERT operations on tables with a ROWID. + * Tables created with WITHOUT ROWID will not set this value (returns 0). + */ lastInsertId: number } @@ -107,11 +114,9 @@ export default class Database { * **execute** * * Executes a write query against the database (INSERT, UPDATE, DELETE, etc.). - * This method is specifically for mutations that modify data. + * This method is for mutations that modify data. * - * **Important:** Do NOT use this for SELECT queries. Use `fetchX()` instead. - * Using `execute()` for read queries will trigger an error to prevent unnecessary - * write mode initialization. + * For SELECT queries, use `fetchAll()` or `fetchOne()` instead. * * SQLite uses `$1`, `$2`, etc. for parameter binding. * @@ -132,7 +137,7 @@ export default class Database { * ); * ``` */ - async execute(query: string, bindValues?: SqlValue[]): Promise { + async execute(query: string, bindValues?: SqlValue[]): Promise { const [rowsAffected, lastInsertId] = await invoke<[number, number]>( 'plugin:sqlite|execute', { @@ -147,6 +152,50 @@ export default class Database { } } + /** + * **executeTransaction** + * + * Executes multiple write statements atomically within a transaction. + * All statements either succeed together or fail together. + * + * The function automatically: + * - Begins a transaction (BEGIN) + * - Executes all statements in order + * - Commits on success (COMMIT) + * - Rolls back on any error (ROLLBACK) + * + * @param statements - Array of [query, values?] tuples to execute + * @returns Promise that resolves with results for each statement when all complete successfully + * @throws SqliteError if any statement fails (after rollback) + * + * @example + * ```ts + * // Execute multiple inserts atomically + * const results = await db.executeTransaction([ + * ['INSERT INTO users (name, email) VALUES ($1, $2)', ['Alice', 'alice@example.com']], + * ['INSERT INTO audit_log (action, user) VALUES ($1, $2)', ['user_created', 'Alice']] + * ]); + * console.log(`User ID: ${results[0].lastInsertId}`); + * console.log(`Log rows affected: ${results[1].rowsAffected}`); + * + * // Mixed operations + * const results = await db.executeTransaction([ + * ['UPDATE accounts SET balance = balance - $1 WHERE id = $2', [100, 1]], + * ['UPDATE accounts SET balance = balance + $1 WHERE id = $2', [100, 2]], + * ['INSERT INTO transfers (from_id, to_id, amount) VALUES ($1, $2, $3)', [1, 2, 100]] + * ]); + * ``` + */ + async executeTransaction(statements: Array<[string, SqlValue[]?]>): Promise { + return await invoke('plugin:sqlite|execute_transaction', { + db: this.path, + statements: statements.map(([query, values]) => ({ + query, + values: values ?? [] + })) + }) + } + /** * **fetchAll** * @@ -211,78 +260,6 @@ export default class Database { return result } - /** - * **beginTransaction** - * - * Begins a new database transaction. All subsequent operations will be - * part of this transaction until `commitTransaction()` or `rollbackTransaction()` - * is called. - * - * Transactions provide atomicity - either all operations succeed or all are rolled back. - * - * @example - * ```ts - * await db.beginTransaction(); - * try { - * await db.execute('INSERT INTO users (name) VALUES ($1)', ['Alice']); - * await db.execute('INSERT INTO logs (action) VALUES ($1)', ['user_created']); - * await db.commitTransaction(); - * } catch (error) { - * await db.rollbackTransaction(); - * throw error; - * } - * ``` - */ - async beginTransaction(): Promise { - await invoke('plugin:sqlite|begin_transaction', { - db: this.path - }) - } - - /** - * **commitTransaction** - * - * Commits the current transaction, making all changes permanent. - * - * @example - * ```ts - * await db.beginTransaction(); - * await db.execute('INSERT INTO users (name) VALUES ($1)', ['Alice']); - * await db.execute('INSERT INTO logs (action) VALUES ($1)', ['user_created']); - * await db.commitTransaction(); - * ``` - */ - async commitTransaction(): Promise { - await invoke('plugin:sqlite|commit_transaction', { - db: this.path - }) - } - - /** - * **rollbackTransaction** - * - * Rolls back the current transaction, discarding all changes made since - * `beginTransaction()` was called. - * - * @example - * ```ts - * await db.beginTransaction(); - * try { - * await db.execute('INSERT INTO users (name) VALUES ($1)', ['Alice']); - * await db.execute('INSERT INTO logs (action) VALUES ($1)', ['user_created']); - * await db.commitTransaction(); - * } catch (error) { - * await db.rollbackTransaction(); - * throw error; - * } - * ``` - */ - async rollbackTransaction(): Promise { - await invoke('plugin:sqlite|rollback_transaction', { - db: this.path - }) - } - /** * **close** * diff --git a/src/decode.rs b/src/decode.rs new file mode 100644 index 0000000..ed7a081 --- /dev/null +++ b/src/decode.rs @@ -0,0 +1,150 @@ +use serde_json::Value as JsonValue; +use sqlx::sqlite::SqliteValueRef; +use sqlx::{TypeInfo, Value, ValueRef}; +use time::PrimitiveDateTime; + +use crate::Error; + +/// Convert a SQLite value to a JSON value. +/// +/// This function handles the type conversion from SQLite's native types +/// to JSON-compatible representations. +pub fn to_json(value: SqliteValueRef) -> Result { + if value.is_null() { + return Ok(JsonValue::Null); + } + + let column_type = value.type_info(); + + // Handle types based on SQLite's type affinity + let result = match column_type.name() { + "TEXT" => { + if let Ok(v) = value.to_owned().try_decode::() { + JsonValue::String(v) + } else { + JsonValue::Null + } + } + + "REAL" => { + if let Ok(v) = value.to_owned().try_decode::() { + JsonValue::from(v) + } else { + JsonValue::Null + } + } + + "INTEGER" | "NUMERIC" => { + if let Ok(v) = value.to_owned().try_decode::() { + JsonValue::Number(v.into()) + } else { + JsonValue::Null + } + } + + "BOOLEAN" => { + if let Ok(v) = value.to_owned().try_decode::() { + JsonValue::Bool(v) + } else { + JsonValue::Null + } + } + + "DATE" => { + // SQLite stores dates as TEXT in ISO 8601 format + if let Ok(v) = value.to_owned().try_decode::() { + JsonValue::String(v) + } else { + JsonValue::Null + } + } + + "TIME" => { + // SQLite stores time as TEXT in HH:MM:SS format + if let Ok(v) = value.to_owned().try_decode::() { + JsonValue::String(v) + } else { + JsonValue::Null + } + } + + "DATETIME" => { + // Try to decode as PrimitiveDateTime + if let Ok(dt) = value.to_owned().try_decode::() { + JsonValue::String(dt.to_string()) + } else if let Ok(v) = value.to_owned().try_decode::() { + // Fall back to string representation + JsonValue::String(v) + } else { + JsonValue::Null + } + } + + "BLOB" => { + if let Ok(blob) = value.to_owned().try_decode::>() { + // Encode binary data as base64 for JSON serialization + JsonValue::String(base64_encode(&blob)) + } else { + JsonValue::Null + } + } + + "NULL" => JsonValue::Null, + + _ => { + // For unknown types, try to decode as text + if let Ok(text) = value.to_owned().try_decode::() { + JsonValue::String(text) + } else { + return Err(Error::UnsupportedDatatype(format!( + "Unknown SQLite type: {}", + column_type.name() + ))); + } + } + }; + + Ok(result) +} + +/// Base64 encode binary data for JSON serialization. +/// +/// SQLite BLOB columns are encoded as base64 strings when serialized to JSON, +/// as JSON does not have a native binary type. +fn base64_encode(data: &[u8]) -> String { + use base64::Engine; + base64::engine::general_purpose::STANDARD.encode(data) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_base64_encode() { + assert_eq!(base64_encode(b"hello"), "aGVsbG8="); + assert_eq!(base64_encode(&[1, 2, 3, 4, 5]), "AQIDBAU="); + assert_eq!(base64_encode(&[]), ""); + } + + #[test] + fn test_base64_encode_binary() { + // Test with binary data including null bytes + assert_eq!(base64_encode(&[0, 0, 0]), "AAAA"); + assert_eq!(base64_encode(&[255, 255, 255]), "////"); + } + + #[test] + fn test_base64_encode_large() { + // Test with larger binary data + let data: Vec = (0..255).collect(); + let encoded = base64_encode(&data); + assert!(!encoded.is_empty()); + // Verify it's valid base64 (only contains valid chars) + assert!( + encoded + .chars() + .all(|c| c.is_alphanumeric() || c == '+' || c == '/' || c == '=') + ); + } +} diff --git a/src/error.rs b/src/error.rs index 3741b35..ae15e36 100644 --- a/src/error.rs +++ b/src/error.rs @@ -41,13 +41,16 @@ pub enum Error { #[error("io error: {0}")] Io(#[from] std::io::Error), - /// Read-only query executed with execute command. - #[error("execute() should not be used for read-only queries. Use fetchX() instead.")] - ReadOnlyQueryInExecute, - /// Multiple rows returned from fetchOne query. #[error("fetchOne() query returned {0} rows, expected 0 or 1")] MultipleRowsReturned(usize), + + /// Transaction failed and rollback also failed. + #[error("transaction failed: {transaction_error}; rollback also failed: {rollback_error}")] + TransactionRollbackFailed { + transaction_error: String, + rollback_error: String, + }, } impl Error { @@ -69,8 +72,8 @@ impl Error { Error::DatabaseNotLoaded(_) => "DATABASE_NOT_LOADED".to_string(), Error::UnsupportedDatatype(_) => "UNSUPPORTED_DATATYPE".to_string(), Error::Io(_) => "IO_ERROR".to_string(), - Error::ReadOnlyQueryInExecute => "READ_ONLY_QUERY_IN_EXECUTE".to_string(), Error::MultipleRowsReturned(_) => "MULTIPLE_ROWS_RETURNED".to_string(), + Error::TransactionRollbackFailed { .. } => "TRANSACTION_ROLLBACK_FAILED".to_string(), } } } @@ -87,3 +90,108 @@ impl Serialize for Error { response.serialize(serializer) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_error_code_database_not_loaded() { + let err = Error::DatabaseNotLoaded("test.db".into()); + assert_eq!(err.error_code(), "DATABASE_NOT_LOADED"); + } + + #[test] + fn test_error_code_invalid_path() { + let err = Error::InvalidPath("/bad/path".into()); + assert_eq!(err.error_code(), "INVALID_PATH"); + } + + #[test] + fn test_error_code_unsupported_datatype() { + let err = Error::UnsupportedDatatype("WEIRD_TYPE".into()); + assert_eq!(err.error_code(), "UNSUPPORTED_DATATYPE"); + } + + #[test] + fn test_error_code_multiple_rows() { + let err = Error::MultipleRowsReturned(5); + assert_eq!(err.error_code(), "MULTIPLE_ROWS_RETURNED"); + } + + #[test] + fn test_error_serialization_structure() { + let err = Error::DatabaseNotLoaded("mydb.db".into()); + let json = serde_json::to_value(&err).unwrap(); + + // Verify structure has both code and message fields + assert!(json.is_object()); + assert!(json.get("code").is_some()); + assert!(json.get("message").is_some()); + } + + #[test] + fn test_error_serialization_database_not_loaded() { + let err = Error::DatabaseNotLoaded("mydb.db".into()); + let json = serde_json::to_value(&err).unwrap(); + + assert_eq!(json["code"], "DATABASE_NOT_LOADED"); + assert!(json["message"].as_str().unwrap().contains("mydb.db")); + assert!(json["message"].as_str().unwrap().contains("not loaded")); + } + + #[test] + fn test_error_serialization_invalid_path() { + let err = Error::InvalidPath("/bad/path".into()); + let json = serde_json::to_value(&err).unwrap(); + + assert_eq!(json["code"], "INVALID_PATH"); + assert!(json["message"].as_str().unwrap().contains("/bad/path")); + } + + #[test] + fn test_error_serialization_multiple_rows() { + let err = Error::MultipleRowsReturned(3); + let json = serde_json::to_value(&err).unwrap(); + + assert_eq!(json["code"], "MULTIPLE_ROWS_RETURNED"); + let message = json["message"].as_str().unwrap(); + assert!(message.contains("3 rows")); + assert!(message.contains("0 or 1")); + } + + #[test] + fn test_error_message_format() { + // Verify error messages are descriptive + let err = Error::MultipleRowsReturned(5); + let message = err.to_string(); + assert!(message.contains("fetchOne()")); + assert!(message.contains("5 rows")); + assert!(message.contains("expected 0 or 1")); + } + + #[test] + fn test_error_code_transaction_rollback_failed() { + let err = Error::TransactionRollbackFailed { + transaction_error: "constraint violation".to_string(), + rollback_error: "connection lost".to_string(), + }; + assert_eq!(err.error_code(), "TRANSACTION_ROLLBACK_FAILED"); + } + + #[test] + fn test_error_serialization_transaction_rollback_failed() { + let err = Error::TransactionRollbackFailed { + transaction_error: "constraint violation".to_string(), + rollback_error: "connection lost".to_string(), + }; + let json = serde_json::to_value(&err).unwrap(); + + assert_eq!(json["code"], "TRANSACTION_ROLLBACK_FAILED"); + let message = json["message"].as_str().unwrap(); + assert!(message.contains("constraint violation")); + assert!(message.contains("connection lost")); + assert!(message.contains("transaction failed")); + assert!(message.contains("rollback also failed")); + } +} diff --git a/src/lib.rs b/src/lib.rs index e2e9e38..753630e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,21 +1,90 @@ -use tauri::{Runtime, plugin::TauriPlugin}; +use std::collections::HashMap; +use std::future::Future; + +use serde::Deserialize; +use tauri::{Runtime, plugin::Builder as PluginBuilder}; +use tokio::sync::RwLock; mod commands; +mod decode; mod error; +mod wrapper; pub use error::{Error, Result}; +pub use wrapper::{DatabaseWrapper, WriteQueryResult}; + +/// Database instances managed by the plugin. +/// +/// This struct maintains a thread-safe map of database paths to their corresponding +/// connection wrappers. +#[derive(Default)] +pub struct DbInstances(pub RwLock>); + +/// Plugin configuration. +/// +/// Defines databases to preload during plugin initialization. +#[derive(Default, Clone, Deserialize)] +pub struct PluginConfig { + /// List of database paths to load on plugin initialization + #[serde(default)] + #[allow(dead_code)] // Will be used in future PR + preload: Vec, +} + +/// Helper function to run async commands in both async and sync contexts. +/// +/// This handles the case where we're already in a Tokio runtime (use `block_in_place`) +/// or need to create one (use Tauri's async runtime). +#[allow(dead_code)] // Will be used in a future PR +fn run_async_command(cmd: F) -> F::Output { + if tokio::runtime::Handle::try_current().is_ok() { + tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(cmd)) + } else { + tauri::async_runtime::block_on(cmd) + } +} + +/// Builder for the SQLite plugin. +/// +/// Use this to configure the plugin and build the plugin instance. +/// +/// # Example +/// +/// ```rust,ignore +/// use tauri_plugin_sqlite::Builder; +/// +/// // In your Tauri app setup: +/// tauri::Builder::default() +/// .plugin(Builder::new().build()) +/// .run(tauri::generate_context!()) +/// .expect("error while running tauri application"); +/// ``` +#[derive(Default)] +pub struct Builder; + +impl Builder { + /// Create a new builder instance. + pub fn new() -> Self { + Self + } + + /// Build the plugin. + /// + /// Full implementation with preload and lifecycle hooks will be added in a future PR. + pub fn build(self) -> tauri::plugin::TauriPlugin> { + // Future PR: Full implementation with setup, preload, and cleanup hooks + PluginBuilder::>::new("sqlite") + .setup(|_app, _api| { + // Future PR: Database preloading and lifecycle management + Ok(()) + }) + .build() + } +} -/// Initializes the plugin. -pub fn init() -> TauriPlugin { - tauri::plugin::Builder::new("sqlite") - // .invoke_handler(tauri::generate_handler![ - // commands::load, - // commands::execute, - // commands::fetch_all, - // commands::fetch_one, - // commands::close, - // commands::close_all, - // commands::remove - // ]) - .build() +/// Initializes the plugin with default configuration. +/// +/// For custom configuration, use `Builder` instead. +pub fn init() -> tauri::plugin::TauriPlugin> { + Builder::new().build() } diff --git a/src/wrapper.rs b/src/wrapper.rs new file mode 100644 index 0000000..c64cf5c --- /dev/null +++ b/src/wrapper.rs @@ -0,0 +1,255 @@ +use std::fs::create_dir_all; +use std::path::PathBuf; +use std::sync::Arc; + +use indexmap::IndexMap; +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; +use sqlx::{Column, Executor, Row}; +use sqlx_sqlite_conn_mgr::{SqliteDatabase, SqliteDatabaseConfig}; +use tauri::{AppHandle, Manager, Runtime}; + +use crate::Error; + +/// Result returned from write operations (e.g. INSERT, UPDATE, DELETE). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WriteQueryResult { + /// The number of rows affected by the write operation. + pub rows_affected: u64, + /// The last inserted row ID (SQLite ROWID). + /// + /// Only set for INSERT operations on tables with a ROWID. + /// Tables created with `WITHOUT ROWID` will not set this value (returns 0). + pub last_insert_id: i64, +} + +/// Wrapper around SqliteDatabase that adapts it for the plugin interface +pub struct DatabaseWrapper { + inner: Arc, +} + +impl DatabaseWrapper { + /// Connect to a SQLite database via the connection manager + pub async fn connect( + path: &str, + app: &AppHandle, + custom_config: Option, + ) -> Result { + // Resolve path relative to app_config_dir + let abs_path = resolve_database_path(path, app)?; + + // Use connection manager to connect with optional custom config + let db = SqliteDatabase::connect(&abs_path, custom_config).await?; + + Ok(Self { inner: db }) + } + + /// Execute a write query (INSERT/UPDATE/DELETE) + pub async fn execute( + &self, + query: String, + values: Vec, + ) -> Result { + // Acquire writer for mutations + let mut writer = self.inner.acquire_writer().await?; + + let mut q = sqlx::query(&query); + for value in values { + q = bind_value(q, value); + } + + let result = q.execute(&mut *writer).await?; + Ok(WriteQueryResult { + rows_affected: result.rows_affected(), + last_insert_id: result.last_insert_rowid(), + }) + } + + /// Execute multiple write statements atomically within a transaction. + /// + /// This method: + /// 1. Begins a transaction (BEGIN) + /// 2. Executes all statements in order + /// 3. Commits on success (COMMIT) + /// 4. Rolls back on any error (ROLLBACK) + /// + /// The writer is held for the entire transaction, ensuring atomicity. + /// Returns the result of each statement execution. + pub async fn execute_transaction( + &self, + statements: Vec<(String, Vec)>, + ) -> Result, Error> { + // Acquire writer for the entire transaction + let mut writer = self.inner.acquire_writer().await?; + + // Begin transaction + sqlx::query("BEGIN IMMEDIATE").execute(&mut *writer).await?; + + // Execute all statements, collecting results and rolling back on error + let result = async { + let mut results = Vec::new(); + for (query, values) in statements { + let mut q = sqlx::query(&query); + for value in values { + q = bind_value(q, value); + } + let exec_result = q.execute(&mut *writer).await?; + results.push(WriteQueryResult { + rows_affected: exec_result.rows_affected(), + last_insert_id: exec_result.last_insert_rowid(), + }); + } + Ok::, Error>(results) + } + .await; + + // Commit or rollback based on result + match result { + Ok(results) => { + sqlx::query("COMMIT").execute(&mut *writer).await?; + Ok(results) + } + Err(e) => { + match sqlx::query("ROLLBACK").execute(&mut *writer).await { + // Rollback succeeded, return original error + Ok(_) => Err(e), + + // Rollback also failed, return the rollback error and the original error + Err(rollback_err) => Err(Error::TransactionRollbackFailed { + transaction_error: e.to_string(), + rollback_error: rollback_err.to_string(), + }), + } + } + } + } + + /// Execute a SELECT query, possibly returning multiple rows + pub async fn fetch_all( + &self, + query: String, + values: Vec, + ) -> Result>, Error> { + // Use read pool for queries + let pool = self.inner.read_pool()?; + + let mut q = sqlx::query(&query); + for value in values { + q = bind_value(q, value); + } + + let rows = pool.fetch_all(q).await?; + + // Decode rows to JSON + let mut values = Vec::new(); + for row in rows { + let mut value = IndexMap::default(); + for (i, column) in row.columns().iter().enumerate() { + let v = row.try_get_raw(i)?; + let v = crate::decode::to_json(v)?; + value.insert(column.name().to_string(), v); + } + values.push(value); + } + + Ok(values) + } + + /// Execute a SELECT query expecting zero or one result + pub async fn fetch_one( + &self, + query: String, + values: Vec, + ) -> Result>, Error> { + // Use read pool for queries + let pool = self.inner.read_pool()?; + + // Add LIMIT 2 to detect if query returns multiple rows + // We only need to fetch up to 2 rows to know if there's more than 1 + let limited_query = format!("{} LIMIT 2", query.trim_end_matches(';')); + + let mut q = sqlx::query(&limited_query); + for value in values { + q = bind_value(q, value); + } + + let rows = pool.fetch_all(q).await?; + + // Validate row count + match rows.len() { + 0 => Ok(None), + 1 => { + // Decode single row to JSON + let row = &rows[0]; + let mut value = IndexMap::default(); + for (i, column) in row.columns().iter().enumerate() { + let v = row.try_get_raw(i)?; + let v = crate::decode::to_json(v)?; + value.insert(column.name().to_string(), v); + } + Ok(Some(value)) + } + count => { + // Multiple rows returned - this is an error + Err(Error::MultipleRowsReturned(count)) + } + } + } + + /// Close the database connection + pub async fn close(self) -> Result<(), Error> { + // Close via Arc (handles both owned and shared cases) + self.inner.close().await?; + Ok(()) + } + + /// Close the database connection and remove all database files + pub async fn remove(self) -> Result<(), Error> { + // Remove via Arc (handles both owned and shared cases) + self.inner.remove().await?; + Ok(()) + } +} + +/// Helper function to bind a JSON value to a SQLx query +fn bind_value<'a>( + query: sqlx::query::Query<'a, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'a>>, + value: JsonValue, +) -> sqlx::query::Query<'a, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'a>> { + if value.is_null() { + query.bind(None::) + } else if value.is_string() { + query.bind(value.as_str().unwrap().to_owned()) + } else if let Some(number) = value.as_number() { + // Preserve integer precision by binding as i64 when possible + if let Some(int_val) = number.as_i64() { + query.bind(int_val) + } else if let Some(uint_val) = number.as_u64() { + // Try to fit u64 into i64 (SQLite's INTEGER type) + if uint_val <= i64::MAX as u64 { + query.bind(uint_val as i64) + } else { + // Value too large for i64, use f64 (will lose precision) + query.bind(uint_val as f64) + } + } else { + // Not an integer, bind as f64 + query.bind(number.as_f64().unwrap_or_default()) + } + } else { + query.bind(value) + } +} + +/// Resolve database file path relative to app config directory +fn resolve_database_path(path: &str, app: &AppHandle) -> Result { + let app_path = app + .path() + .app_config_dir() + .expect("No App config path was found!"); + + create_dir_all(&app_path).expect("Couldn't create app config dir"); + + // Join the relative path to the app config directory + Ok(app_path.join(path)) +}