Skip to content

Commit

Permalink
fix: change read cert to async (#175)
Browse files Browse the repository at this point in the history
* feat(driver): wrap Connection with Client

* fix: change read cert to async

* fix: add license header for pyi

* fix: update nodejs binding
  • Loading branch information
everpcpc authored Jul 28, 2023
1 parent f19733f commit e3e83d4
Show file tree
Hide file tree
Showing 19 changed files with 133 additions and 78 deletions.
3 changes: 2 additions & 1 deletion bindings/nodejs/generated.js
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,10 @@ if (!nativeBinding) {
throw new Error(`Failed to load native binding`)
}

const { Client, ConnectionInfo, Schema, Field, RowIterator, RowIteratorExt, RowOrProgress, Row, QueryProgress } = nativeBinding
const { Client, Connection, ConnectionInfo, Schema, Field, RowIterator, RowIteratorExt, RowOrProgress, Row, QueryProgress } = nativeBinding

module.exports.Client = Client
module.exports.Connection = Connection
module.exports.ConnectionInfo = ConnectionInfo
module.exports.Schema = Schema
module.exports.Field = Field
Expand Down
4 changes: 4 additions & 0 deletions bindings/nodejs/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
export class Client {
/** Create a new databend client with a given DSN. */
constructor(dsn: string)
/** Get a connection from the client. */
getConn(): Promise<Connection>
}
export class Connection {
/** Get the connection information. */
info(): Promise<ConnectionInfo>
/** Get the databend version. */
Expand Down
24 changes: 20 additions & 4 deletions bindings/nodejs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ use futures::StreamExt;
use napi::bindgen_prelude::*;

#[napi]
pub struct Client(Box<dyn databend_driver::Connection>);
pub struct Client(databend_driver::Client);

#[napi]
pub struct Connection(Box<dyn databend_driver::Connection>);

#[napi]
pub struct ConnectionInfo(databend_driver::ConnectionInfo);
Expand Down Expand Up @@ -267,11 +270,24 @@ impl QueryProgress {
impl Client {
/// Create a new databend client with a given DSN.
#[napi(constructor)]
pub fn new(dsn: String) -> Result<Self> {
let conn = databend_driver::new_connection(&dsn).map_err(format_napi_error)?;
Ok(Self(conn))
pub fn new(dsn: String) -> Self {
let client = databend_driver::Client::new(dsn);
Self(client)
}

/// Get a connection from the client.
#[napi]
pub async fn get_conn(&self) -> Result<Connection> {
self.0
.get_conn()
.await
.map(|conn| Connection(conn))
.map_err(format_napi_error)
}
}

#[napi]
impl Connection {
/// Get the connection information.
#[napi]
pub async fn info(&self) -> ConnectionInfo {
Expand Down
19 changes: 10 additions & 9 deletions bindings/nodejs/tests/binding.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,19 @@ const dsn = process.env.TEST_DATABEND_DSN
? process.env.TEST_DATABEND_DSN
: "databend://root:@localhost:8000/default?sslmode=disable";

Given("A new Databend Driver Client", function () {
Given("A new Databend Driver Client", async function () {
this.client = new Client(dsn);
this.conn = await this.client.getConn();
});

Then("Select string {string} should be equal to {string}", async function (input, output) {
const row = await this.client.queryRow(`SELECT '${input}'`);
const row = await this.conn.queryRow(`SELECT '${input}'`);
const value = row.values()[0];
assert.equal(output, value);
});

Then("Select numbers should iterate all rows", async function () {
let rows = await this.client.queryIter("SELECT number FROM numbers(5)");
let rows = await this.conn.queryIter("SELECT number FROM numbers(5)");
let ret = [];
let row = await rows.next();
while (row) {
Expand All @@ -45,8 +46,8 @@ Then("Select numbers should iterate all rows", async function () {
});

When("Create a test table", async function () {
await this.client.exec("DROP TABLE IF EXISTS test");
await this.client.exec(`CREATE TABLE test (
await this.conn.exec("DROP TABLE IF EXISTS test");
await this.conn.exec(`CREATE TABLE test (
i64 Int64,
u64 UInt64,
f64 Float64,
Expand All @@ -58,11 +59,11 @@ When("Create a test table", async function () {
});

Then("Insert and Select should be equal", async function () {
await this.client.exec(`INSERT INTO test VALUES
await this.conn.exec(`INSERT INTO test VALUES
(-1, 1, 1.0, '1', '1', '2011-03-06', '2011-03-06 06:20:00'),
(-2, 2, 2.0, '2', '2', '2012-05-31', '2012-05-31 11:20:00'),
(-3, 3, 3.0, '3', '2', '2016-04-04', '2016-04-04 11:30:00')`);
const rows = await this.client.queryIter("SELECT * FROM test");
const rows = await this.conn.queryIter("SELECT * FROM test");
const ret = [];
let row = await rows.next();
while (row) {
Expand All @@ -83,11 +84,11 @@ Then("Stream load and Select should be equal", async function () {
["-2", "2", "2.0", "2", "2", "2012-05-31", "2012-05-31T11:20:00Z"],
["-3", "3", "3.0", "3", "2", "2016-04-04", "2016-04-04T11:30:00Z"],
];
const progress = await this.client.streamLoad(`INSERT INTO test VALUES`, values);
const progress = await this.conn.streamLoad(`INSERT INTO test VALUES`, values);
assert.equal(progress.writeRows, 3);
assert.equal(progress.writeBytes, 178);

const rows = await this.client.queryIter("SELECT * FROM test");
const rows = await this.conn.queryIter("SELECT * FROM test");
const ret = [];
let row = await rows.next();
while (row) {
Expand Down
14 changes: 14 additions & 0 deletions bindings/python/python/databend_driver/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright 2021 Datafuse Labs
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

class AsyncDatabendDriver:
def __init__(self, dsn: str): ... # NOQA
async def exec(self, sql: str) -> int: ... # NOQA
5 changes: 3 additions & 2 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ mod asyncio;

use crate::asyncio::*;

use databend_driver::{new_connection, Connection};
use databend_driver::{Client, Connection};

use pyo3::create_exception;
use pyo3::exceptions::PyException;
Expand All @@ -39,7 +39,8 @@ pub type FusedConnector = Arc<dyn Connection>;
// For bindings
impl Connector {
pub fn new_connector(dsn: &str) -> Result<Box<Self>, Error> {
let conn = new_connection(dsn).unwrap();
let client = Client::new(dsn.to_string());
let conn = futures::executor::block_on(client.get_conn()).unwrap();
let r = Self {
connector: FusedConnector::from(conn),
};
Expand Down
11 changes: 6 additions & 5 deletions cli/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::sync::Arc;

use anyhow::anyhow;
use anyhow::Result;
use databend_driver::{new_connection, Connection};
use databend_driver::{Client, Connection};
use rustyline::config::Builder;
use rustyline::error::ReadlineError;
use rustyline::history::DefaultHistory;
Expand All @@ -36,7 +36,7 @@ use crate::helper::CliHelper;
use crate::VERSION;

pub struct Session {
dsn: String,
client: Client,
conn: Box<dyn Connection>,
is_repl: bool,

Expand All @@ -47,7 +47,8 @@ pub struct Session {

impl Session {
pub async fn try_new(dsn: String, settings: Settings, is_repl: bool) -> Result<Self> {
let conn = new_connection(&dsn)?;
let client = Client::new(dsn);
let conn = client.get_conn().await?;
let info = conn.info().await;
if is_repl {
println!("Welcome to BendSQL {}.", VERSION.as_str());
Expand All @@ -61,7 +62,7 @@ impl Session {
}

Ok(Self {
dsn,
client,
conn,
is_repl,
settings,
Expand Down Expand Up @@ -366,7 +367,7 @@ impl Session {
}

async fn reconnect(&mut self) -> Result<()> {
self.conn = new_connection(&self.dsn)?;
self.conn = self.client.get_conn().await?;
if self.is_repl {
let info = self.conn.info().await;
eprintln!(
Expand Down
22 changes: 11 additions & 11 deletions core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ pub struct APIClient {
}

impl APIClient {
pub fn from_dsn(dsn: &str) -> Result<Self> {
pub async fn from_dsn(dsn: &str) -> Result<Self> {
let u = Url::parse(dsn)?;
let mut client = Self::default();
if let Some(host) = u.host_str() {
Expand Down Expand Up @@ -166,7 +166,7 @@ impl APIClient {
#[cfg(any(feature = "rustls", feature = "native-tls"))]
if scheme == "https" {
if let Some(ref ca_file) = client.tls_ca_file {
let cert_pem = std::fs::read(ca_file)?;
let cert_pem = tokio::fs::read(ca_file).await?;
let cert = reqwest::Certificate::from_pem(&cert_pem)?;
client.cli = HttpClient::builder().add_root_certificate(cert).build()?;
}
Expand Down Expand Up @@ -543,10 +543,10 @@ impl Default for APIClient {
mod test {
use super::*;

#[test]
fn parse_dsn() -> Result<()> {
#[tokio::test]
async fn parse_dsn() -> Result<()> {
let dsn = "databend://username:password@app.databend.com/test?wait_time_secs=10&max_rows_in_buffer=5000000&max_rows_per_page=10000&warehouse=wh&sslmode=disable";
let client = APIClient::from_dsn(dsn)?;
let client = APIClient::from_dsn(dsn).await?;
assert_eq!(client.host, "app.databend.com");
assert_eq!(client.endpoint, Url::parse("http://app.databend.com:80")?);
assert_eq!(client.user, "username");
Expand All @@ -566,18 +566,18 @@ mod test {
Ok(())
}

#[test]
fn parse_encoded_password() -> Result<()> {
#[tokio::test]
async fn parse_encoded_password() -> Result<()> {
let dsn = "databend://username:3a%40SC(nYE1k%3D%7B%7BR@localhost";
let client = APIClient::from_dsn(dsn)?;
let client = APIClient::from_dsn(dsn).await?;
assert_eq!(client.password, Some("3a@SC(nYE1k={{R".to_string()));
Ok(())
}

#[test]
fn parse_special_chars_password() -> Result<()> {
#[tokio::test]
async fn parse_special_chars_password() -> Result<()> {
let dsn = "databend://username:3a@SC(nYE1k={{R@localhost:8000";
let client = APIClient::from_dsn(dsn)?;
let client = APIClient::from_dsn(dsn).await?;
assert_eq!(client.password, Some("3a@SC(nYE1k={{R".to_string()));
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion core/tests/core/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::common::DEFAULT_DSN;
#[tokio::test]
async fn select_simple() {
let dsn = option_env!("TEST_DATABEND_DSN").unwrap_or(DEFAULT_DSN);
let client = APIClient::from_dsn(dsn).unwrap();
let client = APIClient::from_dsn(dsn).await.unwrap();
let resp = client.query("select 15532").await.unwrap();
assert_eq!(resp.data, [["15532"]]);
}
6 changes: 4 additions & 2 deletions core/tests/core/stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ use crate::common::DEFAULT_DSN;
async fn insert_with_stage(presigned: bool) {
let dsn = option_env!("TEST_DATABEND_DSN").unwrap_or(DEFAULT_DSN);
let client = if presigned {
APIClient::from_dsn(dsn).unwrap()
APIClient::from_dsn(dsn).await.unwrap()
} else {
APIClient::from_dsn(&format!("{}&presigned_url_disabled=1", dsn)).unwrap()
APIClient::from_dsn(&format!("{}&presigned_url_disabled=1", dsn))
.await
.unwrap()
};

let file = File::open("tests/core/data/sample.csv").await.unwrap();
Expand Down
7 changes: 4 additions & 3 deletions driver/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ Databend unified SQL client for RestAPI and FlightSQL
### exec

```rust
use databend_driver::new_connection;
use databend_driver::Client;

let dsn = "databend://root:@localhost:8000/default?sslmode=disable";
let conn = new_connection(dsn).unwrap();
let dsn = "databend://root:@localhost:8000/default?sslmode=disable".to_string();
let client = Client::new(dsn);
let conn = client.get_conn().await.unwrap();

let sql_create = "CREATE TABLE books (
title VARCHAR,
Expand Down
48 changes: 29 additions & 19 deletions driver/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,35 @@ use crate::rows::{Row, RowIterator, RowProgressIterator};
use crate::schema::Schema;
use crate::QueryProgress;

pub struct Client {
dsn: String,
}

impl<'c> Client {
pub fn new(dsn: String) -> Self {
Self { dsn }
}

pub async fn get_conn(&self) -> Result<Box<dyn Connection>> {
let u = Url::parse(&self.dsn)?;
match u.scheme() {
"databend" | "databend+http" | "databend+https" => {
let conn = RestAPIConnection::try_create(&self.dsn).await?;
Ok(Box::new(conn))
}
#[cfg(feature = "flight-sql")]
"databend+flight" | "databend+grpc" => {
let conn = FlightSQLConnection::try_create(&self.dsn).await?;
Ok(Box::new(conn))
}
_ => Err(Error::Parsing(format!(
"Unsupported scheme: {}",
u.scheme()
))),
}
}
}

pub struct ConnectionInfo {
pub handler: String,
pub host: String,
Expand Down Expand Up @@ -70,22 +99,3 @@ pub trait Connection: DynClone + Send + Sync {
) -> Result<QueryProgress>;
}
dyn_clone::clone_trait_object!(Connection);

pub fn new_connection(dsn: &str) -> Result<Box<dyn Connection>> {
let u = Url::parse(dsn)?;
match u.scheme() {
"databend" | "databend+http" | "databend+https" => {
let conn = RestAPIConnection::try_create(dsn)?;
Ok(Box::new(conn))
}
#[cfg(feature = "flight-sql")]
"databend+flight" | "databend+grpc" => {
let conn = FlightSQLConnection::try_create(dsn)?;
Ok(Box::new(conn))
}
_ => Err(Error::Parsing(format!(
"Unsupported scheme: {}",
u.scheme()
))),
}
}
8 changes: 4 additions & 4 deletions driver/src/flight_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ impl Connection for FlightSQLConnection {
}

impl FlightSQLConnection {
pub fn try_create(dsn: &str) -> Result<Self> {
let (args, endpoint) = Self::parse_dsn(dsn)?;
pub async fn try_create(dsn: &str) -> Result<Self> {
let (args, endpoint) = Self::parse_dsn(dsn).await?;
let channel = endpoint.connect_lazy();
let mut client = FlightSqlServiceClient::new(channel);
// enable progress
Expand Down Expand Up @@ -137,7 +137,7 @@ impl FlightSQLConnection {
Ok(())
}

fn parse_dsn(dsn: &str) -> Result<(Args, Endpoint)> {
async fn parse_dsn(dsn: &str) -> Result<(Args, Endpoint)> {
let u = Url::parse(dsn)?;
let args = Args::from_url(&u)?;
let mut endpoint = Endpoint::new(args.uri.clone())?
Expand All @@ -153,7 +153,7 @@ impl FlightSQLConnection {
let tls_config = match args.tls_ca_file {
None => ClientTlsConfig::new(),
Some(ref ca_file) => {
let pem = std::fs::read(ca_file)?;
let pem = tokio::fs::read(ca_file).await?;
let cert = tonic::transport::Certificate::from_pem(pem);
ClientTlsConfig::new().ca_certificate(cert)
}
Expand Down
2 changes: 1 addition & 1 deletion driver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ mod rows;
mod schema;
mod value;

pub use conn::{new_connection, Connection, ConnectionInfo};
pub use conn::{Client, Connection, ConnectionInfo};
pub use error::Error;
pub use rows::{QueryProgress, Row, RowIterator, RowProgressIterator, RowWithProgress};
pub use schema::{DataType, DecimalSize, Field, Schema, SchemaRef};
Expand Down
Loading

0 comments on commit e3e83d4

Please sign in to comment.