diff --git a/CHANGELOG.md b/CHANGELOG.md index f92a6f98..86831e0a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Changed +- *BREAKING:* partiql-logical-planner: moves `NameResolver` to `partiql-ast-passes` + ### Added - Add ability for partiql-extension-ion extension encoding/decoding of `Value` to/from Ion `Element` +- Add `partiql-types` crate that includes data models for PartiQL Types. +- Add `partiql_ast_passes::static_typer` for type annotating the AST. + ### Fixes ## [0.5.0] - 2023-06-06 diff --git a/partiql-ast-passes/Cargo.toml b/partiql-ast-passes/Cargo.toml index 3da2a9cf..081b298d 100644 --- a/partiql-ast-passes/Cargo.toml +++ b/partiql-ast-passes/Cargo.toml @@ -22,12 +22,15 @@ bench = false [dependencies] partiql-ast = { path = "../partiql-ast", version = "0.5.*" } partiql-catalog = { path = "../partiql-catalog", version = "0.5.*" } +partiql-types = { path = "../partiql-types", version = "0.5.*" } +assert_matches = "1.5.*" fnv = "1" indexmap = "1.9" thiserror = "1.0" [dev-dependencies] +partiql-parser = { path = "../partiql-parser", version = "0.5.*" } [features] default = [] diff --git a/partiql-ast-passes/src/lib.rs b/partiql-ast-passes/src/lib.rs index 75ca02a8..17f857a0 100644 --- a/partiql-ast-passes/src/lib.rs +++ b/partiql-ast-passes/src/lib.rs @@ -6,3 +6,4 @@ pub mod error; pub mod name_resolver; +pub mod static_typer; diff --git a/partiql-ast-passes/src/static_typer.rs b/partiql-ast-passes/src/static_typer.rs new file mode 100644 index 00000000..c0ecfe44 --- /dev/null +++ b/partiql-ast-passes/src/static_typer.rs @@ -0,0 +1,275 @@ +use crate::error::{AstTransformError, AstTransformationError}; +use partiql_ast::ast::{ + AstNode, AstTypeMap, Bag, Expr, List, Lit, NodeId, Query, QuerySet, Struct, +}; +use partiql_ast::visit::{Traverse, Visit, Visitor}; +use partiql_catalog::Catalog; +use partiql_types::{ArrayType, BagType, StaticType, StaticTypeKind, StructType}; + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct AstStaticTyper<'c> { + id_stack: Vec, + container_stack: Vec>, + errors: Vec, + type_map: AstTypeMap, + catalog: &'c dyn Catalog, +} + +impl<'c> AstStaticTyper<'c> { + pub fn new(catalog: &'c dyn Catalog) -> Self { + AstStaticTyper { + id_stack: Default::default(), + container_stack: Default::default(), + errors: Default::default(), + type_map: Default::default(), + catalog, + } + } + + pub fn type_nodes( + mut self, + query: &AstNode, + ) -> Result, AstTransformationError> { + query.visit(&mut self); + if self.errors.is_empty() { + Ok(self.type_map) + } else { + Err(AstTransformationError { + errors: self.errors, + }) + } + } + + #[inline] + fn current_node(&self) -> &NodeId { + self.id_stack.last().unwrap() + } +} + +impl<'c, 'ast> Visitor<'ast> for AstStaticTyper<'c> { + fn enter_ast_node(&mut self, id: NodeId) -> Traverse { + self.id_stack.push(id); + Traverse::Continue + } + + fn exit_ast_node(&mut self, id: NodeId) -> Traverse { + assert_eq!(self.id_stack.pop(), Some(id)); + Traverse::Continue + } + + fn enter_query(&mut self, _query: &'ast Query) -> Traverse { + Traverse::Continue + } + + fn exit_query(&mut self, _query: &'ast Query) -> Traverse { + Traverse::Continue + } + + fn enter_query_set(&mut self, _query_set: &'ast QuerySet) -> Traverse { + match _query_set { + QuerySet::SetOp(_) => { + todo!() + } + QuerySet::Select(_) => {} + QuerySet::Expr(_) => {} + QuerySet::Values(_) => { + todo!() + } + QuerySet::Table(_) => { + todo!() + } + } + Traverse::Continue + } + + fn exit_query_set(&mut self, _query_set: &'ast QuerySet) -> Traverse { + Traverse::Continue + } + + fn enter_expr(&mut self, _expr: &'ast Expr) -> Traverse { + Traverse::Continue + } + + fn exit_expr(&mut self, _expr: &'ast Expr) -> Traverse { + Traverse::Continue + } + + fn enter_lit(&mut self, _lit: &'ast Lit) -> Traverse { + // Currently we're assuming no-schema, hence typing to arbitrary sized scalars. + // TODO type to the corresponding scalar with the introduction of schema + let kind = match _lit { + Lit::Null => StaticTypeKind::Null, + Lit::Missing => StaticTypeKind::Missing, + Lit::Int8Lit(_) => StaticTypeKind::Int, + Lit::Int16Lit(_) => StaticTypeKind::Int, + Lit::Int32Lit(_) => StaticTypeKind::Int, + Lit::Int64Lit(_) => StaticTypeKind::Int, + Lit::DecimalLit(_) => StaticTypeKind::Decimal, + Lit::NumericLit(_) => StaticTypeKind::Decimal, + Lit::RealLit(_) => StaticTypeKind::Float64, + Lit::FloatLit(_) => StaticTypeKind::Float64, + Lit::DoubleLit(_) => StaticTypeKind::Float64, + Lit::BoolLit(_) => StaticTypeKind::Bool, + Lit::IonStringLit(_) => todo!(), + Lit::CharStringLit(_) => StaticTypeKind::String, + Lit::NationalCharStringLit(_) => StaticTypeKind::String, + Lit::BitStringLit(_) => todo!(), + Lit::HexStringLit(_) => todo!(), + Lit::StructLit(_) => StaticTypeKind::Struct(StructType::unconstrained()), + Lit::ListLit(_) => StaticTypeKind::Array(ArrayType::array()), + Lit::BagLit(_) => StaticTypeKind::Bag(BagType::bag()), + Lit::TypedLit(_, _) => todo!(), + }; + + let ty = StaticType::new(kind); + let id = *self.current_node(); + if let Some(c) = self.container_stack.last_mut() { + c.push(ty.clone()) + } + self.type_map.insert(id, ty); + Traverse::Continue + } + + fn enter_struct(&mut self, _struct: &'ast Struct) -> Traverse { + self.container_stack.push(vec![]); + Traverse::Continue + } + + fn exit_struct(&mut self, _struct: &'ast Struct) -> Traverse { + let id = *self.current_node(); + let fields = self.container_stack.pop(); + + // Such type checking will very likely move to a common module + // TODO move to a more appropriate place for re-use. + if let Some(f) = fields { + // We already fail during parsing if the struct has wrong number of key-value pairs, e.g.: + // {'a', 1, 'b'} + // However, adding this check here. + let is_malformed = f.len() % 2 > 0; + if is_malformed { + self.errors.push(AstTransformError::IllegalState( + "Struct key-value pairs are malformed".to_string(), + )); + } + + let has_invalid_keys = f.chunks(2).map(|t| &t[0]).any(|t| !t.is_string()); + if has_invalid_keys || is_malformed { + self.errors.push(AstTransformError::IllegalState( + "Struct keys can only resolve to `String` type".to_string(), + )); + } + } + + let ty = StaticType::new_struct(StructType::unconstrained()); + self.type_map.insert(id, ty.clone()); + + if let Some(c) = self.container_stack.last_mut() { + c.push(ty) + } + + Traverse::Continue + } + + fn enter_bag(&mut self, _bag: &'ast Bag) -> Traverse { + self.container_stack.push(vec![]); + Traverse::Continue + } + + fn exit_bag(&mut self, _bag: &'ast Bag) -> Traverse { + // TODO add schema validation of BAG elements, e.g. for Schema Bag if there is at least + // one element that isn't INT there is a type checking error. + + // TODO clarify if we need to record the internal types of bag literal or stick w/Schema? + self.container_stack.pop(); + + let id = *self.current_node(); + let ty = StaticType::new_bag(BagType::bag()); + + self.type_map.insert(id, ty.clone()); + if let Some(s) = self.container_stack.last_mut() { + s.push(ty) + } + Traverse::Continue + } + + fn enter_list(&mut self, _list: &'ast List) -> Traverse { + self.container_stack.push(vec![]); + Traverse::Continue + } + + fn exit_list(&mut self, _list: &'ast List) -> Traverse { + // TODO clarify if we need to record the internal types of array literal or stick w/Schema? + // one element that isn't INT there is a type checking error. + + // TODO clarify if we need to record the internal types of array literal or stick w/Schema? + self.container_stack.pop(); + + let id = *self.current_node(); + let ty = StaticType::new_array(ArrayType::array()); + + self.type_map.insert(id, ty.clone()); + if let Some(s) = self.container_stack.last_mut() { + s.push(ty) + } + Traverse::Continue + } +} + +#[cfg(test)] +mod tests { + use super::*; + use assert_matches::assert_matches; + use partiql_ast::ast; + use partiql_catalog::PartiqlCatalog; + use partiql_types::{StaticType, StaticTypeKind}; + + #[test] + fn simple_test() { + assert_matches!(run_literal_test("NULL"), StaticTypeKind::Null); + assert_matches!(run_literal_test("MISSING"), StaticTypeKind::Missing); + assert_matches!(run_literal_test("Missing"), StaticTypeKind::Missing); + assert_matches!(run_literal_test("true"), StaticTypeKind::Bool); + assert_matches!(run_literal_test("false"), StaticTypeKind::Bool); + assert_matches!(run_literal_test("1"), StaticTypeKind::Int); + assert_matches!(run_literal_test("1.5"), StaticTypeKind::Decimal); + assert_matches!(run_literal_test("'hello world!'"), StaticTypeKind::String); + assert_matches!( + run_literal_test("[1, 2 , {'a': 2}]"), + StaticTypeKind::Array(_) + ); + assert_matches!( + run_literal_test("<<'1', {'a': 11}>>"), + StaticTypeKind::Bag(_) + ); + assert_matches!( + run_literal_test("{'a': 1, 'b': 3, 'c': [1, 2]}"), + StaticTypeKind::Struct(_) + ); + } + + #[test] + fn simple_err_test() { + assert!(type_statement("{'a': 1, a.b: 3}").is_err()); + } + + fn run_literal_test(q: &str) -> StaticTypeKind { + let out = type_statement(q).expect("type map"); + let values: Vec<&StaticType> = out.values().collect(); + values.last().unwrap().kind().clone() + } + + fn type_statement(q: &str) -> Result, AstTransformationError> { + let parsed = partiql_parser::Parser::default() + .parse(q) + .expect("Expect successful parse"); + + let catalog = PartiqlCatalog::default(); + let typer = AstStaticTyper::new(&catalog); + if let ast::Expr::Query(q) = parsed.ast.as_ref() { + typer.type_nodes(&q) + } else { + panic!("Typing statement other than `Query` are unsupported") + } + } +} diff --git a/partiql-ast/Cargo.toml b/partiql-ast/Cargo.toml index d8503eea..a4d2f3de 100644 --- a/partiql-ast/Cargo.toml +++ b/partiql-ast/Cargo.toml @@ -20,14 +20,12 @@ path = "src/lib.rs" bench = false [dependencies] +indexmap = "1.9" rust_decimal = { version = "1.25.0", default-features = false, features = ["std"] } - serde = { version = "1.*", features = ["derive"], optional = true } - [dev-dependencies] - [features] default = [] serde = [ diff --git a/partiql-ast/src/ast.rs b/partiql-ast/src/ast.rs index d40b918c..f320384c 100644 --- a/partiql-ast/src/ast.rs +++ b/partiql-ast/src/ast.rs @@ -8,6 +8,7 @@ // As more changes to this AST are expected, unless explicitly advised, using the structures exposed // in this crate directly is not recommended. +use indexmap::IndexMap; use rust_decimal::Decimal as RustDecimal; use std::fmt; @@ -17,6 +18,8 @@ use serde::{Deserialize, Serialize}; use partiql_ast_macros::Visit; +pub type AstTypeMap = IndexMap; + #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct NodeId(pub u32); diff --git a/partiql-catalog/Cargo.toml b/partiql-catalog/Cargo.toml index 54b77903..6b811a19 100644 --- a/partiql-catalog/Cargo.toml +++ b/partiql-catalog/Cargo.toml @@ -24,6 +24,7 @@ bench = false partiql-value = { path = "../partiql-value", version = "0.5.*" } partiql-parser = { path = "../partiql-parser", version = "0.5.*" } partiql-logical = { path = "../partiql-logical", version = "0.5.*" } + thiserror = "1.0" ordered-float = "3.*" itertools = "0.10.*" diff --git a/partiql-logical-planner/Cargo.toml b/partiql-logical-planner/Cargo.toml index 890ee79e..36724e5b 100644 --- a/partiql-logical-planner/Cargo.toml +++ b/partiql-logical-planner/Cargo.toml @@ -28,6 +28,7 @@ partiql-ast = { path = "../partiql-ast", version = "0.5.*" } partiql-parser = { path = "../partiql-parser", version = "0.5.*" } partiql-catalog = { path = "../partiql-catalog", version = "0.5.*" } partiql-ast-passes = { path = "../partiql-ast-passes", version = "0.5.*" } + ion-rs = "0.18" ordered-float = "3.*" itertools = "0.10.*" diff --git a/partiql-types/Cargo.toml b/partiql-types/Cargo.toml index 16a346cb..60d70ddb 100644 --- a/partiql-types/Cargo.toml +++ b/partiql-types/Cargo.toml @@ -21,6 +21,7 @@ edition.workspace = true bench = false [dependencies] + ordered-float = "3.*" itertools = "0.10.*" unicase = "2.6" diff --git a/partiql-types/src/lib.rs b/partiql-types/src/lib.rs index 25682808..2819e521 100644 --- a/partiql-types/src/lib.rs +++ b/partiql-types/src/lib.rs @@ -1,3 +1,188 @@ +use std::collections::HashSet; + +pub trait Type {} + +impl Type for StaticType {} + +#[derive(Debug, Clone)] +pub struct StaticType { + kind: StaticTypeKind, +} + +#[allow(dead_code)] +impl StaticType { + pub fn new(kind: StaticTypeKind) -> StaticType { + StaticType { kind } + } + + pub fn new_struct(s: StructType) -> StaticType { + StaticType { + kind: StaticTypeKind::Struct(s), + } + } + + pub fn new_bag(b: BagType) -> StaticType { + StaticType { + kind: StaticTypeKind::Bag(b), + } + } + + pub fn new_array(a: ArrayType) -> StaticType { + StaticType { + kind: StaticTypeKind::Array(a), + } + } + + pub fn union_of(types: HashSet) -> StaticType { + StaticType { + kind: StaticTypeKind::AnyOf(AnyOf::new(types)), + } + } + + pub fn is_string(&self) -> bool { + matches!( + &self, + StaticType { + kind: StaticTypeKind::String + } + ) + } + + pub fn kind(&self) -> &StaticTypeKind { + &self.kind + } +} + +#[derive(Debug, Clone)] +pub enum StaticTypeKind { + Any, + AnyOf(AnyOf), + + // Absent Types + Null, + Missing, + + // Scalar Types + Int, + Bool, + Decimal, + + Float64, + String, + + // Container Type + Struct(StructType), + Bag(BagType), + Array(ArrayType), + // TODO Add Sexp, TIMESTAMP +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct AnyOf { + types: HashSet, +} + +impl AnyOf { + pub fn new(types: HashSet) -> Self { + AnyOf { types } + } +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct StructType { + constraints: Vec, +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct StructField { + name: String, + value: StaticType, +} + +impl From<(String, T)> for StructField +where + T: Into, +{ + fn from(pair: (String, T)) -> Self { + StructField { + name: pair.0, + value: pair.1.into(), + } + } +} + +impl StructType { + pub fn unconstrained() -> Self { + StructType { + constraints: vec![], + } + } + + pub fn constrained(constraints: Vec) -> Self { + StructType { constraints } + } +} + +#[derive(Debug, Clone)] +pub enum StructConstraint { + Open(bool), + Ordered(bool), + DuplicateAttrs(bool), + Fields(StructField), +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct BagType { + element_type: Box, + constraints: Vec, +} + +impl BagType { + pub fn bag() -> Self { + BagType::bag_of(Box::new(StaticType { + kind: StaticTypeKind::Any, + })) + } + + pub fn bag_of(typ: Box) -> Self { + BagType { + element_type: typ, + constraints: vec![CollectionConstraint::Ordered(false)], + } + } +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct ArrayType { + element_type: Box, + constraints: Vec, +} + +impl ArrayType { + pub fn array() -> Self { + ArrayType::array_of(Box::new(StaticType { + kind: StaticTypeKind::Any, + })) + } + + pub fn array_of(typ: Box) -> Self { + ArrayType { + element_type: typ, + constraints: vec![CollectionConstraint::Ordered(true)], + } + } +} + +#[derive(Debug, Clone)] +enum CollectionConstraint { + Ordered(bool), +} + #[cfg(test)] mod tests { #[test]