Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[server] Implement shared server in tests #957 #978

Merged
merged 3 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 47 additions & 18 deletions agdb_server/src/db_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -709,16 +709,18 @@ impl DbPool {
pub(crate) fn remove_user(&self, username: &str, config: &Config) -> ServerResult {
let user_id = self.find_user_id(username)?;
let dbs = self.find_user_databases(user_id)?;
for db in dbs.iter() {
self.get_pool_mut()?.remove(&db.name);
}
let mut ids = dbs
.into_iter()
.iter()
.map(|db| db.db_id.unwrap())
.collect::<Vec<DbId>>();
ids.push(user_id);
self.db_mut()?
.exec_mut(&QueryBuilder::remove().ids(ids).query())?;

for db in dbs.into_iter() {
self.get_pool_mut()?.remove(&db.name);
}

let user_dir = Path::new(&config.data_dir).join(username);
if user_dir.exists() {
std::fs::remove_dir_all(user_dir)?;
Expand Down Expand Up @@ -760,7 +762,14 @@ impl DbPool {
self.add_db_user(owner, db, new_owner, DbUserRole::Admin, user)?;
}

let server_db = self.get_pool_mut()?.remove(&db_name).unwrap();
let server_db = ServerDb(
self.get_pool()?
.get(&db_name)
.ok_or(db_not_found(&db_name))?
.0
.clone(),
);

server_db
.get_mut()?
.rename(target_name.to_string_lossy().as_ref())
Expand All @@ -770,21 +779,23 @@ impl DbPool {
&format!("db rename error: {}", e.description),
)
})?;
self.get_pool_mut()?.insert(new_name.to_string(), server_db);
database.name = new_name.to_string();

let backup_path = db_backup_file(owner, db, config);

if backup_path.exists() {
let new_backup_path = db_backup_file(new_owner, new_db, config);
let backups_dir = new_backup_path.parent().unwrap();
std::fs::create_dir_all(backups_dir)?;
std::fs::rename(backup_path, new_backup_path)?;
}

self.get_pool_mut()?.insert(new_name.to_string(), server_db);

database.name = new_name.to_string();
self.db_mut()?
.exec_mut(&QueryBuilder::insert().element(&database).query())?;

self.get_pool_mut()?.remove(&db_name).unwrap();

Ok(())
}

Expand Down Expand Up @@ -814,7 +825,8 @@ impl DbPool {
let current_path = db_file(owner, db, config);
let backup_temp = db_backup_dir(owner, config).join(db);

self.get_pool_mut()?.remove(&db_name);
let mut pool = self.get_pool_mut()?;
pool.remove(&db_name);
std::fs::rename(&current_path, &backup_temp)?;
std::fs::rename(&backup_path, &current_path)?;
std::fs::rename(backup_temp, backup_path)?;
Expand All @@ -823,21 +835,38 @@ impl DbPool {
database.db_type,
current_path.to_string_lossy()
))?;
self.get_pool_mut()?.insert(db_name, server_db);
pool.insert(db_name, server_db);
database.backup = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
self.save_db(database)?;

Ok(())
}

