Skip to content

Commit

Permalink
Add partiql-types and literals typing (#389)
Browse files Browse the repository at this point in the history
  • Loading branch information
am357 authored Jun 14, 2023
1 parent ff3e950 commit babd7ac
Show file tree
Hide file tree
Showing 10 changed files with 476 additions and 3 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]
### Changed
- *BREAKING:* partiql-logical-planner: moves `NameResolver` to `partiql-ast-passes`

### Added
- Add ability for partiql-extension-ion extension encoding/decoding of `Value` to/from Ion `Element`
- Add `partiql-types` crate that includes data models for PartiQL Types.
- Add `partiql_ast_passes::static_typer` for type annotating the AST.

### Fixes

## [0.5.0] - 2023-06-06
Expand Down
3 changes: 3 additions & 0 deletions partiql-ast-passes/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@ bench = false
[dependencies]
partiql-ast = { path = "../partiql-ast", version = "0.5.*" }
partiql-catalog = { path = "../partiql-catalog", version = "0.5.*" }
partiql-types = { path = "../partiql-types", version = "0.5.*" }

assert_matches = "1.5.*"
fnv = "1"
indexmap = "1.9"
thiserror = "1.0"

[dev-dependencies]
partiql-parser = { path = "../partiql-parser", version = "0.5.*" }

[features]
default = []
1 change: 1 addition & 0 deletions partiql-ast-passes/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
pub mod error;
pub mod name_resolver;
pub mod static_typer;
275 changes: 275 additions & 0 deletions partiql-ast-passes/src/static_typer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
use crate::error::{AstTransformError, AstTransformationError};
use partiql_ast::ast::{
AstNode, AstTypeMap, Bag, Expr, List, Lit, NodeId, Query, QuerySet, Struct,
};
use partiql_ast::visit::{Traverse, Visit, Visitor};
use partiql_catalog::Catalog;
use partiql_types::{ArrayType, BagType, StaticType, StaticTypeKind, StructType};

#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct AstStaticTyper<'c> {
id_stack: Vec<NodeId>,
container_stack: Vec<Vec<StaticType>>,
errors: Vec<AstTransformError>,
type_map: AstTypeMap<StaticType>,
catalog: &'c dyn Catalog,
}

impl<'c> AstStaticTyper<'c> {
pub fn new(catalog: &'c dyn Catalog) -> Self {
AstStaticTyper {
id_stack: Default::default(),
container_stack: Default::default(),
errors: Default::default(),
type_map: Default::default(),
catalog,
}
}

pub fn type_nodes(
mut self,
query: &AstNode<Query>,
) -> Result<AstTypeMap<StaticType>, AstTransformationError> {
query.visit(&mut self);
if self.errors.is_empty() {
Ok(self.type_map)
} else {
Err(AstTransformationError {
errors: self.errors,
})
}
}

#[inline]
fn current_node(&self) -> &NodeId {
self.id_stack.last().unwrap()
}
}

