Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sqlite): add preupdate hook #3625

Merged
merged 5 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/sqlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ jobs:
- run: >
cargo clippy
--no-default-features
--features all-databases,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros
--features all-databases,_unstable-all-types,sqlite-preupdate-hook,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros
-- -D warnings

# Run beta for new warnings but don't break the build.
# Use a subdirectory of `target` to avoid clobbering the cache.
- run: >
cargo +beta clippy
--no-default-features
--features all-databases,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros
--features all-databases,_unstable-all-types,sqlite-preupdate-hook,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros
--target-dir target/beta/

check-minimal-versions:
Expand Down Expand Up @@ -140,7 +140,7 @@ jobs:
- run: >
cargo test
--no-default-features
--features any,macros,${{ matrix.linking }},_unstable-all-types,runtime-${{ matrix.runtime }}
--features any,macros,${{ matrix.linking }},${{ matrix.linking == 'sqlite' && 'sqlite-preupdate-hook,' || ''}}_unstable-all-types,runtime-${{ matrix.runtime }}
--
--test-threads=1
env:
Expand Down
12 changes: 6 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ authors.workspace = true
repository.workspace = true

[package.metadata.docs.rs]
features = ["all-databases", "_unstable-all-types"]
features = ["all-databases", "_unstable-all-types", "sqlite-preupdate-hook"]
rustdoc-args = ["--cfg", "docsrs"]

[features]
Expand Down Expand Up @@ -108,6 +108,7 @@ postgres = ["sqlx-postgres", "sqlx-macros?/postgres"]
mysql = ["sqlx-mysql", "sqlx-macros?/mysql"]
sqlite = ["_sqlite", "sqlx-sqlite/bundled", "sqlx-macros?/sqlite"]
sqlite-unbundled = ["_sqlite", "sqlx-sqlite/unbundled", "sqlx-macros?/sqlite-unbundled"]
sqlite-preupdate-hook = ["sqlx-sqlite/preupdate-hook"]
abonander marked this conversation as resolved.
Show resolved Hide resolved

# types
json = ["sqlx-macros?/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlite?/json"]
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ be removed in the future.
* May result in link errors if the SQLite version is too old. Version `3.20.0` or newer is recommended.
* Can increase build time due to the use of bindgen.

- `sqlite-preupdate-hook`: enables SQLite's [preupdate hook](https://sqlite.org/c3ref/preupdate_count.html) API.
* Exposed as a separate feature because it's generally not enabled by default.
* Using this feature with `sqlite-unbundled` may cause linker failures if the system SQLite version does not support it.

- `any`: Add support for the `Any` database driver, which can proxy to a database driver at runtime.

- `derive`: Add support for the derive family macros, those are `FromRow`, `Type`, `Encode`, `Decode`.
Expand Down
2 changes: 2 additions & 0 deletions sqlx-sqlite/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ uuid = ["dep:uuid", "sqlx-core/uuid"]

regexp = ["dep:regex"]

preupdate-hook = ["libsqlite3-sys/preupdate_hook"]

bundled = ["libsqlite3-sys/bundled"]
unbundled = ["libsqlite3-sys/buildtime_bindgen"]

Expand Down
2 changes: 2 additions & 0 deletions sqlx-sqlite/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ impl EstablishParams {
log_settings: self.log_settings.clone(),
progress_handler_callback: None,
update_hook_callback: None,
#[cfg(feature = "preupdate-hook")]
preupdate_hook_callback: None,
abonander marked this conversation as resolved.
Show resolved Hide resolved
commit_hook_callback: None,
rollback_hook_callback: None,
})
Expand Down
221 changes: 221 additions & 0 deletions sqlx-sqlite/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use libsqlite3_sys::{
sqlite3, sqlite3_commit_hook, sqlite3_progress_handler, sqlite3_rollback_hook,
sqlite3_update_hook, SQLITE_DELETE, SQLITE_INSERT, SQLITE_UPDATE,
};
#[cfg(feature = "preupdate-hook")]
pub use preupdate_hook::*;

pub(crate) use handle::ConnectionHandle;
use sqlx_core::common::StatementCache;
Expand Down Expand Up @@ -88,6 +90,7 @@ pub struct UpdateHookResult<'a> {
pub table: &'a str,
pub rowid: i64,
}

