Skip to content

Commit

Permalink
fix multiple simultaneous logins
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelvlach committed Jan 8, 2024
1 parent 3f71ebf commit 84752fa
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 46 deletions.
33 changes: 25 additions & 8 deletions agdb_server/src/db_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -830,14 +830,31 @@ impl DbPool {
Ok(())
}

pub(crate) fn save_token(&self, user: DbId, token: &str) -> ServerResult {
self.db_mut()?.exec_mut(
&QueryBuilder::insert()
.values_uniform(vec![("token", token).into()])
.ids(user)
.query(),
)?;
Ok(())
pub(crate) fn save_token(&self, user: DbId, token: String) -> ServerResult<String> {
self.db_mut()?.transaction_mut(|t| {
let existing = t
.exec(
&QueryBuilder::select()
.values(vec!["token".into()])
.ids(user)
.query(),
)?
.elements[0]
.values[0]
.value
.to_string();
if existing.is_empty() {
t.exec_mut(
&QueryBuilder::insert()
.values_uniform(vec![("token", &token).into()])
.ids(user)
.query(),
)?;
Ok(token.clone())
} else {
Ok(existing)
}
})
}

pub(crate) fn save_user(&self, user: ServerUser) -> ServerResult {
Expand Down
3 changes: 1 addition & 2 deletions agdb_server/src/routes/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ pub(crate) async fn login(
let token = if user.token.is_empty() {
let token_uuid = Uuid::new_v4();
let token = token_uuid.to_string();
db_pool.save_token(user.db_id.unwrap(), &token)?;
token
db_pool.save_token(user.db_id.unwrap(), token)?
} else {
user.token
};
Expand Down
11 changes: 1 addition & 10 deletions agdb_server/tests/routes/server_test.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use crate::TestServer;
use crate::ADMIN;
use assert_cmd::cargo::CommandCargoExt;
use reqwest::StatusCode;
use std::process::Command;

#[tokio::test]
async fn error() -> anyhow::Result<()> {
Expand Down Expand Up @@ -68,12 +65,6 @@ async fn openapi() -> anyhow::Result<()> {
#[tokio::test]
async fn db_config_reuse() -> anyhow::Result<()> {
let mut server = TestServer::new().await?;
server.api.user_login(ADMIN, ADMIN).await?;
server.api.admin_shutdown().await?;

assert!(server.process.wait()?.success());
server.process = Command::cargo_bin("agdb_server")?
.current_dir(&server.dir)
.spawn()?;
server.restart().await?;
Ok(())
}
116 changes: 90 additions & 26 deletions agdb_server/tests/test_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use std::process::Child;
use std::process::Command;
use std::sync::atomic::AtomicU16;
use std::sync::atomic::Ordering;
use std::sync::RwLock;
use std::time::Duration;

const BINARY: &str = "agdb_server";
Expand All @@ -26,16 +27,25 @@ const SHUTDOWN_RETRY_ATTEMPTS: u16 = 100;

static PORT: AtomicU16 = AtomicU16::new(DEFAULT_PORT);
static COUNTER: AtomicU16 = AtomicU16::new(1);
static MUTEX: std::sync::OnceLock<tokio::sync::Mutex<()>> = std::sync::OnceLock::new();
static INSTANCES: AtomicU16 = AtomicU16::new(0);
static SERVER: RwLock<Option<TestServerImpl>> = RwLock::new(None);

pub struct TestServer {
pub dir: String,
pub data_dir: String,
pub port: u16,
pub api: AgdbApi<ReqwestClient>,
}

struct TestServerImpl {
pub dir: String,
pub data_dir: String,
pub port: u16,
pub process: Child,
}

impl TestServer {
impl TestServerImpl {
pub async fn new() -> anyhow::Result<Self> {
let port = PORT.fetch_add(1, Ordering::Relaxed) + std::process::id() as u16;
let dir = format!("{BINARY}.{port}.test");
Expand All @@ -57,15 +67,14 @@ impl TestServer {
serde_yaml::to_writer(file, &config)?;

let process = Command::cargo_bin(BINARY)?.current_dir(&dir).spawn()?;
let api = AgdbApi::new(ReqwestClient::new(), &Self::url_base(), port);
let api = AgdbApi::new(ReqwestClient::new(), &TestServer::url_base(), port);

for _ in 0..RETRY_ATTEMPS {
if let Ok(status) = api.status().await {
if status == 200 {
return Ok(Self {
dir,
data_dir,
api,
port,
process,
});
Expand All @@ -78,26 +87,6 @@ impl TestServer {
anyhow::bail!("Failed to start server")
}

pub fn next_user_name(&mut self) -> String {
format!("db_user{}", COUNTER.fetch_add(1, Ordering::Relaxed))
}

pub fn next_db_name(&mut self) -> String {
format!("db{}", COUNTER.fetch_add(1, Ordering::Relaxed))
}

pub fn url(&self, uri: &str) -> String {
format!("{}:{}/api/v1{uri}", Self::url_base(), self.port)
}

fn remove_dir_if_exists(dir: &str) -> anyhow::Result<()> {
if Path::new(dir).exists() {
std::fs::remove_dir_all(dir)?;
}

Ok(())
}

fn shutdown_server(&mut self) -> anyhow::Result<()> {
if self.process.try_wait()?.is_some() {
return Ok(());
Expand All @@ -111,15 +100,19 @@ impl TestServer {
std::thread::spawn(move || -> anyhow::Result<()> {
let client = reqwest::blocking::Client::new();
let token: String = client
.post(format!("{}:{}/api/v1/user/login", Self::url_base(), port))
.post(format!(
"{}:{}/api/v1/user/login",
TestServer::url_base(),
port
))
.json(&admin)
.send()?
.json()?;

client
.post(format!(
"{}:{}/api/v1/admin/shutdown",
Self::url_base(),
TestServer::url_base(),
port
))
.bearer_auth(token)
Expand All @@ -142,14 +135,85 @@ impl TestServer {
Ok(())
}

fn remove_dir_if_exists(dir: &str) -> anyhow::Result<()> {
if Path::new(dir).exists() {
std::fs::remove_dir_all(dir)?;
}

Ok(())
}
}

impl TestServer {
pub async fn new() -> anyhow::Result<Self> {
let _guard = MUTEX
.get_or_init(|| tokio::sync::Mutex::new(()))
.lock()
.await;
INSTANCES.fetch_add(1, Ordering::Relaxed);

if SERVER.read().unwrap().is_none() {
println!("CREATING");
*SERVER.write().unwrap() = Some(TestServerImpl::new().await?);
}

let read_lock = SERVER.read().unwrap();
let server = read_lock.as_ref().unwrap();

Ok(Self {
api: AgdbApi::new(ReqwestClient::new(), &Self::url_base(), server.port),
dir: server.dir.clone(),
port: server.port,
data_dir: server.data_dir.clone(),
})
}

pub fn next_user_name(&mut self) -> String {
format!("db_user{}", COUNTER.fetch_add(1, Ordering::Relaxed))
}

pub fn next_db_name(&mut self) -> String {
format!("db{}", COUNTER.fetch_add(1, Ordering::Relaxed))
}

pub async fn restart(&mut self) -> anyhow::Result<()> {
let _guard = MUTEX
.get_or_init(|| tokio::sync::Mutex::new(()))
.lock()
.await;
*SERVER.write().unwrap() = Some(TestServerImpl::new().await?);
Ok(())
}

pub fn url(&self, uri: &str) -> String {
format!("{}:{}/api/v1{uri}", Self::url_base(), self.port)
}

fn url_base() -> String {
format!("{PROTOCOL}://{HOST}")
}
}

impl Drop for TestServer {
impl Drop for TestServerImpl {
fn drop(&mut self) {
Self::shutdown_server(self).unwrap();
Self::remove_dir_if_exists(&self.dir).unwrap();
}
}

impl Drop for TestServer {
fn drop(&mut self) {
let mutex = MUTEX.get_or_init(|| tokio::sync::Mutex::new(()));
let _guard = loop {
if let Ok(g) = mutex.try_lock() {
break g;
}
};
let instances = INSTANCES.fetch_sub(1, Ordering::Relaxed) - 1;

if instances == 0 {
println!("DROPPING");
*SERVER.write().unwrap() = None;
}
}
}

0 comments on commit 84752fa

Please sign in to comment.