From ec9320661f8a683797c7f9b04e357461ddbd3719 Mon Sep 17 00:00:00 2001 From: Lars Schumacher Date: Sun, 21 Jan 2024 03:20:04 +0100 Subject: [PATCH] fix(mysql): Close prepared statement if persistence is disabled (#2905) * close prepared statement if persistence or statement cache are disabled * add tests --- Cargo.toml | 2 +- sqlx-mysql/src/connection/executor.rs | 90 ++++++++++++++++++--------- sqlx-sqlite/Cargo.toml | 2 +- tests/mysql/mysql.rs | 66 ++++++++++++++++++++ 4 files changed, 129 insertions(+), 31 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f13975a94f..ab7be1ca54 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -178,7 +178,7 @@ tempdir = "0.3.7" criterion = {version = "0.4", features = ["async_tokio"]} # Needed to test SQLCipher -libsqlite3-sys = { version = "0.27", features = ["bundled-sqlcipher"] } +libsqlite3-sys = { version = "0.26", features = ["bundled-sqlcipher"] } # # Any diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index b668c67396..21fec1ec6b 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -25,16 +25,10 @@ use futures_util::{pin_mut, TryStreamExt}; use std::{borrow::Cow, sync::Arc}; impl MySqlConnection { - async fn get_or_prepare<'c>( + async fn prepare_statement<'c>( &mut self, sql: &str, - persistent: bool, ) -> Result<(u32, MySqlStatementMetadata), Error> { - if let Some(statement) = self.cache_statement.get_mut(sql) { - // is internally reference-counted - return Ok((*statement).clone()); - } - // https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html // https://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html#packet-COM_STMT_PREPARE_OK @@ -72,11 +66,23 @@ impl MySqlConnection { column_names: Arc::new(column_names), }; - if persistent && self.cache_statement.is_enabled() { - // in case of the cache being full, close the least recently used statement - if let Some((id, _)) = self.cache_statement.insert(sql, (id, metadata.clone())) { - self.stream.send_packet(StmtClose { statement: id }).await?; - } + Ok((id, metadata)) + } + + async fn get_or_prepare_statement<'c>( + &mut self, + sql: &str, + ) -> Result<(u32, MySqlStatementMetadata), Error> { + if let Some(statement) = self.cache_statement.get_mut(sql) { + // is internally reference-counted + return Ok((*statement).clone()); + } + + let (id, metadata) = self.prepare_statement(sql).await?; + + // in case of the cache being full, close the least recently used statement + if let Some((id, _)) = self.cache_statement.insert(sql, (id, metadata.clone())) { + self.stream.send_packet(StmtClose { statement: id }).await?; } Ok((id, metadata)) @@ -102,21 +108,37 @@ impl MySqlConnection { let mut columns = Arc::new(Vec::new()); let (mut column_names, format, mut needs_metadata) = if let Some(arguments) = arguments { - let (id, metadata) = self.get_or_prepare( - sql, - persistent, - ) - .await?; - - // https://dev.mysql.com/doc/internals/en/com-stmt-execute.html - self.stream - .send_packet(StatementExecute { - statement: id, - arguments: &arguments, - }) - .await?; - - (metadata.column_names, MySqlValueFormat::Binary, false) + if persistent && self.cache_statement.is_enabled() { + let (id, metadata) = self + .get_or_prepare_statement(sql) + .await?; + + // https://dev.mysql.com/doc/internals/en/com-stmt-execute.html + self.stream + .send_packet(StatementExecute { + statement: id, + arguments: &arguments, + }) + .await?; + + (metadata.column_names, MySqlValueFormat::Binary, false) + } else { + let (id, metadata) = self + .prepare_statement(sql) + .await?; + + // https://dev.mysql.com/doc/internals/en/com-stmt-execute.html + self.stream + .send_packet(StatementExecute { + statement: id, + arguments: &arguments, + }) + .await?; + + self.stream.send_packet(StmtClose { statement: id }).await?; + + (metadata.column_names, MySqlValueFormat::Binary, false) + } } else { // https://dev.mysql.com/doc/internals/en/com-query.html self.stream.send_packet(Query(sql)).await?; @@ -269,7 +291,15 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { Box::pin(async move { self.stream.wait_until_ready().await?; - let (_, metadata) = self.get_or_prepare(sql, true).await?; + let metadata = if self.cache_statement.is_enabled() { + self.get_or_prepare_statement(sql).await?.1 + } else { + let (id, metadata) = self.prepare_statement(sql).await?; + + self.stream.send_packet(StmtClose { statement: id }).await?; + + metadata + }; Ok(MySqlStatement { sql: Cow::Borrowed(sql), @@ -287,7 +317,9 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { Box::pin(async move { self.stream.wait_until_ready().await?; - let (_, metadata) = self.get_or_prepare(sql, false).await?; + let (id, metadata) = self.prepare_statement(sql).await?; + + self.stream.send_packet(StmtClose { statement: id }).await?; let columns = (&*metadata.columns).clone(); diff --git a/sqlx-sqlite/Cargo.toml b/sqlx-sqlite/Cargo.toml index dd260808fe..c1ab012094 100644 --- a/sqlx-sqlite/Cargo.toml +++ b/sqlx-sqlite/Cargo.toml @@ -46,7 +46,7 @@ regex = { version = "1.5.5", optional = true } urlencoding = "2.1.3" [dependencies.libsqlite3-sys] -version = "0.27.0" +version = "0.26.0" default-features = false features = [ "pkg-config", diff --git a/tests/mysql/mysql.rs b/tests/mysql/mysql.rs index 586cef2ed0..ba10824f89 100644 --- a/tests/mysql/mysql.rs +++ b/tests/mysql/mysql.rs @@ -237,6 +237,57 @@ async fn it_caches_statements() -> anyhow::Result<()> { Ok(()) } +#[sqlx_macros::test] +async fn it_closes_statements_with_persistent_disabled() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let old_statement_count = select_statement_count(&mut conn).await.unwrap_or_default(); + + for i in 0..2 { + let row = sqlx::query("SELECT ? AS val") + .bind(i) + .persistent(false) + .fetch_one(&mut conn) + .await?; + + let val: i32 = row.get("val"); + + assert_eq!(i, val); + } + + let new_statement_count = select_statement_count(&mut conn).await.unwrap_or_default(); + + assert_eq!(old_statement_count, new_statement_count); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_closes_statements_with_cache_disabled() -> anyhow::Result<()> { + setup_if_needed(); + + let mut url = url::Url::parse(&env::var("DATABASE_URL")?)?; + url.query_pairs_mut() + .append_pair("statement-cache-capacity", "0"); + + let mut conn = MySqlConnection::connect(url.as_ref()).await?; + + let old_statement_count = select_statement_count(&mut conn).await.unwrap_or_default(); + + for index in 1..=10_i32 { + let _ = sqlx::query("SELECT ?") + .bind(index) + .execute(&mut conn) + .await?; + } + + let new_statement_count = select_statement_count(&mut conn).await.unwrap_or_default(); + + assert_eq!(old_statement_count, new_statement_count); + + Ok(()) +} + #[sqlx_macros::test] async fn it_can_bind_null_and_non_null_issue_540() -> anyhow::Result<()> { let mut conn = new::().await?; @@ -510,3 +561,18 @@ async fn test_shrink_buffers() -> anyhow::Result<()> { Ok(()) } + +async fn select_statement_count(conn: &mut MySqlConnection) -> Result { + // Fails if performance schema does not exist + sqlx::query_scalar( + r#" + SELECT COUNT(*) + FROM performance_schema.threads AS t + INNER JOIN performance_schema.prepared_statements_instances AS psi + ON psi.OWNER_THREAD_ID = t.THREAD_ID + WHERE t.processlist_id = CONNECTION_ID() + "#, + ) + .fetch_one(conn) + .await +}