diff --git a/cli/tests/00-base.result b/cli/tests/00-base.result index 3a6b33f85..3c3d7f05d 100644 --- a/cli/tests/00-base.result +++ b/cli/tests/00-base.result @@ -5,4 +5,5 @@ a 1 true 2 3 with comment +3.00 3.00 bye diff --git a/cli/tests/00-base.sql b/cli/tests/00-base.sql index a932ae7c7..19a83b5a8 100644 --- a/cli/tests/00-base.sql +++ b/cli/tests/00-base.sql @@ -17,5 +17,7 @@ select /* ignore this block */ 'with comment'; select 'in comment block'; */ +select 1.00 + 2.00, 3.00; + select 'bye'; drop table test; diff --git a/driver/Cargo.toml b/driver/Cargo.toml index f444d5d63..02e907660 100644 --- a/driver/Cargo.toml +++ b/driver/Cargo.toml @@ -16,7 +16,7 @@ rustls = ["databend-client/rustls"] # Enable native-tls for TLS support native-tls = ["databend-client/native-tls"] -flight-sql = ["dep:arrow", "dep:arrow-array", "dep:arrow-cast", "dep:arrow-flight", "dep:arrow-schema", "dep:tonic"] +flight-sql = ["dep:arrow-array", "dep:arrow-cast", "dep:arrow-flight", "dep:arrow-schema", "dep:tonic"] [dependencies] async-trait = "0.1.68" @@ -30,7 +30,7 @@ tokio = { version = "1.27.0", features = ["macros"] } tokio-stream = "0.1.12" url = { version = "2.3.1", default-features = false } -arrow = { version = "38.0.0", optional = true } +arrow = { version = "38.0.0" } arrow-array = { version = "38.0.0", optional = true } arrow-cast = { version = "38.0.0", features = ["prettyprint"], optional = true } arrow-flight = { version = "38.0.0", features = ["flight-sql-experimental"], optional = true } diff --git a/driver/src/lib.rs b/driver/src/lib.rs index fa8dd9674..1793ba0b8 100644 --- a/driver/src/lib.rs +++ b/driver/src/lib.rs @@ -22,5 +22,5 @@ mod value; pub use conn::{new_connection, Connection}; pub use rows::{QueryProgress, Row, RowIterator, RowProgressIterator, RowWithProgress}; -pub use schema::{DataType, Schema, SchemaRef}; +pub use schema::{DataType, DecimalSize, Schema, SchemaRef}; pub use value::{NumberValue, Value}; diff --git a/driver/src/schema.rs b/driver/src/schema.rs index 3e0a4c136..c6879b622 100644 --- a/driver/src/schema.rs +++ b/driver/src/schema.rs @@ -35,17 +35,26 @@ pub enum NumberDataType { Float64, } -// #[derive(Debug, Clone, PartialEq, Eq)] -// pub struct DecimalSize { -// pub precision: u8, -// pub scale: u8, -// } - -// #[derive(Debug, Clone, PartialEq, Eq)] -// pub enum DecimalDataType { -// Decimal128(DecimalSize), -// Decimal256(DecimalSize), -// } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct DecimalSize { + pub precision: u8, + pub scale: u8, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum DecimalDataType { + Decimal128(DecimalSize), + Decimal256(DecimalSize), +} + +impl DecimalDataType { + pub fn decimal_size(&self) -> &DecimalSize { + match self { + DecimalDataType::Decimal128(size) => size, + DecimalDataType::Decimal256(size) => size, + } + } +} #[derive(Debug, Clone)] pub enum DataType { @@ -55,9 +64,7 @@ pub enum DataType { Boolean, String, Number(NumberDataType), - Decimal, - // TODO:(everpcpc) fix Decimal internal type - // Decimal(DecimalDataType), + Decimal(DecimalDataType), Timestamp, Date, Nullable(Box), @@ -98,7 +105,10 @@ impl std::fmt::Display for DataType { NumberDataType::Float32 => write!(f, "Float32"), NumberDataType::Float64 => write!(f, "Float64"), }, - DataType::Decimal => write!(f, "Decimal"), + DataType::Decimal(d) => { + let size = d.decimal_size(); + write!(f, "Decimal({}, {})", size.precision, size.scale) + } DataType::Timestamp => write!(f, "Timestamp"), DataType::Date => write!(f, "Date"), DataType::Nullable(inner) => write!(f, "Nullable({})", inner), @@ -152,7 +162,22 @@ impl TryFrom<&TypeDesc<'_>> for DataType { "UInt64" => DataType::Number(NumberDataType::UInt64), "Float32" => DataType::Number(NumberDataType::Float32), "Float64" => DataType::Number(NumberDataType::Float64), - "Decimal" => DataType::Decimal, + "Decimal" => { + let precision = desc.args[0].name.parse::()?; + let scale = desc.args[1].name.parse::()?; + + if precision <= 38 { + DataType::Decimal(DecimalDataType::Decimal128(DecimalSize { + precision, + scale, + })) + } else { + DataType::Decimal(DecimalDataType::Decimal256(DecimalSize { + precision, + scale, + })) + } + } "Timestamp" => DataType::Timestamp, "Date" => DataType::Date, "Nullable" => { @@ -247,8 +272,18 @@ impl TryFrom<&Arc> for Field { | ArrowDataType::FixedSizeBinary(_) => DataType::String, ArrowDataType::Timestamp(_, _) => DataType::Timestamp, ArrowDataType::Date32 => DataType::Date, - ArrowDataType::Decimal128(_, _) => DataType::Decimal, - ArrowDataType::Decimal256(_, _) => DataType::Decimal, + ArrowDataType::Decimal128(p, s) => { + DataType::Decimal(DecimalDataType::Decimal128(DecimalSize { + precision: *p, + scale: *s as u8, + })) + } + ArrowDataType::Decimal256(p, s) => { + DataType::Decimal(DecimalDataType::Decimal256(DecimalSize { + precision: *p, + scale: *s as u8, + })) + } _ => { return Err(Error::Parsing(format!( "Unsupported datatype for arrow field: {:?}", @@ -356,6 +391,23 @@ mod test { args: vec![], }, }, + TestCase { + desc: "decimal type", + input: "Decimal(42, 42)", + output: TypeDesc { + name: "Decimal", + args: vec![ + TypeDesc { + name: "42", + args: vec![], + }, + TypeDesc { + name: "42", + args: vec![], + }, + ], + }, + }, TestCase { desc: "nullable type", input: "Nullable(Nothing)", diff --git a/driver/src/value.rs b/driver/src/value.rs index dc903593f..abb3ac85b 100644 --- a/driver/src/value.rs +++ b/driver/src/value.rs @@ -12,9 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +use arrow::datatypes::i256; +use arrow_array::{ArrowNativeTypeOp, Decimal128Array, Decimal256Array}; use chrono::{Datelike, NaiveDate, NaiveDateTime}; -use crate::error::{ConvertError, Error, Result}; +use crate::{ + error::{ConvertError, Error, Result}, + schema::{DecimalDataType, DecimalSize}, +}; +use std::fmt::Write; // Thu 1970-01-01 is R.D. 719163 const DAYS_FROM_CE: i32 = 719_163; @@ -45,16 +51,16 @@ pub enum NumberValue { UInt64(u64), Float32(f32), Float64(f64), + Decimal128(i128, DecimalSize), + Decimal256(i256, DecimalSize), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq)] pub enum Value { Null, Boolean(bool), String(String), Number(NumberValue), - // TODO:(everpcpc) Decimal(DecimalValue), - // Decimal(String), /// Microseconds from 1970-01-01 00:00:00 UTC Timestamp(i64), Date(i32), @@ -82,9 +88,11 @@ impl Value { NumberValue::UInt64(_) => DataType::Number(NumberDataType::UInt64), NumberValue::Float32(_) => DataType::Number(NumberDataType::Float32), NumberValue::Float64(_) => DataType::Number(NumberDataType::Float64), + NumberValue::Decimal128(_, s) => DataType::Decimal(DecimalDataType::Decimal128(*s)), + NumberValue::Decimal256(_, s) => DataType::Decimal(DecimalDataType::Decimal256(*s)), }, - // Self::Decimal(_) => DataType::Decimal, Self::Timestamp(_) => DataType::Timestamp, + Self::Date(_) => DataType::Date, // TODO:(everpcpc) fix nested type // Self::Array(v) => DataType::Array(Box::new(v[0].get_type())), @@ -133,7 +141,17 @@ impl TryFrom<(&DataType, &str)> for Value { DataType::Number(NumberDataType::Float64) => { Ok(Self::Number(NumberValue::Float64(v.parse()?))) } - // DataType::Decimal => Ok(Self::Decimal(v)), + + DataType::Decimal(DecimalDataType::Decimal128(size)) => { + let d = parse_decimal(v, *size)?; + Ok(Self::Number(d)) + } + + DataType::Decimal(DecimalDataType::Decimal256(size)) => { + let d = parse_decimal(v, *size)?; + Ok(Self::Number(d)) + } + DataType::Timestamp => Ok(Self::Timestamp( chrono::NaiveDateTime::parse_from_str(v, "%Y-%m-%d %H:%M:%S%.6f")? .timestamp_micros(), @@ -210,6 +228,32 @@ impl TryFrom<(&ArrowField, &Arc, usize)> for Value { None => Err(ConvertError::new("float64", format!("{:?}", array)).into()), }, + ArrowDataType::Decimal128(p, s) => { + match array.as_any().downcast_ref::() { + Some(array) => Ok(Value::Number(NumberValue::Decimal128( + array.value(seq), + DecimalSize { + precision: *p, + scale: *s as u8, + }, + ))), + None => Err(ConvertError::new("Decimal128", format!("{:?}", array)).into()), + } + } + + ArrowDataType::Decimal256(p, s) => { + match array.as_any().downcast_ref::() { + Some(array) => Ok(Value::Number(NumberValue::Decimal256( + array.value(seq), + DecimalSize { + precision: *p, + scale: *s as u8, + }, + ))), + None => Err(ConvertError::new("Decimal256", format!("{:?}", array)).into()), + } + } + ArrowDataType::Binary => match array.as_any().downcast_ref::() { Some(array) => Ok(Value::String(String::from_utf8(array.value(seq).to_vec())?)), None => Err(ConvertError::new("binary", format!("{:?}", array)).into()), @@ -418,6 +462,8 @@ impl std::fmt::Display for NumberValue { NumberValue::UInt64(i) => write!(f, "{}", i), NumberValue::Float32(i) => write!(f, "{}", i), NumberValue::Float64(i) => write!(f, "{}", i), + NumberValue::Decimal128(v, s) => write!(f, "{}", display_decimal_128(*v, s.scale)), + NumberValue::Decimal256(v, s) => write!(f, "{}", display_decimal_256(*v, s.scale)), } } } @@ -443,3 +489,116 @@ impl std::fmt::Display for Value { } } } + +pub fn display_decimal_128(num: i128, scale: u8) -> String { + let mut buf = String::new(); + if scale == 0 { + write!(buf, "{}", num).unwrap(); + } else { + let pow_scale = 10_i128.pow(scale as u32); + if num >= 0 { + write!( + buf, + "{}.{:0>width$}", + num / pow_scale, + (num % pow_scale).abs(), + width = scale as usize + ) + .unwrap(); + } else { + write!( + buf, + "-{}.{:0>width$}", + -num / pow_scale, + (num % pow_scale).abs(), + width = scale as usize + ) + .unwrap(); + } + } + buf +} + +pub fn display_decimal_256(num: i256, scale: u8) -> String { + let mut buf = String::new(); + if scale == 0 { + write!(buf, "{}", num).unwrap(); + } else { + let pow_scale = i256::from_i128(10i128).pow_wrapping(scale as u32); + // -1/10 = 0 + if num >= i256::ZERO { + write!( + buf, + "{}.{:0>width$}", + num / pow_scale, + (num % pow_scale).wrapping_abs(), + width = scale as usize + ) + .unwrap(); + } else { + write!( + buf, + "-{}.{:0>width$}", + -num / pow_scale, + (num % pow_scale).wrapping_abs(), + width = scale as usize + ) + .unwrap(); + } + } + buf +} + +/// assume text is from +/// used only for expr, so put more weight on readability +pub fn parse_decimal(text: &str, size: DecimalSize) -> Result { + let mut start = 0; + let bytes = text.as_bytes(); + while bytes[start] == b'0' { + start += 1 + } + let text = &text[start..]; + let point_pos = text.find('.'); + let e_pos = text.find(|c| c == 'e' || c == 'E'); + let (i_part, f_part, e_part) = match (point_pos, e_pos) { + (Some(p1), Some(p2)) => (&text[..p1], &text[(p1 + 1)..p2], Some(&text[(p2 + 1)..])), + (Some(p), None) => (&text[..p], &text[(p + 1)..], None), + (None, Some(p)) => (&text[..p], "", Some(&text[(p + 1)..])), + _ => { + unreachable!() + } + }; + let exp = match e_part { + Some(s) => s.parse::()?, + None => 0, + }; + if i_part.len() as i32 + exp > 76 { + Err(ConvertError::new("decimal", format!("{:?}", text)).into()) + } else { + let mut digits = Vec::with_capacity(76); + digits.extend_from_slice(i_part.as_bytes()); + digits.extend_from_slice(f_part.as_bytes()); + if digits.is_empty() { + digits.push(b'0') + } + let scale = f_part.len() as i32 - exp; + if scale < 0 { + // e.g 123.1e3 + for _ in 0..(-scale) { + digits.push(b'0') + } + }; + + let precision = std::cmp::min(digits.len(), 76); + let digits = unsafe { std::str::from_utf8_unchecked(&digits[..precision]) }; + + if size.precision > 38 { + Ok(NumberValue::Decimal256( + i256::from_string(digits).unwrap(), + size, + )) + } else { + Ok(NumberValue::Decimal128(digits.parse::()?, size)) + } + } +} diff --git a/driver/tests/driver/select_simple.rs b/driver/tests/driver/select_simple.rs index 8af5cdc80..b9b47a386 100644 --- a/driver/tests/driver/select_simple.rs +++ b/driver/tests/driver/select_simple.rs @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::assert_eq; + use chrono::{DateTime, NaiveDate, NaiveDateTime}; -use databend_driver::{new_connection, Connection}; +use databend_driver::{new_connection, Connection, DecimalSize, NumberValue, Value}; use crate::common::DEFAULT_DSN; @@ -118,6 +120,36 @@ async fn select_datetime() { } } +#[tokio::test] +async fn select_decimal() { + let conn = prepare().await; + let row = conn + .query_row("select 1::Decimal(15,2), 2.0 + 3.0") + .await + .unwrap(); + assert!(row.is_some()); + let row = row.unwrap(); + assert_eq!( + row.values().to_owned(), + vec![ + Value::Number(NumberValue::Decimal128( + 100i128, + DecimalSize { + precision: 15, + scale: 2 + }, + )), + Value::Number(NumberValue::Decimal128( + 50i128, + DecimalSize { + precision: 3, + scale: 1 + }, + )), + ] + ); +} + #[tokio::test] async fn select_nullable() { let conn = prepare().await;