diff --git a/.github/workflows/bindings.python.yml b/.github/workflows/bindings.python.yml index 6b2a717d9..3e4aebd8b 100644 --- a/.github/workflows/bindings.python.yml +++ b/.github/workflows/bindings.python.yml @@ -53,7 +53,7 @@ jobs: matrix: include: - { os: linux, arch: x86_64, target: x86_64-unknown-linux-gnu, runner: ubuntu-20.04 } - - { os: linux, arch: aarch64, target: aarch64-unknown-linux-gnu, runner: ubuntu-20.04 } + # - { os: linux, arch: aarch64, target: aarch64-unknown-linux-gnu, runner: ubuntu-20.04 } - { os: windows, arch: x86_64, target: x86_64-pc-windows-msvc, runner: windows-2019 } - { os: macos, arch: x86_64, target: x86_64-apple-darwin, runner: macos-11 } - { os: macos, arch: aarch64, target: aarch64-apple-darwin, runner: macos-11 } @@ -73,7 +73,6 @@ jobs: with: working-directory: bindings/python target: ${{ matrix.target }} - manylinux: "2_28" sccache: 'true' args: ${{ steps.opts.outputs.BUILD_ARGS }} - name: Upload artifact diff --git a/bindings/nodejs/Cargo.toml b/bindings/nodejs/Cargo.toml index 79ed48eea..5300c0b59 100644 --- a/bindings/nodejs/Cargo.toml +++ b/bindings/nodejs/Cargo.toml @@ -16,13 +16,13 @@ doc = false databend-driver = { workspace = true, features = ["rustls", "flight-sql"] } chrono = { version = "0.4", default-features = false } -napi = { version = "2.13", default-features = false, features = [ +napi = { version = "2.14", default-features = false, features = [ "napi6", "async", "serde-json", "chrono_date", ] } -napi-derive = "2.13" +napi-derive = "2.14" tokio-stream = "0.1" [build-dependencies] diff --git a/bindings/python/Cargo.toml b/bindings/python/Cargo.toml index 3f8a5d4b8..dccc43010 100644 --- a/bindings/python/Cargo.toml +++ b/bindings/python/Cargo.toml @@ -14,7 +14,7 @@ doc = false [dependencies] databend-driver = { workspace = true, features = ["rustls", "flight-sql"] } -pyo3 = { version = "0.19", features = ["abi3-py37"] } -pyo3-asyncio = { version = "0.19", features = ["tokio-runtime"] } -tokio = "1.28" +pyo3 = { version = "0.20", features = ["abi3-py37"] } +pyo3-asyncio = { version = "0.20", features = ["tokio-runtime"] } +tokio = "1.34" tokio-stream = "0.1" diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 06fbe243c..d19e34730 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -16,9 +16,9 @@ databend-driver = { workspace = true, features = ["rustls", "flight-sql"] } anyhow = "1.0" async-trait = "0.1" chrono = { version = "0.4.31", default-features = false, features = ["clock"] } -clap = { version = "4.3", features = ["derive", "env"] } -comfy-table = "7.0" -csv = "1.2" +clap = { version = "4.4", features = ["derive", "env"] } +comfy-table = "7.1" +csv = "1.3" fern = "0.6" indicatif = "0.17" log = "0.4" @@ -30,7 +30,7 @@ sqlformat = "0.2" strum = "0.25" strum_macros = "0.25" terminal_size = "0.3" -tokio = { version = "1.28", features = [ +tokio = { version = "1.34", features = [ "macros", "rt", "rt-multi-thread", @@ -41,7 +41,7 @@ tokio-stream = "0.1" toml = "0.8" tracing-appender = "0.2" unicode-segmentation = "1.10" -url = { version = "2.4", default-features = false } +url = { version = "2.5", default-features = false } [build-dependencies] vergen = { version = "8.2", features = ["build", "git", "gix"] } diff --git a/core/Cargo.toml b/core/Cargo.toml index 9fda63855..8e978ad67 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -19,19 +19,19 @@ rustls = ["reqwest/rustls-tls"] native-tls = ["reqwest/native-tls"] [dependencies] -http = "0.2" +async-trait = "0.1" log = "0.4" once_cell = "1.18" percent-encoding = "2.3" reqwest = { version = "0.11", default-features = false, features = ["json", "multipart", "stream"] } serde = { version = "1.0", default-features = false, features = ["derive"] } serde_json = { version = "1.0", default-features = false, features = ["std"] } -tokio = { version = "1.28", features = ["macros"] } +tokio = { version = "1.34", features = ["macros"] } tokio-retry = "0.3" tokio-stream = "0.1" tokio-util = { version = "0.7", features = ["io-util"] } -url = { version = "2.4", default-features = false } -uuid = { version = "1.4", features = ["v4"] } +url = { version = "2.5", default-features = false } +uuid = { version = "1.6", features = ["v4"] } [dev-dependencies] chrono = { version = "0.4", default-features = false, features = ["clock"] } diff --git a/core/src/auth.rs b/core/src/auth.rs new file mode 100644 index 000000000..fe67fc2c2 --- /dev/null +++ b/core/src/auth.rs @@ -0,0 +1,98 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use reqwest::RequestBuilder; + +use crate::error::{Error, Result}; + +#[async_trait::async_trait] +pub trait Auth: Sync + Send { + async fn wrap(&self, builder: RequestBuilder) -> Result; + fn username(&self) -> String; +} + +#[derive(Clone)] +pub struct BasicAuth { + username: String, + password: Option, +} + +impl BasicAuth { + pub fn new(username: String, password: Option) -> Self { + Self { username, password } + } +} + +#[async_trait::async_trait] +impl Auth for BasicAuth { + async fn wrap(&self, builder: RequestBuilder) -> Result { + Ok(builder.basic_auth(&self.username, self.password.as_deref())) + } + + fn username(&self) -> String { + self.username.clone() + } +} + +#[derive(Clone)] +pub struct AccessTokenAuth { + token: String, +} + +impl AccessTokenAuth { + pub fn new(token: String) -> Self { + Self { token } + } +} + +#[async_trait::async_trait] +impl Auth for AccessTokenAuth { + async fn wrap(&self, builder: RequestBuilder) -> Result { + Ok(builder.bearer_auth(&self.token)) + } + + fn username(&self) -> String { + "token".to_string() + } +} + +#[derive(Clone)] +pub struct AccessTokenFileAuth { + token_file: String, +} + +impl AccessTokenFileAuth { + pub fn new(token_file: String) -> Self { + Self { token_file } + } +} + +#[async_trait::async_trait] +impl Auth for AccessTokenFileAuth { + async fn wrap(&self, builder: RequestBuilder) -> Result { + let token = tokio::fs::read_to_string(&self.token_file) + .await + .map_err(|e| { + Error::IO(format!( + "cannot read access token from file {}: {}", + self.token_file, e + )) + })?; + Ok(builder.bearer_auth(token.trim())) + } + + fn username(&self) -> String { + "token".to_string() + } +} diff --git a/core/src/client.rs b/core/src/client.rs index ad4844407..485943bc3 100644 --- a/core/src/client.rs +++ b/core/src/client.rs @@ -16,7 +16,6 @@ use std::collections::BTreeMap; use std::sync::Arc; use std::time::Duration; -use http::StatusCode; use log::info; use once_cell::sync::Lazy; use percent_encoding::percent_decode_str; @@ -29,6 +28,7 @@ use tokio_retry::Retry; use tokio_util::io::ReaderStream; use url::Url; +use crate::auth::{AccessTokenAuth, AccessTokenFileAuth, Auth, BasicAuth}; use crate::presign::{presign_upload_to_stage, PresignedResponse, Reader}; use crate::stage::StageLocation; use crate::{ @@ -53,8 +53,8 @@ pub struct APIClient { endpoint: Url, pub host: String, pub port: u16, - pub user: String, - password: Option, + + auth: Arc, tenant: Option, warehouse: Arc>>, @@ -78,10 +78,14 @@ impl APIClient { if let Some(host) = u.host_str() { client.host = host.to_string(); } - client.user = u.username().to_string(); - client.password = u - .password() - .map(|s| percent_decode_str(s).decode_utf8_lossy().to_string()); + + if u.username() != "" { + client.auth = Arc::new(BasicAuth::new( + u.username().to_string(), + u.password() + .map(|s| percent_decode_str(s).decode_utf8_lossy().to_string()), + )); + } let database = match u.path().trim_start_matches('/') { "" => None, s => Some(s.to_string()), @@ -138,6 +142,12 @@ impl APIClient { "tls_ca_file" => { client.tls_ca_file = Some(v.to_string()); } + "access_token" => { + client.auth = Arc::new(AccessTokenAuth::new(v.to_string())); + } + "access_token_file" => { + client.auth = Arc::new(AccessTokenFileAuth::new(v.to_string())); + } _ => { session_settings.insert(k.to_string(), v.to_string()); } @@ -190,6 +200,10 @@ impl APIClient { guard.role.clone() } + pub fn username(&self) -> String { + self.auth.username() + } + fn gen_query_id(&self) -> String { uuid::Uuid::new_v4().to_string() } @@ -224,30 +238,20 @@ impl APIClient { let endpoint = self.endpoint.join("v1/query")?; let query_id = self.gen_query_id(); let headers = self.make_headers(&query_id).await?; - let mut resp = self - .cli - .post(endpoint.clone()) - .json(&req) - .basic_auth(self.user.clone(), self.password.clone()) - .headers(headers.clone()) - .send() - .await?; + let mut builder = self.cli.post(endpoint.clone()).json(&req); + builder = self.auth.wrap(builder).await?; + let mut resp = builder.headers(headers.clone()).send().await?; let mut retries = 3; - while resp.status() != StatusCode::OK { - if resp.status() != StatusCode::SERVICE_UNAVAILABLE || retries <= 0 { + while resp.status() != 200 { + if resp.status() != 503 || retries <= 0 { break; } retries -= 1; - resp = self - .cli - .post(endpoint.clone()) - .json(&req) - .basic_auth(self.user.clone(), self.password.clone()) - .headers(headers.clone()) - .send() - .await?; + let mut builder = self.cli.post(endpoint.clone()).json(&req); + builder = self.auth.wrap(builder).await?; + resp = builder.headers(headers.clone()).send().await?; } - if resp.status() != StatusCode::OK { + if resp.status() != 200 { return Err(Error::Request(format!( "StartQuery failed with status {}: {}", resp.status(), @@ -269,18 +273,19 @@ impl APIClient { let headers = self.make_headers(query_id).await?; let retry_strategy = ExponentialBackoff::from_millis(10).map(jitter).take(3); let req = || async { - self.cli - .get(endpoint.clone()) - .basic_auth(self.user.clone(), self.password.clone()) + let mut builder = self.cli.get(endpoint.clone()); + builder = self.auth.wrap(builder).await?; + builder .headers(headers.clone()) .timeout(self.page_request_timeout) .send() .await + .map_err(Error::from) }; let resp = Retry::spawn(retry_strategy, req).await?; - if resp.status() != StatusCode::OK { + if resp.status() != 200 { // TODO(liyz): currently it's not possible to distinguish between session timeout and server crashed - if resp.status() == StatusCode::NOT_FOUND { + if resp.status() == 404 { return Err(Error::SessionTimeout(resp.text().await?)); } return Err(Error::Request(format!( @@ -301,14 +306,10 @@ impl APIClient { info!("kill query: {}", kill_uri); let endpoint = self.endpoint.join(kill_uri)?; let headers = self.make_headers(query_id).await?; - let resp = self - .cli - .post(endpoint.clone()) - .basic_auth(self.user.clone(), self.password.clone()) - .headers(headers.clone()) - .send() - .await?; - if resp.status() != StatusCode::OK { + let mut builder = self.cli.post(endpoint.clone()); + builder = self.auth.wrap(builder).await?; + let resp = builder.headers(headers.clone()).send().await?; + if resp.status() != 200 { let resp_err = QueryError { code: resp.status().as_u16(), message: format!("kill query failed: {}", resp.text().await?), @@ -409,30 +410,20 @@ impl APIClient { let query_id = self.gen_query_id(); let headers = self.make_headers(&query_id).await?; - let mut resp = self - .cli - .post(endpoint.clone()) - .json(&req) - .basic_auth(self.user.clone(), self.password.clone()) - .headers(headers.clone()) - .send() - .await?; + let mut builder = self.cli.post(endpoint.clone()).json(&req); + builder = self.auth.wrap(builder).await?; + let mut resp = builder.headers(headers.clone()).send().await?; let mut retries = 3; - while resp.status() != StatusCode::OK { - if resp.status() != StatusCode::SERVICE_UNAVAILABLE || retries <= 0 { + while resp.status() != 200 { + if resp.status() != 503 || retries <= 0 { break; } retries -= 1; - resp = self - .cli - .post(endpoint.clone()) - .json(&req) - .basic_auth(self.user.clone(), self.password.clone()) - .headers(headers.clone()) - .send() - .await?; + let mut builder = self.cli.post(endpoint.clone()).json(&req); + builder = self.auth.wrap(builder).await?; + resp = builder.headers(headers.clone()).send().await?; } - if resp.status() != StatusCode::OK { + if resp.status() != 200 { let resp_err = QueryError { code: resp.status().as_u16(), message: resp.text().await?, @@ -503,24 +494,18 @@ impl APIClient { let stream = Body::wrap_stream(ReaderStream::new(data)); let part = Part::stream_with_length(stream, size).file_name(location.path); let form = Form::new().part("upload", part); - let resp = self - .cli - .put(endpoint) - .basic_auth(self.user.clone(), self.password.clone()) - .headers(headers) - .multipart(form) - .send() - .await?; - + let mut builder = self.cli.put(endpoint.clone()); + builder = self.auth.wrap(builder).await?; + let resp = builder.headers(headers).multipart(form).send().await?; let status = resp.status(); let body = resp.bytes().await?; - match status { - StatusCode::OK => Ok(()), - _ => Err(Error::Request(format!( + if status != 200 { + return Err(Error::Request(format!( "Stage Upload Failed: {}", String::from_utf8_lossy(&body) - ))), + ))); } + Ok(()) } } @@ -533,8 +518,7 @@ impl Default for APIClient { port: 8000, tenant: None, warehouse: Arc::new(Mutex::new(None)), - user: "root".to_string(), - password: None, + auth: Arc::new(BasicAuth::new("root".to_string(), None)) as Arc, session_state: Arc::new(Mutex::new(SessionState::default())), wait_time_secs: None, max_rows_in_buffer: None, @@ -556,8 +540,6 @@ mod test { let client = APIClient::from_dsn(dsn).await?; assert_eq!(client.host, "app.databend.com"); assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?); - assert_eq!(client.user, "username"); - assert_eq!(client.password, Some("password".to_string())); assert_eq!(client.wait_time_secs, Some(10)); assert_eq!(client.max_rows_in_buffer, Some(5000000)); assert_eq!(client.max_rows_per_page, Some(10000)); @@ -569,19 +551,19 @@ mod test { Ok(()) } - #[tokio::test] - async fn parse_encoded_password() -> Result<()> { - let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost"; - let client = APIClient::from_dsn(dsn).await?; - assert_eq!(client.password, Some("3a@SC(nYE1k={{R".to_string())); - Ok(()) - } - - #[tokio::test] - async fn parse_special_chars_password() -> Result<()> { - let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000"; - let client = APIClient::from_dsn(dsn).await?; - assert_eq!(client.password, Some("3a@SC(nYE1k={{R".to_string())); - Ok(()) - } + // #[tokio::test] + // async fn parse_encoded_password() -> Result<()> { + // let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost"; + // let client = APIClient::from_dsn(dsn).await?; + // assert_eq!(client.password, Some("3a@SC(nYE1k={{R".to_string())); + // Ok(()) + // } + + // #[tokio::test] + // async fn parse_special_chars_password() -> Result<()> { + // let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000"; + // let client = APIClient::from_dsn(dsn).await?; + // assert_eq!(client.password, Some("3a@SC(nYE1k={{R".to_string())); + // Ok(()) + // } } diff --git a/core/src/lib.rs b/core/src/lib.rs index ed13c0931..e2652c747 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -14,6 +14,7 @@ mod client; +pub mod auth; pub mod error; pub mod presign; pub mod request; diff --git a/driver/Cargo.toml b/driver/Cargo.toml index 3acd4e520..3834163f1 100644 --- a/driver/Cargo.toml +++ b/driver/Cargo.toml @@ -38,9 +38,9 @@ glob = "0.3" log = "0.4" percent-encoding = "2.3" serde_json = { version = "1.0", default-features = false, features = ["std"] } -tokio = { version = "1.28", features = ["macros"] } +tokio = { version = "1.34", features = ["macros"] } tokio-stream = "0.1" -url = { version = "2.4", default-features = false } +url = { version = "2.5", default-features = false } arrow = { version = "47.0" } arrow-flight = { version = "47.0", features = ["flight-sql-experimental"], optional = true } diff --git a/driver/src/conn.rs b/driver/src/conn.rs index 18475d5c3..60b0add66 100644 --- a/driver/src/conn.rs +++ b/driver/src/conn.rs @@ -39,7 +39,7 @@ pub struct Client { dsn: String, } -impl<'c> Client { +impl Client { pub fn new(dsn: String) -> Self { Self { dsn } } diff --git a/driver/src/rest_api.rs b/driver/src/rest_api.rs index 81f085d60..4545a74d8 100644 --- a/driver/src/rest_api.rs +++ b/driver/src/rest_api.rs @@ -45,7 +45,7 @@ impl Connection for RestAPIConnection { handler: "RestAPI".to_string(), host: self.client.host.clone(), port: self.client.port, - user: self.client.user.clone(), + user: self.client.username(), database: self.client.current_database().await, warehouse: self.client.current_warehouse().await, } diff --git a/sql/Cargo.toml b/sql/Cargo.toml index c4e8f6f72..bafcae96e 100644 --- a/sql/Cargo.toml +++ b/sql/Cargo.toml @@ -18,13 +18,13 @@ databend-client = { workspace = true } chrono = { version = "0.4", default-features = false } glob = "0.3" -itertools = "0.11" +itertools = "0.12" jsonb = "0.3" -roaring = { version = "0.10.1", features = ["serde"] } +roaring = { version = "0.10", features = ["serde"] } serde = { version = "1.0", default-features = false, features = ["derive"] } serde_json = { version = "1.0", default-features = false, features = ["std"] } tokio-stream = "0.1" -url = { version = "2.4", default-features = false } +url = { version = "2.5", default-features = false } arrow = { version = "47.0" } arrow-array = { version = "47.0", optional = true }