Skip to content

Commit

Permalink
feat(cli): support setting --role
Browse files Browse the repository at this point in the history
  • Loading branch information
everpcpc committed Nov 20, 2023
1 parent 7718530 commit bea3ccf
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 15 deletions.
9 changes: 9 additions & 0 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ struct Args {

#[clap(short = 'l', default_value = "info", long)]
log_level: String,

#[clap(short = 'r', long, help = "Downgrade role name")]
role: Option<String>,
}

/// Parse a single key-value pair
Expand Down Expand Up @@ -266,6 +269,9 @@ pub async fn main() -> Result<()> {
if args.database.is_some() {
eprintln!("warning: --database is ignored when --dsn is set");
}
if args.role.is_some() {
eprintln!("warning: --role is ignored when --dsn is set");
}
if !args.set.is_empty() {
eprintln!("warning: --set is ignored when --dsn is set");
}
Expand Down Expand Up @@ -293,6 +299,9 @@ pub async fn main() -> Result<()> {
for (k, v) in args.set {
config.connection.args.insert(k, v);
}
if let Some(role) = args.role {
config.connection.args.insert("role".to_string(), role);
}
let conn_args = ConnectionArgs {
host: config.connection.host.clone(),
port: config.connection.port,
Expand Down
25 changes: 10 additions & 15 deletions core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ pub struct APIClient {

tenant: Option<String>,
warehouse: Arc<Mutex<Option<String>>>,
database: Arc<Mutex<Option<String>>>,
session_state: Arc<Mutex<SessionState>>,

wait_time_secs: Option<i64>,
Expand Down Expand Up @@ -87,7 +86,7 @@ impl APIClient {
"" => None,
s => Some(s.to_string()),
};
client.database = Arc::new(Mutex::new(database.clone()));
let mut role = None;
let mut scheme = "https";
let mut session_settings = BTreeMap::new();
for (k, v) in u.query_pairs() {
Expand Down Expand Up @@ -125,6 +124,7 @@ impl APIClient {
"warehouse" => {
client.warehouse = Arc::new(Mutex::new(Some(v.to_string())));
}
"role" => role = Some(v.to_string()),
"sslmode" => match v.as_ref() {
"disable" => scheme = "http",
"require" | "enable" => scheme = "https",
Expand Down Expand Up @@ -169,6 +169,7 @@ impl APIClient {
client.session_state = Arc::new(Mutex::new(
SessionState::default()
.with_settings(Some(session_settings))
.with_role(role)
.with_database(database),
));
Ok(client)
Expand All @@ -180,8 +181,13 @@ impl APIClient {
}

pub async fn current_database(&self) -> Option<String> {
let guard = self.database.lock().await;
guard.clone()
let guard = self.session_state.lock().await;
guard.database.clone()
}

pub async fn current_role(&self) -> Option<String> {
let guard = self.session_state.lock().await;
guard.role.clone()
}

fn gen_query_id(&self) -> String {
Expand All @@ -200,12 +206,6 @@ impl APIClient {
*session_state = session.clone();
}

// process database changed via session.db
if session.database.is_some() {
let mut database = self.database.lock().await;
*database = session.database.clone();
}

// process warehouse changed via session settings
if let Some(settings) = session.settings.as_ref() {
if let Some(v) = settings.get("warehouse") {
Expand Down Expand Up @@ -533,7 +533,6 @@ impl Default for APIClient {
port: 8000,
tenant: None,
warehouse: Arc::new(Mutex::new(None)),
database: Arc::new(Mutex::new(None)),
user: "root".to_string(),
password: None,
session_state: Arc::new(Mutex::new(SessionState::default())),
Expand All @@ -559,10 +558,6 @@ mod test {
assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
assert_eq!(client.user, "username");
assert_eq!(client.password, Some("password".to_string()));
assert_eq!(
*client.database.try_lock().unwrap(),
Some("test".to_string())
);
assert_eq!(client.wait_time_secs, Some(10));
assert_eq!(client.max_rows_in_buffer, Some(5000000));
assert_eq!(client.max_rows_per_page, Some(10000));
Expand Down

0 comments on commit bea3ccf

Please sign in to comment.