pub(crate) fn save_token(&self, user: DbId, token: &str) -> ServerResult {
self.db_mut()?.exec_mut(
&QueryBuilder::insert()
.values_uniform(vec![("token", token).into()])
.ids(user)
.query(),
)?;
Ok(())
pub(crate) fn save_token(&self, user: DbId, token: String) -> ServerResult<String> {
self.db_mut()?.transaction_mut(|t| {
let existing = t
.exec(
&QueryBuilder::select()
.values(vec!["token".into()])
.ids(user)
.query(),
)?
.elements[0]
.values[0]
.value
.to_string();
if existing.is_empty() {
t.exec_mut(
&QueryBuilder::insert()
.values_uniform(vec![("token", &token).into()])
.ids(user)
.query(),
)?;
Ok(token.clone())
} else {
Ok(existing)
}
})
}

pub(crate) fn save_user(&self, user: ServerUser) -> ServerResult {
Expand Down
3 changes: 1 addition & 2 deletions agdb_server/src/routes/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ pub(crate) async fn login(
let token = if user.token.is_empty() {
let token_uuid = Uuid::new_v4();
let token = token_uuid.to_string();
db_pool.save_token(user.db_id.unwrap(), &token)?;
token
db_pool.save_token(user.db_id.unwrap(), token)?
} else {
user.token
};
Expand Down
8 changes: 7 additions & 1 deletion agdb_server/tests/routes/admin_user_remove_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@ async fn remove_with_other() -> anyhow::Result<()> {
.admin_db_user_add(owner, db, user, DbUserRole::Write)
.await?;
server.api.admin_user_remove(owner).await?;
assert!(server.api.admin_db_list().await?.1.is_empty());
assert!(!server
.api
.admin_user_list()
.await?
.1
.iter()
.any(|u| u.name == *owner));
assert!(!Path::new(&server.data_dir).join(owner).exists());
server.api.user_login(user, user).await?;
assert!(server.api.db_list().await?.1.is_empty());
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use crate::TestServer;
use crate::TestServerImpl;
use crate::ADMIN;
use agdb_api::AgdbApi;
use agdb_api::ReqwestClient;
use assert_cmd::cargo::CommandCargoExt;
use reqwest::StatusCode;
use std::process::Command;
Expand Down Expand Up @@ -67,10 +70,10 @@ async fn openapi() -> anyhow::Result<()> {

#[tokio::test]
async fn db_config_reuse() -> anyhow::Result<()> {
let mut server = TestServer::new().await?;
server.api.user_login(ADMIN, ADMIN).await?;
server.api.admin_shutdown().await?;

let mut server = TestServerImpl::new().await?;
let mut client = AgdbApi::new(ReqwestClient::new(), &TestServer::url_base(), server.port);
client.user_login(ADMIN, ADMIN).await?;
client.admin_shutdown().await?;
assert!(server.process.wait()?.success());
server.process = Command::cargo_bin("agdb_server")?
.current_dir(&server.dir)
Expand Down
2 changes: 1 addition & 1 deletion agdb_server/tests/routes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ mod db_rename_test;
mod db_user_add_test;
mod db_user_list;
mod db_user_remove_test;
mod server_test;
mod misc_routes;
mod user_change_password_test;
mod user_login_test;
mod user_logout_test;
112 changes: 86 additions & 26 deletions agdb_server/tests/test_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,26 @@ const SHUTDOWN_RETRY_ATTEMPTS: u16 = 100;

static PORT: AtomicU16 = AtomicU16::new(DEFAULT_PORT);
static COUNTER: AtomicU16 = AtomicU16::new(1);
static MUTEX: std::sync::OnceLock<tokio::sync::Mutex<()>> = std::sync::OnceLock::new();
static SERVER: std::sync::OnceLock<tokio::sync::RwLock<Option<TestServerImpl>>> =
std::sync::OnceLock::new();

pub struct TestServer {
pub dir: String,
pub data_dir: String,
pub port: u16,
pub api: AgdbApi<ReqwestClient>,
}

struct TestServerImpl {
pub dir: String,
pub data_dir: String,
pub port: u16,
pub process: Child,
pub instances: u16,
}

impl TestServer {
impl TestServerImpl {
pub async fn new() -> anyhow::Result<Self> {
let port = PORT.fetch_add(1, Ordering::Relaxed) + std::process::id() as u16;
let dir = format!("{BINARY}.{port}.test");
Expand All @@ -57,17 +67,17 @@ impl TestServer {
serde_yaml::to_writer(file, &config)?;

let process = Command::cargo_bin(BINARY)?.current_dir(&dir).spawn()?;
let api = AgdbApi::new(ReqwestClient::new(), &Self::url_base(), port);
let api = AgdbApi::new(ReqwestClient::new(), &TestServer::url_base(), port);

for _ in 0..RETRY_ATTEMPS {
if let Ok(status) = api.status().await {
if status == 200 {
return Ok(Self {
dir,
data_dir,
api,
port,
process,
instances: 1,
});
}
}
Expand All @@ -78,26 +88,6 @@ impl TestServer {
anyhow::bail!("Failed to start server")
}

pub fn next_user_name(&mut self) -> String {
format!("db_user{}", COUNTER.fetch_add(1, Ordering::Relaxed))
}

pub fn next_db_name(&mut self) -> String {
format!("db{}", COUNTER.fetch_add(1, Ordering::Relaxed))
}

pub fn url(&self, uri: &str) -> String {
format!("{}:{}/api/v1{uri}", Self::url_base(), self.port)
}

fn remove_dir_if_exists(dir: &str) -> anyhow::Result<()> {
if Path::new(dir).exists() {
std::fs::remove_dir_all(dir)?;
}

Ok(())
}

fn shutdown_server(&mut self) -> anyhow::Result<()> {
if self.process.try_wait()?.is_some() {
return Ok(());
Expand All @@ -111,15 +101,19 @@ impl TestServer {
std::thread::spawn(move || -> anyhow::Result<()> {
let client = reqwest::blocking::Client::new();
let token: String = client
.post(format!("{}:{}/api/v1/user/login", Self::url_base(), port))
.post(format!(
"{}:{}/api/v1/user/login",
TestServer::url_base(),
port
))
.json(&admin)
.send()?
.json()?;

client
.post(format!(
"{}:{}/api/v1/admin/shutdown",
Self::url_base(),
TestServer::url_base(),
port
))
.bearer_auth(token)
Expand All @@ -142,14 +136,80 @@ impl TestServer {
Ok(())
}

fn remove_dir_if_exists(dir: &str) -> anyhow::Result<()> {
if Path::new(dir).exists() {
std::fs::remove_dir_all(dir)?;
}

Ok(())
}
}

impl TestServer {
pub async fn new() -> anyhow::Result<Self> {
let _guard = MUTEX
.get_or_init(|| tokio::sync::Mutex::new(()))
.lock()
.await;
let global_server = SERVER.get_or_init(|| tokio::sync::RwLock::new(None));
let mut server_guard = global_server.try_write().unwrap();

if server_guard.is_none() {
*server_guard = Some(TestServerImpl::new().await?);
} else {
server_guard.as_mut().unwrap().instances += 1;
}

let server = server_guard.as_ref().unwrap();

Ok(Self {
api: AgdbApi::new(ReqwestClient::new(), &Self::url_base(), server.port),
dir: server.dir.clone(),
port: server.port,
data_dir: server.data_dir.clone(),
})
}

pub fn next_user_name(&mut self) -> String {
format!("db_user{}", COUNTER.fetch_add(1, Ordering::SeqCst))
}

pub fn next_db_name(&mut self) -> String {
format!("db{}", COUNTER.fetch_add(1, Ordering::SeqCst))
}

pub fn url(&self, uri: &str) -> String {
format!("{}:{}/api/v1{uri}", Self::url_base(), self.port)
}

fn url_base() -> String {
format!("{PROTOCOL}://{HOST}")
}
}

impl Drop for TestServer {
impl Drop for TestServerImpl {
fn drop(&mut self) {
Self::shutdown_server(self).unwrap();
Self::remove_dir_if_exists(&self.dir).unwrap();
}
}

impl Drop for TestServer {
fn drop(&mut self) {
let mutex = MUTEX.get().unwrap();
let _guard = loop {
if let Ok(g) = mutex.try_lock() {
break g;
}
};
let global_server = SERVER.get().unwrap();
let mut server_guard = global_server.try_write().unwrap();
let server = server_guard.as_mut().unwrap();

if server.instances == 1 {
*server_guard = None;
} else {
server.instances -= 1;
}
}
}