From 1765ad52bf45a4638afbf66311613979e0613397 Mon Sep 17 00:00:00 2001 From: Arash M <27716912+am357@users.noreply.github.com> Date: Thu, 18 May 2023 08:47:05 -0700 Subject: [PATCH 1/3] Add `partiql-types` and literals typing This is the first commit to introduce types to `partiql-lang-rust`. With the changes in this commit, we're introducing `partiql-types` create which includes the initial model for our built-in types. We also introduce an `AstTyper` in `partiql-ast-passes` which provides an API for receiving `partiql-ast` and outputting a mapping from AST Node Ids to their corresponding types. In the first version, we're loosely typing the literal with the expection that this goes through multiple iterations. Here is the high-level plan: [ ] Type literals [ ] Type variable references in SFW [ ] Type operators [ ] Type using schemas in typing environment [ ] Add type coercions (using a Type Lattice?) [ ] Update CHANGELOG.md and adh documentation --- CHANGELOG.md | 6 + partiql-ast-passes/Cargo.toml | 3 + partiql-ast-passes/src/lib.rs | 1 + partiql-ast-passes/src/static_typer.rs | 275 +++++++++++++++++++++++++ partiql-ast/Cargo.toml | 4 +- partiql-ast/src/ast.rs | 3 + partiql-catalog/Cargo.toml | 1 + partiql-logical-planner/Cargo.toml | 1 + partiql-types/Cargo.toml | 1 + partiql-types/src/lib.rs | 193 +++++++++++++++++ 10 files changed, 485 insertions(+), 3 deletions(-) create mode 100644 partiql-ast-passes/src/static_typer.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 60a2bb40..c9f06a0b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Changed +- *BREAKING:* partiql-logical-planner: moves `NameResolver` to `partiql-ast-passes` + +### Added +- Adds `partiql-types` crate that includes data models for PartiQL Types. +- Adds `partiql_ast_passes::static_typer` for type annotating the AST. ## [0.5.0] - 2023-06-06 ### Changed diff --git a/partiql-ast-passes/Cargo.toml b/partiql-ast-passes/Cargo.toml index 3da2a9cf..5c3d7040 100644 --- a/partiql-ast-passes/Cargo.toml +++ b/partiql-ast-passes/Cargo.toml @@ -21,8 +21,11 @@ bench = false [dependencies] partiql-ast = { path = "../partiql-ast", version = "0.5.*" } +partiql-parser = { path = "../partiql-parser", version = "0.5.*" } partiql-catalog = { path = "../partiql-catalog", version = "0.5.*" } +partiql-types = { path = "../partiql-types", version = "0.5.*" } +assert_matches = "1.5.*" fnv = "1" indexmap = "1.9" thiserror = "1.0" diff --git a/partiql-ast-passes/src/lib.rs b/partiql-ast-passes/src/lib.rs index 75ca02a8..17f857a0 100644 --- a/partiql-ast-passes/src/lib.rs +++ b/partiql-ast-passes/src/lib.rs @@ -6,3 +6,4 @@ pub mod error; pub mod name_resolver; +pub mod static_typer; diff --git a/partiql-ast-passes/src/static_typer.rs b/partiql-ast-passes/src/static_typer.rs new file mode 100644 index 00000000..c0ecfe44 --- /dev/null +++ b/partiql-ast-passes/src/static_typer.rs @@ -0,0 +1,275 @@ +use crate::error::{AstTransformError, AstTransformationError}; +use partiql_ast::ast::{ + AstNode, AstTypeMap, Bag, Expr, List, Lit, NodeId, Query, QuerySet, Struct, +}; +use partiql_ast::visit::{Traverse, Visit, Visitor}; +use partiql_catalog::Catalog; +use partiql_types::{ArrayType, BagType, StaticType, StaticTypeKind, StructType}; + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct AstStaticTyper<'c> { + id_stack: Vec, + container_stack: Vec>, + errors: Vec, + type_map: AstTypeMap, + catalog: &'c dyn Catalog, +} + +impl<'c> AstStaticTyper<'c> { + pub fn new(catalog: &'c dyn Catalog) -> Self { + AstStaticTyper { + id_stack: Default::default(), + container_stack: Default::default(), + errors: Default::default(), + type_map: Default::default(), + catalog, + } + } + + pub fn type_nodes( + mut self, + query: &AstNode, + ) -> Result, AstTransformationError> { + query.visit(&mut self); + if self.errors.is_empty() { + Ok(self.type_map) + } else { + Err(AstTransformationError { + errors: self.errors, + }) + } + } + + #[inline] + fn current_node(&self) -> &NodeId { + self.id_stack.last().unwrap() + } +} + +impl<'c, 'ast> Visitor<'ast> for AstStaticTyper<'c> { + fn enter_ast_node(&mut self, id: NodeId) -> Traverse { + self.id_stack.push(id); + Traverse::Continue + } + + fn exit_ast_node(&mut self, id: NodeId) -> Traverse { + assert_eq!(self.id_stack.pop(), Some(id)); + Traverse::Continue + } + + fn enter_query(&mut self, _query: &'ast Query) -> Traverse { + Traverse::Continue + } + + fn exit_query(&mut self, _query: &'ast Query) -> Traverse { + Traverse::Continue + } + + fn enter_query_set(&mut self, _query_set: &'ast QuerySet) -> Traverse { + match _query_set { + QuerySet::SetOp(_) => { + todo!() + } + QuerySet::Select(_) => {} + QuerySet::Expr(_) => {} + QuerySet::Values(_) => { + todo!() + } + QuerySet::Table(_) => { + todo!() + } + } + Traverse::Continue + } + + fn exit_query_set(&mut self, _query_set: &'ast QuerySet) -> Traverse { + Traverse::Continue + } + + fn enter_expr(&mut self, _expr: &'ast Expr) -> Traverse { + Traverse::Continue + } + + fn exit_expr(&mut self, _expr: &'ast Expr) -> Traverse { + Traverse::Continue + } + + fn enter_lit(&mut self, _lit: &'ast Lit) -> Traverse { + // Currently we're assuming no-schema, hence typing to arbitrary sized scalars. + // TODO type to the corresponding scalar with the introduction of schema + let kind = match _lit { + Lit::Null => StaticTypeKind::Null, + Lit::Missing => StaticTypeKind::Missing, + Lit::Int8Lit(_) => StaticTypeKind::Int, + Lit::Int16Lit(_) => StaticTypeKind::Int, + Lit::Int32Lit(_) => StaticTypeKind::Int, + Lit::Int64Lit(_) => StaticTypeKind::Int, + Lit::DecimalLit(_) => StaticTypeKind::Decimal, + Lit::NumericLit(_) => StaticTypeKind::Decimal, + Lit::RealLit(_) => StaticTypeKind::Float64, + Lit::FloatLit(_) => StaticTypeKind::Float64, + Lit::DoubleLit(_) => StaticTypeKind::Float64, + Lit::BoolLit(_) => StaticTypeKind::Bool, + Lit::IonStringLit(_) => todo!(), + Lit::CharStringLit(_) => StaticTypeKind::String, + Lit::NationalCharStringLit(_) => StaticTypeKind::String, + Lit::BitStringLit(_) => todo!(), + Lit::HexStringLit(_) => todo!(), + Lit::StructLit(_) => StaticTypeKind::Struct(StructType::unconstrained()), + Lit::ListLit(_) => StaticTypeKind::Array(ArrayType::array()), + Lit::BagLit(_) => StaticTypeKind::Bag(BagType::bag()), + Lit::TypedLit(_, _) => todo!(), + }; + + let ty = StaticType::new(kind); + let id = *self.current_node(); + if let Some(c) = self.container_stack.last_mut() { + c.push(ty.clone()) + } + self.type_map.insert(id, ty); + Traverse::Continue + } + + fn enter_struct(&mut self, _struct: &'ast Struct) -> Traverse { + self.container_stack.push(vec![]); + Traverse::Continue + } + + fn exit_struct(&mut self, _struct: &'ast Struct) -> Traverse { + let id = *self.current_node(); + let fields = self.container_stack.pop(); + + // Such type checking will very likely move to a common module + // TODO move to a more appropriate place for re-use. + if let Some(f) = fields { + // We already fail during parsing if the struct has wrong number of key-value pairs, e.g.: + // {'a', 1, 'b'} + // However, adding this check here. + let is_malformed = f.len() % 2 > 0; + if is_malformed { + self.errors.push(AstTransformError::IllegalState( + "Struct key-value pairs are malformed".to_string(), + )); + } + + let has_invalid_keys = f.chunks(2).map(|t| &t[0]).any(|t| !t.is_string()); + if has_invalid_keys || is_malformed { + self.errors.push(AstTransformError::IllegalState( + "Struct keys can only resolve to `String` type".to_string(), + )); + } + } + + let ty = StaticType::new_struct(StructType::unconstrained()); + self.type_map.insert(id, ty.clone()); + + if let Some(c) = self.container_stack.last_mut() { + c.push(ty) + } + + Traverse::Continue + } + + fn enter_bag(&mut self, _bag: &'ast Bag) -> Traverse { + self.container_stack.push(vec![]); + Traverse::Continue + } + + fn exit_bag(&mut self, _bag: &'ast Bag) -> Traverse { + // TODO add schema validation of BAG elements, e.g. for Schema Bag if there is at least + // one element that isn't INT there is a type checking error. + + // TODO clarify if we need to record the internal types of bag literal or stick w/Schema? + self.container_stack.pop(); + + let id = *self.current_node(); + let ty = StaticType::new_bag(BagType::bag()); + + self.type_map.insert(id, ty.clone()); + if let Some(s) = self.container_stack.last_mut() { + s.push(ty) + } + Traverse::Continue + } + + fn enter_list(&mut self, _list: &'ast List) -> Traverse { + self.container_stack.push(vec![]); + Traverse::Continue + } + + fn exit_list(&mut self, _list: &'ast List) -> Traverse { + // TODO clarify if we need to record the internal types of array literal or stick w/Schema? + // one element that isn't INT there is a type checking error. + + // TODO clarify if we need to record the internal types of array literal or stick w/Schema? + self.container_stack.pop(); + + let id = *self.current_node(); + let ty = StaticType::new_array(ArrayType::array()); + + self.type_map.insert(id, ty.clone()); + if let Some(s) = self.container_stack.last_mut() { + s.push(ty) + } + Traverse::Continue + } +} + +#[cfg(test)] +mod tests { + use super::*; + use assert_matches::assert_matches; + use partiql_ast::ast; + use partiql_catalog::PartiqlCatalog; + use partiql_types::{StaticType, StaticTypeKind}; + + #[test] + fn simple_test() { + assert_matches!(run_literal_test("NULL"), StaticTypeKind::Null); + assert_matches!(run_literal_test("MISSING"), StaticTypeKind::Missing); + assert_matches!(run_literal_test("Missing"), StaticTypeKind::Missing); + assert_matches!(run_literal_test("true"), StaticTypeKind::Bool); + assert_matches!(run_literal_test("false"), StaticTypeKind::Bool); + assert_matches!(run_literal_test("1"), StaticTypeKind::Int); + assert_matches!(run_literal_test("1.5"), StaticTypeKind::Decimal); + assert_matches!(run_literal_test("'hello world!'"), StaticTypeKind::String); + assert_matches!( + run_literal_test("[1, 2 , {'a': 2}]"), + StaticTypeKind::Array(_) + ); + assert_matches!( + run_literal_test("<<'1', {'a': 11}>>"), + StaticTypeKind::Bag(_) + ); + assert_matches!( + run_literal_test("{'a': 1, 'b': 3, 'c': [1, 2]}"), + StaticTypeKind::Struct(_) + ); + } + + #[test] + fn simple_err_test() { + assert!(type_statement("{'a': 1, a.b: 3}").is_err()); + } + + fn run_literal_test(q: &str) -> StaticTypeKind { + let out = type_statement(q).expect("type map"); + let values: Vec<&StaticType> = out.values().collect(); + values.last().unwrap().kind().clone() + } + + fn type_statement(q: &str) -> Result, AstTransformationError> { + let parsed = partiql_parser::Parser::default() + .parse(q) + .expect("Expect successful parse"); + + let catalog = PartiqlCatalog::default(); + let typer = AstStaticTyper::new(&catalog); + if let ast::Expr::Query(q) = parsed.ast.as_ref() { + typer.type_nodes(&q) + } else { + panic!("Typing statement other than `Query` are unsupported") + } + } +} diff --git a/partiql-ast/Cargo.toml b/partiql-ast/Cargo.toml index d8503eea..a4d2f3de 100644 --- a/partiql-ast/Cargo.toml +++ b/partiql-ast/Cargo.toml @@ -20,14 +20,12 @@ path = "src/lib.rs" bench = false [dependencies] +indexmap = "1.9" rust_decimal = { version = "1.25.0", default-features = false, features = ["std"] } - serde = { version = "1.*", features = ["derive"], optional = true } - [dev-dependencies] - [features] default = [] serde = [ diff --git a/partiql-ast/src/ast.rs b/partiql-ast/src/ast.rs index d40b918c..f320384c 100644 --- a/partiql-ast/src/ast.rs +++ b/partiql-ast/src/ast.rs @@ -8,6 +8,7 @@ // As more changes to this AST are expected, unless explicitly advised, using the structures exposed // in this crate directly is not recommended. +use indexmap::IndexMap; use rust_decimal::Decimal as RustDecimal; use std::fmt; @@ -17,6 +18,8 @@ use serde::{Deserialize, Serialize}; use partiql_ast_macros::Visit; +pub type AstTypeMap = IndexMap; + #[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct NodeId(pub u32); diff --git a/partiql-catalog/Cargo.toml b/partiql-catalog/Cargo.toml index 54b77903..6b811a19 100644 --- a/partiql-catalog/Cargo.toml +++ b/partiql-catalog/Cargo.toml @@ -24,6 +24,7 @@ bench = false partiql-value = { path = "../partiql-value", version = "0.5.*" } partiql-parser = { path = "../partiql-parser", version = "0.5.*" } partiql-logical = { path = "../partiql-logical", version = "0.5.*" } + thiserror = "1.0" ordered-float = "3.*" itertools = "0.10.*" diff --git a/partiql-logical-planner/Cargo.toml b/partiql-logical-planner/Cargo.toml index 3d9b45b0..f6c42a2b 100644 --- a/partiql-logical-planner/Cargo.toml +++ b/partiql-logical-planner/Cargo.toml @@ -28,6 +28,7 @@ partiql-ast = { path = "../partiql-ast", version = "0.5.*" } partiql-parser = { path = "../partiql-parser", version = "0.5.*" } partiql-catalog = { path = "../partiql-catalog", version = "0.5.*" } partiql-ast-passes = { path = "../partiql-ast-passes", version = "0.5.*" } + ion-rs = "0.17" ordered-float = "3.*" itertools = "0.10.*" diff --git a/partiql-types/Cargo.toml b/partiql-types/Cargo.toml index 16a346cb..60d70ddb 100644 --- a/partiql-types/Cargo.toml +++ b/partiql-types/Cargo.toml @@ -21,6 +21,7 @@ edition.workspace = true bench = false [dependencies] + ordered-float = "3.*" itertools = "0.10.*" unicase = "2.6" diff --git a/partiql-types/src/lib.rs b/partiql-types/src/lib.rs index 25682808..d17f9dfa 100644 --- a/partiql-types/src/lib.rs +++ b/partiql-types/src/lib.rs @@ -1,3 +1,196 @@ +use std::collections::HashSet; + +pub trait Type {} + +impl Type for StaticType {} + +#[derive(Debug, Clone)] +pub struct StaticType { + kind: StaticTypeKind, +} + +#[allow(dead_code)] +impl StaticType { + pub fn new(kind: StaticTypeKind) -> StaticType { + StaticType { kind } + } + + pub fn new_struct(s: StructType) -> StaticType { + StaticType { + kind: StaticTypeKind::Struct(s), + } + } + + pub fn new_bag(b: BagType) -> StaticType { + StaticType { + kind: StaticTypeKind::Bag(b), + } + } + + pub fn new_array(a: ArrayType) -> StaticType { + StaticType { + kind: StaticTypeKind::Array(a), + } + } + + pub fn union_of(types: HashSet) -> StaticType { + StaticType { + kind: StaticTypeKind::AnyOf(AnyOf::new(types)), + } + } + + pub fn is_string(&self) -> bool { + matches!( + &self, + StaticType { + kind: StaticTypeKind::String + } + ) + } + + pub fn kind(&self) -> &StaticTypeKind { + &self.kind + } +} + +#[derive(Debug, Clone)] +pub enum StaticTypeKind { + Any, + AnyOf(AnyOf), + + // Absent Types + Null, + Missing, + + // Scalar Types + Int, + Int8, + Int16, + Int32, + Int64, + Bool, + Decimal, + DecimalP(usize, usize), + Float32, + Float64, + + String, + StringFixed(usize), + StringVarying(usize), + + // Container Type + Struct(StructType), + Bag(BagType), + Array(ArrayType), + // TODO Add Sexp +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct AnyOf { + types: HashSet, +} + +impl AnyOf { + pub fn new(types: HashSet) -> Self { + AnyOf { types } + } +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct StructType { + constraints: Vec, +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct StructField { + name: String, + value: StaticType, +} + +impl From<(String, T)> for StructField +where + T: Into, +{ + fn from(pair: (String, T)) -> Self { + StructField { + name: pair.0, + value: pair.1.into(), + } + } +} + +impl StructType { + pub fn unconstrained() -> Self { + StructType { + constraints: vec![], + } + } + + pub fn constrained(constraints: Vec) -> Self { + StructType { constraints } + } +} + +#[derive(Debug, Clone)] +pub enum StructConstraint { + Open(bool), + Ordered(bool), + DuplicateAttrs(bool), + Fields(StructField), +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct BagType { + element_type: Box, + constraints: Vec, +} + +impl BagType { + pub fn bag() -> Self { + BagType::bag_of(Box::new(StaticType { + kind: StaticTypeKind::Any, + })) + } + + pub fn bag_of(typ: Box) -> Self { + BagType { + element_type: typ, + constraints: vec![CollectionConstraint::Ordered(false)], + } + } +} + +#[derive(Debug, Clone)] +#[allow(dead_code)] +pub struct ArrayType { + element_type: Box, + constraints: Vec, +} + +impl ArrayType { + pub fn array() -> Self { + ArrayType::array_of(Box::new(StaticType { + kind: StaticTypeKind::Any, + })) + } + + pub fn array_of(typ: Box) -> Self { + ArrayType { + element_type: typ, + constraints: vec![CollectionConstraint::Ordered(true)], + } + } +} + +#[derive(Debug, Clone)] +enum CollectionConstraint { + Ordered(bool), +} + #[cfg(test)] mod tests { #[test] From 01716e604dfd3255259350ce0a7a40862e44ab25 Mon Sep 17 00:00:00 2001 From: Arash M <27716912+am357@users.noreply.github.com> Date: Tue, 13 Jun 2023 17:30:01 -0700 Subject: [PATCH 2/3] Address PR feedback: Remove constrained primitives --- partiql-types/src/lib.rs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/partiql-types/src/lib.rs b/partiql-types/src/lib.rs index d17f9dfa..2819e521 100644 --- a/partiql-types/src/lib.rs +++ b/partiql-types/src/lib.rs @@ -64,25 +64,17 @@ pub enum StaticTypeKind { // Scalar Types Int, - Int8, - Int16, - Int32, - Int64, Bool, Decimal, - DecimalP(usize, usize), - Float32, - Float64, + Float64, String, - StringFixed(usize), - StringVarying(usize), // Container Type Struct(StructType), Bag(BagType), Array(ArrayType), - // TODO Add Sexp + // TODO Add Sexp, TIMESTAMP } #[derive(Debug, Clone)] From 47736deb0942058da4a54c34c2a92dc8f2103e12 Mon Sep 17 00:00:00 2001 From: Arash Maymandi <27716912+am357@users.noreply.github.com> Date: Wed, 14 Jun 2023 07:36:16 -0700 Subject: [PATCH 3/3] Move `partiql-parser` to `dev-dependencies` --- partiql-ast-passes/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/partiql-ast-passes/Cargo.toml b/partiql-ast-passes/Cargo.toml index 5c3d7040..081b298d 100644 --- a/partiql-ast-passes/Cargo.toml +++ b/partiql-ast-passes/Cargo.toml @@ -21,7 +21,6 @@ bench = false [dependencies] partiql-ast = { path = "../partiql-ast", version = "0.5.*" } -partiql-parser = { path = "../partiql-parser", version = "0.5.*" } partiql-catalog = { path = "../partiql-catalog", version = "0.5.*" } partiql-types = { path = "../partiql-types", version = "0.5.*" } @@ -31,6 +30,7 @@ indexmap = "1.9" thiserror = "1.0" [dev-dependencies] +partiql-parser = { path = "../partiql-parser", version = "0.5.*" } [features] default = []