Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Client add close function. #534

Merged
merged 4 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,14 @@ pub async fn main() -> Result<()> {
}
settings.time = args.time;

let log_dir = format!(
"{}/.bendsql",
std::env::var("HOME").unwrap_or_else(|_| ".".to_string())
);

let _guards = trace::init_logging(&log_dir, &args.log_level).await?;
info!("-> bendsql version: {}", VERSION.as_str());

let mut session = match session::Session::try_new(dsn, settings, is_repl).await {
Ok(session) => session,
Err(err) => {
Expand All @@ -390,14 +398,6 @@ pub async fn main() -> Result<()> {
}
};

let log_dir = format!(
"{}/.bendsql",
std::env::var("HOME").unwrap_or_else(|_| ".".to_string())
);

let _guards = trace::init_logging(&log_dir, &args.log_level).await?;
info!("-> bendsql version: {}", VERSION.as_str());

if args.check {
session.check().await?;
return Ok(());
Expand Down
13 changes: 8 additions & 5 deletions cli/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::BTreeMap;
use std::io::BufRead;
use std::path::Path;
use std::sync::Arc;

use anyhow::anyhow;
use anyhow::Result;
use async_recursion::async_recursion;
Expand All @@ -31,6 +26,10 @@ use rustyline::config::Builder;
use rustyline::error::ReadlineError;
use rustyline::history::DefaultHistory;
use rustyline::{CompletionType, Editor};
use std::collections::BTreeMap;
use std::io::BufRead;
use std::path::Path;
use std::sync::Arc;
use tokio::fs::{remove_file, File};
use tokio::io::AsyncWriteExt;
use tokio::task::JoinHandle;
Expand Down Expand Up @@ -352,6 +351,9 @@ impl Session {
},
}
}
if let Err(e) = self.conn.close().await {
println!("got error when closing session: {}", e);
}
println!("Bye~");
let _ = rl.save_history(&get_history_path());
}
Expand Down Expand Up @@ -394,6 +396,7 @@ impl Session {
println!("{:.3}", server_time_ms / 1000.0);
}
}
self.conn.close().await.ok();
Ok(())
}

Expand Down
45 changes: 30 additions & 15 deletions core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

use std::collections::BTreeMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};

Expand All @@ -32,7 +32,7 @@ use crate::{
response::QueryResponse,
session::SessionState,
};
use log::{error, info, warn};
use log::{debug, error, info, warn};
use once_cell::sync::Lazy;
use percent_encoding::percent_decode_str;
use reqwest::cookie::CookieStore;
Expand Down Expand Up @@ -78,6 +78,8 @@ pub struct APIClient {
disable_session_token: bool,
session_token_info: Option<Arc<parking_lot::Mutex<(SessionTokenInfo, Instant)>>>,

closed: Arc<AtomicBool>,

server_version: Option<String>,

wait_time_secs: Option<i64>,
Expand Down Expand Up @@ -354,6 +356,7 @@ impl APIClient {
}

pub async fn start_query(&self, sql: &str) -> Result<QueryResponse> {
info!("start query: {}", sql);
self.start_query_inner(sql, None).await
}

Expand Down Expand Up @@ -483,7 +486,6 @@ impl APIClient {
}

pub async fn query(&self, sql: &str) -> Result<QueryResponse> {
info!("query: {}", sql);
let resp = self.start_query(sql).await?;
self.wait_for_query(resp).await
}
Expand Down Expand Up @@ -652,7 +654,7 @@ impl APIClient {
Err(Error::Logic(status, ..)) | Err(Error::Response { status, .. })
if status == 404 =>
{
// old server
info!("login return 404, skip login on the old version server");
return Ok(());
}
Err(e) => return Err(e),
Expand All @@ -664,15 +666,17 @@ impl APIClient {
LoginResponseResult::Ok(info) => {
self.server_version = Some(info.version.clone());
if let Some(tokens) = info.tokens {
info!("login success with session token");
self.session_token_info =
Some(Arc::new(parking_lot::Mutex::new((tokens, Instant::now()))))
}
info!("login success without session token");
}
}
Ok(())
}

fn build_log_out_request(&mut self) -> Result<Request> {
fn build_log_out_request(&self) -> Result<Request> {
let endpoint = self.endpoint.join("/v1/session/logout")?;

let session_state = self.session_state();
Expand All @@ -691,8 +695,12 @@ impl APIClient {
}

fn need_logout(&self) -> bool {
self.session_token_info.is_some()
|| self.session_state.lock().need_keep_alive.unwrap_or(false)
(self.session_token_info.is_some()
|| self.session_state.lock().need_keep_alive.unwrap_or(false))
&& !self
.closed
.compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed)
.unwrap()
}

async fn refresh_session_token(
Expand Down Expand Up @@ -882,20 +890,26 @@ impl APIClient {
sleep(jitter(Duration::from_secs(10))).await;
}
}
}

impl Drop for APIClient {
fn drop(&mut self) {
pub async fn close(&self) {
if self.need_logout() {
let cli = self.cli.clone();
let req = self
.build_log_out_request()
.expect("failed to build logout request");
tokio::spawn(async move {
if let Err(err) = cli.execute(req).await {
error!("logout request failed: {}", err);
};
});
if let Err(err) = cli.execute(req).await {
error!("logout request failed: {}", err);
} else {
debug!("logout success");
};
}
}
}

impl Drop for APIClient {
fn drop(&mut self) {
if self.need_logout() {
warn!("APIClient::close() was not called");
}
}
}
Expand Down Expand Up @@ -937,6 +951,7 @@ impl Default for APIClient {
disable_session_token: true,
disable_login: false,
session_token_info: None,
closed: Arc::new(Default::default()),
server_version: None,
}
}
Expand Down
3 changes: 3 additions & 0 deletions driver/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ pub type Reader = Box<dyn AsyncRead + Send + Sync + Unpin + 'static>;
#[async_trait]
pub trait Connection: Send + Sync {
async fn info(&self) -> ConnectionInfo;
async fn close(&self) -> Result<()> {
Ok(())
}

async fn version(&self) -> Result<String> {
let row = self.query_row("SELECT version()").await?;
Expand Down
5 changes: 5 additions & 0 deletions driver/src/rest_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ impl Connection for RestAPIConnection {
}
}

async fn close(&self) -> Result<()> {
self.client.close().await;
Ok(())
}

async fn exec(&self, sql: &str) -> Result<i64> {
info!("exec: {}", sql);
let mut resp = self.client.start_query(sql).await?;
Expand Down
Loading