impl<'c, 'ast> Visitor<'ast> for AstStaticTyper<'c> {
fn enter_ast_node(&mut self, id: NodeId) -> Traverse {
self.id_stack.push(id);
Traverse::Continue
}

fn exit_ast_node(&mut self, id: NodeId) -> Traverse {
assert_eq!(self.id_stack.pop(), Some(id));
Traverse::Continue
}

fn enter_query(&mut self, _query: &'ast Query) -> Traverse {
Traverse::Continue
}

fn exit_query(&mut self, _query: &'ast Query) -> Traverse {
Traverse::Continue
}

fn enter_query_set(&mut self, _query_set: &'ast QuerySet) -> Traverse {
match _query_set {
QuerySet::SetOp(_) => {
todo!()
}
QuerySet::Select(_) => {}
QuerySet::Expr(_) => {}
QuerySet::Values(_) => {
todo!()
}
QuerySet::Table(_) => {
todo!()
}
}
Traverse::Continue
}

fn exit_query_set(&mut self, _query_set: &'ast QuerySet) -> Traverse {
Traverse::Continue
}

fn enter_expr(&mut self, _expr: &'ast Expr) -> Traverse {
Traverse::Continue
}

fn exit_expr(&mut self, _expr: &'ast Expr) -> Traverse {
Traverse::Continue
}

fn enter_lit(&mut self, _lit: &'ast Lit) -> Traverse {
// Currently we're assuming no-schema, hence typing to arbitrary sized scalars.
// TODO type to the corresponding scalar with the introduction of schema
let kind = match _lit {
Lit::Null => StaticTypeKind::Null,
Lit::Missing => StaticTypeKind::Missing,
Lit::Int8Lit(_) => StaticTypeKind::Int,
Lit::Int16Lit(_) => StaticTypeKind::Int,
Lit::Int32Lit(_) => StaticTypeKind::Int,
Lit::Int64Lit(_) => StaticTypeKind::Int,
Lit::DecimalLit(_) => StaticTypeKind::Decimal,
Lit::NumericLit(_) => StaticTypeKind::Decimal,
Lit::RealLit(_) => StaticTypeKind::Float64,
Lit::FloatLit(_) => StaticTypeKind::Float64,
Lit::DoubleLit(_) => StaticTypeKind::Float64,
Lit::BoolLit(_) => StaticTypeKind::Bool,
Lit::IonStringLit(_) => todo!(),
Lit::CharStringLit(_) => StaticTypeKind::String,
Lit::NationalCharStringLit(_) => StaticTypeKind::String,
Lit::BitStringLit(_) => todo!(),
Lit::HexStringLit(_) => todo!(),
Lit::StructLit(_) => StaticTypeKind::Struct(StructType::unconstrained()),
Lit::ListLit(_) => StaticTypeKind::Array(ArrayType::array()),
Lit::BagLit(_) => StaticTypeKind::Bag(BagType::bag()),
Lit::TypedLit(_, _) => todo!(),
};

let ty = StaticType::new(kind);
let id = *self.current_node();
if let Some(c) = self.container_stack.last_mut() {
c.push(ty.clone())
}
self.type_map.insert(id, ty);
Traverse::Continue
}

fn enter_struct(&mut self, _struct: &'ast Struct) -> Traverse {
self.container_stack.push(vec![]);
Traverse::Continue
}

fn exit_struct(&mut self, _struct: &'ast Struct) -> Traverse {
let id = *self.current_node();
let fields = self.container_stack.pop();

// Such type checking will very likely move to a common module
// TODO move to a more appropriate place for re-use.
if let Some(f) = fields {
// We already fail during parsing if the struct has wrong number of key-value pairs, e.g.:
// {'a', 1, 'b'}
// However, adding this check here.
let is_malformed = f.len() % 2 > 0;
if is_malformed {
self.errors.push(AstTransformError::IllegalState(
"Struct key-value pairs are malformed".to_string(),
));
}

let has_invalid_keys = f.chunks(2).map(|t| &t[0]).any(|t| !t.is_string());
if has_invalid_keys || is_malformed {
self.errors.push(AstTransformError::IllegalState(
"Struct keys can only resolve to `String` type".to_string(),
));
}
}

let ty = StaticType::new_struct(StructType::unconstrained());
self.type_map.insert(id, ty.clone());

if let Some(c) = self.container_stack.last_mut() {
c.push(ty)
}

Traverse::Continue
}

fn enter_bag(&mut self, _bag: &'ast Bag) -> Traverse {
self.container_stack.push(vec![]);
Traverse::Continue
}

fn exit_bag(&mut self, _bag: &'ast Bag) -> Traverse {
// TODO add schema validation of BAG elements, e.g. for Schema Bag<Int> if there is at least
// one element that isn't INT there is a type checking error.

// TODO clarify if we need to record the internal types of bag literal or stick w/Schema?
self.container_stack.pop();

let id = *self.current_node();
let ty = StaticType::new_bag(BagType::bag());

self.type_map.insert(id, ty.clone());
if let Some(s) = self.container_stack.last_mut() {
s.push(ty)
}
Traverse::Continue
}

fn enter_list(&mut self, _list: &'ast List) -> Traverse {
self.container_stack.push(vec![]);
Traverse::Continue
}

fn exit_list(&mut self, _list: &'ast List) -> Traverse {
// TODO clarify if we need to record the internal types of array literal or stick w/Schema?
// one element that isn't INT there is a type checking error.

// TODO clarify if we need to record the internal types of array literal or stick w/Schema?
self.container_stack.pop();

let id = *self.current_node();
let ty = StaticType::new_array(ArrayType::array());

self.type_map.insert(id, ty.clone());
if let Some(s) = self.container_stack.last_mut() {
s.push(ty)
}
Traverse::Continue
}
}

#[cfg(test)]
mod tests {
use super::*;
use assert_matches::assert_matches;
use partiql_ast::ast;
use partiql_catalog::PartiqlCatalog;
use partiql_types::{StaticType, StaticTypeKind};

#[test]
fn simple_test() {
assert_matches!(run_literal_test("NULL"), StaticTypeKind::Null);
assert_matches!(run_literal_test("MISSING"), StaticTypeKind::Missing);
assert_matches!(run_literal_test("Missing"), StaticTypeKind::Missing);
assert_matches!(run_literal_test("true"), StaticTypeKind::Bool);
assert_matches!(run_literal_test("false"), StaticTypeKind::Bool);
assert_matches!(run_literal_test("1"), StaticTypeKind::Int);
assert_matches!(run_literal_test("1.5"), StaticTypeKind::Decimal);
assert_matches!(run_literal_test("'hello world!'"), StaticTypeKind::String);
assert_matches!(
run_literal_test("[1, 2 , {'a': 2}]"),
StaticTypeKind::Array(_)
);
assert_matches!(
run_literal_test("<<'1', {'a': 11}>>"),
StaticTypeKind::Bag(_)
);
assert_matches!(
run_literal_test("{'a': 1, 'b': 3, 'c': [1, 2]}"),
StaticTypeKind::Struct(_)
);
}

#[test]
fn simple_err_test() {
assert!(type_statement("{'a': 1, a.b: 3}").is_err());
}

fn run_literal_test(q: &str) -> StaticTypeKind {
let out = type_statement(q).expect("type map");
let values: Vec<&StaticType> = out.values().collect();
values.last().unwrap().kind().clone()
}

fn type_statement(q: &str) -> Result<AstTypeMap<StaticType>, AstTransformationError> {
let parsed = partiql_parser::Parser::default()
.parse(q)
.expect("Expect successful parse");

let catalog = PartiqlCatalog::default();
let typer = AstStaticTyper::new(&catalog);
if let ast::Expr::Query(q) = parsed.ast.as_ref() {
typer.type_nodes(&q)
} else {
panic!("Typing statement other than `Query` are unsupported")
}
}
}
4 changes: 1 addition & 3 deletions partiql-ast/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,12 @@ path = "src/lib.rs"
bench = false

