From ea59261f41d18be400e5c03feaa608342703bf58 Mon Sep 17 00:00:00 2001 From: everpcpc Date: Mon, 10 Jul 2023 18:05:49 +0800 Subject: [PATCH] feat: support use warehouse with cloud (#142) --- bindings/nodejs/Cargo.toml | 2 +- bindings/python/Cargo.toml | 4 ++-- cli/Cargo.toml | 4 ++-- core/Cargo.toml | 2 +- core/src/client.rs | 37 ++++++++++++++++++++++++++----------- driver/Cargo.toml | 4 ++-- 6 files changed, 34 insertions(+), 19 deletions(-) diff --git a/bindings/nodejs/Cargo.toml b/bindings/nodejs/Cargo.toml index 65f92144a..f497ad54b 100644 --- a/bindings/nodejs/Cargo.toml +++ b/bindings/nodejs/Cargo.toml @@ -10,7 +10,7 @@ crate-type = ["cdylib"] doc = false [dependencies] -databend-driver = { path = "../../driver", version = "0.2.24", features = ["rustls", "flight-sql"] } +databend-driver = { path = "../../driver", version = "0.3.0", features = ["rustls", "flight-sql"] } futures = "0.3.28" napi = { version = "2.13.2", default-features = false, features = [ "napi6", diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 06f36b00c..20577e0c8 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -28,9 +28,9 @@ doc = false [dependencies] chrono = { version = "0.4.24", default-features = false, features = ["std"] } +databend-client = { version = "0.2.0", path = "../../core" } +databend-driver = { path = "../../driver", version = "0.3.0", features = ["rustls", "flight-sql"] } futures = "0.3.28" -databend-driver = { path = "../../driver", version = "0.2.20", features = ["rustls", "flight-sql"] } -databend-client = { version = "0.1.15", path = "../../core" } pyo3 = { version = "0.18", features = ["abi3-py37"] } pyo3-asyncio = { version = "0.18", features = ["tokio-runtime"] } tokio = "1" diff --git a/cli/Cargo.toml b/cli/Cargo.toml index da526dfa3..8cbcb3f31 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bendsql" -version = "0.3.12" +version = "0.4.0" edition = "2021" license = "Apache-2.0" description = "Databend Native Command Line Tool" @@ -15,7 +15,7 @@ chrono = { version = "0.4.24", default-features = false, features = ["clock"] } clap = { version = "4.1.0", features = ["derive", "env"] } comfy-table = "6.1.4" csv = "1.2.1" -databend-driver = { path = "../driver", version = "0.2.23", features = ["rustls", "flight-sql"] } +databend-driver = { path = "../driver", version = "0.3.0", features = ["rustls", "flight-sql"] } futures = { version = "0.3", default-features = false, features = ["alloc"] } humantime-serde = "1.1.1" indicatif = "0.17.3" diff --git a/core/Cargo.toml b/core/Cargo.toml index e115209f4..a21955d3a 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "databend-client" -version = "0.1.16" +version = "0.2.0" edition = "2021" license = "Apache-2.0" description = "Databend Client for Rust" diff --git a/core/src/client.rs b/core/src/client.rs index fb1611217..8ee8d0867 100644 --- a/core/src/client.rs +++ b/core/src/client.rs @@ -77,7 +77,7 @@ pub struct APIClient { pub port: u16, tenant: Option, - warehouse: Option, + warehouse: Arc>>, pub database: Arc>>, pub user: String, password: Option, @@ -135,7 +135,7 @@ impl APIClient { client.tenant = Some(v.to_string()); } "warehouse" => { - client.warehouse = Some(v.to_string()); + client.warehouse = Arc::new(Mutex::new(Some(v.to_string()))); } "sslmode" => { if v == "disable" { @@ -167,12 +167,13 @@ impl APIClient { .with_pagination(self.make_pagination()) .with_session(session_settings); let endpoint = self.endpoint.join("v1/query")?; + let headers = self.make_headers().await?; let resp = self .cli .post(endpoint) .json(&req) .basic_auth(self.user.clone(), self.password.clone()) - .headers(self.make_headers()?) + .headers(headers) .send() .await?; if resp.status() != StatusCode::OK { @@ -194,7 +195,15 @@ impl APIClient { } if let Some(settings) = &session.settings { for (k, v) in settings { - session_settings.insert(k.clone(), v.clone()); + match k.as_str() { + "warehouse" => { + let mut warehouse = self.warehouse.lock().await; + *warehouse = Some(v.clone()); + } + _ => { + session_settings.insert(k.clone(), v.clone()); + } + } } } } @@ -203,11 +212,12 @@ impl APIClient { pub async fn query_page(&self, next_uri: &str) -> Result { let endpoint = self.endpoint.join(next_uri)?; + let headers = self.make_headers().await?; let resp = self .cli .get(endpoint) .basic_auth(self.user.clone(), self.password.clone()) - .headers(self.make_headers()?) + .headers(headers) .send() .await?; if resp.status() != StatusCode::OK { @@ -289,12 +299,13 @@ impl APIClient { Some(pagination) } - fn make_headers(&self) -> Result { + async fn make_headers(&self) -> Result { let mut headers = HeaderMap::new(); if let Some(tenant) = &self.tenant { headers.insert("X-DATABEND-TENANT", tenant.parse()?); } - if let Some(warehouse) = &self.warehouse { + let warehouse = self.warehouse.lock().await; + if let Some(warehouse) = &*warehouse { headers.insert("X-DATABEND-WAREHOUSE", warehouse.parse()?); } Ok(headers) @@ -318,12 +329,13 @@ impl APIClient { .with_session(session_settings) .with_stage_attachment(stage_attachment); let endpoint = self.endpoint.join("v1/query")?; + let headers = self.make_headers().await?; let resp = self .cli .post(endpoint) .json(&req) .basic_auth(self.user.clone(), self.password.clone()) - .headers(self.make_headers()?) + .headers(headers) .send() .await?; if resp.status() != StatusCode::OK { @@ -361,7 +373,7 @@ impl APIClient { ) -> Result<()> { let endpoint = self.endpoint.join("v1/upload_to_stage")?; let location = StageLocation::try_from(stage_location)?; - let mut headers = self.make_headers()?; + let mut headers = self.make_headers().await?; headers.insert("stage_name", location.name.parse()?); let stream = Body::wrap_stream(ReaderStream::new(data)); let part = Part::stream_with_length(stream, size).file_name(location.path); @@ -451,7 +463,7 @@ impl Default for APIClient { host: "localhost".to_string(), port: 8000, tenant: None, - warehouse: None, + warehouse: Arc::new(Mutex::new(None)), database: Arc::new(Mutex::new(None)), user: "root".to_string(), password: None, @@ -484,7 +496,10 @@ mod test { assert_eq!(client.max_rows_in_buffer, Some(5000000)); assert_eq!(client.max_rows_per_page, Some(10000)); assert_eq!(client.tenant, None); - assert_eq!(client.warehouse, Some("wh".to_string())); + assert_eq!( + *client.warehouse.try_lock().unwrap(), + Some("wh".to_string()) + ); Ok(()) } diff --git a/driver/Cargo.toml b/driver/Cargo.toml index 056ef9b0b..a60fd52e4 100644 --- a/driver/Cargo.toml +++ b/driver/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "databend-driver" -version = "0.2.24" +version = "0.3.0" edition = "2021" license = "Apache-2.0" description = "Databend Driver for Rust" @@ -21,7 +21,7 @@ flight-sql = ["dep:arrow-array", "dep:arrow-cast", "dep:arrow-flight", "dep:arro [dependencies] async-trait = "0.1.68" chrono = { version = "0.4.24", default-features = false, features = ["clock"] } -databend-client = { version = "0.1.16", path = "../core" } +databend-client = { version = "0.2.0", path = "../core" } dyn-clone = "1.0.11" http = "0.2.9" percent-encoding = "2.2.0"