Skip to content

Commit

Permalink
feat: Client add close function. (#534)
Browse files Browse the repository at this point in the history
* refactor: init log before creating session.

* refactor: move log for "start query" to a common fn.

* refactor: add log for login.

* feat: add close interface.
  • Loading branch information
youngsofun authored Dec 16, 2024
1 parent ba058d5 commit 2f7835b
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 28 deletions.
16 changes: 8 additions & 8 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,14 @@ pub async fn main() -> Result<()> {
}
settings.time = args.time;

let log_dir = format!(
"{}/.bendsql",
std::env::var("HOME").unwrap_or_else(|_| ".".to_string())
);

let _guards = trace::init_logging(&log_dir, &args.log_level).await?;
info!("-> bendsql version: {}", VERSION.as_str());

let mut session = match session::Session::try_new(dsn, settings, is_repl).await {
Ok(session) => session,
Err(err) => {
Expand All @@ -390,14 +398,6 @@ pub async fn main() -> Result<()> {
}
};

let log_dir = format!(
"{}/.bendsql",
std::env::var("HOME").unwrap_or_else(|_| ".".to_string())
);

let _guards = trace::init_logging(&log_dir, &args.log_level).await?;
info!("-> bendsql version: {}", VERSION.as_str());

if args.check {
session.check().await?;
return Ok(());
Expand Down
13 changes: 8 additions & 5 deletions cli/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::BTreeMap;
use std::io::BufRead;
use std::path::Path;
use std::sync::Arc;

use anyhow::anyhow;
use anyhow::Result;
use async_recursion::async_recursion;
Expand All @@ -31,6 +26,10 @@ use rustyline::config::Builder;
use rustyline::error::ReadlineError;
use rustyline::history::DefaultHistory;
use rustyline::{CompletionType, Editor};
use std::collections::BTreeMap;
use std::io::BufRead;
use std::path::Path;
use std::sync::Arc;
use tokio::fs::{remove_file, File};
use tokio::io::AsyncWriteExt;
use tokio::task::JoinHandle;
Expand Down Expand Up @@ -352,6 +351,9 @@ impl Session {
},
}
}
if let Err(e) = self.conn.close().await {
println!("got error when closing session: {}", e);
}
println!("Bye~");
let _ = rl.save_history(&get_history_path());
}
Expand Down Expand Up @@ -394,6 +396,7 @@ impl Session {
println!("{:.3}", server_time_ms / 1000.0);
}
}
self.conn.close().await.ok();
Ok(())
}

Expand Down
45 changes: 30 additions & 15 deletions core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

use std::collections::BTreeMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};

Expand All @@ -32,7 +32,7 @@ use crate::{
response::QueryResponse,
session::SessionState,
};
use log::{error, info, warn};
use log::{debug, error, info, warn};
use once_cell::sync::Lazy;
use percent_encoding::percent_decode_str;
use reqwest::cookie::CookieStore;
Expand Down Expand Up @@ -78,6 +78,8 @@ pub struct APIClient {
disable_session_token: bool,
session_token_info: Option<Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>>,

closed: Arc<AtomicBool>,

server_version: Option<String>,

wait_time_secs: Option<i64>,
Expand Down Expand Up @@ -354,6 +356,7 @@ impl APIClient {
}

pub async fn start_query(&self, sql: &str) -> Result<QueryResponse> {
info!("start query: {}", sql);
self.start_query_inner(sql, None).await
}

Expand Down Expand Up @@ -483,7 +486,6 @@ impl APIClient {
}

pub async fn query(&self, sql: &str) -> Result<QueryResponse> {
info!("query: {}", sql);
let resp = self.start_query(sql).await?;
self.wait_for_query(resp).await
}
Expand Down Expand Up @@ -652,7 +654,7 @@ impl APIClient {
Err(Error::Logic(status, ..)) | Err(Error::Response { status, .. })
if status == 404 =>
{
// old server
info!("login return 404, skip login on the old version server");
return Ok(());
}
Err(e) => return Err(e),
Expand All @@ -664,15 +666,17 @@ impl APIClient {
LoginResponseResult::Ok(info) => {
self.server_version = Some(info.version.clone());
if let Some(tokens) = info.tokens {
info!("login success with session token");
self.session_token_info =
Some(Arc::new(parking_lot::Mutex::new((tokens, Instant::now()))))
}
info!("login success without session token");
}
}
Ok(())
}

fn build_log_out_request(&mut self) -> Result<Request> {
fn build_log_out_request(&self) -> Result<Request> {
let endpoint = self.endpoint.join("/v1/session/logout")?;

let session_state = self.session_state();
Expand All @@ -691,8 +695,12 @@ impl APIClient {
}

fn need_logout(&self) -> bool {
self.session_token_info.is_some()
|| self.session_state.lock().need_keep_alive.unwrap_or(false)
(self.session_token_info.is_some()
|| self.session_state.lock().need_keep_alive.unwrap_or(false))
&& !self
.closed
.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
.unwrap()
}

async fn refresh_session_token(
Expand Down Expand Up @@ -882,20 +890,26 @@ impl APIClient {
sleep(jitter(Duration::from_secs(10))).await;
}
}
}

impl Drop for APIClient {
fn drop(&mut self) {
pub async fn close(&self) {
if self.need_logout() {
let cli = self.cli.clone();
let req = self
.build_log_out_request()
.expect("failed to build logout request");
tokio::spawn(async move {
if let Err(err) = cli.execute(req).await {
error!("logout request failed: {}", err);
};
});
if let Err(err) = cli.execute(req).await {
error!("logout request failed: {}", err);
} else {
debug!("logout success");
};
}
}
}

impl Drop for APIClient {
fn drop(&mut self) {
if self.need_logout() {
warn!("APIClient::close() was not called");
}
}
}
Expand Down Expand Up @@ -937,6 +951,7 @@ impl Default for APIClient {
disable_session_token: true,
disable_login: false,
session_token_info: None,
closed: Arc::new(Default::default()),
server_version: None,
}
}
Expand Down
3 changes: 3 additions & 0 deletions driver/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ pub type Reader = Box<dyn AsyncRead + Send + Sync + Unpin + 'static>;
#[async_trait]
pub trait Connection: Send + Sync {
async fn info(&self) -> ConnectionInfo;
async fn close(&self) -> Result<()> {
Ok(())
}

async fn version(&self) -> Result<String> {
let row = self.query_row("SELECT version()").await?;
Expand Down
5 changes: 5 additions & 0 deletions driver/src/rest_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ impl Connection for RestAPIConnection {
}
}

async fn close(&self) -> Result<()> {
self.client.close().await;
Ok(())
}

async fn exec(&self, sql: &str) -> Result<i64> {
info!("exec: {}", sql);
let mut resp = self.client.start_query(sql).await?;
Expand Down

0 comments on commit 2f7835b

Please sign in to comment.