Skip to content

Commit

Permalink
feat(driver): support parsing to struct (#190)
Browse files Browse the repository at this point in the history
* refactor: split crates

* feat(driver): support parsing to struct
  • Loading branch information
everpcpc authored Aug 23, 2023
1 parent d6622c3 commit c47a102
Show file tree
Hide file tree
Showing 19 changed files with 402 additions and 91 deletions.
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
[workspace]
default-members = ["core", "driver", "cli"]
default-members = ["core", "sql", "driver", "macros", "cli"]
members = [
"core",
"sql",
"driver",
"macros",
"cli",
"bindings/python",
"bindings/nodejs",
Expand All @@ -20,3 +22,5 @@ repository = "https://github.com/datafuselabs/bendsql"
[workspace.dependencies]
databend-client = { path = "core", version = "0.5.2" }
databend-driver = { path = "driver", version = "0.5.2" }
databend-driver-macros = { path = "macros", version = "0.5.2" }
databend-sql = { path = "sql", version = "0.5.2" }
2 changes: 1 addition & 1 deletion bindings/nodejs/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[package]
name = "bindings-nodejs"
name = "databend-nodejs"
publish = false

version = { workspace = true }
Expand Down
11 changes: 8 additions & 3 deletions driver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,17 @@ rustls = ["databend-client/rustls"]
# Enable native-tls for TLS support
native-tls = ["databend-client/native-tls"]

flight-sql = ["dep:arrow-array", "dep:arrow-cast", "dep:arrow-flight", "dep:arrow-schema", "dep:tonic"]
flight-sql = [
"dep:arrow-flight",
"dep:arrow-schema",
"dep:tonic",
"databend-sql/flight-sql",
]

[dependencies]
databend-client = { workspace = true }
databend-driver-macros = { workspace = true }
databend-sql = { workspace = true }

async-trait = "0.1.68"
chrono = { version = "0.4.26", default-features = false, features = ["clock"] }
Expand All @@ -35,8 +42,6 @@ tokio-stream = "0.1.14"
url = { version = "2.4.0", default-features = false }

arrow = { version = "41.0.0" }
arrow-array = { version = "41.0.0", optional = true }
arrow-cast = { version = "41.0.0", features = ["prettyprint"], optional = true }
arrow-flight = { version = "41.0.0", features = ["flight-sql-experimental"], optional = true }
arrow-schema = { version = "41.0.0", optional = true }
tonic = { version = "0.9.2", default-features = false, features = [
Expand Down
8 changes: 4 additions & 4 deletions driver/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ use url::Url;
#[cfg(feature = "flight-sql")]
use crate::flight_sql::FlightSQLConnection;

use crate::error::{Error, Result};
use databend_sql::error::{Error, Result};
use databend_sql::rows::{QueryProgress, Row, RowIterator, RowProgressIterator};
use databend_sql::schema::Schema;

use crate::rest_api::RestAPIConnection;
use crate::rows::{Row, RowIterator, RowProgressIterator};
use crate::schema::Schema;
use crate::QueryProgress;

pub struct Client {
dsn: String,
Expand Down
9 changes: 6 additions & 3 deletions driver/src/flight_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ use tonic::transport::{Channel, ClientTlsConfig, Endpoint};
use tonic::Streaming;
use url::Url;

use databend_sql::error::{Error, Result};
use databend_sql::rows::{
QueryProgress, Row, RowIterator, RowProgressIterator, RowWithProgress, Rows,
};
use databend_sql::schema::Schema;

use crate::conn::{Connection, ConnectionInfo, Reader};
use crate::error::{Error, Result};
use crate::rows::{QueryProgress, Row, RowIterator, RowProgressIterator, RowWithProgress, Rows};
use crate::Schema;

#[derive(Clone)]
pub struct FlightSQLConnection {
Expand Down
18 changes: 10 additions & 8 deletions driver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@
// limitations under the License.

mod conn;
mod error;
#[cfg(feature = "flight-sql")]
mod flight_sql;
mod rest_api;
mod rows;
mod schema;
mod value;

pub use conn::{Client, Connection, ConnectionInfo};
pub use error::Error;
pub use rows::{QueryProgress, Row, RowIterator, RowProgressIterator, RowWithProgress};
pub use schema::{DataType, DecimalSize, Field, Schema, SchemaRef};
pub use value::{NumberValue, Value};

// pub use for convenience
pub use databend_sql::error::Error;
pub use databend_sql::rows::{
QueryProgress, Row, RowIterator, RowProgressIterator, RowWithProgress,
};
pub use databend_sql::schema::{DataType, DecimalSize, Field, Schema, SchemaRef};
pub use databend_sql::value::{NumberValue, Value};

pub use databend_driver_macros::TryFromRow;
8 changes: 4 additions & 4 deletions driver/src/rest_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ use databend_client::response::QueryResponse;
use databend_client::APIClient;
use tokio_stream::{Stream, StreamExt};

use databend_sql::error::{Error, Result};
use databend_sql::rows::{QueryProgress, Row, RowIterator, RowProgressIterator, RowWithProgress};
use databend_sql::schema::{Schema, SchemaRef};

use crate::conn::{Connection, ConnectionInfo, Reader};
use crate::error::{Error, Result};
use crate::rows::{Row, RowIterator, RowProgressIterator, RowWithProgress};
use crate::schema::{Schema, SchemaRef};
use crate::QueryProgress;

#[derive(Clone)]
pub struct RestAPIConnection {
Expand Down
86 changes: 81 additions & 5 deletions driver/tests/driver/select_iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use databend_driver::{Client, Connection};
use tokio_stream::StreamExt;

use databend_driver::{Client, Connection};

use crate::common::DEFAULT_DSN;

async fn prepare(name: &str) -> (Box<dyn Connection>, String) {
Expand All @@ -25,9 +26,8 @@ async fn prepare(name: &str) -> (Box<dyn Connection>, String) {
(conn, table)
}

#[tokio::test]
async fn select_iter() {
let (conn, table) = prepare("select_iter").await;
async fn prepare_data(name: &str) -> (Box<dyn Connection>, String) {
let (conn, table) = prepare(name).await;
let sql_create = format!(
"CREATE TABLE `{}` (
i64 Int64,
Expand All @@ -48,6 +48,14 @@ async fn select_iter() {
(-3, 3, 3.0, '3', '2', '2016-04-04', '2016-04-04 11:30:00')",
table
);
conn.exec(&sql_insert).await.unwrap();
(conn, table)
}

#[tokio::test]
async fn select_iter_tuple() {
let (conn, table) = prepare_data("select_iter_tuple").await;

type RowResult = (
i64,
u64,
Expand Down Expand Up @@ -92,7 +100,6 @@ async fn select_iter() {
.naive_utc(),
),
];
conn.exec(&sql_insert).await.unwrap();
let sql_select = format!("SELECT * FROM `{}`", table);
let mut rows = conn.query_iter(&sql_select).await.unwrap();
let mut row_count = 0;
Expand All @@ -107,6 +114,75 @@ async fn select_iter() {
conn.exec(&sql_drop).await.unwrap();
}

#[tokio::test]
async fn select_iter_struct() {
let (conn, table) = prepare_data("select_iter_struct").await;

use databend_driver::TryFromRow;
#[derive(TryFromRow)]
struct RowResult {
i64: i64,
u64: u64,
f64: f64,
s: String,
s2: String,
d: chrono::NaiveDate,
t: chrono::NaiveDateTime,
}

let expected: Vec<RowResult> = vec![
RowResult {
i64: -1,
u64: 1,
f64: 1.0,
s: "1".into(),
s2: "1".into(),
d: chrono::NaiveDate::from_ymd_opt(2011, 3, 6).unwrap(),
t: chrono::DateTime::parse_from_rfc3339("2011-03-06T06:20:00Z")
.unwrap()
.naive_utc(),
},
RowResult {
i64: -2,
u64: 2,
f64: 2.0,
s: "2".into(),
s2: "2".into(),
d: chrono::NaiveDate::from_ymd_opt(2012, 5, 31).unwrap(),
t: chrono::DateTime::parse_from_rfc3339("2012-05-31T11:20:00Z")
.unwrap()
.naive_utc(),
},
RowResult {
i64: -3,
u64: 3,
f64: 3.0,
s: "3".into(),
s2: "2".into(),
d: chrono::NaiveDate::from_ymd_opt(2016, 4, 4).unwrap(),
t: chrono::DateTime::parse_from_rfc3339("2016-04-04T11:30:00Z")
.unwrap()
.naive_utc(),
},
];

let sql_select = format!("SELECT * FROM `{}`", table);
let mut rows = conn.query_iter(&sql_select).await.unwrap();
let mut row_count = 0;
while let Some(row) = rows.next().await {
let v: RowResult = row.unwrap().try_into().unwrap();
let expected_row = &expected[row_count];
assert_eq!(v.i64, expected_row.i64);
assert_eq!(v.u64, expected_row.u64);
assert_eq!(v.f64, expected_row.f64);
assert_eq!(v.s, expected_row.s);
assert_eq!(v.s2, expected_row.s2);
assert_eq!(v.d, expected_row.d);
assert_eq!(v.t, expected_row.t);
row_count += 1;
}
}

#[tokio::test]
async fn select_numbers() {
let (conn, _) = prepare("select_numbers").await;
Expand Down
19 changes: 19 additions & 0 deletions macros/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[package]
name = "databend-driver-macros"
description = "Macros for Databend Driver"
categories = ["database"]
keywords = ["databend", "database", "macros"]

version = { workspace = true }
edition = { workspace = true }
license = { workspace = true }
authors = { workspace = true }
repository = { workspace = true }

[lib]
proc-macro = true

[dependencies]
proc-macro2 = "1.0"
quote = "1.0"
syn = "2.0"
66 changes: 66 additions & 0 deletions macros/src/from_row.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright 2021 Datafuse Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use proc_macro::TokenStream;
use quote::{quote, quote_spanned};
use syn::{spanned::Spanned, DeriveInput};

/// #[derive(TryFromRow)] derives TryFromRow for struct
pub fn from_row_derive(tokens_input: TokenStream) -> TokenStream {
let item = syn::parse::<DeriveInput>(tokens_input).expect("No DeriveInput");
let struct_fields = crate::parser::parse_named_fields(&item, "TryFromRow");

let struct_name = &item.ident;
let (impl_generics, ty_generics, where_clause) = item.generics.split_for_impl();

let set_fields_code = struct_fields.named.iter().map(|field| {
let field_name = &field.ident;
let field_type = &field.ty;

quote_spanned! {field.span() =>
#field_name: {
let (col_ix, col_value) = vals_iter
.next()
.unwrap(); // vals_iter size is checked before this code is reached, so
// it is safe to unwrap
let t = col_value.get_type();

<#field_type>::try_from(col_value)
.map_err(|_| Error::InvalidResponse(format!("failed converting column {} from type({:?}) to type({})", col_ix, t, std::any::type_name::<#field_type>())))?
},
}
});

let fields_count = struct_fields.named.len();
let generated = quote! {
use databend_sql::rows::Row;
use databend_sql::error::{Error, Result};

impl #impl_generics TryFrom<Row> for #struct_name #ty_generics #where_clause {
type Error = Error;
fn try_from(row: Row) -> Result<Self> {
if #fields_count != row.len() {
return Err(Error::InvalidResponse(format!("row size mismatch: expected {} columns, got {}", #fields_count, row.len())));
}
let mut vals_iter = row.into_iter().enumerate();

Ok(#struct_name {
#(#set_fields_code)*
})
}
}
};

TokenStream::from(generated)
}
23 changes: 23 additions & 0 deletions macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright 2021 Datafuse Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use proc_macro::TokenStream;

mod from_row;
mod parser;

#[proc_macro_derive(TryFromRow)]
pub fn from_row_derive(tokens_input: TokenStream) -> TokenStream {
from_row::from_row_derive(tokens_input)
}
33 changes: 33 additions & 0 deletions macros/src/parser.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright 2021 Datafuse Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use syn::{Data, DeriveInput, Fields, FieldsNamed};

/// Parses the tokens_input to a DeriveInput and returns the struct name from which it derives and
/// the named fields
pub(crate) fn parse_named_fields<'a>(
input: &'a DeriveInput,
current_derive: &str,
) -> &'a FieldsNamed {
match &input.data {
Data::Struct(data) => match &data.fields {
Fields::Named(named_fields) => named_fields,
_ => panic!(
"derive({}) works only for structs with named fields. Tuples don't need derive.",
current_derive
),
},
_ => panic!("derive({}) works only on structs!", current_derive),
}
}
Loading

0 comments on commit c47a102

Please sign in to comment.