Skip to content

Commit

Permalink
feat(query): support decimal (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
sundy-li authored May 16, 2023
1 parent 027f639 commit b3d6e3e
Show file tree
Hide file tree
Showing 7 changed files with 274 additions and 28 deletions.
1 change: 1 addition & 0 deletions cli/tests/00-base.result
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ a 1 true
2
3
with comment
3.00 3.00
bye
2 changes: 2 additions & 0 deletions cli/tests/00-base.sql
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@ select /* ignore this block */ 'with comment';
select 'in comment block';
*/

select 1.00 + 2.00, 3.00;

select 'bye';
drop table test;
4 changes: 2 additions & 2 deletions driver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ rustls = ["databend-client/rustls"]
# Enable native-tls for TLS support
native-tls = ["databend-client/native-tls"]

flight-sql = ["dep:arrow", "dep:arrow-array", "dep:arrow-cast", "dep:arrow-flight", "dep:arrow-schema", "dep:tonic"]
flight-sql = ["dep:arrow-array", "dep:arrow-cast", "dep:arrow-flight", "dep:arrow-schema", "dep:tonic"]

[dependencies]
async-trait = "0.1.68"
Expand All @@ -30,7 +30,7 @@ tokio = { version = "1.27.0", features = ["macros"] }
tokio-stream = "0.1.12"
url = { version = "2.3.1", default-features = false }

arrow = { version = "38.0.0", optional = true }
arrow = { version = "38.0.0" }
arrow-array = { version = "38.0.0", optional = true }
arrow-cast = { version = "38.0.0", features = ["prettyprint"], optional = true }
arrow-flight = { version = "38.0.0", features = ["flight-sql-experimental"], optional = true }
Expand Down
2 changes: 1 addition & 1 deletion driver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ mod value;

pub use conn::{new_connection, Connection};
pub use rows::{QueryProgress, Row, RowIterator, RowProgressIterator, RowWithProgress};
pub use schema::{DataType, Schema, SchemaRef};
pub use schema::{DataType, DecimalSize, Schema, SchemaRef};
pub use value::{NumberValue, Value};
88 changes: 70 additions & 18 deletions driver/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,26 @@ pub enum NumberDataType {
Float64,
}

// #[derive(Debug, Clone, PartialEq, Eq)]
// pub struct DecimalSize {
// pub precision: u8,
// pub scale: u8,
// }

// #[derive(Debug, Clone, PartialEq, Eq)]
// pub enum DecimalDataType {
// Decimal128(DecimalSize),
// Decimal256(DecimalSize),
// }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct DecimalSize {
pub precision: u8,
pub scale: u8,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DecimalDataType {
Decimal128(DecimalSize),
Decimal256(DecimalSize),
}

impl DecimalDataType {
pub fn decimal_size(&self) -> &DecimalSize {
match self {
DecimalDataType::Decimal128(size) => size,
DecimalDataType::Decimal256(size) => size,
}
}
}

#[derive(Debug, Clone)]
pub enum DataType {
Expand All @@ -55,9 +64,7 @@ pub enum DataType {
Boolean,
String,
Number(NumberDataType),
Decimal,
// TODO:(everpcpc) fix Decimal internal type
// Decimal(DecimalDataType),
Decimal(DecimalDataType),
Timestamp,
Date,
Nullable(Box<DataType>),
Expand Down Expand Up @@ -98,7 +105,10 @@ impl std::fmt::Display for DataType {
NumberDataType::Float32 => write!(f, "Float32"),
NumberDataType::Float64 => write!(f, "Float64"),
},
DataType::Decimal => write!(f, "Decimal"),
DataType::Decimal(d) => {
let size = d.decimal_size();
write!(f, "Decimal({}, {})", size.precision, size.scale)
}
DataType::Timestamp => write!(f, "Timestamp"),
DataType::Date => write!(f, "Date"),
DataType::Nullable(inner) => write!(f, "Nullable({})", inner),
Expand Down Expand Up @@ -152,7 +162,22 @@ impl TryFrom<&TypeDesc<'_>> for DataType {
"UInt64" => DataType::Number(NumberDataType::UInt64),
"Float32" => DataType::Number(NumberDataType::Float32),
"Float64" => DataType::Number(NumberDataType::Float64),
"Decimal" => DataType::Decimal,
"Decimal" => {
let precision = desc.args[0].name.parse::<u8>()?;
let scale = desc.args[1].name.parse::<u8>()?;

if precision <= 38 {
DataType::Decimal(DecimalDataType::Decimal128(DecimalSize {
precision,
scale,
}))
} else {
DataType::Decimal(DecimalDataType::Decimal256(DecimalSize {
precision,
scale,
}))
}
}
"Timestamp" => DataType::Timestamp,
"Date" => DataType::Date,
"Nullable" => {
Expand Down Expand Up @@ -247,8 +272,18 @@ impl TryFrom<&Arc<ArrowField>> for Field {
| ArrowDataType::FixedSizeBinary(_) => DataType::String,
ArrowDataType::Timestamp(_, _) => DataType::Timestamp,
ArrowDataType::Date32 => DataType::Date,
ArrowDataType::Decimal128(_, _) => DataType::Decimal,
ArrowDataType::Decimal256(_, _) => DataType::Decimal,
ArrowDataType::Decimal128(p, s) => {
DataType::Decimal(DecimalDataType::Decimal128(DecimalSize {
precision: *p,
scale: *s as u8,
}))
}
ArrowDataType::Decimal256(p, s) => {
DataType::Decimal(DecimalDataType::Decimal256(DecimalSize {
precision: *p,
scale: *s as u8,
}))
}
_ => {
return Err(Error::Parsing(format!(
"Unsupported datatype for arrow field: {:?}",
Expand Down Expand Up @@ -356,6 +391,23 @@ mod test {
args: vec![],
},
},
TestCase {
desc: "decimal type",
input: "Decimal(42, 42)",
output: TypeDesc {
name: "Decimal",
args: vec![
TypeDesc {
name: "42",
args: vec![],
},
TypeDesc {
name: "42",
args: vec![],
},
],
},
},
TestCase {
desc: "nullable type",
input: "Nullable(Nothing)",
Expand Down
171 changes: 165 additions & 6 deletions driver/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use arrow::datatypes::i256;
use arrow_array::{ArrowNativeTypeOp, Decimal128Array, Decimal256Array};
use chrono::{Datelike, NaiveDate, NaiveDateTime};

use crate::error::{ConvertError, Error, Result};
use crate::{
error::{ConvertError, Error, Result},
schema::{DecimalDataType, DecimalSize},
};
use std::fmt::Write;

// Thu 1970-01-01 is R.D. 719163
const DAYS_FROM_CE: i32 = 719_163;
Expand Down Expand Up @@ -45,16 +51,16 @@ pub enum NumberValue {
UInt64(u64),
Float32(f32),
Float64(f64),
Decimal128(i128, DecimalSize),
Decimal256(i256, DecimalSize),
}

#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq)]
pub enum Value {
Null,
Boolean(bool),
String(String),
Number(NumberValue),
// TODO:(everpcpc) Decimal(DecimalValue),
// Decimal(String),
/// Microseconds from 1970-01-01 00:00:00 UTC
Timestamp(i64),
Date(i32),
Expand Down Expand Up @@ -82,9 +88,11 @@ impl Value {
NumberValue::UInt64(_) => DataType::Number(NumberDataType::UInt64),
NumberValue::Float32(_) => DataType::Number(NumberDataType::Float32),
NumberValue::Float64(_) => DataType::Number(NumberDataType::Float64),
NumberValue::Decimal128(_, s) => DataType::Decimal(DecimalDataType::Decimal128(*s)),
NumberValue::Decimal256(_, s) => DataType::Decimal(DecimalDataType::Decimal256(*s)),
},
// Self::Decimal(_) => DataType::Decimal,
Self::Timestamp(_) => DataType::Timestamp,

Self::Date(_) => DataType::Date,
// TODO:(everpcpc) fix nested type
// Self::Array(v) => DataType::Array(Box::new(v[0].get_type())),
Expand Down Expand Up @@ -133,7 +141,17 @@ impl TryFrom<(&DataType, &str)> for Value {
DataType::Number(NumberDataType::Float64) => {
Ok(Self::Number(NumberValue::Float64(v.parse()?)))
}
// DataType::Decimal => Ok(Self::Decimal(v)),

DataType::Decimal(DecimalDataType::Decimal128(size)) => {
let d = parse_decimal(v, *size)?;
Ok(Self::Number(d))
}

DataType::Decimal(DecimalDataType::Decimal256(size)) => {
let d = parse_decimal(v, *size)?;
Ok(Self::Number(d))
}

DataType::Timestamp => Ok(Self::Timestamp(
chrono::NaiveDateTime::parse_from_str(v, "%Y-%m-%d %H:%M:%S%.6f")?
.timestamp_micros(),
Expand Down Expand Up @@ -210,6 +228,32 @@ impl TryFrom<(&ArrowField, &Arc<dyn ArrowArray>, usize)> for Value {
None => Err(ConvertError::new("float64", format!("{:?}", array)).into()),
},

ArrowDataType::Decimal128(p, s) => {
match array.as_any().downcast_ref::<Decimal128Array>() {
Some(array) => Ok(Value::Number(NumberValue::Decimal128(
array.value(seq),
DecimalSize {
precision: *p,
scale: *s as u8,
},
))),
None => Err(ConvertError::new("Decimal128", format!("{:?}", array)).into()),
}
}

ArrowDataType::Decimal256(p, s) => {
match array.as_any().downcast_ref::<Decimal256Array>() {
Some(array) => Ok(Value::Number(NumberValue::Decimal256(
array.value(seq),
DecimalSize {
precision: *p,
scale: *s as u8,
},
))),
None => Err(ConvertError::new("Decimal256", format!("{:?}", array)).into()),
}
}

ArrowDataType::Binary => match array.as_any().downcast_ref::<BinaryArray>() {
Some(array) => Ok(Value::String(String::from_utf8(array.value(seq).to_vec())?)),
None => Err(ConvertError::new("binary", format!("{:?}", array)).into()),
Expand Down Expand Up @@ -418,6 +462,8 @@ impl std::fmt::Display for NumberValue {
NumberValue::UInt64(i) => write!(f, "{}", i),
NumberValue::Float32(i) => write!(f, "{}", i),
NumberValue::Float64(i) => write!(f, "{}", i),
NumberValue::Decimal128(v, s) => write!(f, "{}", display_decimal_128(*v, s.scale)),
NumberValue::Decimal256(v, s) => write!(f, "{}", display_decimal_256(*v, s.scale)),
}
}
}
Expand All @@ -443,3 +489,116 @@ impl std::fmt::Display for Value {
}
}
}

pub fn display_decimal_128(num: i128, scale: u8) -> String {
let mut buf = String::new();
if scale == 0 {
write!(buf, "{}", num).unwrap();
} else {
let pow_scale = 10_i128.pow(scale as u32);
if num >= 0 {
write!(
buf,
"{}.{:0>width$}",
num / pow_scale,
(num % pow_scale).abs(),
width = scale as usize
)
.unwrap();
} else {
write!(
buf,
"-{}.{:0>width$}",
-num / pow_scale,
(num % pow_scale).abs(),
width = scale as usize
)
.unwrap();
}
}
buf
}

pub fn display_decimal_256(num: i256, scale: u8) -> String {
let mut buf = String::new();
if scale == 0 {
write!(buf, "{}", num).unwrap();
} else {
let pow_scale = i256::from_i128(10i128).pow_wrapping(scale as u32);
// -1/10 = 0
if num >= i256::ZERO {
write!(
buf,
"{}.{:0>width$}",
num / pow_scale,
(num % pow_scale).wrapping_abs(),
width = scale as usize
)
.unwrap();
} else {
write!(
buf,
"-{}.{:0>width$}",
-num / pow_scale,
(num % pow_scale).wrapping_abs(),
width = scale as usize
)
.unwrap();
}
}
buf
}

/// assume text is from
/// used only for expr, so put more weight on readability
pub fn parse_decimal(text: &str, size: DecimalSize) -> Result<NumberValue> {
let mut start = 0;
let bytes = text.as_bytes();
while bytes[start] == b'0' {
start += 1
}
let text = &text[start..];
let point_pos = text.find('.');
let e_pos = text.find(|c| c == 'e' || c == 'E');
let (i_part, f_part, e_part) = match (point_pos, e_pos) {
(Some(p1), Some(p2)) => (&text[..p1], &text[(p1 + 1)..p2], Some(&text[(p2 + 1)..])),
(Some(p), None) => (&text[..p], &text[(p + 1)..], None),
(None, Some(p)) => (&text[..p], "", Some(&text[(p + 1)..])),
_ => {
unreachable!()
}
};
let exp = match e_part {
Some(s) => s.parse::<i32>()?,
None => 0,
};
if i_part.len() as i32 + exp > 76 {
Err(ConvertError::new("decimal", format!("{:?}", text)).into())
} else {
let mut digits = Vec::with_capacity(76);
digits.extend_from_slice(i_part.as_bytes());
digits.extend_from_slice(f_part.as_bytes());
if digits.is_empty() {
digits.push(b'0')
}
let scale = f_part.len() as i32 - exp;
if scale < 0 {
// e.g 123.1e3
for _ in 0..(-scale) {
digits.push(b'0')
}
};

let precision = std::cmp::min(digits.len(), 76);
let digits = unsafe { std::str::from_utf8_unchecked(&digits[..precision]) };

if size.precision > 38 {
Ok(NumberValue::Decimal256(
i256::from_string(digits).unwrap(),
size,
))
} else {
Ok(NumberValue::Decimal128(digits.parse::<i128>()?, size))
}
}
}
Loading

0 comments on commit b3d6e3e

Please sign in to comment.