From e64ad8795420b53ae38423aee27793700bc808ce Mon Sep 17 00:00:00 2001 From: everpcpc Date: Tue, 23 Apr 2024 10:33:58 +0800 Subject: [PATCH] feat: make password & dsn sensitive (#404) * feat: make password & dsn sensitive * chore: tmp disable geometry test --- cli/Cargo.toml | 3 +- cli/src/args.rs | 13 ++-- cli/src/display.rs | 11 ++-- cli/src/main.rs | 73 +++++++++++++---------- cli/tests/00-base.result | 2 +- cli/tests/00-base.sql | 2 +- core/src/auth.rs | 89 +++++++++++++++++++++++++--- core/src/client.rs | 14 ++--- core/src/error.rs | 6 ++ driver/src/flight_sql.rs | 13 ++-- driver/tests/driver/select_iter.rs | 16 ----- driver/tests/driver/select_simple.rs | 34 ++++++++++- sql/src/error.rs | 6 ++ sql/src/value.rs | 4 +- tests/docker-compose.yaml | 2 +- 15 files changed, 196 insertions(+), 92 deletions(-) diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 41631b3b0..b9c647a7e 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -11,9 +11,11 @@ authors = { workspace = true } repository = { workspace = true } [dependencies] +databend-client = { workspace = true } databend-driver = { workspace = true, features = ["rustls", "flight-sql"] } anyhow = "1.0" +async-recursion = "1.1.0" async-trait = "0.1" chrono = { version = "0.4.35", default-features = false, features = ["clock"] } clap = { version = "4.4", features = ["derive", "env"] } @@ -42,7 +44,6 @@ toml = "0.8" tracing-appender = "0.2" unicode-segmentation = "1.10" url = { version = "2.5", default-features = false } -async-recursion = "1.1.0" [build-dependencies] vergen = { version = "8.2", features = ["build", "git", "gix"] } diff --git a/cli/src/args.rs b/cli/src/args.rs index 25643b487..35cf9a4c5 100644 --- a/cli/src/args.rs +++ b/cli/src/args.rs @@ -15,13 +15,14 @@ use std::collections::BTreeMap; use anyhow::{anyhow, Result}; +use databend_client::auth::SensitiveString; #[derive(Debug, Clone, PartialEq, Default)] pub struct ConnectionArgs { pub host: String, pub port: Option, pub user: String, - pub password: Option, + pub password: SensitiveString, pub database: Option, pub flight: bool, pub args: BTreeMap, @@ -33,9 +34,7 @@ impl ConnectionArgs { dsn.set_host(Some(&self.host))?; _ = dsn.set_port(self.port); _ = dsn.set_username(&self.user); - if let Some(password) = self.password { - _ = dsn.set_password(Some(&password)) - }; + _ = dsn.set_password(Some(self.password.inner())); if let Some(database) = self.database { dsn.set_path(&database); } @@ -64,7 +63,7 @@ impl ConnectionArgs { let host = u.host_str().ok_or(anyhow!("missing host"))?.to_string(); let port = u.port(); let user = u.username().to_string(); - let password = u.password().map(|s| s.to_string()); + let password = SensitiveString::from(u.password().unwrap_or_default()); let database = u.path().strip_prefix('/').map(|s| s.to_string()); Ok(Self { host, @@ -90,7 +89,7 @@ mod test { host: "app.databend.com".to_string(), port: None, user: "username".to_string(), - password: Some("3a@SC(nYE1k={{R".to_string()), + password: SensitiveString::from("3a@SC(nYE1k={{R"), database: Some("test".to_string()), flight: false, args: { @@ -113,7 +112,7 @@ mod test { host: "app.databend.com".to_string(), port: Some(443), user: "username".to_string(), - password: Some("3a@SC(nYE1k={{R".to_string()), + password: SensitiveString::from("3a@SC(nYE1k={{R"), database: Some("test".to_string()), flight: false, args: { diff --git a/cli/src/display.rs b/cli/src/display.rs index 9283865c5..b0f1e7984 100644 --- a/cli/src/display.rs +++ b/cli/src/display.rs @@ -14,23 +14,20 @@ use std::collections::HashSet; use std::fmt::Write; -use unicode_segmentation::UnicodeSegmentation; use anyhow::{anyhow, Result}; use comfy_table::{Cell, CellAlignment, Table}; -use terminal_size::{terminal_size, Width}; - use databend_driver::{Row, RowStatsIterator, RowWithStats, SchemaRef, ServerStats}; +use indicatif::{HumanBytes, ProgressBar, ProgressState, ProgressStyle}; use rustyline::highlight::Highlighter; +use terminal_size::{terminal_size, Width}; use tokio::time::Instant; use tokio_stream::StreamExt; +use unicode_segmentation::UnicodeSegmentation; -use indicatif::{HumanBytes, ProgressBar, ProgressState, ProgressStyle}; - -use crate::config::OutputQuoteStyle; use crate::{ ast::format_query, - config::{ExpandMode, OutputFormat, Settings}, + config::{ExpandMode, OutputFormat, OutputQuoteStyle, Settings}, helper::CliHelper, session::QueryKind, }; diff --git a/cli/src/main.rs b/cli/src/main.rs index 4e9f90147..e0ba959a5 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -27,14 +27,17 @@ use std::{ io::{stdin, IsTerminal}, }; -use crate::args::ConnectionArgs; -use crate::config::OutputQuoteStyle; use anyhow::{anyhow, Result}; use clap::{ArgAction, CommandFactory, Parser, ValueEnum}; -use config::{Config, OutputFormat, Settings, TimeOption}; +use databend_client::auth::SensitiveString; use log::info; use once_cell::sync::Lazy; +use crate::{ + args::ConnectionArgs, + config::{Config, OutputFormat, OutputQuoteStyle, Settings, TimeOption}, +}; + static VERSION: Lazy = Lazy::new(|| { let version = option_env!("CARGO_PKG_VERSION").unwrap_or("unknown"); let sha = option_env!("VERGEN_GIT_SHA").unwrap_or("dev"); @@ -106,32 +109,45 @@ struct Args { #[clap(long, help = "Print help information")] help: bool, - #[clap(long, help = "Using flight sql protocol")] + #[clap(long, help = "Using flight sql protocol, ignored when --dsn is set")] flight: bool, - #[clap(long, help = "Enable TLS")] + #[clap(long, help = "Enable TLS, ignored when --dsn is set")] tls: bool, - #[clap(short = 'h', long, help = "Databend Server host, Default: 127.0.0.1")] + #[clap( + short = 'h', + long, + help = "Databend Server host, Default: 127.0.0.1, ignored when --dsn is set" + )] host: Option, - #[clap(short = 'P', long, help = "Databend Server port, Default: 8000")] + #[clap( + short = 'P', + long, + help = "Databend Server port, Default: 8000, ignored when --dsn is set" + )] port: Option, - #[clap(short = 'u', long, help = "Default: root")] + #[clap(short = 'u', long, help = "Default: root, overrides username in DSN")] user: Option, - #[clap(short = 'p', long, env = "BENDSQL_PASSWORD")] - password: Option, + #[clap( + short = 'p', + long, + env = "BENDSQL_PASSWORD", + help = "Password, overrides password in DSN" + )] + password: Option, - #[clap(short = 'D', long, help = "Database name")] + #[clap(short = 'D', long, help = "Database name, overrides database in DSN")] database: Option, - #[clap(long, value_parser = parse_key_val::, help = "Settings")] + #[clap(long, value_parser = parse_key_val::, help = "Settings, ignored when --dsn is set")] set: Vec<(String, String)>, #[clap(long, env = "BENDSQL_DSN", help = "Data source name")] - dsn: Option, + dsn: Option, #[clap(short = 'n', long, help = "Force non-interactive mode")] non_interactive: bool, @@ -219,15 +235,6 @@ pub async fn main() -> Result<()> { if args.port.is_some() { eprintln!("warning: --port is ignored when --dsn is set"); } - if args.user.is_some() { - eprintln!("warning: --user is ignored when --dsn is set"); - } - if args.password.is_some() { - eprintln!("warning: --password is ignored when --dsn is set"); - } - if args.role.is_some() { - eprintln!("warning: --role is ignored when --dsn is set"); - } if !args.set.is_empty() { eprintln!("warning: --set is ignored when --dsn is set"); } @@ -237,7 +244,7 @@ pub async fn main() -> Result<()> { if args.flight { eprintln!("warning: --flight is ignored when --dsn is set"); } - ConnectionArgs::from_dsn(&dsn)? + ConnectionArgs::from_dsn(dsn.inner())? } None => { if let Some(host) = args.host { @@ -246,9 +253,6 @@ pub async fn main() -> Result<()> { if let Some(port) = args.port { config.connection.port = Some(port); } - if let Some(user) = args.user { - config.connection.user = user; - } for (k, v) in args.set { config.connection.args.insert(k, v); } @@ -258,14 +262,11 @@ pub async fn main() -> Result<()> { .args .insert("sslmode".to_string(), "disable".to_string()); } - if let Some(role) = args.role { - config.connection.args.insert("role".to_string(), role); - } ConnectionArgs { host: config.connection.host.clone(), port: config.connection.port, user: config.connection.user.clone(), - password: args.password, + password: SensitiveString::from(""), database: config.connection.database.clone(), flight: args.flight, args: config.connection.args.clone(), @@ -276,6 +277,18 @@ pub async fn main() -> Result<()> { if args.database.is_some() { conn_args.database = args.database; } + // override user if specified in command line + if let Some(user) = args.user { + config.connection.user = user; + } + // override password if specified in command line + if let Some(password) = args.password { + conn_args.password = password; + } + // override role if specified in command line + if let Some(role) = args.role { + config.connection.args.insert("role".to_string(), role); + } let dsn = conn_args.get_dsn()?; let mut settings = Settings::default(); diff --git a/cli/tests/00-base.result b/cli/tests/00-base.result index 4962e053b..ccbe2cb43 100644 --- a/cli/tests/00-base.result +++ b/cli/tests/00-base.result @@ -20,5 +20,5 @@ Asia/Shanghai NULL {'k1':'v1','k2':'v2'} (2,NULL) 1 NULL 1 ab NULL v1 2 NULL -{'k1':'v1','k2':'v2'} [6162,78797A] ('[1,2]','SRID=4326;POINT(1 2)','2024-04-10') +{'k1':'v1','k2':'v2'} [6162,78797A] ('[1,2]','2024-04-10') bye diff --git a/cli/tests/00-base.sql b/cli/tests/00-base.sql index e07c32703..d441fadfd 100644 --- a/cli/tests/00-base.sql +++ b/cli/tests/00-base.sql @@ -46,7 +46,7 @@ insert into test_nested values([1,2,3], null, (1, 'ab')), (null, {'k1':'v1', 'k2 select * from test_nested; select a[1], b['k1'], c:x, c:y from test_nested; -select {'k1':'v1','k2':'v2'}, [to_binary('ab'), to_binary('xyz')], (parse_json('[1,2]'), st_geometryfromwkt('SRID=4326;POINT(1.0 2.0)'), to_date('2024-04-10')); +select {'k1':'v1','k2':'v2'}, [to_binary('ab'), to_binary('xyz')], (parse_json('[1,2]'), to_date('2024-04-10')); select 'bye'; drop table test; diff --git a/core/src/auth.rs b/core/src/auth.rs index fe67fc2c2..a85c0001b 100644 --- a/core/src/auth.rs +++ b/core/src/auth.rs @@ -25,19 +25,22 @@ pub trait Auth: Sync + Send { #[derive(Clone)] pub struct BasicAuth { username: String, - password: Option, + password: SensitiveString, } impl BasicAuth { - pub fn new(username: String, password: Option) -> Self { - Self { username, password } + pub fn new(username: impl ToString, password: impl ToString) -> Self { + Self { + username: username.to_string(), + password: SensitiveString(password.to_string()), + } } } #[async_trait::async_trait] impl Auth for BasicAuth { async fn wrap(&self, builder: RequestBuilder) -> Result { - Ok(builder.basic_auth(&self.username, self.password.as_deref())) + Ok(builder.basic_auth(&self.username, Some(self.password.inner()))) } fn username(&self) -> String { @@ -47,19 +50,21 @@ impl Auth for BasicAuth { #[derive(Clone)] pub struct AccessTokenAuth { - token: String, + token: SensitiveString, } impl AccessTokenAuth { - pub fn new(token: String) -> Self { - Self { token } + pub fn new(token: impl ToString) -> Self { + Self { + token: SensitiveString::from(token.to_string()), + } } } #[async_trait::async_trait] impl Auth for AccessTokenAuth { async fn wrap(&self, builder: RequestBuilder) -> Result { - Ok(builder.bearer_auth(&self.token)) + Ok(builder.bearer_auth(self.token.inner())) } fn username(&self) -> String { @@ -73,7 +78,8 @@ pub struct AccessTokenFileAuth { } impl AccessTokenFileAuth { - pub fn new(token_file: String) -> Self { + pub fn new(token_file: impl ToString) -> Self { + let token_file = token_file.to_string(); Self { token_file } } } @@ -96,3 +102,68 @@ impl Auth for AccessTokenFileAuth { "token".to_string() } } + +#[derive(::serde::Deserialize, ::serde::Serialize)] +#[serde(from = "String", into = "String")] +#[derive(Clone, Default, PartialEq, Eq)] +pub struct SensitiveString(String); + +impl From for SensitiveString { + fn from(value: String) -> Self { + Self(value) + } +} + +impl From<&str> for SensitiveString { + fn from(value: &str) -> Self { + Self(value.to_string()) + } +} + +impl From for String { + fn from(value: SensitiveString) -> Self { + value.0 + } +} + +impl std::fmt::Display for SensitiveString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "**REDACTED**") + } +} + +impl std::fmt::Debug for SensitiveString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // we keep the double quotes here to keep the String behavior + write!(f, "\"**REDACTED**\"") + } +} + +impl SensitiveString { + #[must_use] + pub fn inner(&self) -> &str { + self.0.as_str() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn serialization() { + let json_value = "\"foo\""; + let value: SensitiveString = serde_json::from_str(json_value).unwrap(); + let result: String = serde_json::to_string(&value).unwrap(); + assert_eq!(result, json_value); + } + + #[test] + fn hide_content() { + let value = SensitiveString("hello world".to_string()); + let display = format!("{value}"); + assert_eq!(display, "**REDACTED**"); + let debug = format!("{value:?}"); + assert_eq!(debug, "\"**REDACTED**\""); + } +} diff --git a/core/src/client.rs b/core/src/client.rs index 5fda2388f..ca4c3e70d 100644 --- a/core/src/client.rs +++ b/core/src/client.rs @@ -90,11 +90,9 @@ impl APIClient { } if u.username() != "" { - client.auth = Arc::new(BasicAuth::new( - u.username().to_string(), - u.password() - .map(|s| percent_decode_str(s).decode_utf8_lossy().to_string()), - )); + let password = u.password().unwrap_or_default(); + let password = percent_decode_str(password).decode_utf8()?; + client.auth = Arc::new(BasicAuth::new(u.username(), password)); } let database = match u.path().trim_start_matches('/') { "" => None, @@ -156,10 +154,10 @@ impl APIClient { client.tls_ca_file = Some(v.to_string()); } "access_token" => { - client.auth = Arc::new(AccessTokenAuth::new(v.to_string())); + client.auth = Arc::new(AccessTokenAuth::new(v)); } "access_token_file" => { - client.auth = Arc::new(AccessTokenFileAuth::new(v.to_string())); + client.auth = Arc::new(AccessTokenFileAuth::new(v)); } _ => { session_settings.insert(k.to_string(), v.to_string()); @@ -581,7 +579,7 @@ impl Default for APIClient { port: 8000, tenant: None, warehouse: Arc::new(Mutex::new(None)), - auth: Arc::new(BasicAuth::new("root".to_string(), None)) as Arc, + auth: Arc::new(BasicAuth::new("root", "")) as Arc, session_state: Arc::new(Mutex::new(SessionState::default())), wait_time_secs: None, max_rows_in_buffer: None, diff --git a/core/src/error.rs b/core/src/error.rs index f4efb8bbd..6d535cd8d 100644 --- a/core/src/error.rs +++ b/core/src/error.rs @@ -83,3 +83,9 @@ impl From for Error { Error::IO(e.to_string()) } } + +impl From for Error { + fn from(e: std::str::Utf8Error) -> Self { + Error::Parsing(e.to_string()) + } +} diff --git a/driver/src/flight_sql.rs b/driver/src/flight_sql.rs index d80f054c2..dd1d2b467 100644 --- a/driver/src/flight_sql.rs +++ b/driver/src/flight_sql.rs @@ -31,6 +31,7 @@ use tonic::transport::{Channel, ClientTlsConfig, Endpoint}; use tonic::Streaming; use url::Url; +use databend_client::auth::SensitiveString; use databend_client::presign::{presign_upload_to_stage, PresignedResponse}; use databend_sql::error::{Error, Result}; use databend_sql::rows::{Row, RowIterator, RowStatsIterator, RowWithStats, Rows, ServerStats}; @@ -166,7 +167,7 @@ impl FlightSQLConnection { } let mut client = self.client.lock().await; let _token = client - .handshake(&self.args.user, &self.args.password) + .handshake(&self.args.user, self.args.password.inner()) .await?; *handshaked = true; Ok(()) @@ -206,7 +207,7 @@ struct Args { host: String, port: u16, user: String, - password: String, + password: SensitiveString, database: Option, tenant: Option, warehouse: Option, @@ -234,7 +235,7 @@ impl Default for Args { tls: true, tls_ca_file: None, user: "root".to_string(), - password: "".to_string(), + password: SensitiveString::from(""), connect_timeout: Duration::from_secs(20), query_timeout: Duration::from_secs(60), tcp_nodelay: true, @@ -308,9 +309,9 @@ impl Args { None => format!("{}://{}:{}", scheme, host, port), }; args.user = u.username().to_string(); - args.password = percent_decode_str(u.password().unwrap_or_default()) - .decode_utf8_lossy() - .to_string(); + let password = u.password().unwrap_or_default(); + let password = percent_decode_str(password).decode_utf8()?; + args.password = SensitiveString::from(password.to_string()); Ok(args) } } diff --git a/driver/tests/driver/select_iter.rs b/driver/tests/driver/select_iter.rs index 4178ba731..9a1f5e501 100644 --- a/driver/tests/driver/select_iter.rs +++ b/driver/tests/driver/select_iter.rs @@ -226,19 +226,3 @@ async fn select_sleep() { } assert_eq!(result, vec![0]); } - -// #[tokio::test] -// async fn select_bitmap_string() { -// let (conn, _) = prepare("select_bitmap_string").await; -// let mut rows = conn -// .query_iter("select build_bitmap([1,2,3,4,5,6]), 11::String") -// .await -// .unwrap(); -// let mut result = vec![]; -// while let Some(row) = rows.next().await { -// let row: (String, String) = row.unwrap().try_into().unwrap(); -// assert!(row.0.contains('\0')); -// result.push(row.1); -// } -// assert_eq!(result, vec!["11".to_string()]); -// } diff --git a/driver/tests/driver/select_simple.rs b/driver/tests/driver/select_simple.rs index f44c72fd9..7792e2d6f 100644 --- a/driver/tests/driver/select_simple.rs +++ b/driver/tests/driver/select_simple.rs @@ -278,15 +278,15 @@ async fn select_tuple() { assert_eq!(val1, ("[1,2]".to_string(), vec![1, 2], true,)); let row2 = conn - .query_row("select (st_geometryfromwkt('SRID=4126;POINT(3.0 5.0)'), to_timestamp('2024-10-22 10:11:12'))") + .query_row("select (to_binary('xyz'), to_timestamp('2024-10-22 10:11:12'))") .await .unwrap() .unwrap(); - let (val2,): ((String, NaiveDateTime),) = row2.try_into().unwrap(); + let (val2,): ((Vec, NaiveDateTime),) = row2.try_into().unwrap(); assert_eq!( val2, ( - "SRID=4126;POINT(3 5)".to_string(), + vec![120, 121, 122], DateTime::parse_from_rfc3339("2024-10-22T10:11:12Z") .unwrap() .naive_utc() @@ -294,6 +294,34 @@ async fn select_tuple() { ); } +#[tokio::test] +async fn select_variant() { + // TODO: +} + +#[tokio::test] +async fn select_bitmap() { + // TODO: + // let (conn, _) = prepare("select_bitmap_string").await; + // let mut rows = conn + // .query_iter("select build_bitmap([1,2,3,4,5,6]), 11::String") + // .await + // .unwrap(); + // let mut result = vec![]; + // while let Some(row) = rows.next().await { + // let row: (String, String) = row.unwrap().try_into().unwrap(); + // assert!(row.0.contains('\0')); + // result.push(row.1); + // } + // assert_eq!(result, vec!["11".to_string()]); +} + +#[tokio::test] +async fn select_geometry() { + // TODO: response type changed to json after + // https://github.com/datafuselabs/databend/pull/15214 +} + #[tokio::test] async fn select_multiple_columns() { let conn = prepare().await; diff --git a/sql/src/error.rs b/sql/src/error.rs index eda306b39..295e11474 100644 --- a/sql/src/error.rs +++ b/sql/src/error.rs @@ -159,6 +159,12 @@ impl From for Error { } } +impl From for Error { + fn from(e: std::str::Utf8Error) -> Self { + Error::Parsing(e.to_string()) + } +} + impl From for Error { fn from(e: std::string::FromUtf8Error) -> Self { Error::Parsing(e.to_string()) diff --git a/sql/src/value.rs b/sql/src/value.rs index 68d2acc4f..be2f6418b 100644 --- a/sql/src/value.rs +++ b/sql/src/value.rs @@ -948,8 +948,8 @@ pub fn parse_decimal(text: &str, size: DecimalSize) -> Result { pub fn parse_geometry(raw_data: &[u8]) -> Result { let mut data = std::io::Cursor::new(raw_data); - let wkt = Ewkt::from_wkb(&mut data, WkbDialect::Ewkb); - wkt.map(|g| g.0).map_err(|e| e.into()) + let wkt = Ewkt::from_wkb(&mut data, WkbDialect::Ewkb)?; + Ok(wkt.0) } struct ValueDecoder {} diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml index 53ec75688..f773bea52 100644 --- a/tests/docker-compose.yaml +++ b/tests/docker-compose.yaml @@ -7,7 +7,7 @@ services: volumes: - ./data:/data databend: - image: docker.io/datafuselabs/databend:nightly + image: docker.io/datafuselabs/databend environment: - QUERY_STORAGE_TYPE=s3 - AWS_S3_ENDPOINT=http://localhost:9000