pub(crate) struct UpdateHookHandler(NonNull<dyn FnMut(UpdateHookResult) + Send + 'static>);
unsafe impl Send for UpdateHookHandler {}

Expand All @@ -112,6 +115,8 @@ pub(crate) struct ConnectionState {
progress_handler_callback: Option<Handler>,

update_hook_callback: Option<UpdateHookHandler>,
#[cfg(feature = "preupdate-hook")]
preupdate_hook_callback: Option<preupdate_hook::PreupdateHookHandler>,

commit_hook_callback: Option<CommitHookHandler>,

Expand Down Expand Up @@ -544,3 +549,219 @@ impl Statements {
self.temp = None;
}
}

#[cfg(feature = "preupdate-hook")]
mod preupdate_hook {
abonander marked this conversation as resolved.
Show resolved Hide resolved
use super::ConnectionState;
use super::LockedSqliteHandle;
use super::SqliteOperation;
use crate::type_info::DataType;
use crate::{SqliteError, SqliteTypeInfo, SqliteValue};
use libsqlite3_sys::{
sqlite3, sqlite3_preupdate_count, sqlite3_preupdate_depth, sqlite3_preupdate_hook,
sqlite3_preupdate_new, sqlite3_preupdate_old, sqlite3_value, sqlite3_value_type, SQLITE_OK,
};
use sqlx_core::error::Error;
use std::ffi::CStr;
use std::fmt::Debug;
use std::os::raw::{c_char, c_int, c_void};
use std::panic::catch_unwind;
use std::ptr;
use std::ptr::NonNull;

pub struct PreupdateHookResult<'a> {
pub operation: SqliteOperation,
pub database: &'a str,
pub table: &'a str,
pub case: PreupdateCase,
}

pub(crate) struct PreupdateHookHandler(
NonNull<dyn FnMut(PreupdateHookResult) + Send + 'static>,
);
unsafe impl Send for PreupdateHookHandler {}

/// The possible cases for when a PreUpdate Hook gets triggered. Allows access to the relevant
/// functions for each case through the contained values.
pub enum PreupdateCase {
/// Pre-update hook was triggered by an insert.
Insert(PreupdateNewValueAccessor),
/// Pre-update hook was triggered by a delete.
Delete(PreupdateOldValueAccessor),
/// Pre-update hook was triggered by an update.
Update {
old_value_accessor: PreupdateOldValueAccessor,
new_value_accessor: PreupdateNewValueAccessor,
},
/// This variant is not normally produced by SQLite. You may encounter it
/// if you're using a different version than what's supported by this library.
Unknown,
}
abonander marked this conversation as resolved.
Show resolved Hide resolved

/// An accessor for the old values of the row being deleted/updated during the preupdate callback.
#[derive(Debug)]
pub struct PreupdateOldValueAccessor {
db: *mut sqlite3,
abonander marked this conversation as resolved.
Show resolved Hide resolved
old_row_id: i64,
}

impl PreupdateOldValueAccessor {
/// Gets the amount of columns in the row being deleted/updated.
pub fn get_column_count(&self) -> i32 {
unsafe { sqlite3_preupdate_count(self.db) }
}

/// Gets the depth of the query that triggered the preupdate hook.
/// Returns 0 if the preupdate callback was invoked as a result of
/// a direct insert, update, or delete operation;
/// 1 for inserts, updates, or deletes invoked by top-level triggers;
/// 2 for changes resulting from triggers called by top-level triggers; and so forth.
pub fn get_query_depth(&self) -> i32 {
unsafe { sqlite3_preupdate_depth(self.db) }
}

/// Gets the row id of the row being updated/deleted.
pub fn get_old_row_id(&self) -> i64 {
self.old_row_id
}

/// Gets the value of the row being updated/deleted at the specified index.
pub fn get_old_column_value(&self, i: i32) -> Result<SqliteValue, Error> {
let mut p_value: *mut sqlite3_value = ptr::null_mut();
unsafe {
let ret = sqlite3_preupdate_old(self.db, i, &mut p_value);
Copy link
Collaborator

@abonander abonander Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sqlite3_value returned by this call and sqlite3_preupdate_new has weird semantics. It's a "protected" value so it's thread-safe, but at the same time the documentation also specifies:

The sqlite3_value that P points to will be destroyed when the preupdate callback returns.

We should handle this as a SqliteValueRef instead, using the same lifetime. The user can then use .to_owned() to get a fully independent value if they need it.

You'll need to add a new case here:

Value(&'r SqliteValue),

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also check i against sqlite3_preupdate_count because the result is undefined if the index is out of bounds.

Copy link
Contributor Author

@aschey aschey Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should handle this as a SqliteValueRef instead, using the same lifetime. The user can then use .to_owned() to get a fully independent value if they need it.

SqliteValue::new calls sqlite3_value_dup which creates a copy of the sqlite3_value. I tested this out by storing the returned SqliteValue in a mutex and ensuring the returned value could still be decoded properly after the callback was completed.

I can add something to SqliteValueRef for this to avoid the call to sqlite3_value_dup, but I think that would require re-implementing a lot of the logic in SqliteValue or modifying it so that it can operate on both an owned or borrowed value. Let me know if I'm missing something.

EDIT: added this in a separate commit. I can revert it if this isn't what you had in mind.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also check i against sqlite3_preupdate_count because the result is undefined if the index is out of bounds.

I think the documentation might be a bit conservative here (or they don't want to make any guarantees) because it does check for these conditions and return an error, but I went ahead and added an explicit check here too.

if ret != SQLITE_OK {
return Err(Error::Database(Box::new(SqliteError::new(self.db))));
}
let data_type = DataType::from_code(sqlite3_value_type(p_value));
Ok(SqliteValue::new(p_value, SqliteTypeInfo(data_type)))
}
}
}

/// An accessor for the new values of the row being inserted/updated during the preupdate callback.
#[derive(Debug)]
pub struct PreupdateNewValueAccessor {
db: *mut sqlite3,
new_row_id: i64,
}

impl PreupdateNewValueAccessor {
/// Gets the amount of columns in the row being inserted/updated.
pub fn get_column_count(&self) -> i32 {
unsafe { sqlite3_preupdate_count(self.db) }
}

/// Gets the depth of the query that triggered the preupdate hook.
/// Returns 0 if the preupdate callback was invoked as a result of
/// a direct insert, update, or delete operation;
/// 1 for inserts, updates, or deletes invoked by top-level triggers;
/// 2 for changes resulting from triggers called by top-level triggers; and so forth.
pub fn get_query_depth(&self) -> i32 {
unsafe { sqlite3_preupdate_depth(self.db) }
}

/// Gets the row id of the row being inserted/updated.
pub fn get_new_row_id(&self) -> i64 {
self.new_row_id
}

/// Gets the value of the row being updated/deleted at the specified index.
pub fn get_new_column_value(&self, i: i32) -> Result<SqliteValue, Error> {
let mut p_value: *mut sqlite3_value = ptr::null_mut();
unsafe {
let ret = sqlite3_preupdate_new(self.db, i, &mut p_value);
if ret != SQLITE_OK {
return Err(Error::Database(Box::new(SqliteError::new(self.db))));
}
let data_type = DataType::from_code(sqlite3_value_type(p_value));
Ok(SqliteValue::new(p_value, SqliteTypeInfo(data_type)))
}
}
}

impl ConnectionState {
pub(crate) fn remove_preupdate_hook(&mut self) {
if let Some(mut handler) = self.preupdate_hook_callback.take() {
unsafe {
sqlite3_preupdate_hook(self.handle.as_ptr(), None, ptr::null_mut());
let _ = { Box::from_raw(handler.0.as_mut()) };
}
}
}
}

impl LockedSqliteHandle<'_> {
/// Registers a hook that is invoked prior to each `INSERT`, `UPDATE`, and `DELETE` operation on a database table.
/// At most one preupdate hook may be registered at a time on a single database connection.
///
/// The preupdate hook only fires for changes to real database tables;
/// it is not invoked for changes to virtual tables or to system tables like sqlite_sequence or sqlite_stat1.
///
/// See https://sqlite.org/c3ref/preupdate_count.html
pub fn set_preupdate_hook<F>(&mut self, callback: F)
where
F: FnMut(PreupdateHookResult) + Send + 'static,
{
unsafe {
let callback_boxed = Box::new(callback);
// SAFETY: `Box::into_raw()` always returns a non-null pointer.
let callback = NonNull::new_unchecked(Box::into_raw(callback_boxed));
let handler = callback.as_ptr() as *mut _;
self.guard.remove_preupdate_hook();
self.guard.preupdate_hook_callback = Some(PreupdateHookHandler(callback));

sqlite3_preupdate_hook(
self.as_raw_handle().as_mut(),
Some(preupdate_hook::<F>),
handler,
);
}
}

pub fn remove_preupdate_hook(&mut self) {
self.guard.remove_preupdate_hook();
}
}

extern "C" fn preupdate_hook<F>(
callback: *mut c_void,
db: *mut sqlite3,
op_code: c_int,
database: *const c_char,
table: *const c_char,
old_row_id: i64,
new_row_id: i64,
) where
F: FnMut(PreupdateHookResult),
{
unsafe {
let _ = catch_unwind(|| {
let callback: *mut F = callback.cast::<F>();
let operation: SqliteOperation = op_code.into();
let database = CStr::from_ptr(database).to_str().unwrap_or_default();
let table = CStr::from_ptr(table).to_str().unwrap_or_default();

let preupdate_case = match operation {
SqliteOperation::Insert => {
PreupdateCase::Insert(PreupdateNewValueAccessor { db, new_row_id })
}
SqliteOperation::Delete => {
PreupdateCase::Delete(PreupdateOldValueAccessor { db, old_row_id })
}
SqliteOperation::Update => PreupdateCase::Update {
old_value_accessor: PreupdateOldValueAccessor { db, old_row_id },
new_value_accessor: PreupdateNewValueAccessor { db, new_row_id },
},
SqliteOperation::Unknown(_) => PreupdateCase::Unknown,
};
(*callback)(PreupdateHookResult {
operation,
database,
table,
case: preupdate_case,
})
});
}
}
}
4 changes: 4 additions & 0 deletions sqlx-sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ use std::sync::atomic::AtomicBool;
pub use arguments::{SqliteArgumentValue, SqliteArguments};
pub use column::SqliteColumn;
pub use connection::{LockedSqliteHandle, SqliteConnection, SqliteOperation, UpdateHookResult};
#[cfg(feature = "preupdate-hook")]
pub use connection::{
PreupdateCase, PreupdateHookResult, PreupdateNewValueAccessor, PreupdateOldValueAccessor,
};
pub use database::Sqlite;
pub use error::SqliteError;
pub use options::{
Expand Down
Loading
Loading