From 84752fa2640ca4bfa990beb22ff27111ffc33630 Mon Sep 17 00:00:00 2001 From: Michael Vlach Date: Tue, 9 Jan 2024 00:02:20 +0100 Subject: [PATCH 1/3] fix multiple simultaneous logins --- agdb_server/src/db_pool.rs | 33 +++++-- agdb_server/src/routes/user.rs | 3 +- agdb_server/tests/routes/server_test.rs | 11 +-- agdb_server/tests/test_server.rs | 116 ++++++++++++++++++------ 4 files changed, 117 insertions(+), 46 deletions(-) diff --git a/agdb_server/src/db_pool.rs b/agdb_server/src/db_pool.rs index 023e31271..13b9ad784 100644 --- a/agdb_server/src/db_pool.rs +++ b/agdb_server/src/db_pool.rs @@ -830,14 +830,31 @@ impl DbPool { Ok(()) } - pub(crate) fn save_token(&self, user: DbId, token: &str) -> ServerResult { - self.db_mut()?.exec_mut( - &QueryBuilder::insert() - .values_uniform(vec![("token", token).into()]) - .ids(user) - .query(), - )?; - Ok(()) + pub(crate) fn save_token(&self, user: DbId, token: String) -> ServerResult { + self.db_mut()?.transaction_mut(|t| { + let existing = t + .exec( + &QueryBuilder::select() + .values(vec!["token".into()]) + .ids(user) + .query(), + )? + .elements[0] + .values[0] + .value + .to_string(); + if existing.is_empty() { + t.exec_mut( + &QueryBuilder::insert() + .values_uniform(vec![("token", &token).into()]) + .ids(user) + .query(), + )?; + Ok(token.clone()) + } else { + Ok(existing) + } + }) } pub(crate) fn save_user(&self, user: ServerUser) -> ServerResult { diff --git a/agdb_server/src/routes/user.rs b/agdb_server/src/routes/user.rs index 76d817a60..504da3774 100644 --- a/agdb_server/src/routes/user.rs +++ b/agdb_server/src/routes/user.rs @@ -35,8 +35,7 @@ pub(crate) async fn login( let token = if user.token.is_empty() { let token_uuid = Uuid::new_v4(); let token = token_uuid.to_string(); - db_pool.save_token(user.db_id.unwrap(), &token)?; - token + db_pool.save_token(user.db_id.unwrap(), token)? } else { user.token }; diff --git a/agdb_server/tests/routes/server_test.rs b/agdb_server/tests/routes/server_test.rs index c5bc4cd64..78044331d 100644 --- a/agdb_server/tests/routes/server_test.rs +++ b/agdb_server/tests/routes/server_test.rs @@ -1,8 +1,5 @@ use crate::TestServer; -use crate::ADMIN; -use assert_cmd::cargo::CommandCargoExt; use reqwest::StatusCode; -use std::process::Command; #[tokio::test] async fn error() -> anyhow::Result<()> { @@ -68,12 +65,6 @@ async fn openapi() -> anyhow::Result<()> { #[tokio::test] async fn db_config_reuse() -> anyhow::Result<()> { let mut server = TestServer::new().await?; - server.api.user_login(ADMIN, ADMIN).await?; - server.api.admin_shutdown().await?; - - assert!(server.process.wait()?.success()); - server.process = Command::cargo_bin("agdb_server")? - .current_dir(&server.dir) - .spawn()?; + server.restart().await?; Ok(()) } diff --git a/agdb_server/tests/test_server.rs b/agdb_server/tests/test_server.rs index 99b17e3b3..0825ad76b 100644 --- a/agdb_server/tests/test_server.rs +++ b/agdb_server/tests/test_server.rs @@ -10,6 +10,7 @@ use std::process::Child; use std::process::Command; use std::sync::atomic::AtomicU16; use std::sync::atomic::Ordering; +use std::sync::RwLock; use std::time::Duration; const BINARY: &str = "agdb_server"; @@ -26,16 +27,25 @@ const SHUTDOWN_RETRY_ATTEMPTS: u16 = 100; static PORT: AtomicU16 = AtomicU16::new(DEFAULT_PORT); static COUNTER: AtomicU16 = AtomicU16::new(1); +static MUTEX: std::sync::OnceLock> = std::sync::OnceLock::new(); +static INSTANCES: AtomicU16 = AtomicU16::new(0); +static SERVER: RwLock> = RwLock::new(None); pub struct TestServer { pub dir: String, pub data_dir: String, + pub port: u16, pub api: AgdbApi, +} + +struct TestServerImpl { + pub dir: String, + pub data_dir: String, pub port: u16, pub process: Child, } -impl TestServer { +impl TestServerImpl { pub async fn new() -> anyhow::Result { let port = PORT.fetch_add(1, Ordering::Relaxed) + std::process::id() as u16; let dir = format!("{BINARY}.{port}.test"); @@ -57,7 +67,7 @@ impl TestServer { serde_yaml::to_writer(file, &config)?; let process = Command::cargo_bin(BINARY)?.current_dir(&dir).spawn()?; - let api = AgdbApi::new(ReqwestClient::new(), &Self::url_base(), port); + let api = AgdbApi::new(ReqwestClient::new(), &TestServer::url_base(), port); for _ in 0..RETRY_ATTEMPS { if let Ok(status) = api.status().await { @@ -65,7 +75,6 @@ impl TestServer { return Ok(Self { dir, data_dir, - api, port, process, }); @@ -78,26 +87,6 @@ impl TestServer { anyhow::bail!("Failed to start server") } - pub fn next_user_name(&mut self) -> String { - format!("db_user{}", COUNTER.fetch_add(1, Ordering::Relaxed)) - } - - pub fn next_db_name(&mut self) -> String { - format!("db{}", COUNTER.fetch_add(1, Ordering::Relaxed)) - } - - pub fn url(&self, uri: &str) -> String { - format!("{}:{}/api/v1{uri}", Self::url_base(), self.port) - } - - fn remove_dir_if_exists(dir: &str) -> anyhow::Result<()> { - if Path::new(dir).exists() { - std::fs::remove_dir_all(dir)?; - } - - Ok(()) - } - fn shutdown_server(&mut self) -> anyhow::Result<()> { if self.process.try_wait()?.is_some() { return Ok(()); @@ -111,7 +100,11 @@ impl TestServer { std::thread::spawn(move || -> anyhow::Result<()> { let client = reqwest::blocking::Client::new(); let token: String = client - .post(format!("{}:{}/api/v1/user/login", Self::url_base(), port)) + .post(format!( + "{}:{}/api/v1/user/login", + TestServer::url_base(), + port + )) .json(&admin) .send()? .json()?; @@ -119,7 +112,7 @@ impl TestServer { client .post(format!( "{}:{}/api/v1/admin/shutdown", - Self::url_base(), + TestServer::url_base(), port )) .bearer_auth(token) @@ -142,14 +135,85 @@ impl TestServer { Ok(()) } + fn remove_dir_if_exists(dir: &str) -> anyhow::Result<()> { + if Path::new(dir).exists() { + std::fs::remove_dir_all(dir)?; + } + + Ok(()) + } +} + +impl TestServer { + pub async fn new() -> anyhow::Result { + let _guard = MUTEX + .get_or_init(|| tokio::sync::Mutex::new(())) + .lock() + .await; + INSTANCES.fetch_add(1, Ordering::Relaxed); + + if SERVER.read().unwrap().is_none() { + println!("CREATING"); + *SERVER.write().unwrap() = Some(TestServerImpl::new().await?); + } + + let read_lock = SERVER.read().unwrap(); + let server = read_lock.as_ref().unwrap(); + + Ok(Self { + api: AgdbApi::new(ReqwestClient::new(), &Self::url_base(), server.port), + dir: server.dir.clone(), + port: server.port, + data_dir: server.data_dir.clone(), + }) + } + + pub fn next_user_name(&mut self) -> String { + format!("db_user{}", COUNTER.fetch_add(1, Ordering::Relaxed)) + } + + pub fn next_db_name(&mut self) -> String { + format!("db{}", COUNTER.fetch_add(1, Ordering::Relaxed)) + } + + pub async fn restart(&mut self) -> anyhow::Result<()> { + let _guard = MUTEX + .get_or_init(|| tokio::sync::Mutex::new(())) + .lock() + .await; + *SERVER.write().unwrap() = Some(TestServerImpl::new().await?); + Ok(()) + } + + pub fn url(&self, uri: &str) -> String { + format!("{}:{}/api/v1{uri}", Self::url_base(), self.port) + } + fn url_base() -> String { format!("{PROTOCOL}://{HOST}") } } -impl Drop for TestServer { +impl Drop for TestServerImpl { fn drop(&mut self) { Self::shutdown_server(self).unwrap(); Self::remove_dir_if_exists(&self.dir).unwrap(); } } + +impl Drop for TestServer { + fn drop(&mut self) { + let mutex = MUTEX.get_or_init(|| tokio::sync::Mutex::new(())); + let _guard = loop { + if let Ok(g) = mutex.try_lock() { + break g; + } + }; + let instances = INSTANCES.fetch_sub(1, Ordering::Relaxed) - 1; + + if instances == 0 { + println!("DROPPING"); + *SERVER.write().unwrap() = None; + } + } +} From 66c4ffa4724fdfa88e9f546ca90a93d61e1b8d80 Mon Sep 17 00:00:00 2001 From: Michael Vlach Date: Tue, 9 Jan 2024 20:59:28 +0100 Subject: [PATCH 2/3] fix server test --- agdb_server/src/db_pool.rs | 27 ++++++--- .../tests/routes/admin_user_remove_test.rs | 1 - agdb_server/tests/routes/server_test.rs | 16 +++++- agdb_server/tests/test_server.rs | 55 ++++++++++--------- 4 files changed, 63 insertions(+), 36 deletions(-) diff --git a/agdb_server/src/db_pool.rs b/agdb_server/src/db_pool.rs index 13b9ad784..d14b7c03b 100644 --- a/agdb_server/src/db_pool.rs +++ b/agdb_server/src/db_pool.rs @@ -709,16 +709,18 @@ impl DbPool { pub(crate) fn remove_user(&self, username: &str, config: &Config) -> ServerResult { let user_id = self.find_user_id(username)?; let dbs = self.find_user_databases(user_id)?; - for db in dbs.iter() { - self.get_pool_mut()?.remove(&db.name); - } let mut ids = dbs - .into_iter() + .iter() .map(|db| db.db_id.unwrap()) .collect::>(); ids.push(user_id); self.db_mut()? .exec_mut(&QueryBuilder::remove().ids(ids).query())?; + + for db in dbs.into_iter() { + self.get_pool_mut()?.remove(&db.name); + } + let user_dir = Path::new(&config.data_dir).join(username); if user_dir.exists() { std::fs::remove_dir_all(user_dir)?; @@ -760,7 +762,14 @@ impl DbPool { self.add_db_user(owner, db, new_owner, DbUserRole::Admin, user)?; } - let server_db = self.get_pool_mut()?.remove(&db_name).unwrap(); + let server_db = ServerDb( + self.get_pool()? + .get(&db_name) + .ok_or(db_not_found(&db_name))? + .0 + .clone(), + ); + server_db .get_mut()? .rename(target_name.to_string_lossy().as_ref()) @@ -770,11 +779,8 @@ impl DbPool { &format!("db rename error: {}", e.description), ) })?; - self.get_pool_mut()?.insert(new_name.to_string(), server_db); - database.name = new_name.to_string(); let backup_path = db_backup_file(owner, db, config); - if backup_path.exists() { let new_backup_path = db_backup_file(new_owner, new_db, config); let backups_dir = new_backup_path.parent().unwrap(); @@ -782,9 +788,14 @@ impl DbPool { std::fs::rename(backup_path, new_backup_path)?; } + self.get_pool_mut()?.insert(new_name.to_string(), server_db); + + database.name = new_name.to_string(); self.db_mut()? .exec_mut(&QueryBuilder::insert().element(&database).query())?; + self.get_pool_mut()?.remove(&db_name).unwrap(); + Ok(()) } diff --git a/agdb_server/tests/routes/admin_user_remove_test.rs b/agdb_server/tests/routes/admin_user_remove_test.rs index e24d89319..b9d4c98bb 100644 --- a/agdb_server/tests/routes/admin_user_remove_test.rs +++ b/agdb_server/tests/routes/admin_user_remove_test.rs @@ -37,7 +37,6 @@ async fn remove_with_other() -> anyhow::Result<()> { .admin_db_user_add(owner, db, user, DbUserRole::Write) .await?; server.api.admin_user_remove(owner).await?; - assert!(server.api.admin_db_list().await?.1.is_empty()); assert!(!Path::new(&server.data_dir).join(owner).exists()); server.api.user_login(user, user).await?; assert!(server.api.db_list().await?.1.is_empty()); diff --git a/agdb_server/tests/routes/server_test.rs b/agdb_server/tests/routes/server_test.rs index 78044331d..e8d47beb7 100644 --- a/agdb_server/tests/routes/server_test.rs +++ b/agdb_server/tests/routes/server_test.rs @@ -1,5 +1,11 @@ use crate::TestServer; +use crate::TestServerImpl; +use crate::ADMIN; +use agdb_api::AgdbApi; +use agdb_api::ReqwestClient; +use assert_cmd::cargo::CommandCargoExt; use reqwest::StatusCode; +use std::process::Command; #[tokio::test] async fn error() -> anyhow::Result<()> { @@ -64,7 +70,13 @@ async fn openapi() -> anyhow::Result<()> { #[tokio::test] async fn db_config_reuse() -> anyhow::Result<()> { - let mut server = TestServer::new().await?; - server.restart().await?; + let mut server = TestServerImpl::new().await?; + let mut client = AgdbApi::new(ReqwestClient::new(), &TestServer::url_base(), server.port); + client.user_login(ADMIN, ADMIN).await?; + client.admin_shutdown().await?; + assert!(server.process.wait()?.success()); + server.process = Command::cargo_bin("agdb_server")? + .current_dir(&server.dir) + .spawn()?; Ok(()) } diff --git a/agdb_server/tests/test_server.rs b/agdb_server/tests/test_server.rs index 0825ad76b..0b2ed92de 100644 --- a/agdb_server/tests/test_server.rs +++ b/agdb_server/tests/test_server.rs @@ -10,7 +10,6 @@ use std::process::Child; use std::process::Command; use std::sync::atomic::AtomicU16; use std::sync::atomic::Ordering; -use std::sync::RwLock; use std::time::Duration; const BINARY: &str = "agdb_server"; @@ -28,8 +27,8 @@ const SHUTDOWN_RETRY_ATTEMPTS: u16 = 100; static PORT: AtomicU16 = AtomicU16::new(DEFAULT_PORT); static COUNTER: AtomicU16 = AtomicU16::new(1); static MUTEX: std::sync::OnceLock> = std::sync::OnceLock::new(); -static INSTANCES: AtomicU16 = AtomicU16::new(0); -static SERVER: RwLock> = RwLock::new(None); +static SERVER: std::sync::OnceLock>> = + std::sync::OnceLock::new(); pub struct TestServer { pub dir: String, @@ -43,6 +42,7 @@ struct TestServerImpl { pub data_dir: String, pub port: u16, pub process: Child, + pub instances: u16, } impl TestServerImpl { @@ -77,6 +77,7 @@ impl TestServerImpl { data_dir, port, process, + instances: 1, }); } } @@ -150,15 +151,16 @@ impl TestServer { .get_or_init(|| tokio::sync::Mutex::new(())) .lock() .await; - INSTANCES.fetch_add(1, Ordering::Relaxed); + let global_server = SERVER.get_or_init(|| tokio::sync::RwLock::new(None)); + let mut server_guard = global_server.try_write().unwrap(); - if SERVER.read().unwrap().is_none() { - println!("CREATING"); - *SERVER.write().unwrap() = Some(TestServerImpl::new().await?); + if server_guard.is_none() { + *server_guard = Some(TestServerImpl::new().await?); + } else { + server_guard.as_mut().unwrap().instances += 1; } - let read_lock = SERVER.read().unwrap(); - let server = read_lock.as_ref().unwrap(); + let server = server_guard.as_ref().unwrap(); Ok(Self { api: AgdbApi::new(ReqwestClient::new(), &Self::url_base(), server.port), @@ -169,21 +171,21 @@ impl TestServer { } pub fn next_user_name(&mut self) -> String { - format!("db_user{}", COUNTER.fetch_add(1, Ordering::Relaxed)) + format!("db_user{}", COUNTER.fetch_add(1, Ordering::SeqCst)) } pub fn next_db_name(&mut self) -> String { - format!("db{}", COUNTER.fetch_add(1, Ordering::Relaxed)) + format!("db{}", COUNTER.fetch_add(1, Ordering::SeqCst)) } - pub async fn restart(&mut self) -> anyhow::Result<()> { - let _guard = MUTEX - .get_or_init(|| tokio::sync::Mutex::new(())) - .lock() - .await; - *SERVER.write().unwrap() = Some(TestServerImpl::new().await?); - Ok(()) - } + // pub async fn restart(&mut self) -> anyhow::Result<()> { + // let _guard = MUTEX + // .get_or_init(|| tokio::sync::Mutex::new(())) + // .lock() + // .await; + // *SERVER.write().await = Some(TestServerImpl::new().await?); + // Ok(()) + // } pub fn url(&self, uri: &str) -> String { format!("{}:{}/api/v1{uri}", Self::url_base(), self.port) @@ -203,17 +205,20 @@ impl Drop for TestServerImpl { impl Drop for TestServer { fn drop(&mut self) { - let mutex = MUTEX.get_or_init(|| tokio::sync::Mutex::new(())); + let mutex = MUTEX.get().unwrap(); let _guard = loop { if let Ok(g) = mutex.try_lock() { break g; } }; - let instances = INSTANCES.fetch_sub(1, Ordering::Relaxed) - 1; - - if instances == 0 { - println!("DROPPING"); - *SERVER.write().unwrap() = None; + let global_server = SERVER.get().unwrap(); + let mut server_guard = global_server.try_write().unwrap(); + let server = server_guard.as_mut().unwrap(); + + if server.instances == 1 { + *server_guard = None; + } else { + server.instances -= 1; } } } From e8b95910c1362e1d9679089bd6ed75edccea0d86 Mon Sep 17 00:00:00 2001 From: Michael Vlach Date: Tue, 9 Jan 2024 21:57:38 +0100 Subject: [PATCH 3/3] fix restore --- agdb_server/src/db_pool.rs | 5 +++-- agdb_server/tests/routes/admin_user_remove_test.rs | 7 +++++++ .../tests/routes/{server_test.rs => misc_routes.rs} | 0 agdb_server/tests/routes/mod.rs | 2 +- agdb_server/tests/test_server.rs | 9 --------- 5 files changed, 11 insertions(+), 12 deletions(-) rename agdb_server/tests/routes/{server_test.rs => misc_routes.rs} (100%) diff --git a/agdb_server/src/db_pool.rs b/agdb_server/src/db_pool.rs index d14b7c03b..f3160b8b4 100644 --- a/agdb_server/src/db_pool.rs +++ b/agdb_server/src/db_pool.rs @@ -825,7 +825,8 @@ impl DbPool { let current_path = db_file(owner, db, config); let backup_temp = db_backup_dir(owner, config).join(db); - self.get_pool_mut()?.remove(&db_name); + let mut pool = self.get_pool_mut()?; + pool.remove(&db_name); std::fs::rename(¤t_path, &backup_temp)?; std::fs::rename(&backup_path, ¤t_path)?; std::fs::rename(backup_temp, backup_path)?; @@ -834,7 +835,7 @@ impl DbPool { database.db_type, current_path.to_string_lossy() ))?; - self.get_pool_mut()?.insert(db_name, server_db); + pool.insert(db_name, server_db); database.backup = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(); self.save_db(database)?; diff --git a/agdb_server/tests/routes/admin_user_remove_test.rs b/agdb_server/tests/routes/admin_user_remove_test.rs index b9d4c98bb..1582527b4 100644 --- a/agdb_server/tests/routes/admin_user_remove_test.rs +++ b/agdb_server/tests/routes/admin_user_remove_test.rs @@ -37,6 +37,13 @@ async fn remove_with_other() -> anyhow::Result<()> { .admin_db_user_add(owner, db, user, DbUserRole::Write) .await?; server.api.admin_user_remove(owner).await?; + assert!(!server + .api + .admin_user_list() + .await? + .1 + .iter() + .any(|u| u.name == *owner)); assert!(!Path::new(&server.data_dir).join(owner).exists()); server.api.user_login(user, user).await?; assert!(server.api.db_list().await?.1.is_empty()); diff --git a/agdb_server/tests/routes/server_test.rs b/agdb_server/tests/routes/misc_routes.rs similarity index 100% rename from agdb_server/tests/routes/server_test.rs rename to agdb_server/tests/routes/misc_routes.rs diff --git a/agdb_server/tests/routes/mod.rs b/agdb_server/tests/routes/mod.rs index 76ef4d831..a78be18c6 100644 --- a/agdb_server/tests/routes/mod.rs +++ b/agdb_server/tests/routes/mod.rs @@ -26,7 +26,7 @@ mod db_rename_test; mod db_user_add_test; mod db_user_list; mod db_user_remove_test; -mod server_test; +mod misc_routes; mod user_change_password_test; mod user_login_test; mod user_logout_test; diff --git a/agdb_server/tests/test_server.rs b/agdb_server/tests/test_server.rs index 0b2ed92de..e907dabea 100644 --- a/agdb_server/tests/test_server.rs +++ b/agdb_server/tests/test_server.rs @@ -178,15 +178,6 @@ impl TestServer { format!("db{}", COUNTER.fetch_add(1, Ordering::SeqCst)) } - // pub async fn restart(&mut self) -> anyhow::Result<()> { - // let _guard = MUTEX - // .get_or_init(|| tokio::sync::Mutex::new(())) - // .lock() - // .await; - // *SERVER.write().await = Some(TestServerImpl::new().await?); - // Ok(()) - // } - pub fn url(&self, uri: &str) -> String { format!("{}:{}/api/v1{uri}", Self::url_base(), self.port) }