[dependencies]
indexmap = "1.9"
rust_decimal = { version = "1.25.0", default-features = false, features = ["std"] }

serde = { version = "1.*", features = ["derive"], optional = true }


[dev-dependencies]


[features]
default = []
serde = [
Expand Down
3 changes: 3 additions & 0 deletions partiql-ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// As more changes to this AST are expected, unless explicitly advised, using the structures exposed
// in this crate directly is not recommended.

use indexmap::IndexMap;
use rust_decimal::Decimal as RustDecimal;

use std::fmt;
Expand All @@ -17,6 +18,8 @@ use serde::{Deserialize, Serialize};

use partiql_ast_macros::Visit;

pub type AstTypeMap<T> = IndexMap<NodeId, T>;

#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct NodeId(pub u32);
Expand Down
1 change: 1 addition & 0 deletions partiql-catalog/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ bench = false
partiql-value = { path = "../partiql-value", version = "0.5.*" }
partiql-parser = { path = "../partiql-parser", version = "0.5.*" }
partiql-logical = { path = "../partiql-logical", version = "0.5.*" }

thiserror = "1.0"
ordered-float = "3.*"
itertools = "0.10.*"
Expand Down
1 change: 1 addition & 0 deletions partiql-logical-planner/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ partiql-ast = { path = "../partiql-ast", version = "0.5.*" }
partiql-parser = { path = "../partiql-parser", version = "0.5.*" }
partiql-catalog = { path = "../partiql-catalog", version = "0.5.*" }
partiql-ast-passes = { path = "../partiql-ast-passes", version = "0.5.*" }

ion-rs = "0.18"
ordered-float = "3.*"
itertools = "0.10.*"
Expand Down
1 change: 1 addition & 0 deletions partiql-types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ edition.workspace = true
bench = false

[dependencies]

ordered-float = "3.*"
itertools = "0.10.*"
unicase = "2.6"
Expand Down
Loading

2 comments on commit babd7ac

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'PartiQL (rust) Benchmark'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 1.50.

Benchmark suite Current: babd7ac Previous: ff3e950 Ratio
numbers 160 ns/iter (± 10) 106 ns/iter (± 0) 1.51

This comment was automatically generated by workflow using github-action-benchmark.

CC: @am357 @am357 @partiql

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PartiQL (rust) Benchmark

Benchmark suite Current: babd7ac Previous: ff3e950 Ratio
parse-1 6933 ns/iter (± 847) 5980 ns/iter (± 10) 1.16
parse-15 65659 ns/iter (± 5519) 56132 ns/iter (± 341) 1.17
parse-30 135387 ns/iter (± 8910) 113067 ns/iter (± 87) 1.20
compile-1 6051 ns/iter (± 361) 5277 ns/iter (± 18) 1.15
compile-15 44369 ns/iter (± 5790) 36765 ns/iter (± 38) 1.21
compile-30 95681 ns/iter (± 21741) 74446 ns/iter (± 57) 1.29
plan-1 26157 ns/iter (± 2358) 20083 ns/iter (± 410) 1.30
plan-15 502372 ns/iter (± 30894) 365764 ns/iter (± 870) 1.37
plan-30 1001031 ns/iter (± 68498) 742654 ns/iter (± 1077) 1.35
eval-1 27122381 ns/iter (± 1817512) 21715733 ns/iter (± 719301) 1.25
eval-15 142875254 ns/iter (± 5879075) 122458027 ns/iter (± 499204) 1.17
eval-30 270865397 ns/iter (± 19224673) 239598074 ns/iter (± 537477) 1.13
join 17216 ns/iter (± 1123) 14209 ns/iter (± 32) 1.21
simple 8165 ns/iter (± 523) 7170 ns/iter (± 28) 1.14
simple-no 737 ns/iter (± 57) 632 ns/iter (± 0) 1.17
numbers 160 ns/iter (± 10) 106 ns/iter (± 0) 1.51
parse-simple 827 ns/iter (± 55) 707 ns/iter (± 3) 1.17
parse-ion 3026 ns/iter (± 264) 2664 ns/iter (± 13) 1.14
parse-group 10373 ns/iter (± 1067) 8678 ns/iter (± 25) 1.20
parse-complex 27630 ns/iter (± 1303) 22316 ns/iter (± 53) 1.24
parse-complex-fexpr 45491 ns/iter (± 3352) 35171 ns/iter (± 90) 1.29

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.