From 9551055efd0171b6e31f2c0d40561d166b5569e6 Mon Sep 17 00:00:00 2001 From: everpcpc Date: Thu, 1 Feb 2024 15:22:10 +0800 Subject: [PATCH] feat: support setting name for each client --- bindings/nodejs/Cargo.toml | 1 + bindings/nodejs/src/lib.rs | 9 +++++- bindings/python/Cargo.toml | 1 + bindings/python/src/asyncio.rs | 5 +-- bindings/python/src/blocking.rs | 5 +-- bindings/python/src/types.rs | 6 ++++ cli/src/session.rs | 30 +++++++++++++++--- core/src/client.rs | 54 +++++++++++++++++++++------------ core/tests/core/simple.rs | 2 +- core/tests/core/stage.rs | 4 +-- driver/Cargo.toml | 1 + driver/src/conn.rs | 19 ++++++++++-- driver/src/flight_sql.rs | 7 +++-- driver/src/rest_api.rs | 4 +-- 14 files changed, 108 insertions(+), 40 deletions(-) diff --git a/bindings/nodejs/Cargo.toml b/bindings/nodejs/Cargo.toml index 5300c0b59..0ca51d488 100644 --- a/bindings/nodejs/Cargo.toml +++ b/bindings/nodejs/Cargo.toml @@ -23,6 +23,7 @@ napi = { version = "2.14", default-features = false, features = [ "chrono_date", ] } napi-derive = "2.14" +once_cell = "1.18" tokio-stream = "0.1" [build-dependencies] diff --git a/bindings/nodejs/src/lib.rs b/bindings/nodejs/src/lib.rs index abec23fc1..c266e3b85 100644 --- a/bindings/nodejs/src/lib.rs +++ b/bindings/nodejs/src/lib.rs @@ -17,8 +17,14 @@ extern crate napi_derive; use chrono::{NaiveDate, NaiveDateTime, NaiveTime}; use napi::bindgen_prelude::*; +use once_cell::sync::Lazy; use tokio_stream::StreamExt; +static VERSION: Lazy = Lazy::new(|| { + let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown"); + version.to_string() +}); + #[napi] pub struct Client(databend_driver::Client); @@ -279,7 +285,8 @@ impl Client { /// Create a new databend client with a given DSN. #[napi(constructor)] pub fn new(dsn: String) -> Self { - let client = databend_driver::Client::new(dsn); + let name = format!("databend-driver-nodejs/{}", VERSION.as_str()); + let client = databend_driver::Client::new(dsn).with_name(name); Self(client) } diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index dccc43010..2c2bfe8ae 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -14,6 +14,7 @@ doc = false [dependencies] databend-driver = { workspace = true, features = ["rustls", "flight-sql"] } +once_cell = "1.18" pyo3 = { version = "0.20", features = ["abi3-py37"] } pyo3-asyncio = { version = "0.20", features = ["tokio-runtime"] } tokio = "1.34" diff --git a/bindings/python/src/asyncio.rs b/bindings/python/src/asyncio.rs index 1ae4e6c4c..3b8eb3d2a 100644 --- a/bindings/python/src/asyncio.rs +++ b/bindings/python/src/asyncio.rs @@ -15,7 +15,7 @@ use pyo3::prelude::*; use pyo3_asyncio::tokio::future_into_py; -use crate::types::{ConnectionInfo, DriverError, Row, RowIterator, ServerStats}; +use crate::types::{ConnectionInfo, DriverError, Row, RowIterator, ServerStats, VERSION}; #[pyclass(module = "databend_driver")] pub struct AsyncDatabendClient(databend_driver::Client); @@ -25,7 +25,8 @@ impl AsyncDatabendClient { #[new] #[pyo3(signature = (dsn))] pub fn new(dsn: String) -> PyResult { - let client = databend_driver::Client::new(dsn); + let name = format!("databend-driver-python/{}", VERSION.as_str()); + let client = databend_driver::Client::new(dsn).with_name(name); Ok(Self(client)) } diff --git a/bindings/python/src/blocking.rs b/bindings/python/src/blocking.rs index b348cf3ae..8724fc5e9 100644 --- a/bindings/python/src/blocking.rs +++ b/bindings/python/src/blocking.rs @@ -14,7 +14,7 @@ use pyo3::prelude::*; -use crate::types::{ConnectionInfo, DriverError, Row, RowIterator, ServerStats}; +use crate::types::{ConnectionInfo, DriverError, Row, RowIterator, ServerStats, VERSION}; #[pyclass(module = "databend_driver")] pub struct BlockingDatabendClient(databend_driver::Client); @@ -24,7 +24,8 @@ impl BlockingDatabendClient { #[new] #[pyo3(signature = (dsn))] pub fn new(dsn: String) -> PyResult { - let client = databend_driver::Client::new(dsn); + let name = format!("databend-driver-python/{}", VERSION.as_str()); + let client = databend_driver::Client::new(dsn).with_name(name); Ok(Self(client)) } diff --git a/bindings/python/src/types.rs b/bindings/python/src/types.rs index 1358d8701..0b8158123 100644 --- a/bindings/python/src/types.rs +++ b/bindings/python/src/types.rs @@ -14,6 +14,7 @@ use std::sync::Arc; +use once_cell::sync::Lazy; use pyo3::exceptions::{PyException, PyStopAsyncIteration, PyStopIteration}; use pyo3::prelude::*; use pyo3::types::{PyDict, PyList, PyTuple}; @@ -21,6 +22,11 @@ use pyo3_asyncio::tokio::future_into_py; use tokio::sync::Mutex; use tokio_stream::StreamExt; +pub static VERSION: Lazy = Lazy::new(|| { + let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown"); + version.to_string() +}); + pub struct Value(databend_driver::Value); impl IntoPy for Value { diff --git a/cli/src/session.rs b/cli/src/session.rs index d9102e0ff..f240ffd0d 100644 --- a/cli/src/session.rs +++ b/cli/src/session.rs @@ -22,6 +22,7 @@ use anyhow::Result; use chrono::NaiveDateTime; use databend_driver::ServerStats; use databend_driver::{Client, Connection}; +use once_cell::sync::Lazy; use rustyline::config::Builder; use rustyline::error::ReadlineError; use rustyline::history::DefaultHistory; @@ -40,6 +41,15 @@ use crate::VERSION; static PROMPT_SQL: &str = "select name from system.tables union all select name from system.columns union all select name from system.databases union all select name from system.functions"; +static VERSION_SHORT: Lazy = Lazy::new(|| { + let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown"); + let sha = option_env!("VERGEN_GIT_SHA").unwrap_or("dev"); + match option_env!("BENDSQL_BUILD_INFO") { + Some(info) => format!("{}-{}", version, info), + None => format!("{}-{}", version, sha), + } +}); + pub struct Session { client: Client, conn: Box, @@ -54,16 +64,26 @@ pub struct Session { impl Session { pub async fn try_new(dsn: String, settings: Settings, is_repl: bool) -> Result { - let client = Client::new(dsn); + let client = Client::new(dsn).with_name(format!("bendsql/{}", VERSION_SHORT.as_str())); let conn = client.get_conn().await?; let info = conn.info().await; let mut keywords = Vec::with_capacity(1024); if is_repl { println!("Welcome to BendSQL {}.", VERSION.as_str()); - println!( - "Connecting to {}:{} as user {}.", - info.host, info.port, info.user - ); + match info.warehouse { + Some(ref warehouse) => { + println!( + "Connecting to {}:{} with warehouse {} as user {}", + info.host, info.port, warehouse, info.user + ); + } + None => { + println!( + "Connecting to {}:{} as user {}.", + info.host, info.port, info.user + ); + } + } let version = conn.version().await?; println!("Connected to {}", version); println!(); diff --git a/core/src/client.rs b/core/src/client.rs index bc8a5262a..923b85848 100644 --- a/core/src/client.rs +++ b/core/src/client.rs @@ -50,6 +50,7 @@ static VERSION: Lazy = Lazy::new(|| { #[derive(Clone)] pub struct APIClient { pub cli: HttpClient, + scheme: String, endpoint: Url, pub host: String, pub port: u16, @@ -72,7 +73,14 @@ pub struct APIClient { } impl APIClient { - pub async fn from_dsn(dsn: &str) -> Result { + pub async fn new(dsn: &str, name: Option) -> Result { + let mut client = Self::from_dsn(dsn).await?; + client.build_client(name).await?; + client.check_presign().await?; + Ok(client) + } + + async fn from_dsn(dsn: &str) -> Result { let u = Url::parse(dsn)?; let mut client = Self::default(); if let Some(host) = u.host_str() { @@ -176,21 +184,9 @@ impl APIClient { _ => unreachable!(), }, }; + client.scheme = scheme.to_string(); - let mut cli_builder = HttpClient::builder() - .user_agent(format!("databend-client-rust/{}", VERSION.as_str())) - .pool_idle_timeout(Duration::from_secs(1)); - #[cfg(any(feature = "rustls", feature = "native-tls"))] - if scheme == "https" { - if let Some(ref ca_file) = client.tls_ca_file { - let cert_pem = tokio::fs::read(ca_file).await?; - let cert = reqwest::Certificate::from_pem(&cert_pem)?; - cli_builder = cli_builder.add_root_certificate(cert); - } - } - client.cli = cli_builder.build()?; client.endpoint = Url::parse(&format!("{}://{}:{}", scheme, client.host, client.port))?; - client.session_state = Arc::new(Mutex::new( SessionState::default() .with_settings(Some(session_settings)) @@ -198,12 +194,30 @@ impl APIClient { .with_database(database), )); - client.init_presign().await?; - Ok(client) } - async fn init_presign(&mut self) -> Result<()> { + async fn build_client(&mut self, name: Option) -> Result<()> { + let ua = match name { + Some(n) => n, + None => format!("databend-client-rust/{}", VERSION.as_str()), + }; + let mut cli_builder = HttpClient::builder() + .user_agent(ua) + .pool_idle_timeout(Duration::from_secs(1)); + #[cfg(any(feature = "rustls", feature = "native-tls"))] + if self.scheme == "https" { + if let Some(ref ca_file) = self.tls_ca_file { + let cert_pem = tokio::fs::read(ca_file).await?; + let cert = reqwest::Certificate::from_pem(&cert_pem)?; + cli_builder = cli_builder.add_root_certificate(cert); + } + } + self.cli = cli_builder.build()?; + Ok(()) + } + + async fn check_presign(&mut self) -> Result<()> { match self.presign { PresignMode::Auto => { if self.host.ends_with(".databend.com") || self.host.ends_with(".databend.cn") { @@ -212,7 +226,7 @@ impl APIClient { self.presign = PresignMode::Off; } } - PresignMode::Detect => match self.get_presigned_upload_url("~/.bendsql/check").await { + PresignMode::Detect => match self.get_presigned_upload_url("@~/.bendsql/check").await { Ok(_) => self.presign = PresignMode::On, Err(e) => { warn!("presign mode off with error detected: {}", e); @@ -344,7 +358,8 @@ impl APIClient { } let resp: QueryResponse = resp.json().await?; self.handle_session(&resp.session).await; - // TODO: duplicate warnings with start_query, maybe we should only print warnings on final response + // TODO: duplicate warnings with start_query, + // maybe we should only print warnings on final response // self.handle_warnings(&resp); match resp.error { Some(err) => Err(Error::InvalidResponse(err)), @@ -570,6 +585,7 @@ impl Default for APIClient { fn default() -> Self { Self { cli: HttpClient::new(), + scheme: "http".to_string(), endpoint: Url::parse("http://localhost:8080").unwrap(), host: "localhost".to_string(), port: 8000, diff --git a/core/tests/core/simple.rs b/core/tests/core/simple.rs index 13fcf1d47..b48d1c13d 100644 --- a/core/tests/core/simple.rs +++ b/core/tests/core/simple.rs @@ -19,7 +19,7 @@ use crate::common::DEFAULT_DSN; #[tokio::test] async fn select_simple() { let dsn = option_env!("TEST_DATABEND_DSN").unwrap_or(DEFAULT_DSN); - let client = APIClient::from_dsn(dsn).await.unwrap(); + let client = APIClient::new(dsn, None).await.unwrap(); let resp = client.start_query("select 15532").await.unwrap(); assert_eq!(resp.data, [["15532"]]); } diff --git a/core/tests/core/stage.rs b/core/tests/core/stage.rs index 6b3e392a8..1700557b7 100644 --- a/core/tests/core/stage.rs +++ b/core/tests/core/stage.rs @@ -21,11 +21,11 @@ use crate::common::DEFAULT_DSN; async fn insert_with_stage(presign: bool) { let dsn = option_env!("TEST_DATABEND_DSN").unwrap_or(DEFAULT_DSN); let client = if presign { - APIClient::from_dsn(&format!("{}&presign=on", dsn)) + APIClient::new(&format!("{}&presign=on", dsn), None) .await .unwrap() } else { - APIClient::from_dsn(&format!("{}&presign=off", dsn)) + APIClient::new(&format!("{}&presign=off", dsn), None) .await .unwrap() }; diff --git a/driver/Cargo.toml b/driver/Cargo.toml index 3834163f1..2c48fe13e 100644 --- a/driver/Cargo.toml +++ b/driver/Cargo.toml @@ -36,6 +36,7 @@ csv = "1.3" dyn-clone = "1.0" glob = "0.3" log = "0.4" +once_cell = "1.18" percent-encoding = "2.3" serde_json = { version = "1.0", default-features = false, features = ["std"] } tokio = { version = "1.34", features = ["macros"] } diff --git a/driver/src/conn.rs b/driver/src/conn.rs index 212ca433a..e49e47f94 100644 --- a/driver/src/conn.rs +++ b/driver/src/conn.rs @@ -18,6 +18,7 @@ use std::sync::Arc; use async_trait::async_trait; use dyn_clone::DynClone; +use once_cell::sync::Lazy; use tokio::io::AsyncRead; use tokio_stream::StreamExt; use url::Url; @@ -34,26 +35,38 @@ use databend_sql::value::{NumberValue, Value}; use crate::rest_api::RestAPIConnection; +static VERSION: Lazy = Lazy::new(|| { + let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown"); + version.to_string() +}); + #[derive(Clone)] pub struct Client { dsn: String, + name: String, } impl Client { pub fn new(dsn: String) -> Self { - Self { dsn } + let name = format!("databend-driver-rust/{}", VERSION.as_str()); + Self { dsn, name } + } + + pub fn with_name(mut self, name: String) -> Self { + self.name = name; + self } pub async fn get_conn(&self) -> Result> { let u = Url::parse(&self.dsn)?; match u.scheme() { "databend" | "databend+http" | "databend+https" => { - let conn = RestAPIConnection::try_create(&self.dsn).await?; + let conn = RestAPIConnection::try_create(&self.dsn, self.name.clone()).await?; Ok(Box::new(conn)) } #[cfg(feature = "flight-sql")] "databend+flight" | "databend+grpc" => { - let conn = FlightSQLConnection::try_create(&self.dsn).await?; + let conn = FlightSQLConnection::try_create(&self.dsn, self.name.clone()).await?; Ok(Box::new(conn)) } _ => Err(Error::Parsing(format!( diff --git a/driver/src/flight_sql.rs b/driver/src/flight_sql.rs index a5996c933..2e6ebc934 100644 --- a/driver/src/flight_sql.rs +++ b/driver/src/flight_sql.rs @@ -146,8 +146,8 @@ impl Connection for FlightSQLConnection { } impl FlightSQLConnection { - pub async fn try_create(dsn: &str) -> Result { - let (args, endpoint) = Self::parse_dsn(dsn).await?; + pub async fn try_create(dsn: &str, name: String) -> Result { + let (args, endpoint) = Self::parse_dsn(dsn, name).await?; let channel = endpoint.connect_lazy(); let mut client = FlightSqlServiceClient::new(channel); // enable progress @@ -178,10 +178,11 @@ impl FlightSQLConnection { Ok(()) } - async fn parse_dsn(dsn: &str) -> Result<(Args, Endpoint)> { + async fn parse_dsn(dsn: &str, name: String) -> Result<(Args, Endpoint)> { let u = Url::parse(dsn)?; let args = Args::from_url(&u)?; let mut endpoint = Endpoint::new(args.uri.clone())? + .user_agent(name)? .connect_timeout(args.connect_timeout) .timeout(args.query_timeout) .tcp_nodelay(args.tcp_nodelay) diff --git a/driver/src/rest_api.rs b/driver/src/rest_api.rs index ac140ff35..2456e463f 100644 --- a/driver/src/rest_api.rs +++ b/driver/src/rest_api.rs @@ -190,8 +190,8 @@ impl Connection for RestAPIConnection { } impl<'o> RestAPIConnection { - pub async fn try_create(dsn: &str) -> Result { - let client = APIClient::from_dsn(dsn).await?; + pub async fn try_create(dsn: &str, name: String) -> Result { + let client = APIClient::new(dsn, Some(name)).await?; Ok(Self { client }) }