diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b77fa23b0..79dd071f8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,24 @@ Thank you to all who have contributed! ## [Unreleased] +### Added + +### Changed + +### Deprecated + +### Fixed + +### Removed + +### Security + +### Contributors +Thank you to all who have contributed! +- @ + +## [0.14.0-alpha] - 2023-12-15 + ### Added - Adds top-level IR node creation functions. - Adds `componentN` functions (destructuring) to IR nodes via Kotlin data classes @@ -43,22 +61,29 @@ Thank you to all who have contributed! - **Breaking** The default integer literal type is now 32-bit; if the literal can not fit in a 32-bit integer, it overflows to 64-bit. - **BREAKING** `PartiQLValueType` now distinguishes between Arbitrary Precision Decimal and Fixed Precision Decimal. - **BREAKING** Function Signature Changes. Now Function signature has two subclasses, `Scalar` and `Aggregation`. +- **BREAKING** Plugin Changes. Only return one Connector.Factory, use Kotlin fields. JVM signature remains the same. - **BREAKING** In the produced plan: - The new plan is fully resolved and typed. - Operators will be converted to function call. - Changes the return type of `filter_distinct` to a list if input collection is list +- Changes the `PartiQLValue` collections to implement Iterable rather than Sequence, allowing for multiple consumption. +- **BREAKING** Moves PartiQLParserBuilder.standard().build() to be PartiQLParser.default(). +- **BREAKING** Changed modeling of `EXCLUDE` in `partiql-ast` ### Deprecated ### Fixed - Fixes the CLI hanging on invalid queries. See issue #1230. - Fixes Timestamp Type parsing issue. Previously Timestamp Type would get parsed to a Time type. +- Fixes PIVOT parsing to assign the key and value as defined by spec section 14. - Fixes the physical plan compiler to return list when `DISTINCT` used with `ORDER BY` ### Removed - **Breaking** Removed IR factory in favor of static top-level functions. Change `Ast.foo()` to `foo()` - **Breaking** Removed `org.partiql.lang.planner.transforms.AstToPlan`. Use `org.partiql.planner.PartiQLPlanner`. +- **Breaking** Removed `org.partiql.lang.planner.transforms.PartiQLSchemaInferencer`. In order to achieve the same functionality, one would need to use the `org.partiql.planner.PartiQLPlanner`. + - To get the inferred type of the query result, one can do: `(plan.statement as Statement.Query).root.type` ### Security @@ -66,6 +91,8 @@ Thank you to all who have contributed! Thank you to all who have contributed! - @rchowell - @johnedquinn +- @yliuuuu +- @alancai98 ## [0.13.2-alpha] - 2023-09-29 @@ -920,6 +947,7 @@ breaking changes if migrating from v0.9.2. The breaking changes accidentally int Initial alpha release of PartiQL. [Unreleased]: https://github.com/partiql/partiql-lang-kotlin/compare/v0.13.2-alpha...HEAD +[0.14.0-alpha]: https://github.com/partiql/partiql-lang-kotlin/compare/v0.13.2-alpha...v0.14.0-alpha [0.13.2-alpha]: https://github.com/partiql/partiql-lang-kotlin/compare/v0.13.1-alpha...v0.13.2-alpha [0.13.1-alpha]: https://github.com/partiql/partiql-lang-kotlin/compare/v0.13.0-alpha...v0.13.1-alpha [0.13.0-alpha]: https://github.com/partiql/partiql-lang-kotlin/compare/v0.12.0-alpha...v0.13.0-alpha diff --git a/README.md b/README.md index ab35dde43e..0ceadb8dfe 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ This project is published to [Maven Central](https://search.maven.org/artifact/o | Group ID | Artifact ID | Recommended Version | |---------------|-----------------------|---------------------| -| `org.partiql` | `partiql-lang-kotlin` | `0.13.2` | +| `org.partiql` | `partiql-lang-kotlin` | `0.14.0` | For Maven builds, add the following to your `pom.xml`: diff --git a/gradle.properties b/gradle.properties index 8a7580fe4e..0be69b0ea8 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,5 +1,5 @@ group=org.partiql -version=0.14.0-SNAPSHOT +version=0.14.1-SNAPSHOT ossrhUsername=EMPTY ossrhPassword=EMPTY diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt index 67c3c954ec..0dd1a879ff 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/helpers/ToLegacyAst.kt @@ -708,9 +708,12 @@ private class AstTranslator(val metas: Map) : AstBaseVisi projectExpr(expr, alias, metas) } + // !! + // Legacy AST mislabels key and value in PIVOT, swapping the order here to recreate bug for compatibility. + // !! override fun visitSelectPivot(node: Select.Pivot, ctx: Ctx) = translate(node) { metas -> - val value = visitExpr(node.value, ctx) - val key = visitExpr(node.key, ctx) + val key = visitExpr(node.value, ctx) // SWAP val -> key + val value = visitExpr(node.key, ctx) // SWAP key -> val projectPivot(value, key, metas) } @@ -753,42 +756,42 @@ private class AstTranslator(val metas: Map) : AstBaseVisi } override fun visitExclude(node: Exclude, ctx: Ctx): PartiqlAst.ExcludeOp = translate(node) { metas -> - val excludeExprs = node.exprs.translate(ctx) + val excludeExprs = node.items.translate(ctx) excludeOp(excludeExprs, metas) } - override fun visitExcludeExcludeExpr(node: Exclude.ExcludeExpr, ctx: Ctx) = translate(node) { metas -> - val root = visitIdentifierSymbol(node.root, ctx) + override fun visitExcludeItem(node: Exclude.Item, ctx: Ctx) = translate(node) { metas -> + val root = visitExprVar(node.root, ctx) val steps = node.steps.translate(ctx) - excludeExpr(root = root, steps = steps, metas) + excludeExpr(root = identifier_(root.name, root.case), steps = steps, metas) } override fun visitExcludeStep(node: Exclude.Step, ctx: Ctx) = super.visitExcludeStep(node, ctx) as PartiqlAst.ExcludeStep - override fun visitExcludeStepExcludeTupleAttr(node: Exclude.Step.ExcludeTupleAttr, ctx: Ctx) = translate(node) { metas -> + override fun visitExcludeStepStructField(node: Exclude.Step.StructField, ctx: Ctx) = translate(node) { metas -> val attr = node.symbol.symbol val case = node.symbol.caseSensitivity.toLegacyCaseSensitivity() excludeTupleAttr(identifier(attr, case), metas) } - override fun visitExcludeStepExcludeCollectionIndex( - node: Exclude.Step.ExcludeCollectionIndex, + override fun visitExcludeStepCollIndex( + node: Exclude.Step.CollIndex, ctx: Ctx ) = translate(node) { metas -> val index = node.index.toLong() excludeCollectionIndex(index, metas) } - override fun visitExcludeStepExcludeTupleWildcard( - node: Exclude.Step.ExcludeTupleWildcard, + override fun visitExcludeStepStructWildcard( + node: Exclude.Step.StructWildcard, ctx: Ctx ) = translate(node) { metas -> excludeTupleWildcard(metas) } - override fun visitExcludeStepExcludeCollectionWildcard( - node: Exclude.Step.ExcludeCollectionWildcard, + override fun visitExcludeStepCollWildcard( + node: Exclude.Step.CollWildcard, ctx: Ctx ) = translate(node) { metas -> excludeCollectionWildcard(metas) diff --git a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlDialect.kt b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlDialect.kt index 676f687dc2..05dc0a9610 100644 --- a/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlDialect.kt +++ b/partiql-ast/src/main/kotlin/org/partiql/ast/sql/SqlDialect.kt @@ -86,32 +86,32 @@ public abstract class SqlDialect : AstBaseVisitor() { override fun visitExclude(node: Exclude, head: SqlBlock): SqlBlock { var h = head h = h concat " EXCLUDE " - h = h concat list(start = null, end = null) { node.exprs } + h = h concat list(start = null, end = null) { node.items } return h } - override fun visitExcludeExcludeExpr(node: Exclude.ExcludeExpr, head: SqlBlock): SqlBlock { + override fun visitExcludeItem(node: Exclude.Item, head: SqlBlock): SqlBlock { var h = head - h = h concat visitIdentifierSymbol(node.root, SqlBlock.Nil) + h = h concat visitExprVar(node.root, SqlBlock.Nil) h = h concat list(delimiter = null, start = null, end = null) { node.steps } return h } - override fun visitExcludeStepExcludeCollectionIndex(node: Exclude.Step.ExcludeCollectionIndex, head: SqlBlock): SqlBlock { + override fun visitExcludeStepCollIndex(node: Exclude.Step.CollIndex, head: SqlBlock): SqlBlock { return head concat r("[${node.index}]") } - override fun visitExcludeStepExcludeTupleWildcard(node: Exclude.Step.ExcludeTupleWildcard, head: SqlBlock): SqlBlock { + override fun visitExcludeStepStructWildcard(node: Exclude.Step.StructWildcard, head: SqlBlock): SqlBlock { return head concat r(".*") } - override fun visitExcludeStepExcludeTupleAttr(node: Exclude.Step.ExcludeTupleAttr, head: SqlBlock): SqlBlock { + override fun visitExcludeStepStructField(node: Exclude.Step.StructField, head: SqlBlock): SqlBlock { var h = head concat r(".") h = h concat visitIdentifierSymbol(node.symbol, SqlBlock.Nil) return h } - override fun visitExcludeStepExcludeCollectionWildcard(node: Exclude.Step.ExcludeCollectionWildcard, head: SqlBlock): SqlBlock { + override fun visitExcludeStepCollWildcard(node: Exclude.Step.CollWildcard, head: SqlBlock): SqlBlock { return head concat r("[*]") } diff --git a/partiql-ast/src/main/resources/partiql_ast.ion b/partiql-ast/src/main/resources/partiql_ast.ion index 1c598fcc9e..67c48663c8 100644 --- a/partiql-ast/src/main/resources/partiql_ast.ion +++ b/partiql-ast/src/main/resources/partiql_ast.ion @@ -563,17 +563,17 @@ select::[ ] exclude::{ - exprs: list::[exclude_expr], + items: list::[item], _: [ - exclude_expr::{ - root: '.identifier.symbol', + item::{ + root: '.expr.var', steps: list::[step], }, step::[ - exclude_tuple_attr::{ symbol: '.identifier.symbol' }, - exclude_collection_index::{ index: int }, - exclude_tuple_wildcard::{}, - exclude_collection_wildcard::{}, + struct_field::{ symbol: '.identifier.symbol' }, + coll_index::{ index: int }, + struct_wildcard::{}, + coll_wildcard::{}, ] ] } diff --git a/partiql-ast/src/test/kotlin/org/partiql/ast/helpers/ToLegacyAstTest.kt b/partiql-ast/src/test/kotlin/org/partiql/ast/helpers/ToLegacyAstTest.kt index 59f2082860..6f59c5069b 100644 --- a/partiql-ast/src/test/kotlin/org/partiql/ast/helpers/ToLegacyAstTest.kt +++ b/partiql-ast/src/test/kotlin/org/partiql/ast/helpers/ToLegacyAstTest.kt @@ -536,10 +536,14 @@ class ToLegacyAstTest { } } }, - expect("(project_pivot (lit 1) (lit 2))") { + expect("(project_pivot (lit 2) (lit 1))") { selectPivot { - value = exprLit(int32Value(1)) + // PIVOT 1 AT 2 + // - 1 is the VALUE + // - 2 is the KEY + // In the legacy implementation these were accidentally flipped key = exprLit(int32Value(2)) + value = exprLit(int32Value(1)) } }, expect("(project_value (lit null))") { diff --git a/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlDialectTest.kt b/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlDialectTest.kt index 422d88f5c9..28b161bef6 100644 --- a/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlDialectTest.kt +++ b/partiql-ast/src/test/kotlin/org/partiql/ast/sql/SqlDialectTest.kt @@ -1073,11 +1073,9 @@ class SqlDialectTest { type = From.Value.Type.SCAN } exclude = exclude { - exprs += excludeExcludeExpr { - root = id("t", Identifier.CaseSensitivity.INSENSITIVE) - steps += excludeStepExcludeTupleAttr { - symbol = id("a", Identifier.CaseSensitivity.INSENSITIVE) - } + items += excludeItem { + root = v("t") + steps += insensitiveExcludeStructField("a") } } } @@ -1090,21 +1088,21 @@ class SqlDialectTest { type = From.Value.Type.SCAN } exclude = exclude { - exprs += excludeExcludeExpr { - root = id("a", Identifier.CaseSensitivity.INSENSITIVE) - steps += insensitiveExcludeTupleAttr("b") + items += excludeItem { + root = v("a") + steps += insensitiveExcludeStructField("b") } - exprs += excludeExcludeExpr { - root = id("c", Identifier.CaseSensitivity.INSENSITIVE) - steps += insensitiveExcludeTupleAttr("d") + items += excludeItem { + root = v("c") + steps += insensitiveExcludeStructField("d") } - exprs += excludeExcludeExpr { - root = id("e", Identifier.CaseSensitivity.INSENSITIVE) - steps += insensitiveExcludeTupleAttr("f") + items += excludeItem { + root = v("e") + steps += insensitiveExcludeStructField("f") } - exprs += excludeExcludeExpr { - root = id("g", Identifier.CaseSensitivity.INSENSITIVE) - steps += insensitiveExcludeTupleAttr("h") + items += excludeItem { + root = v("g") + steps += insensitiveExcludeStructField("h") } } } @@ -1117,25 +1115,25 @@ class SqlDialectTest { type = From.Value.Type.SCAN } exclude = exclude { - exprs += excludeExcludeExpr { - root = id("t", Identifier.CaseSensitivity.INSENSITIVE) + items += excludeItem { + root = v("t") steps += mutableListOf( - insensitiveExcludeTupleAttr("a"), - sensitiveExcludeTupleAttr("b"), - excludeStepExcludeTupleWildcard(), - excludeStepExcludeCollectionWildcard(), - insensitiveExcludeTupleAttr("c"), + insensitiveExcludeStructField("a"), + sensitiveExcludeStructField("b"), + excludeStepStructWildcard(), + excludeStepCollWildcard(), + insensitiveExcludeStructField("c"), ) } - exprs += excludeExcludeExpr { - root = id("s", Identifier.CaseSensitivity.SENSITIVE) + items += excludeItem { + root = exprVar(id("s", Identifier.CaseSensitivity.SENSITIVE), Expr.Var.Scope.DEFAULT) steps += mutableListOf( - excludeStepExcludeCollectionIndex(0), - insensitiveExcludeTupleAttr("d"), - sensitiveExcludeTupleAttr("e"), - excludeStepExcludeCollectionWildcard(), - insensitiveExcludeTupleAttr("f"), - excludeStepExcludeTupleWildcard(), + excludeStepCollIndex(0), + insensitiveExcludeStructField("d"), + sensitiveExcludeStructField("e"), + excludeStepCollWildcard(), + insensitiveExcludeStructField("f"), + excludeStepStructWildcard(), ) } } @@ -1143,11 +1141,11 @@ class SqlDialectTest { }, ) - private fun AstBuilder.insensitiveExcludeTupleAttr(str: String) = excludeStepExcludeTupleAttr { + private fun AstBuilder.insensitiveExcludeStructField(str: String) = excludeStepStructField { symbol = id(str, Identifier.CaseSensitivity.INSENSITIVE) } - private fun AstBuilder.sensitiveExcludeTupleAttr(str: String) = excludeStepExcludeTupleAttr { + private fun AstBuilder.sensitiveExcludeStructField(str: String) = excludeStepStructField { symbol = id(str, Identifier.CaseSensitivity.SENSITIVE) } diff --git a/partiql-cli/src/main/kotlin/org/partiql/cli/Main.kt b/partiql-cli/src/main/kotlin/org/partiql/cli/Main.kt index 481030278f..6ff092079e 100644 --- a/partiql-cli/src/main/kotlin/org/partiql/cli/Main.kt +++ b/partiql-cli/src/main/kotlin/org/partiql/cli/Main.kt @@ -17,19 +17,16 @@ package org.partiql.cli import AstPrinter import com.amazon.ion.system.IonSystemBuilder -import com.amazon.ionelement.api.field -import com.amazon.ionelement.api.ionString -import com.amazon.ionelement.api.ionStructOf import org.partiql.cli.pico.PartiQLCommand import org.partiql.cli.shell.info import org.partiql.lang.eval.EvaluationSession -import org.partiql.parser.PartiQLParserBuilder +import org.partiql.parser.PartiQLParser import org.partiql.plan.debug.PlanPrinter import org.partiql.planner.PartiQLPlanner -import org.partiql.planner.PartiQLPlannerBuilder -import org.partiql.plugins.local.LocalPlugin +import org.partiql.plugins.local.LocalConnector import picocli.CommandLine import java.io.PrintStream +import java.nio.file.Paths import java.util.UUID import kotlin.system.exitProcess @@ -53,15 +50,14 @@ object Debug { private const val USER_ID = "DEBUG_USER_ID" - private val plugins = listOf(LocalPlugin()) - private val catalogs = mapOf( - "local" to ionStructOf( - field("connector_name", ionString("local")), - ) - ) + private val root = Paths.get(System.getProperty("user.home")).resolve(".partiql/local") - private val planner = PartiQLPlannerBuilder().plugins(plugins).build() - private val parser = PartiQLParserBuilder.standard().build() + private val planner = PartiQLPlanner.builder() + .catalogs( + "local" to LocalConnector.Metadata(root) + ) + .build() + private val parser = PartiQLParser.default() // !! // IMPLEMENT DEBUG BEHAVIOR HERE @@ -80,7 +76,6 @@ object Debug { val sess = PartiQLPlanner.Session( queryId = UUID.randomUUID().toString(), userId = "debug", - catalogConfig = catalogs, ) val result = planner.plan(statement, sess).plan out.info("-- Plan ----------") diff --git a/partiql-cli/src/main/kotlin/org/partiql/cli/utils/ServiceLoaderUtil.kt b/partiql-cli/src/main/kotlin/org/partiql/cli/utils/ServiceLoaderUtil.kt index d5e47477d9..0ed38ae35b 100644 --- a/partiql-cli/src/main/kotlin/org/partiql/cli/utils/ServiceLoaderUtil.kt +++ b/partiql-cli/src/main/kotlin/org/partiql/cli/utils/ServiceLoaderUtil.kt @@ -113,7 +113,7 @@ class ServiceLoaderUtil { } else { listOf() } - return plugins.flatMap { plugin -> plugin.getFunctions() } + return plugins.flatMap { plugin -> plugin.functions } .filterIsInstance() .map { partiqlFunc -> PartiQLtoExprFunction(partiqlFunc) } } @@ -269,28 +269,42 @@ class ServiceLoaderUtil { PartiQLValueType.INTERVAL -> TODO() PartiQLValueType.BAG -> { - (partiqlValue as? BagValue<*>)?.elements?.map { PartiQLtoExprValue(it) }?.let { newBag(it) } - ?: ExprValue.nullValue + if (partiqlValue.isNull) { + ExprValue.nullValue + } else { + newBag((partiqlValue as? BagValue<*>)!!.map { PartiQLtoExprValue(it) }) + } } PartiQLValueType.LIST -> { - (partiqlValue as? ListValue<*>)?.elements?.map { PartiQLtoExprValue(it) }?.let { newList(it) } - ?: ExprValue.nullValue + if (partiqlValue.isNull) { + ExprValue.nullValue + } else { + newList((partiqlValue as? ListValue<*>)!!.map { PartiQLtoExprValue(it) }) + } } PartiQLValueType.SEXP -> { - (partiqlValue as? SexpValue<*>)?.elements?.map { PartiQLtoExprValue(it) }?.let { newSexp(it) } - ?: ExprValue.nullValue + if (partiqlValue.isNull) { + ExprValue.nullValue + } else { + newSexp((partiqlValue as? SexpValue<*>)!!.map { PartiQLtoExprValue(it) }) + } } PartiQLValueType.STRUCT -> { - (partiqlValue as? StructValue<*>)?.fields?.map { - PartiQLtoExprValue(it.second).namedValue( - newString( - it.first + if (partiqlValue.isNull) { + ExprValue.nullValue + } else { + val entries = (partiqlValue as? StructValue<*>)!!.entries + entries.map { + PartiQLtoExprValue(it.second).namedValue( + newString( + it.first + ) ) - ) - }?.let { newStruct(it, StructOrdering.ORDERED) } ?: ExprValue.nullValue + }.let { newStruct(it, StructOrdering.ORDERED) } + } } PartiQLValueType.DECIMAL -> TODO() @@ -448,7 +462,7 @@ class ServiceLoaderUtil { PartiQLValueType.INTERVAL -> TODO() PartiQLValueType.BAG -> when (exprValue.type) { ExprValueType.NULL -> bagValue(null) - ExprValueType.BAG -> bagValue(exprValue.map { ExprToPartiQLValue(it, ExprToPartiQLValueType(it)) }.asSequence()) + ExprValueType.BAG -> bagValue(exprValue.map { ExprToPartiQLValue(it, ExprToPartiQLValueType(it)) }) else -> throw ExprToPartiQLValueTypeMismatchException( PartiQLValueType.BAG, ExprToPartiQLValueType(exprValue) ) @@ -460,7 +474,7 @@ class ServiceLoaderUtil { ExprToPartiQLValue( it, ExprToPartiQLValueType(it) ) - }.asSequence() + } ) else -> throw ExprToPartiQLValueTypeMismatchException( PartiQLValueType.LIST, ExprToPartiQLValueType(exprValue) @@ -473,7 +487,7 @@ class ServiceLoaderUtil { ExprToPartiQLValue( it, ExprToPartiQLValueType(it) ) - }.asSequence() + } ) else -> throw ExprToPartiQLValueTypeMismatchException( PartiQLValueType.SEXP, ExprToPartiQLValueType(exprValue) @@ -486,7 +500,7 @@ class ServiceLoaderUtil { Pair( it.name?.stringValue() ?: "", ExprToPartiQLValue(it, ExprToPartiQLValueType(it)) ) - }.asSequence() + } ) else -> throw ExprToPartiQLValueTypeMismatchException( PartiQLValueType.STRUCT, ExprToPartiQLValueType(exprValue) diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt index c1362377d1..a8ba36b7a8 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/Compiler.kt @@ -10,7 +10,6 @@ import org.partiql.eval.internal.operator.rel.RelProject import org.partiql.eval.internal.operator.rel.RelScan import org.partiql.eval.internal.operator.rex.ExprCollection import org.partiql.eval.internal.operator.rex.ExprLiteral -import org.partiql.eval.internal.operator.rex.ExprPathKey import org.partiql.eval.internal.operator.rex.ExprSelect import org.partiql.eval.internal.operator.rex.ExprStruct import org.partiql.eval.internal.operator.rex.ExprVar @@ -85,24 +84,6 @@ internal object Compiler { return super.visitRexOp(node.op, ctx) as Operator.Expr } - override fun visitRexOpPath(node: Rex.Op.Path, ctx: Unit): Operator { - val root = visitRex(node.root, ctx) - var path = root - node.steps.forEach { - when (it) { - is Rex.Op.Path.Step.Key -> { - val key = visitRex(it.key, ctx) - path = ExprPathKey(path, key) - } - is Rex.Op.Path.Step.Index -> TODO() - is Rex.Op.Path.Step.Symbol -> TODO() - is Rex.Op.Path.Step.Unpivot -> TODO() - is Rex.Op.Path.Step.Wildcard -> TODO() - } - } - return path - } - override fun visitRelOpJoin(node: Rel.Op.Join, ctx: Unit): Operator { val lhs = visitRel(node.lhs, ctx) val rhs = visitRel(node.rhs, ctx) diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelJoinNestedLoop.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelJoinNestedLoop.kt index 57ccb5b5bb..f7d5c85d2c 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelJoinNestedLoop.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelJoinNestedLoop.kt @@ -69,7 +69,7 @@ internal abstract class RelJoinNestedLoop : Operator.Relation { private fun PartiQLValue.padNull(): PartiQLValue { return when (this) { is StructValue<*> -> { - val newFields = this.fields?.map { it.first to nullValue() } + val newFields = this.entries.map { it.first to nullValue() } structValue(newFields) } else -> nullValue() diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelScan.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelScan.kt index d12a8043f7..9adaa51a44 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelScan.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rel/RelScan.kt @@ -15,7 +15,7 @@ internal class RelScan( override fun open() { val r = expr.eval(Record.empty) records = when (r) { - is CollectionValue<*> -> r.elements!!.map { Record.of(it) }.iterator() + is CollectionValue<*> -> r.map { Record.of(it) }.iterator() else -> iterator { yield(Record.of(r)) } } } diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCall.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCall.kt index cec58bb9c6..d8046c6a18 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCall.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCall.kt @@ -5,11 +5,12 @@ import org.partiql.eval.internal.Record import org.partiql.eval.internal.helpers.toNull import org.partiql.eval.internal.operator.Operator import org.partiql.spi.function.PartiQLFunction +import org.partiql.spi.function.PartiQLFunctionExperimental import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental import org.partiql.value.missingValue -@OptIn(PartiQLValueExperimental::class) +@OptIn(PartiQLValueExperimental::class, PartiQLFunctionExperimental::class) internal class ExprCall( private val fn: PartiQLFunction.Scalar, private val inputs: Array, diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCollection.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCollection.kt index 7c251887f1..e784dd37e1 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCollection.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprCollection.kt @@ -13,7 +13,7 @@ internal class ExprCollection( @PartiQLValueExperimental override fun eval(record: Record): PartiQLValue { return bagValue( - values.map { it.eval(record) }.asSequence() + values.map { it.eval(record) } ) } } diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprSelect.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprSelect.kt index 8f753db7ea..eb71d4435c 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprSelect.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprSelect.kt @@ -31,6 +31,6 @@ internal class ExprSelect( elements.add(e) } input.close() - return bagValue(elements.asSequence()) + return bagValue(elements) } } diff --git a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprStruct.kt b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprStruct.kt index 774e3e6eb7..b46c3b14cc 100644 --- a/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprStruct.kt +++ b/partiql-eval/src/main/kotlin/org/partiql/eval/internal/operator/rex/ExprStruct.kt @@ -16,7 +16,7 @@ internal class ExprStruct(val fields: List) : Operator.Expr { val value = it.value.eval(record) key.value!! to value } - return structValue(fields.asSequence()) + return structValue(fields) } internal class Field(val key: Operator.Expr, val value: Operator.Expr) diff --git a/partiql-eval/src/main/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencer.kt b/partiql-eval/src/main/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencer.kt deleted file mode 100644 index 78df76e912..0000000000 --- a/partiql-eval/src/main/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencer.kt +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Copyright Amazon.com, Inc. or its affiliates. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"). - * You may not use this file except in compliance with the License. - * A copy of the License is located at: - * - * http://aws.amazon.com/apache2.0/ - * - * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific - * language governing permissions and limitations under the License. - */ - -package org.partiql.lang.planner.transforms - -import org.partiql.annotations.ExperimentalPartiQLSchemaInferencer -import org.partiql.errors.ErrorCode -import org.partiql.errors.Problem -import org.partiql.errors.ProblemHandler -import org.partiql.errors.ProblemSeverity -import org.partiql.errors.Property -import org.partiql.errors.PropertyValueMap -import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION -import org.partiql.lang.SqlException -import org.partiql.lang.planner.PlanningProblemDetails -import org.partiql.lang.planner.transforms.PartiQLSchemaInferencer.infer -import org.partiql.lang.util.propertyValueMapOf -import org.partiql.parser.PartiQLParserBuilder -import org.partiql.plan.PartiQLPlan -import org.partiql.plan.Statement -import org.partiql.planner.PartiQLPlanner -import org.partiql.planner.PartiQLPlannerBuilder -import org.partiql.spi.Plugin -import org.partiql.types.StaticType - -/** - * Vends functions, such as [infer], to infer the output [StaticType] of a PartiQL query. - */ -@ExperimentalPartiQLSchemaInferencer -public object PartiQLSchemaInferencer { - - /** - * Infers a query's schema. - * - * As an example, consider the following query: - * ```partiql - * SELECT a FROM t - * ``` - * - * The inferred [StaticType] of the above query will resemble a [StaticType.BAG] with an element type [StaticType.STRUCT] with a - * single field named "a". - * - * Consider another valid PartiQL query: - * ```partiql - * 1 + 1 - * ``` - * - * In the above example, the inferred [StaticType] will resemble a [StaticType.INT]. - * - * @param query the PartiQL statement to infer - * @param ctx relevant metadata for inference - * @return the type of the output data. - * @throws SqlException always throws a [SqlException]. - */ - @JvmStatic - @Throws(InferenceException::class) - public fun infer( - query: String, - ctx: Context - ): StaticType { - return try { - inferInternal(query, ctx).second - } catch (t: Throwable) { - throw when (t) { - is SqlException -> InferenceException( - t.message, - t.errorCode, - t.errorContext, - t.cause - ) - else -> InferenceException( - err = Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.CompileError("Unhandled exception occurred.") - ), - cause = t - ) - } - } - } - - /** - * Context object required for performing schema inference. - */ - public class Context( - public val session: PartiQLPlanner.Session, - public val plugins: List, - public val problemHandler: ProblemHandler = ProblemThrower() - ) - - public class InferenceException( - message: String = "", - errorCode: ErrorCode, - errorContext: PropertyValueMap, - cause: Throwable? = null - ) : SqlException(message, errorCode, errorContext, cause) { - - constructor(err: Problem, cause: Throwable? = null) : - this( - message = "", - errorCode = ErrorCode.INTERNAL_ERROR, - errorContext = propertyValueMapOf( - Property.LINE_NUMBER to err.sourceLocation.lineNum, - Property.COLUMN_NUMBER to err.sourceLocation.charOffset, - Property.MESSAGE to err.details.message - ), - cause = cause - ) - } - - // - // - // INTERNAL - // - // - - internal class ProblemThrower : ProblemHandler { - override fun handleProblem(problem: Problem) { - if (problem.details.severity == ProblemSeverity.ERROR) { - throw InferenceException(problem) - } - } - } - - internal fun inferInternal(query: String, ctx: Context): Pair { - val parser = PartiQLParserBuilder.standard().build() - val planner = PartiQLPlannerBuilder() - .plugins(ctx.plugins) - .build() - val ast = parser.parse(query).root - val plan = planner.plan(ast, ctx.session, ctx.problemHandler::handleProblem).plan - if (plan.statement !is Statement.Query) { - throw InferenceException( - Problem( - UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.CompileError("Invalid statement, only `Statement.Query` supported for schema inference") - ) - ) - } - return plan to (plan.statement as Statement.Query).root.type - } -} diff --git a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEngineDefaultTest.kt b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEngineDefaultTest.kt index c305ae28a6..f6699eb733 100644 --- a/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEngineDefaultTest.kt +++ b/partiql-eval/src/test/kotlin/org/partiql/eval/internal/PartiQLEngineDefaultTest.kt @@ -1,9 +1,10 @@ package org.partiql.eval.internal +import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Test import org.partiql.eval.PartiQLEngine import org.partiql.eval.PartiQLResult -import org.partiql.parser.PartiQLParserBuilder +import org.partiql.parser.PartiQLParser import org.partiql.planner.PartiQLPlanner import org.partiql.planner.PartiQLPlannerBuilder import org.partiql.value.BagValue @@ -20,12 +21,15 @@ import kotlin.test.assertEquals /** * This holds sanity tests during the development of the [PartiQLEngine.default] implementation. + * + * TODO need to update implementations */ +@Disabled class PartiQLEngineDefaultTest { private val engine = PartiQLEngine.default() private val planner = PartiQLPlannerBuilder().build() - private val parser = PartiQLParserBuilder.standard().build() + private val parser = PartiQLParser.default() @OptIn(PartiQLValueExperimental::class) @Test @@ -38,7 +42,7 @@ class PartiQLEngineDefaultTest { val result = engine.execute(prepared) as PartiQLResult.Value val output = result.value as BagValue<*> - val expected = bagValue(sequenceOf(int32Value(1), int32Value(1))) + val expected = bagValue(int32Value(1), int32Value(1)) assertEquals(expected, output) } @@ -53,14 +57,15 @@ class PartiQLEngineDefaultTest { val result = engine.execute(prepared) as PartiQLResult.Value val output = result.value as BagValue<*> - val expected = bagValue(sequenceOf(int32Value(10), int32Value(20), int32Value(30))) + val expected = bagValue(int32Value(10), int32Value(20), int32Value(30)) assertEquals(expected, output) } @OptIn(PartiQLValueExperimental::class) @Test fun testFilter() { - val statement = parser.parse("SELECT VALUE t FROM <> AS t WHERE t;").root + val statement = + parser.parse("SELECT VALUE t FROM <> AS t WHERE t;").root val session = PartiQLPlanner.Session("q", "u") val plan = planner.plan(statement, session) @@ -68,7 +73,7 @@ class PartiQLEngineDefaultTest { val result = engine.execute(prepared) as PartiQLResult.Value val output = result.value as BagValue<*> - val expected = bagValue(sequenceOf(boolValue(true), boolValue(true))) + val expected = bagValue(boolValue(true), boolValue(true)) assertEquals(expected, output) } @@ -83,7 +88,7 @@ class PartiQLEngineDefaultTest { val result = engine.execute(prepared) as PartiQLResult.Value val output = result.value as BagValue<*> - val expected = bagValue(sequenceOf(structValue(sequenceOf("a" to int32Value(1), "b" to int32Value(2))))) + val expected = bagValue(structValue("a" to int32Value(1), "b" to int32Value(2))) assertEquals(expected, output) } @@ -98,14 +103,15 @@ class PartiQLEngineDefaultTest { val result = engine.execute(prepared) as PartiQLResult.Value val output = result.value as BagValue<*> - val expected = bagValue(sequenceOf(structValue(sequenceOf("a" to int32Value(1), "b" to nullValue())))) + val expected = bagValue(structValue("a" to int32Value(1), "b" to nullValue())) assertEquals(expected, output) } @OptIn(PartiQLValueExperimental::class) @Test fun testJoinOuterFull() { - val statement = parser.parse("SELECT a, b FROM << { 'a': 1 } >> t FULL OUTER JOIN << { 'b': 2 } >> s ON false;").root + val statement = + parser.parse("SELECT a, b FROM << { 'a': 1 } >> t FULL OUTER JOIN << { 'b': 2 } >> s ON false;").root val session = PartiQLPlanner.Session("q", "u") val plan = planner.plan(statement, session) @@ -118,20 +124,14 @@ class PartiQLEngineDefaultTest { val output = result.value as BagValue<*> val expected = bagValue( - sequenceOf( - structValue( - sequenceOf( - "a" to int32Value(1), - "b" to nullValue() - ) - ), - structValue( - sequenceOf( - "a" to nullValue(), - "b" to int32Value(2) - ) - ), - ) + structValue( + "a" to int32Value(1), + "b" to nullValue() + ), + structValue( + "a" to nullValue(), + "b" to int32Value(2) + ), ) assertEquals(expected, output, comparisonString(expected, output)) } @@ -139,7 +139,8 @@ class PartiQLEngineDefaultTest { @OptIn(PartiQLValueExperimental::class) @Test fun testJoinOuterFullOnTrue() { - val statement = parser.parse("SELECT a, b FROM << { 'a': 1 } >> t FULL OUTER JOIN << { 'b': 2 } >> s ON TRUE;").root + val statement = + parser.parse("SELECT a, b FROM << { 'a': 1 } >> t FULL OUTER JOIN << { 'b': 2 } >> s ON TRUE;").root val session = PartiQLPlanner.Session("q", "u") val plan = planner.plan(statement, session) @@ -152,14 +153,10 @@ class PartiQLEngineDefaultTest { val output = result.value as BagValue<*> val expected = bagValue( - sequenceOf( - structValue( - sequenceOf( - "a" to int32Value(1), - "b" to int32Value(2) - ) - ), - ) + structValue( + "a" to int32Value(1), + "b" to int32Value(2) + ), ) assertEquals(expected, output, comparisonString(expected, output)) } diff --git a/partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/PartiQLParserBenchmark.kt b/partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/PartiQLParserBenchmark.kt index 250d6269cd..ec9f312521 100644 --- a/partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/PartiQLParserBenchmark.kt +++ b/partiql-lang/src/jmh/kotlin/org/partiql/jmh/benchmarks/PartiQLParserBenchmark.kt @@ -29,7 +29,7 @@ import org.partiql.jmh.utils.MEASUREMENT_ITERATION_VALUE_RECOMMENDED import org.partiql.jmh.utils.MEASUREMENT_TIME_VALUE_RECOMMENDED import org.partiql.jmh.utils.WARMUP_ITERATION_VALUE_RECOMMENDED import org.partiql.jmh.utils.WARMUP_TIME_VALUE_RECOMMENDED -import org.partiql.parser.PartiQLParserBuilder +import org.partiql.parser.PartiQLParser import org.partiql.parser.PartiQLParserException import java.util.concurrent.TimeUnit @@ -854,7 +854,7 @@ internal open class PartiQLParserBenchmark { @State(Scope.Thread) open class MyState { - val parser = PartiQLParserBuilder.standard().build() + val parser = PartiQLParser.default() val query15OrsAndLikes = """ SELECT * diff --git a/partiql-lang/src/main/kotlin/org/partiql/annotations/Experimental.kt b/partiql-lang/src/main/kotlin/org/partiql/annotations/Experimental.kt index 7f827013e8..7a4322c23e 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/annotations/Experimental.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/annotations/Experimental.kt @@ -20,6 +20,3 @@ annotation class ExperimentalPartiQLCompilerPipeline // TODO: Remove from experimental once https://github.com/partiql/partiql-docs/issues/31 is resolved and a RFC is approved @RequiresOptIn(message = "Window Function is experimental. It may be changed in the future without notice.", level = RequiresOptIn.Level.ERROR) annotation class ExperimentalWindowFunctions - -@RequiresOptIn(message = "PartiQLSchemaInferencer is experimental. It may be changed in the future without notice.", level = RequiresOptIn.Level.ERROR) -annotation class ExperimentalPartiQLSchemaInferencer diff --git a/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/PartiQLParserBuilder.kt b/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/PartiQLParserBuilder.kt index b1537c2ece..ff6a0b1993 100644 --- a/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/PartiQLParserBuilder.kt +++ b/partiql-lang/src/main/kotlin/org/partiql/lang/syntax/PartiQLParserBuilder.kt @@ -44,7 +44,7 @@ class PartiQLParserBuilder { val builder = PartiQLParserBuilder() builder.constructor = { _ -> // currently don't pass custom types - val delegate = org.partiql.parser.PartiQLParserBuilder.standard().build() + val delegate = org.partiql.parser.PartiQLParser.default() PartiQLShimParser(delegate) } return builder diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/PartiQLParser.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/PartiQLParser.kt index 4b8f08616a..8cbf60dd6b 100644 --- a/partiql-parser/src/main/kotlin/org/partiql/parser/PartiQLParser.kt +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/PartiQLParser.kt @@ -15,6 +15,7 @@ package org.partiql.parser import org.partiql.ast.Statement +import org.partiql.parser.impl.PartiQLParserDefault public interface PartiQLParser { @@ -26,4 +27,13 @@ public interface PartiQLParser { val root: Statement, val locations: SourceLocations, ) + + public companion object { + + @JvmStatic + public fun builder(): PartiQLParserBuilder = PartiQLParserBuilder() + + @JvmStatic + public fun default(): PartiQLParser = PartiQLParserDefault() + } } diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/PartiQLParserBuilder.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/PartiQLParserBuilder.kt index 56318f47f6..b985ed7b04 100644 --- a/partiql-parser/src/main/kotlin/org/partiql/parser/PartiQLParserBuilder.kt +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/PartiQLParserBuilder.kt @@ -21,14 +21,6 @@ import org.partiql.parser.impl.PartiQLParserDefault */ public class PartiQLParserBuilder { - public companion object { - - @JvmStatic - public fun standard(): PartiQLParserBuilder { - return PartiQLParserBuilder() - } - } - public fun build(): PartiQLParser { return PartiQLParserDefault() } diff --git a/partiql-parser/src/main/kotlin/org/partiql/parser/impl/PartiQLParserDefault.kt b/partiql-parser/src/main/kotlin/org/partiql/parser/impl/PartiQLParserDefault.kt index 08f394dfc3..870f427420 100644 --- a/partiql-parser/src/main/kotlin/org/partiql/parser/impl/PartiQLParserDefault.kt +++ b/partiql-parser/src/main/kotlin/org/partiql/parser/impl/PartiQLParserDefault.kt @@ -52,11 +52,11 @@ import org.partiql.ast.Statement import org.partiql.ast.TableDefinition import org.partiql.ast.Type import org.partiql.ast.exclude -import org.partiql.ast.excludeExcludeExpr -import org.partiql.ast.excludeStepExcludeCollectionIndex -import org.partiql.ast.excludeStepExcludeCollectionWildcard -import org.partiql.ast.excludeStepExcludeTupleAttr -import org.partiql.ast.excludeStepExcludeTupleWildcard +import org.partiql.ast.excludeItem +import org.partiql.ast.excludeStepCollIndex +import org.partiql.ast.excludeStepCollWildcard +import org.partiql.ast.excludeStepStructField +import org.partiql.ast.excludeStepStructWildcard import org.partiql.ast.exprAgg import org.partiql.ast.exprBagOp import org.partiql.ast.exprBetween @@ -939,8 +939,8 @@ internal class PartiQLParserDefault : PartiQLParser { } override fun visitSelectPivot(ctx: GeneratedParser.SelectPivotContext) = translate(ctx) { - val key = visitExpr(ctx.pivot) - val value = visitExpr(ctx.at) + val key = visitExpr(ctx.at) + val value = visitExpr(ctx.pivot) selectPivot(key, value) } @@ -1058,20 +1058,21 @@ internal class PartiQLParserDefault : PartiQLParser { } override fun visitExcludeExpr(ctx: GeneratedParser.ExcludeExprContext) = translate(ctx) { - val root = visitSymbolPrimitive(ctx.symbolPrimitive()) + val rootId = visitSymbolPrimitive(ctx.symbolPrimitive()) + val root = exprVar(rootId, Expr.Var.Scope.DEFAULT) val steps = visitOrEmpty(ctx.excludeExprSteps()) - excludeExcludeExpr(root, steps) + excludeItem(root, steps) } override fun visitExcludeExprTupleAttr(ctx: GeneratedParser.ExcludeExprTupleAttrContext) = translate(ctx) { val identifier = visitSymbolPrimitive(ctx.symbolPrimitive()) - excludeStepExcludeTupleAttr(identifier) + excludeStepStructField(identifier) } override fun visitExcludeExprCollectionIndex(ctx: GeneratedParser.ExcludeExprCollectionIndexContext) = translate(ctx) { val index = ctx.index.text.toInt() - excludeStepExcludeCollectionIndex(index) + excludeStepCollIndex(index) } override fun visitExcludeExprCollectionAttr(ctx: GeneratedParser.ExcludeExprCollectionAttrContext) = @@ -1081,17 +1082,17 @@ internal class PartiQLParserDefault : PartiQLParser { attr, Identifier.CaseSensitivity.SENSITIVE, ) - excludeStepExcludeTupleAttr(identifier) + excludeStepStructField(identifier) } override fun visitExcludeExprCollectionWildcard(ctx: org.partiql.parser.antlr.PartiQLParser.ExcludeExprCollectionWildcardContext) = translate(ctx) { - excludeStepExcludeCollectionWildcard() + excludeStepCollWildcard() } override fun visitExcludeExprTupleWildcard(ctx: org.partiql.parser.antlr.PartiQLParser.ExcludeExprTupleWildcardContext) = translate(ctx) { - excludeStepExcludeTupleWildcard() + excludeStepStructWildcard() } /** diff --git a/partiql-plan/src/main/resources/partiql_plan.ion b/partiql-plan/src/main/resources/partiql_plan.ion index 170fa29017..1e970807f3 100644 --- a/partiql-plan/src/main/resources/partiql_plan.ion +++ b/partiql-plan/src/main/resources/partiql_plan.ion @@ -8,15 +8,32 @@ imports::{ } parti_q_l_plan::{ - globals: list::[global], // (globals ...) + catalogs: list::[catalog], // (catalogs ...) statement: statement, // (statement ...) } -// Globals - -global::{ - path: '.identifier.qualified', - type: static_type, +// Represent an instance of a database. +// - Currently, `symbols` represents all values from this catalog to be used in this plan. +// - Eventually, TODO functions may be resolved to a specific namespace within a catalog. +catalog::{ + name: string, + symbols: list::[symbol], + _: [ + // A reference to a value contained within a catalog. + symbol::{ + // The path to a value WITHIN a catalog. Note: This should not start with the catalog's name. Also, this + // should not be empty + path: list::[string], + type: static_type, + _: [ + // A reference to a symbol + ref::{ + catalog: int, + symbol: int + } + ] + } + ] } // Functions @@ -70,33 +87,19 @@ rex::{ }, global::{ - ref: int, + ref: '.catalog.symbol.ref' }, - path::{ - root: rex, - steps: list::[step], - _: [ - step::[ - // The key MUST be an integer expression. Ex: a[0], a[1 + 1] - index::{ key: rex }, - - // Case-sensitive lookup. The key MUST be a string expression. Ex: a["b"], a."b", a[CAST(b AS STRING)] - key::{ key: rex }, + path::[ + // The key MUST be an integer expression. Ex: a[0], a[1 + 1] + index::{ root: rex, key: rex }, - // Case-insensitive lookup. The key MUST be a literal string. Ex: a.b - symbol::{ key: string }, + // Case-sensitive lookup. The key MUST be a string expression. Ex: a["b"], a."b", a[CAST(b AS STRING)] + key::{ root: rex, key: rex }, - // For arrays. Ex: a[*] - // TODO: Do we need this? According to specification: [1,2,3][*] ⇔ SELECT VALUE v FROM [1, 2, 3] AS v - wildcard::{}, - - // For tuples. Ex: a.* - // TODO: Do we need this? According to specification: {'a':1, 'b':2}.* ⇔ SELECT VALUE v FROM UNPIVOT {'a':1, 'b':2} AS v - unpivot::{}, - ], - ], - }, + // Case-insensitive lookup. The key MUST be a literal string. Ex: a.b + symbol::{ root: rex, key: string }, + ], call::[ static::{ @@ -288,18 +291,14 @@ rel::{ items: list::[item], _: [ item::{ - root: '.identifier.symbol', + root: '.rex.op.var', steps: list::[step], }, step::[ - attr::{ - symbol: '.identifier.symbol', - }, - pos::{ - index: int, - }, + struct_field::{ symbol: '.identifier.symbol' }, + coll_index::{ index: int }, struct_wildcard::{}, - collection_wildcard::{}, + coll_wildcard::{}, ], ], }, diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt index 934a030b1e..5c3a01b544 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/Header.kt @@ -10,27 +10,27 @@ import org.partiql.value.PartiQLValueType * A (temporary) place for function definitions; there are whispers of loading this as information_schema. */ @OptIn(PartiQLValueExperimental::class) -public abstract class Header { +internal abstract class Header { /** * Definition namespace e.g. partiql, spark, redshift, ... */ - public abstract val namespace: String + abstract val namespace: String /** * Scalar function signatures available via call syntax. */ - public open val functions: List = emptyList() + open val functions: List = emptyList() /** * Hidden scalar function signatures available via operator or special form syntax. */ - public open val operators: List = emptyList() + open val operators: List = emptyList() /** * Aggregation function signatures. */ - public open val aggregations: List = emptyList() + open val aggregations: List = emptyList() /** * Type relationships; this is primarily a helper for defining operators. @@ -55,7 +55,7 @@ public abstract class Header { // HELPERS // ==================================== - public companion object { + companion object { @JvmStatic internal fun unary(name: String, returns: PartiQLValueType, value: PartiQLValueType) = diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt index af7e6698d3..2bbae8b146 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLHeader.kt @@ -25,7 +25,7 @@ import org.partiql.value.PartiQLValueType.TIMESTAMP * */ @OptIn(PartiQLValueExperimental::class) -public object PartiQLHeader : Header() { +internal object PartiQLHeader : Header() { override val namespace: String = "partiql" diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlanner.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlanner.kt index 6960fcd430..8030703b34 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlanner.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlanner.kt @@ -1,6 +1,5 @@ package org.partiql.planner -import com.amazon.ionelement.api.StructElement import org.partiql.ast.Statement import org.partiql.errors.Problem import org.partiql.errors.ProblemCallback @@ -39,7 +38,6 @@ public interface PartiQLPlanner { * @property userId * @property currentCatalog * @property currentDirectory - * @property catalogConfig * @property instant */ public class Session( @@ -47,7 +45,15 @@ public interface PartiQLPlanner { public val userId: String, public val currentCatalog: String? = null, public val currentDirectory: List = emptyList(), - public val catalogConfig: Map = emptyMap(), public val instant: Instant = Instant.now(), ) + + public companion object { + + @JvmStatic + public fun builder(): PartiQLPlannerBuilder = PartiQLPlannerBuilder() + + @JvmStatic + public fun default(): PartiQLPlanner = PartiQLPlannerBuilder().build() + } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerBuilder.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerBuilder.kt index 64ec9af2e2..05d4b06fde 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerBuilder.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerBuilder.kt @@ -1,27 +1,51 @@ package org.partiql.planner -import org.partiql.spi.Plugin +import org.partiql.spi.connector.ConnectorMetadata /** - * PartiQLPlannerBuilder + * PartiQLPlannerBuilder is used to programmatically construct a [PartiQLPlanner] implementation. + * + * Usage: + * PartiQLPlanner.builder() + * .addCatalog("foo", FooConnector()) + * .addCatalog("bar", BarConnector()) + * .builder() */ public class PartiQLPlannerBuilder { private var headers: MutableList
= mutableListOf(PartiQLHeader) - private var plugins: List = emptyList() + private var catalogs: MutableMap = mutableMapOf() private var passes: List = emptyList() - public fun build(): PartiQLPlanner = PartiQLPlannerDefault(headers, plugins, passes) + /** + * Build the builder, return an implementation of a [PartiQLPlanner]. + * + * @return + */ + public fun build(): PartiQLPlanner = PartiQLPlannerDefault(headers, catalogs, passes) - public fun plugins(plugins: List): PartiQLPlannerBuilder = this.apply { - this.plugins = plugins + /** + * Java style method for assigning a Catalog name to [ConnectorMetadata]. + * + * @param catalog + * @param metadata + * @return + */ + public fun addCatalog(catalog: String, metadata: ConnectorMetadata): PartiQLPlannerBuilder = this.apply { + this.catalogs[catalog] = metadata } - public fun passes(passes: List): PartiQLPlannerBuilder = this.apply { - this.passes = passes + /** + * Kotlin style method for assigning Catalog names to [ConnectorMetadata]. + * + * @param catalogs + * @return + */ + public fun catalogs(vararg catalogs: Pair): PartiQLPlannerBuilder = this.apply { + this.catalogs = mutableMapOf(*catalogs) } - public fun headers(headers: List
): PartiQLPlannerBuilder = this.apply { - this.headers += headers + public fun passes(passes: List): PartiQLPlannerBuilder = this.apply { + this.passes = passes } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerDefault.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerDefault.kt index e0dcda8028..9ba2281fed 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerDefault.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/PartiQLPlannerDefault.kt @@ -8,14 +8,14 @@ import org.partiql.planner.internal.ir.PartiQLVersion import org.partiql.planner.internal.transforms.AstToPlan import org.partiql.planner.internal.transforms.PlanTransform import org.partiql.planner.internal.typer.PlanTyper -import org.partiql.spi.Plugin +import org.partiql.spi.connector.ConnectorMetadata /** * Default PartiQL logical query planner. */ internal class PartiQLPlannerDefault( private val headers: List
, - private val plugins: List, + private val catalogs: Map, private val passes: List, ) : PartiQLPlanner { @@ -25,7 +25,7 @@ internal class PartiQLPlannerDefault( onProblem: ProblemCallback, ): PartiQLPlanner.Result { // 0. Initialize the planning environment - val env = Env(headers, plugins, session) + val env = Env(headers, catalogs, session) // 1. Normalize val ast = statement.normalize() @@ -37,7 +37,7 @@ internal class PartiQLPlannerDefault( val typer = PlanTyper(env, onProblem) val internal = org.partiql.planner.internal.ir.PartiQLPlan( version = PartiQLVersion.VERSION_0_1, - globals = env.globals, + catalogs = env.catalogs, statement = typer.resolve(root), ) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt index 2e85724933..80a7eff38a 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/Env.kt @@ -3,31 +3,26 @@ package org.partiql.planner.internal import org.partiql.planner.Header import org.partiql.planner.PartiQLPlanner import org.partiql.planner.internal.ir.Agg +import org.partiql.planner.internal.ir.Catalog import org.partiql.planner.internal.ir.Fn -import org.partiql.planner.internal.ir.Global import org.partiql.planner.internal.ir.Identifier import org.partiql.planner.internal.ir.Rel import org.partiql.planner.internal.ir.Rex -import org.partiql.planner.internal.ir.global -import org.partiql.planner.internal.ir.identifierQualified import org.partiql.planner.internal.ir.identifierSymbol import org.partiql.planner.internal.typer.FnResolver import org.partiql.spi.BindingCase import org.partiql.spi.BindingName import org.partiql.spi.BindingPath -import org.partiql.spi.Plugin -import org.partiql.spi.connector.Connector import org.partiql.spi.connector.ConnectorMetadata import org.partiql.spi.connector.ConnectorObjectHandle import org.partiql.spi.connector.ConnectorObjectPath import org.partiql.spi.connector.ConnectorSession -import org.partiql.spi.connector.Constants import org.partiql.types.StaticType import org.partiql.types.StructType import org.partiql.types.TupleConstraint /** - * Handle for associating a catalog with the metadata; pair of catalog to data. + * Handle for associating a catalog name with catalog related metadata objects. */ internal typealias Handle = Pair @@ -69,11 +64,13 @@ internal class TypeEnv( /** * Metadata regarding a resolved variable. + * @property depth The depth/level of the path match. */ internal sealed interface ResolvedVar { public val type: StaticType public val ordinal: Int + public val depth: Int /** * Metadata for a resolved local variable. @@ -88,20 +85,22 @@ internal sealed interface ResolvedVar { override val ordinal: Int, val rootType: StaticType, val replacementSteps: List, - val depth: Int + override val depth: Int ) : ResolvedVar /** * Metadata for a resolved global variable * * @property type Resolved StaticType - * @property ordinal Index offset in the environment `globals` list + * @property ordinal The relevant catalog's index offset in the [Env.catalogs] list * @property depth The depth/level of the path match. + * @property position The relevant value's index offset in the [Catalog.values] list */ class Global( override val type: StaticType, override val ordinal: Int, - val depth: Int, + override val depth: Int, + val position: Int ) : ResolvedVar } @@ -122,19 +121,19 @@ internal enum class ResolutionStrategy { * PartiQL Planner Global Environment of Catalogs backed by given plugins. * * @property headers List of namespaced definitions - * @property plugins List of plugins for global resolution + * @property catalogs List of plugins for global resolution * @property session Session details */ internal class Env( private val headers: List
, - private val plugins: List, + private val connectors: Map, private val session: PartiQLPlanner.Session, ) { /** * Collect the list of all referenced globals during planning. */ - public val globals = mutableListOf() + public val catalogs = mutableListOf() /** * Encapsulate all function resolving logic within [FnResolver]. @@ -147,26 +146,7 @@ internal class Env( } /** - * Map of catalog names to its underlying connector - */ - private val catalogs: Map - - // Initialize connectors - init { - val catalogs = mutableMapOf() - val connectors = plugins.flatMap { it.getConnectorFactories() } - // map catalogs to connectors - for ((catalog, config) in session.catalogConfig) { - // find corresponding connector - val connectorName = config[Constants.CONFIG_KEY_CONNECTOR_NAME].stringValue - val connector = connectors.first { it.getName() == connectorName } - // initialize connector with given config - catalogs[catalog] = connector.create(catalog, config) - } - this.catalogs = catalogs.toMap() - } - /** * Leverages a [FunctionResolver] to find a matching function defined in the [Header] scalar function catalog. */ internal fun resolveFn(fn: Fn.Unresolved, args: List) = fnResolver.resolveFn(fn, args) @@ -197,8 +177,9 @@ internal class Env( * @return */ internal fun getObjectDescriptor(handle: Handle): StaticType { - val metadata = getMetadata(BindingName(handle.first, BindingCase.SENSITIVE))!!.second - return metadata.getObjectType(connectorSession, handle.second)!! + val metadata = getMetadata(BindingName(handle.first, BindingCase.SENSITIVE))?.second + ?: error("Unable to fetch connector metadata based on handle $handle") + return metadata.getObjectType(connectorSession, handle.second) ?: error("Unable to produce Static Type") } /** @@ -208,9 +189,8 @@ internal class Env( * @return */ private fun getMetadata(catalogName: BindingName): Handle? { - val catalogKey = catalogs.keys.firstOrNull { catalogName.isEquivalentTo(it) } ?: return null - val connector = catalogs[catalogKey] ?: return null - val metadata = connector.getMetadata(connectorSession) + val catalogKey = connectors.keys.firstOrNull { catalogName.isEquivalentTo(it) } ?: return null + val metadata = connectors[catalogKey] ?: return null return catalogKey to metadata } @@ -231,15 +211,47 @@ internal class Env( getObjectHandle(cat, catalogPath)?.let { handle -> getObjectDescriptor(handle).let { type -> val depth = calculateMatched(originalPath, catalogPath, handle.second.absolutePath) - val qualifiedPath = identifierQualified( - root = handle.first.toIdentifier(), - steps = handle.second.absolutePath.steps.map { it.toIdentifier() } - ) - val global = global(qualifiedPath, type) - globals.add(global) + val (catalogIndex, valueIndex) = getOrAddCatalogValue(handle.first, handle.second.absolutePath.steps, type) // Return resolution metadata - ResolvedVar.Global(type, globals.size - 1, depth) + ResolvedVar.Global(type, catalogIndex, depth, valueIndex) + } + } + } + } + + /** + * @return a [Pair] where [Pair.first] is the catalog index and [Pair.second] is the value index within that catalog + */ + private fun getOrAddCatalogValue(catalogName: String, valuePath: List, valueType: StaticType): Pair { + val catalogIndex = getOrAddCatalog(catalogName) + val symbols = catalogs[catalogIndex].symbols + return symbols.indexOfFirst { value -> + value.path == valuePath + }.let { index -> + when (index) { + -1 -> { + catalogs[catalogIndex] = catalogs[catalogIndex].copy( + symbols = symbols + listOf(Catalog.Symbol(valuePath, valueType)) + ) + catalogIndex to 0 + } + else -> { + catalogIndex to index + } + } + } + } + + private fun getOrAddCatalog(catalogName: String): Int { + return catalogs.indexOfFirst { catalog -> + catalog.name == catalogName + }.let { + when (it) { + -1 -> { + catalogs.add(Catalog(catalogName, emptyList())) + catalogs.lastIndex } + else -> it } } } @@ -304,7 +316,7 @@ internal class Env( /** * Check locals, else search structs. */ - private fun resolveLocalBind(path: BindingPath, locals: List): ResolvedVar? { + internal fun resolveLocalBind(path: BindingPath, locals: List): ResolvedVar? { if (path.steps.isEmpty()) { return null } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt index d357cfbb28..0d23c2a5ca 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Nodes.kt @@ -4,9 +4,11 @@ package org.partiql.planner.internal.ir import org.partiql.planner.internal.ir.builder.AggResolvedBuilder import org.partiql.planner.internal.ir.builder.AggUnresolvedBuilder +import org.partiql.planner.internal.ir.builder.CatalogBuilder +import org.partiql.planner.internal.ir.builder.CatalogSymbolBuilder +import org.partiql.planner.internal.ir.builder.CatalogSymbolRefBuilder import org.partiql.planner.internal.ir.builder.FnResolvedBuilder import org.partiql.planner.internal.ir.builder.FnUnresolvedBuilder -import org.partiql.planner.internal.ir.builder.GlobalBuilder import org.partiql.planner.internal.ir.builder.IdentifierQualifiedBuilder import org.partiql.planner.internal.ir.builder.IdentifierSymbolBuilder import org.partiql.planner.internal.ir.builder.PartiQlPlanBuilder @@ -19,9 +21,9 @@ import org.partiql.planner.internal.ir.builder.RelOpErrBuilder import org.partiql.planner.internal.ir.builder.RelOpExceptBuilder import org.partiql.planner.internal.ir.builder.RelOpExcludeBuilder import org.partiql.planner.internal.ir.builder.RelOpExcludeItemBuilder -import org.partiql.planner.internal.ir.builder.RelOpExcludeStepAttrBuilder -import org.partiql.planner.internal.ir.builder.RelOpExcludeStepCollectionWildcardBuilder -import org.partiql.planner.internal.ir.builder.RelOpExcludeStepPosBuilder +import org.partiql.planner.internal.ir.builder.RelOpExcludeStepCollIndexBuilder +import org.partiql.planner.internal.ir.builder.RelOpExcludeStepCollWildcardBuilder +import org.partiql.planner.internal.ir.builder.RelOpExcludeStepStructFieldBuilder import org.partiql.planner.internal.ir.builder.RelOpExcludeStepStructWildcardBuilder import org.partiql.planner.internal.ir.builder.RelOpFilterBuilder import org.partiql.planner.internal.ir.builder.RelOpIntersectBuilder @@ -46,11 +48,9 @@ import org.partiql.planner.internal.ir.builder.RexOpCollectionBuilder import org.partiql.planner.internal.ir.builder.RexOpErrBuilder import org.partiql.planner.internal.ir.builder.RexOpGlobalBuilder import org.partiql.planner.internal.ir.builder.RexOpLitBuilder -import org.partiql.planner.internal.ir.builder.RexOpPathBuilder -import org.partiql.planner.internal.ir.builder.RexOpPathStepIndexBuilder -import org.partiql.planner.internal.ir.builder.RexOpPathStepSymbolBuilder -import org.partiql.planner.internal.ir.builder.RexOpPathStepUnpivotBuilder -import org.partiql.planner.internal.ir.builder.RexOpPathStepWildcardBuilder +import org.partiql.planner.internal.ir.builder.RexOpPathIndexBuilder +import org.partiql.planner.internal.ir.builder.RexOpPathKeyBuilder +import org.partiql.planner.internal.ir.builder.RexOpPathSymbolBuilder import org.partiql.planner.internal.ir.builder.RexOpPivotBuilder import org.partiql.planner.internal.ir.builder.RexOpSelectBuilder import org.partiql.planner.internal.ir.builder.RexOpStructBuilder @@ -80,13 +80,13 @@ internal data class PartiQLPlan( @JvmField internal val version: PartiQLVersion, @JvmField - internal val globals: List, + internal val catalogs: List, @JvmField internal val statement: Statement, ) : PlanNode() { internal override val children: List by lazy { val kids = mutableListOf() - kids.addAll(globals) + kids.addAll(catalogs) kids.add(statement) kids.filterNotNull() } @@ -100,24 +100,58 @@ internal data class PartiQLPlan( } } -internal data class Global( +internal data class Catalog( @JvmField - internal val path: Identifier.Qualified, + internal val name: String, @JvmField - internal val type: StaticType, + internal val symbols: List, ) : PlanNode() { internal override val children: List by lazy { val kids = mutableListOf() - kids.add(path) + kids.addAll(symbols) kids.filterNotNull() } internal override fun accept(visitor: PlanVisitor, ctx: C): R = - visitor.visitGlobal(this, ctx) + visitor.visitCatalog(this, ctx) + + internal data class Symbol( + @JvmField + internal val path: List, + @JvmField + internal val type: StaticType, + ) : PlanNode() { + internal override val children: List = emptyList() + + internal override fun accept(visitor: PlanVisitor, ctx: C): R = + visitor.visitCatalogSymbol(this, ctx) + + internal data class Ref( + @JvmField + internal val catalog: Int, + @JvmField + internal val symbol: Int, + ) : PlanNode() { + internal override val children: List = emptyList() + + internal override fun accept(visitor: PlanVisitor, ctx: C): R = + visitor.visitCatalogSymbolRef(this, ctx) + + internal companion object { + @JvmStatic + internal fun builder(): CatalogSymbolRefBuilder = CatalogSymbolRefBuilder() + } + } + + internal companion object { + @JvmStatic + internal fun builder(): CatalogSymbolBuilder = CatalogSymbolBuilder() + } + } internal companion object { @JvmStatic - internal fun builder(): GlobalBuilder = GlobalBuilder() + internal fun builder(): CatalogBuilder = CatalogBuilder() } } @@ -380,9 +414,13 @@ internal data class Rex( internal data class Global( @JvmField - internal val ref: Int, + internal val ref: Catalog.Symbol.Ref, ) : Op() { - internal override val children: List = emptyList() + internal override val children: List by lazy { + val kids = mutableListOf() + kids.add(ref) + kids.filterNotNull() + } internal override fun accept(visitor: PlanVisitor, ctx: C): R = visitor.visitRexOpGlobal(this, ctx) @@ -393,142 +431,77 @@ internal data class Rex( } } - internal data class Path( - @JvmField - internal val root: Rex, - @JvmField - internal val steps: List, - ) : Op() { - internal override val children: List by lazy { - val kids = mutableListOf() - kids.add(root) - kids.addAll(steps) - kids.filterNotNull() + internal sealed class Path : Op() { + internal override fun accept(visitor: PlanVisitor, ctx: C): R = when (this) { + is Index -> visitor.visitRexOpPathIndex(this, ctx) + is Key -> visitor.visitRexOpPathKey(this, ctx) + is Symbol -> visitor.visitRexOpPathSymbol(this, ctx) } - internal override fun accept(visitor: PlanVisitor, ctx: C): R = - visitor.visitRexOpPath(this, ctx) - - internal sealed class Step : PlanNode() { - internal override fun accept(visitor: PlanVisitor, ctx: C): R = when (this) { - is Index -> visitor.visitRexOpPathStepIndex(this, ctx) - is Symbol -> visitor.visitRexOpPathStepSymbol(this, ctx) - is Wildcard -> visitor.visitRexOpPathStepWildcard(this, ctx) - is Unpivot -> visitor.visitRexOpPathStepUnpivot(this, ctx) - is Key -> visitor.visitRexOpPathStepKey(this, ctx) + internal data class Index( + @JvmField + internal val root: Rex, + @JvmField + internal val key: Rex, + ) : Path() { + internal override val children: List by lazy { + val kids = mutableListOf() + kids.add(root) + kids.add(key) + kids.filterNotNull() } - internal data class Index( - @JvmField - internal val key: Rex, - ) : Step() { - internal override val children: List by lazy { - val kids = mutableListOf() - kids.add(key) - kids.filterNotNull() - } - - internal override fun accept(visitor: PlanVisitor, ctx: C): R = - visitor.visitRexOpPathStepIndex(this, ctx) + internal override fun accept(visitor: PlanVisitor, ctx: C): R = + visitor.visitRexOpPathIndex(this, ctx) - internal companion object { - @JvmStatic - internal fun builder(): RexOpPathStepIndexBuilder = RexOpPathStepIndexBuilder() - } + internal companion object { + @JvmStatic + internal fun builder(): RexOpPathIndexBuilder = RexOpPathIndexBuilder() } + } - /** - * This represents a case-sensitive lookup on a tuple. Ex: a['b'] or a[CAST('a' || 'b' AS STRING)]. - * This would normally contain the dot notation for case-sensitive lookup, however, due to - * limitations -- we cannot consolidate these. See [Symbol] for more information. - * - * The main difference is that this does NOT include `a."b"` - */ - internal data class Key( - @JvmField - internal val key: Rex, - ) : Step() { - internal override val children: List by lazy { - val kids = mutableListOf() - kids.add(key) - kids.filterNotNull() - } - - internal override fun accept(visitor: PlanVisitor, ctx: C): R = - visitor.visitRexOpPathStepKey(this, ctx) - - internal companion object { - @JvmStatic - internal fun builder(): RexOpPathStepIndexBuilder = RexOpPathStepIndexBuilder() - } + internal data class Key( + @JvmField + internal val root: Rex, + @JvmField + internal val key: Rex, + ) : Path() { + internal override val children: List by lazy { + val kids = mutableListOf() + kids.add(root) + kids.add(key) + kids.filterNotNull() } - /** - * This represents a lookup on a tuple. We differentiate a [Key] and a [Symbol] at this point in the - * pipeline because we NEED to retain some syntactic knowledge for the following reason: we cannot - * use the syntactic index operation on a schema -- as it is not synonymous with a tuple. In other words, - * `.""` is not interchangeable with `['']`. - * - * So, in order to temporarily differentiate the `a."b"` from `a['b']` (see [Key]), we need to maintain - * the syntactic difference here. Note that this would potentially be mitigated by typing during the AST to Plan - * transformation. - * - * That being said, this represents a lookup on a tuple such as `a.b` or `a."b"`. - */ - internal data class Symbol( - @JvmField - internal val identifier: Identifier.Symbol, - ) : Step() { - internal override val children: List by lazy { - val kids = mutableListOf() - kids.add(identifier) - kids.filterNotNull() - } - - internal override fun accept(visitor: PlanVisitor, ctx: C): R = - visitor.visitRexOpPathStepSymbol(this, ctx) + internal override fun accept(visitor: PlanVisitor, ctx: C): R = + visitor.visitRexOpPathKey(this, ctx) - internal companion object { - @JvmStatic - internal fun builder(): RexOpPathStepSymbolBuilder = RexOpPathStepSymbolBuilder() - } + internal companion object { + @JvmStatic + fun builder(): RexOpPathKeyBuilder = RexOpPathKeyBuilder() } + } - internal data class Wildcard( - @JvmField - internal val ` `: Char = ' ', - ) : Step() { - internal override val children: List = emptyList() - - internal override fun accept(visitor: PlanVisitor, ctx: C): R = - visitor.visitRexOpPathStepWildcard(this, ctx) - - internal companion object { - @JvmStatic - internal fun builder(): RexOpPathStepWildcardBuilder = RexOpPathStepWildcardBuilder() - } + internal data class Symbol( + @JvmField + internal val root: Rex, + @JvmField + internal val key: String, + ) : Path() { + internal override val children: List by lazy { + val kids = mutableListOf() + kids.add(root) + kids.filterNotNull() } - internal data class Unpivot( - @JvmField - internal val ` `: Char = ' ', - ) : Step() { - internal override val children: List = emptyList() - - internal override fun accept(visitor: PlanVisitor, ctx: C): R = - visitor.visitRexOpPathStepUnpivot(this, ctx) + internal override fun accept(visitor: PlanVisitor, ctx: C): R = + visitor.visitRexOpPathSymbol(this, ctx) - internal companion object { - @JvmStatic - internal fun builder(): RexOpPathStepUnpivotBuilder = RexOpPathStepUnpivotBuilder() - } + internal companion object { + @JvmStatic + internal fun builder(): RexOpPathSymbolBuilder = RexOpPathSymbolBuilder() } } - - internal companion object { - @JvmStatic - internal fun builder(): RexOpPathBuilder = RexOpPathBuilder() - } } internal sealed class Call : Op() { @@ -1272,7 +1245,7 @@ internal data class Rel( internal data class Item( @JvmField - internal val root: Identifier.Symbol, + internal val root: Rex.Op.Var, @JvmField internal val steps: List, ) : PlanNode() { @@ -1294,13 +1267,13 @@ internal data class Rel( internal sealed class Step : PlanNode() { internal override fun accept(visitor: PlanVisitor, ctx: C): R = when (this) { - is Attr -> visitor.visitRelOpExcludeStepAttr(this, ctx) - is Pos -> visitor.visitRelOpExcludeStepPos(this, ctx) + is StructField -> visitor.visitRelOpExcludeStepStructField(this, ctx) + is CollIndex -> visitor.visitRelOpExcludeStepCollIndex(this, ctx) is StructWildcard -> visitor.visitRelOpExcludeStepStructWildcard(this, ctx) - is CollectionWildcard -> visitor.visitRelOpExcludeStepCollectionWildcard(this, ctx) + is CollWildcard -> visitor.visitRelOpExcludeStepCollWildcard(this, ctx) } - internal data class Attr( + internal data class StructField( @JvmField internal val symbol: Identifier.Symbol, ) : Step() { @@ -1311,26 +1284,28 @@ internal data class Rel( } internal override fun accept(visitor: PlanVisitor, ctx: C): R = - visitor.visitRelOpExcludeStepAttr(this, ctx) + visitor.visitRelOpExcludeStepStructField(this, ctx) internal companion object { @JvmStatic - internal fun builder(): RelOpExcludeStepAttrBuilder = RelOpExcludeStepAttrBuilder() + internal fun builder(): RelOpExcludeStepStructFieldBuilder = + RelOpExcludeStepStructFieldBuilder() } } - internal data class Pos( + internal data class CollIndex( @JvmField internal val index: Int, ) : Step() { internal override val children: List = emptyList() internal override fun accept(visitor: PlanVisitor, ctx: C): R = - visitor.visitRelOpExcludeStepPos(this, ctx) + visitor.visitRelOpExcludeStepCollIndex(this, ctx) internal companion object { @JvmStatic - internal fun builder(): RelOpExcludeStepPosBuilder = RelOpExcludeStepPosBuilder() + internal fun builder(): RelOpExcludeStepCollIndexBuilder = + RelOpExcludeStepCollIndexBuilder() } } @@ -1350,19 +1325,19 @@ internal data class Rel( } } - internal data class CollectionWildcard( + internal data class CollWildcard( @JvmField internal val ` `: Char = ' ', ) : Step() { internal override val children: List = emptyList() internal override fun accept(visitor: PlanVisitor, ctx: C): R = - visitor.visitRelOpExcludeStepCollectionWildcard(this, ctx) + visitor.visitRelOpExcludeStepCollWildcard(this, ctx) internal companion object { @JvmStatic - internal fun builder(): RelOpExcludeStepCollectionWildcardBuilder = - RelOpExcludeStepCollectionWildcardBuilder() + internal fun builder(): RelOpExcludeStepCollWildcardBuilder = + RelOpExcludeStepCollWildcardBuilder() } } } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Plan.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Plan.kt index ddd3b9547d..e7516fba9f 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Plan.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/Plan.kt @@ -10,11 +10,17 @@ import org.partiql.value.PartiQLValueExperimental internal fun partiQLPlan( version: PartiQLVersion, - globals: List, + catalogs: List, statement: Statement, -): PartiQLPlan = PartiQLPlan(version, globals, statement) +): PartiQLPlan = PartiQLPlan(version, catalogs, statement) -internal fun global(path: Identifier.Qualified, type: StaticType): Global = Global(path, type) +internal fun catalog(name: String, symbols: List): Catalog = Catalog(name, symbols) + +internal fun catalogSymbol(path: List, type: StaticType): Catalog.Symbol = + Catalog.Symbol(path, type) + +internal fun catalogSymbolRef(catalog: Int, symbol: Int): Catalog.Symbol.Ref = + Catalog.Symbol.Ref(catalog, symbol) internal fun fnResolved(signature: FunctionSignature.Scalar): Fn.Resolved = Fn.Resolved(signature) @@ -44,23 +50,16 @@ internal fun rexOpVarResolved(ref: Int): Rex.Op.Var.Resolved = Rex.Op.Var.Resolv internal fun rexOpVarUnresolved(identifier: Identifier, scope: Rex.Op.Var.Scope): Rex.Op.Var.Unresolved = Rex.Op.Var.Unresolved(identifier, scope) -internal fun rexOpGlobal(ref: Int): Rex.Op.Global = Rex.Op.Global(ref) - -internal fun rexOpPath(root: Rex, steps: List): Rex.Op.Path = Rex.Op.Path( - root, - steps -) +internal fun rexOpGlobal(ref: Catalog.Symbol.Ref): Rex.Op.Global = Rex.Op.Global(ref) -internal fun rexOpPathStepIndex(key: Rex): Rex.Op.Path.Step.Index = Rex.Op.Path.Step.Index(key) +internal fun rexOpPathIndex(root: Rex, key: Rex): Rex.Op.Path.Index = Rex.Op.Path.Index(root, key) -internal fun rexOpPathStepKey(key: Rex): Rex.Op.Path.Step.Key = Rex.Op.Path.Step.Key(key) +internal fun rexOpPathKey(root: Rex, key: Rex): Rex.Op.Path.Key = Rex.Op.Path.Key(root, key) -internal fun rexOpPathStepSymbol(identifier: Identifier.Symbol): Rex.Op.Path.Step.Symbol = - Rex.Op.Path.Step.Symbol(identifier) - -internal fun rexOpPathStepWildcard(): Rex.Op.Path.Step.Wildcard = Rex.Op.Path.Step.Wildcard() - -internal fun rexOpPathStepUnpivot(): Rex.Op.Path.Step.Unpivot = Rex.Op.Path.Step.Unpivot() +internal fun rexOpPathSymbol(root: Rex, key: String): Rex.Op.Path.Symbol = Rex.Op.Path.Symbol( + root, + key +) internal fun rexOpCallStatic(fn: Fn, args: List): Rex.Op.Call.Static = Rex.Op.Call.Static( fn, @@ -163,19 +162,20 @@ internal fun relOpAggregateCall(agg: Agg, args: List): Rel.Op.Aggregate.Cal internal fun relOpExclude(input: Rel, items: List): Rel.Op.Exclude = Rel.Op.Exclude(input, items) -internal fun relOpExcludeItem(root: Identifier.Symbol, steps: List): - Rel.Op.Exclude.Item = Rel.Op.Exclude.Item(root, steps) +internal fun relOpExcludeItem(root: Rex.Op.Var, steps: List): Rel.Op.Exclude.Item = + Rel.Op.Exclude.Item(root, steps) -internal fun relOpExcludeStepAttr(symbol: Identifier.Symbol): Rel.Op.Exclude.Step.Attr = - Rel.Op.Exclude.Step.Attr(symbol) +internal fun relOpExcludeStepStructField(symbol: Identifier.Symbol): Rel.Op.Exclude.Step.StructField = + Rel.Op.Exclude.Step.StructField(symbol) -internal fun relOpExcludeStepPos(index: Int): Rel.Op.Exclude.Step.Pos = Rel.Op.Exclude.Step.Pos(index) +internal fun relOpExcludeStepCollIndex(index: Int): Rel.Op.Exclude.Step.CollIndex = + Rel.Op.Exclude.Step.CollIndex(index) internal fun relOpExcludeStepStructWildcard(): Rel.Op.Exclude.Step.StructWildcard = Rel.Op.Exclude.Step.StructWildcard() -internal fun relOpExcludeStepCollectionWildcard(): Rel.Op.Exclude.Step.CollectionWildcard = - Rel.Op.Exclude.Step.CollectionWildcard() +internal fun relOpExcludeStepCollWildcard(): Rel.Op.Exclude.Step.CollWildcard = + Rel.Op.Exclude.Step.CollWildcard() internal fun relOpErr(message: String): Rel.Op.Err = Rel.Op.Err(message) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/builder/PlanBuilder.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/builder/PlanBuilder.kt index 8ea53d5c5f..bbfd68e420 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/builder/PlanBuilder.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/builder/PlanBuilder.kt @@ -3,8 +3,8 @@ package org.partiql.planner.internal.ir.builder import org.partiql.planner.internal.ir.Agg +import org.partiql.planner.internal.ir.Catalog import org.partiql.planner.internal.ir.Fn -import org.partiql.planner.internal.ir.Global import org.partiql.planner.internal.ir.Identifier import org.partiql.planner.internal.ir.PartiQLPlan import org.partiql.planner.internal.ir.PartiQLVersion @@ -22,21 +22,41 @@ internal fun plan(block: PlanBuilder.() -> T) = PlanBuilder().blo internal class PlanBuilder { internal fun partiQLPlan( version: PartiQLVersion? = null, - globals: MutableList = mutableListOf(), + catalogs: MutableList = mutableListOf(), statement: Statement? = null, block: PartiQlPlanBuilder.() -> Unit = {}, ): PartiQLPlan { - val builder = PartiQlPlanBuilder(version, globals, statement) + val builder = PartiQlPlanBuilder(version, catalogs, statement) builder.block() return builder.build() } - internal fun global( - path: Identifier.Qualified? = null, + internal fun catalog( + name: String? = null, + symbols: MutableList = mutableListOf(), + block: CatalogBuilder.() -> Unit = {}, + ): Catalog { + val builder = CatalogBuilder(name, symbols) + builder.block() + return builder.build() + } + + internal fun catalogSymbol( + path: MutableList = mutableListOf(), type: StaticType? = null, - block: GlobalBuilder.() -> Unit = {}, - ): Global { - val builder = GlobalBuilder(path, type) + block: CatalogSymbolBuilder.() -> Unit = {}, + ): Catalog.Symbol { + val builder = CatalogSymbolBuilder(path, type) + builder.block() + return builder.build() + } + + internal fun catalogSymbolRef( + catalog: Int? = null, + symbol: Int? = null, + block: CatalogSymbolRefBuilder.() -> Unit = {}, + ): Catalog.Symbol.Ref { + val builder = CatalogSymbolRefBuilder(catalog, symbol) builder.block() return builder.build() } @@ -140,57 +160,41 @@ internal class PlanBuilder { return builder.build() } - internal fun rexOpGlobal(ref: Int? = null, block: RexOpGlobalBuilder.() -> Unit = {}): Rex.Op.Global { + internal fun rexOpGlobal( + ref: Catalog.Symbol.Ref? = null, + block: RexOpGlobalBuilder.() -> Unit = {} + ): Rex.Op.Global { val builder = RexOpGlobalBuilder(ref) builder.block() return builder.build() } - internal fun rexOpPath( + internal fun rexOpPathIndex( root: Rex? = null, - steps: MutableList = mutableListOf(), - block: RexOpPathBuilder.() -> Unit = {}, - ): Rex.Op.Path { - val builder = RexOpPathBuilder(root, steps) - builder.block() - return builder.build() - } - - internal fun rexOpPathStepIndex( key: Rex? = null, - block: RexOpPathStepIndexBuilder.() -> Unit = {}, - ): Rex.Op.Path.Step.Index { - val builder = RexOpPathStepIndexBuilder(key) + block: RexOpPathIndexBuilder.() -> Unit = {}, + ): Rex.Op.Path.Index { + val builder = RexOpPathIndexBuilder(root, key) builder.block() return builder.build() } - internal fun rexOpPathStepKey( + internal fun rexOpPathKey( + root: Rex? = null, key: Rex? = null, - block: RexOpPathStepKeyBuilder.() -> Unit = {}, - ): Rex.Op.Path.Step.Key { - val builder = RexOpPathStepKeyBuilder(key) - builder.block() - return builder.build() - } - - internal fun rexOpPathStepSymbol( - identifier: Identifier.Symbol? = null, - block: RexOpPathStepSymbolBuilder.() -> Unit = {}, - ): Rex.Op.Path.Step.Symbol { - val builder = RexOpPathStepSymbolBuilder(identifier) + block: RexOpPathKeyBuilder.() -> Unit = {}, + ): Rex.Op.Path.Key { + val builder = RexOpPathKeyBuilder(root, key) builder.block() return builder.build() } - internal fun rexOpPathStepWildcard(block: RexOpPathStepWildcardBuilder.() -> Unit = {}): Rex.Op.Path.Step.Wildcard { - val builder = RexOpPathStepWildcardBuilder() - builder.block() - return builder.build() - } - - internal fun rexOpPathStepUnpivot(block: RexOpPathStepUnpivotBuilder.() -> Unit = {}): Rex.Op.Path.Step.Unpivot { - val builder = RexOpPathStepUnpivotBuilder() + internal fun rexOpPathSymbol( + root: Rex? = null, + key: String? = null, + block: RexOpPathSymbolBuilder.() -> Unit = {}, + ): Rex.Op.Path.Symbol { + val builder = RexOpPathSymbolBuilder(root, key) builder.block() return builder.build() } @@ -501,7 +505,7 @@ internal class PlanBuilder { } internal fun relOpExcludeItem( - root: Identifier.Symbol? = null, + root: Rex.Op.Var? = null, steps: MutableList = mutableListOf(), block: RelOpExcludeItemBuilder.() -> Unit = {}, ): Rel.Op.Exclude.Item { @@ -510,36 +514,38 @@ internal class PlanBuilder { return builder.build() } - internal fun relOpExcludeStepAttr( + internal fun relOpExcludeStepStructField( symbol: Identifier.Symbol? = null, - block: RelOpExcludeStepAttrBuilder.() -> Unit = {}, - ): Rel.Op.Exclude.Step.Attr { - val builder = RelOpExcludeStepAttrBuilder(symbol) + block: RelOpExcludeStepStructFieldBuilder.() -> Unit = {} + ): Rel.Op.Exclude.Step.StructField { + val builder = RelOpExcludeStepStructFieldBuilder(symbol) builder.block() return builder.build() } - internal fun relOpExcludeStepPos( + internal fun relOpExcludeStepCollIndex( index: Int? = null, - block: RelOpExcludeStepPosBuilder.() -> Unit = {}, - ): Rel.Op.Exclude.Step.Pos { - val builder = RelOpExcludeStepPosBuilder(index) + block: RelOpExcludeStepCollIndexBuilder.() -> Unit = {} + ): Rel.Op.Exclude.Step.CollIndex { + val builder = RelOpExcludeStepCollIndexBuilder(index) builder.block() return builder.build() } internal fun relOpExcludeStepStructWildcard( - block: RelOpExcludeStepStructWildcardBuilder.() -> Unit = {}, + block: RelOpExcludeStepStructWildcardBuilder.() -> Unit = + {} ): Rel.Op.Exclude.Step.StructWildcard { val builder = RelOpExcludeStepStructWildcardBuilder() builder.block() return builder.build() } - internal fun relOpExcludeStepCollectionWildcard( - block: RelOpExcludeStepCollectionWildcardBuilder.() -> Unit = {}, - ): Rel.Op.Exclude.Step.CollectionWildcard { - val builder = RelOpExcludeStepCollectionWildcardBuilder() + internal fun relOpExcludeStepCollWildcard( + block: RelOpExcludeStepCollWildcardBuilder.() -> Unit = + {} + ): Rel.Op.Exclude.Step.CollWildcard { + val builder = RelOpExcludeStepCollWildcardBuilder() builder.block() return builder.build() } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/builder/PlanBuilders.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/builder/PlanBuilders.kt index 8f4cf3197d..be428c8c75 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/builder/PlanBuilders.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/builder/PlanBuilders.kt @@ -3,8 +3,8 @@ package org.partiql.planner.internal.ir.builder import org.partiql.planner.internal.ir.Agg +import org.partiql.planner.internal.ir.Catalog import org.partiql.planner.internal.ir.Fn -import org.partiql.planner.internal.ir.Global import org.partiql.planner.internal.ir.Identifier import org.partiql.planner.internal.ir.PartiQLPlan import org.partiql.planner.internal.ir.PartiQLVersion @@ -18,15 +18,15 @@ import org.partiql.value.PartiQLValueExperimental internal class PartiQlPlanBuilder( internal var version: PartiQLVersion? = null, - internal var globals: MutableList = mutableListOf(), + internal var catalogs: MutableList = mutableListOf(), internal var statement: Statement? = null, ) { internal fun version(version: PartiQLVersion?): PartiQlPlanBuilder = this.apply { this.version = version } - internal fun globals(globals: MutableList): PartiQlPlanBuilder = this.apply { - this.globals = globals + internal fun catalogs(catalogs: MutableList): PartiQlPlanBuilder = this.apply { + this.catalogs = catalogs } internal fun statement(statement: Statement?): PartiQlPlanBuilder = this.apply { @@ -34,25 +34,59 @@ internal class PartiQlPlanBuilder( } internal fun build(): PartiQLPlan = PartiQLPlan( - version = version!!, globals = globals, + version = version!!, catalogs = catalogs, statement = statement!! ) } -internal class GlobalBuilder( - internal var path: Identifier.Qualified? = null, +internal class CatalogBuilder( + internal var name: String? = null, + internal var symbols: MutableList = mutableListOf(), +) { + internal fun name(name: String?): CatalogBuilder = this.apply { + this.name = name + } + + internal fun symbols(symbols: MutableList): CatalogBuilder = this.apply { + this.symbols = symbols + } + + internal fun build(): Catalog = Catalog(name = name!!, symbols = symbols) +} + +internal class CatalogSymbolBuilder( + internal var path: MutableList = mutableListOf(), internal var type: StaticType? = null, ) { - internal fun path(path: Identifier.Qualified?): GlobalBuilder = this.apply { + internal fun path(path: MutableList): CatalogSymbolBuilder = this.apply { this.path = path } - internal fun type(type: StaticType?): GlobalBuilder = this.apply { + internal fun type(type: StaticType?): CatalogSymbolBuilder = this.apply { this.type = type } - internal fun build(): Global = Global(path = path!!, type = type!!) + internal fun build(): Catalog.Symbol = Catalog.Symbol(path = path, type = type!!) +} + +internal class CatalogSymbolRefBuilder( + internal var catalog: Int? = null, + internal var symbol: Int? = null, +) { + internal fun catalog(catalog: Int?): CatalogSymbolRefBuilder = this.apply { + this.catalog = catalog + } + + internal fun symbol(symbol: Int?): CatalogSymbolRefBuilder = this.apply { + this.symbol = symbol + } + + internal fun build(): Catalog.Symbol.Ref = Catalog.Symbol.Ref( + catalog = catalog!!, + symbol = + symbol!! + ) } internal class FnResolvedBuilder( @@ -205,66 +239,58 @@ internal class RexOpVarUnresolvedBuilder( } internal class RexOpGlobalBuilder( - internal var ref: Int? = null, + internal var ref: Catalog.Symbol.Ref? = null, ) { - internal fun ref(ref: Int?): RexOpGlobalBuilder = this.apply { + internal fun ref(ref: Catalog.Symbol.Ref?): RexOpGlobalBuilder = this.apply { this.ref = ref } internal fun build(): Rex.Op.Global = Rex.Op.Global(ref = ref!!) } -internal class RexOpPathBuilder( +internal class RexOpPathIndexBuilder( internal var root: Rex? = null, - internal var steps: MutableList = mutableListOf(), + internal var key: Rex? = null, ) { - internal fun root(root: Rex?): RexOpPathBuilder = this.apply { + internal fun root(root: Rex?): RexOpPathIndexBuilder = this.apply { this.root = root } - internal fun steps(steps: MutableList): RexOpPathBuilder = this.apply { - this.steps = steps + internal fun key(key: Rex?): RexOpPathIndexBuilder = this.apply { + this.key = key } - internal fun build(): Rex.Op.Path = Rex.Op.Path(root = root!!, steps = steps) + internal fun build(): Rex.Op.Path.Index = Rex.Op.Path.Index(root = root!!, key = key!!) } -internal class RexOpPathStepIndexBuilder( +internal class RexOpPathKeyBuilder( + internal var root: Rex? = null, internal var key: Rex? = null, ) { - internal fun key(key: Rex?): RexOpPathStepIndexBuilder = this.apply { - this.key = key + internal fun root(root: Rex?): RexOpPathKeyBuilder = this.apply { + this.root = root } - internal fun build(): Rex.Op.Path.Step.Index = Rex.Op.Path.Step.Index(key = key!!) -} - -internal class RexOpPathStepKeyBuilder( - internal var key: Rex? = null, -) { - internal fun key(key: Rex?): RexOpPathStepKeyBuilder = this.apply { + internal fun key(key: Rex?): RexOpPathKeyBuilder = this.apply { this.key = key } - internal fun build(): Rex.Op.Path.Step.Key = Rex.Op.Path.Step.Key(key = key!!) + internal fun build(): Rex.Op.Path.Key = Rex.Op.Path.Key(root = root!!, key = key!!) } -internal class RexOpPathStepSymbolBuilder( - internal var identifier: Identifier.Symbol? = null, +internal class RexOpPathSymbolBuilder( + internal var root: Rex? = null, + internal var key: String? = null, ) { - internal fun identifier(identifier: Identifier.Symbol?): RexOpPathStepSymbolBuilder = this.apply { - this.identifier = identifier + internal fun root(root: Rex?): RexOpPathSymbolBuilder = this.apply { + this.root = root } - internal fun build(): Rex.Op.Path.Step.Symbol = Rex.Op.Path.Step.Symbol(identifier = identifier!!) -} - -internal class RexOpPathStepWildcardBuilder() { - internal fun build(): Rex.Op.Path.Step.Wildcard = Rex.Op.Path.Step.Wildcard() -} + internal fun key(key: String?): RexOpPathSymbolBuilder = this.apply { + this.key = key + } -internal class RexOpPathStepUnpivotBuilder() { - internal fun build(): Rex.Op.Path.Step.Unpivot = Rex.Op.Path.Step.Unpivot() + internal fun build(): Rex.Op.Path.Symbol = Rex.Op.Path.Symbol(root = root!!, key = key!!) } internal class RexOpCallStaticBuilder( @@ -749,10 +775,10 @@ internal class RelOpExcludeBuilder( } internal class RelOpExcludeItemBuilder( - internal var root: Identifier.Symbol? = null, + internal var root: Rex.Op.Var? = null, internal var steps: MutableList = mutableListOf(), ) { - internal fun root(root: Identifier.Symbol?): RelOpExcludeItemBuilder = this.apply { + internal fun root(root: Rex.Op.Var?): RelOpExcludeItemBuilder = this.apply { this.root = root } @@ -763,33 +789,35 @@ internal class RelOpExcludeItemBuilder( internal fun build(): Rel.Op.Exclude.Item = Rel.Op.Exclude.Item(root = root!!, steps = steps) } -internal class RelOpExcludeStepAttrBuilder( +internal class RelOpExcludeStepStructFieldBuilder( internal var symbol: Identifier.Symbol? = null, ) { - internal fun symbol(symbol: Identifier.Symbol?): RelOpExcludeStepAttrBuilder = this.apply { + internal fun symbol(symbol: Identifier.Symbol?): RelOpExcludeStepStructFieldBuilder = this.apply { this.symbol = symbol } - internal fun build(): Rel.Op.Exclude.Step.Attr = Rel.Op.Exclude.Step.Attr(symbol = symbol!!) + internal fun build(): Rel.Op.Exclude.Step.StructField = Rel.Op.Exclude.Step.StructField( + symbol = + symbol!! + ) } -internal class RelOpExcludeStepPosBuilder( +internal class RelOpExcludeStepCollIndexBuilder( internal var index: Int? = null, ) { - internal fun index(index: Int?): RelOpExcludeStepPosBuilder = this.apply { + internal fun index(index: Int?): RelOpExcludeStepCollIndexBuilder = this.apply { this.index = index } - internal fun build(): Rel.Op.Exclude.Step.Pos = Rel.Op.Exclude.Step.Pos(index = index!!) + internal fun build(): Rel.Op.Exclude.Step.CollIndex = Rel.Op.Exclude.Step.CollIndex(index = index!!) } internal class RelOpExcludeStepStructWildcardBuilder() { internal fun build(): Rel.Op.Exclude.Step.StructWildcard = Rel.Op.Exclude.Step.StructWildcard() } -internal class RelOpExcludeStepCollectionWildcardBuilder() { - internal fun build(): Rel.Op.Exclude.Step.CollectionWildcard = - Rel.Op.Exclude.Step.CollectionWildcard() +internal class RelOpExcludeStepCollWildcardBuilder() { + internal fun build(): Rel.Op.Exclude.Step.CollWildcard = Rel.Op.Exclude.Step.CollWildcard() } internal class RelOpErrBuilder( diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/util/PlanRewriter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/util/PlanRewriter.kt index 0aae2cd0da..a9e807160a 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/util/PlanRewriter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/util/PlanRewriter.kt @@ -6,8 +6,8 @@ package org.partiql.planner.internal.ir.util import org.partiql.planner.internal.ir.Agg +import org.partiql.planner.internal.ir.Catalog import org.partiql.planner.internal.ir.Fn -import org.partiql.planner.internal.ir.Global import org.partiql.planner.internal.ir.Identifier import org.partiql.planner.internal.ir.PartiQLPlan import org.partiql.planner.internal.ir.PlanNode @@ -87,25 +87,37 @@ internal abstract class PlanRewriter : PlanBaseVisitor() { override fun visitPartiQLPlan(node: PartiQLPlan, ctx: C): PlanNode { val version = node.version - val globals = _visitList(node.globals, ctx, ::visitGlobal) + val globals = _visitList(node.catalogs, ctx, ::visitCatalog) val statement = visitStatement(node.statement, ctx) as Statement - return if (version !== node.version || globals !== node.globals || statement !== node.statement) { + return if (version !== node.version || globals !== node.catalogs || statement !== node.statement) { PartiQLPlan(version, globals, statement) } else { node } } - override fun visitGlobal(node: Global, ctx: C): PlanNode { - val path = visitIdentifierQualified(node.path, ctx) as Identifier.Qualified - val type = node.type - return if (path !== node.path || type !== node.type) { - Global(path, type) + override fun visitCatalog(node: Catalog, ctx: C): PlanNode { + val name = node.name + val symbols = _visitList(node.symbols, ctx, ::visitCatalogSymbol) + return if (name !== node.name || symbols !== node.symbols) { + Catalog(name, symbols) } else { node } } + override fun visitCatalogSymbol(node: Catalog.Symbol, ctx: C): PlanNode { + val path = node.path + val type = node.type + return node + } + + override fun visitCatalogSymbolRef(node: Catalog.Symbol.Ref, ctx: C): PlanNode { + val catalog = node.catalog + val symbol = node.symbol + return node + } + override fun visitFnResolved(node: Fn.Resolved, ctx: C): PlanNode { val signature = node.signature return node @@ -192,42 +204,44 @@ internal abstract class PlanRewriter : PlanBaseVisitor() { } override fun visitRexOpGlobal(node: Rex.Op.Global, ctx: C): PlanNode { - val ref = node.ref - return node + val ref = visitCatalogSymbolRef(node.ref, ctx) as Catalog.Symbol.Ref + return if (ref !== node.ref) { + Rex.Op.Global(ref) + } else { + node + } } - override fun visitRexOpPath(node: Rex.Op.Path, ctx: C): PlanNode { + override fun visitRexOpPathIndex(node: Rex.Op.Path.Index, ctx: C): PlanNode { val root = visitRex(node.root, ctx) as Rex - val steps = _visitList(node.steps, ctx, ::visitRexOpPathStep) - return if (root !== node.root || steps !== node.steps) { - Rex.Op.Path(root, steps) + val key = visitRex(node.key, ctx) as Rex + return if (root !== node.root || key !== node.key) { + Rex.Op.Path.Index(root, key) } else { node } } - override fun visitRexOpPathStepIndex(node: Rex.Op.Path.Step.Index, ctx: C): PlanNode { + override fun visitRexOpPathKey(node: Rex.Op.Path.Key, ctx: C): PlanNode { + val root = visitRex(node.root, ctx) as Rex val key = visitRex(node.key, ctx) as Rex - return if (key !== node.key) { - Rex.Op.Path.Step.Index(key) + return if (root !== node.root || key !== node.key) { + Rex.Op.Path.Key(root, key) } else { node } } - override fun visitRexOpPathStepSymbol(node: Rex.Op.Path.Step.Symbol, ctx: C): PlanNode { - val identifier = visitIdentifierSymbol(node.identifier, ctx) as Identifier.Symbol - return if (identifier !== node.identifier) { - Rex.Op.Path.Step.Symbol(identifier) + override fun visitRexOpPathSymbol(node: Rex.Op.Path.Symbol, ctx: C): PlanNode { + val root = visitRex(node.root, ctx) as Rex + val key = node.key + return if (root !== node.root || key !== node.key) { + Rex.Op.Path.Symbol(root, key) } else { node } } - override fun visitRexOpPathStepWildcard(node: Rex.Op.Path.Step.Wildcard, ctx: C): PlanNode = node - - override fun visitRexOpPathStepUnpivot(node: Rex.Op.Path.Step.Unpivot, ctx: C): PlanNode = node - override fun visitRexOpCallStatic(node: Rex.Op.Call.Static, ctx: C): PlanNode { val fn = visitFn(node.fn, ctx) as Fn val args = _visitList(node.args, ctx, ::visitRex) @@ -542,7 +556,7 @@ internal abstract class PlanRewriter : PlanBaseVisitor() { } override fun visitRelOpExcludeItem(node: Rel.Op.Exclude.Item, ctx: C): PlanNode { - val root = visitIdentifierSymbol(node.root, ctx) as Identifier.Symbol + val root = visitRexOpVar(node.root, ctx) as Rex.Op.Var val steps = _visitList(node.steps, ctx, ::visitRelOpExcludeStep) return if (root !== node.root || steps !== node.steps) { Rel.Op.Exclude.Item(root, steps) @@ -551,28 +565,32 @@ internal abstract class PlanRewriter : PlanBaseVisitor() { } } - override fun visitRelOpExcludeStepAttr(node: Rel.Op.Exclude.Step.Attr, ctx: C): PlanNode { + override fun visitRelOpExcludeStepStructField( + node: Rel.Op.Exclude.Step.StructField, + ctx: C + ): PlanNode { val symbol = visitIdentifierSymbol(node.symbol, ctx) as Identifier.Symbol return if (symbol !== node.symbol) { - Rel.Op.Exclude.Step.Attr(symbol) + Rel.Op.Exclude.Step.StructField(symbol) } else { node } } - override fun visitRelOpExcludeStepPos(node: Rel.Op.Exclude.Step.Pos, ctx: C): PlanNode { + override fun visitRelOpExcludeStepCollIndex(node: Rel.Op.Exclude.Step.CollIndex, ctx: C): + PlanNode { val index = node.index return node } override fun visitRelOpExcludeStepStructWildcard( node: Rel.Op.Exclude.Step.StructWildcard, - ctx: C, + ctx: C ): PlanNode = node - override fun visitRelOpExcludeStepCollectionWildcard( - node: Rel.Op.Exclude.Step.CollectionWildcard, - ctx: C, + override fun visitRelOpExcludeStepCollWildcard( + node: Rel.Op.Exclude.Step.CollWildcard, + ctx: C ): PlanNode = node override fun visitRelOpErr(node: Rel.Op.Err, ctx: C): PlanNode { diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/visitor/PlanBaseVisitor.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/visitor/PlanBaseVisitor.kt index afe8dac281..901d17895a 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/visitor/PlanBaseVisitor.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/visitor/PlanBaseVisitor.kt @@ -3,8 +3,8 @@ package org.partiql.planner.internal.ir.visitor import org.partiql.planner.internal.ir.Agg +import org.partiql.planner.internal.ir.Catalog import org.partiql.planner.internal.ir.Fn -import org.partiql.planner.internal.ir.Global import org.partiql.planner.internal.ir.Identifier import org.partiql.planner.internal.ir.PartiQLPlan import org.partiql.planner.internal.ir.PlanNode @@ -18,7 +18,12 @@ internal abstract class PlanBaseVisitor : PlanVisitor { override fun visitPartiQLPlan(node: PartiQLPlan, ctx: C): R = defaultVisit(node, ctx) - override fun visitGlobal(node: Global, ctx: C): R = defaultVisit(node, ctx) + public override fun visitCatalog(node: Catalog, ctx: C): R = defaultVisit(node, ctx) + + public override fun visitCatalogSymbol(node: Catalog.Symbol, ctx: C): R = defaultVisit(node, ctx) + + public override fun visitCatalogSymbolRef(node: Catalog.Symbol.Ref, ctx: C): R = + defaultVisit(node, ctx) override fun visitFn(node: Fn, ctx: C): R = when (node) { is Fn.Resolved -> visitFnResolved(node, ctx) @@ -93,30 +98,23 @@ internal abstract class PlanBaseVisitor : PlanVisitor { override fun visitRexOpGlobal(node: Rex.Op.Global, ctx: C): R = defaultVisit(node, ctx) - override fun visitRexOpPath(node: Rex.Op.Path, ctx: C): R = defaultVisit(node, ctx) - - override fun visitRexOpPathStep(node: Rex.Op.Path.Step, ctx: C): R = when (node) { - is Rex.Op.Path.Step.Index -> visitRexOpPathStepIndex(node, ctx) - is Rex.Op.Path.Step.Key -> visitRexOpPathStepKey(node, ctx) - is Rex.Op.Path.Step.Symbol -> visitRexOpPathStepSymbol(node, ctx) - is Rex.Op.Path.Step.Wildcard -> visitRexOpPathStepWildcard(node, ctx) - is Rex.Op.Path.Step.Unpivot -> visitRexOpPathStepUnpivot(node, ctx) + override fun visitRexOpPath(node: Rex.Op.Path, ctx: C): R = when (node) { + is Rex.Op.Path.Index -> visitRexOpPathIndex(node, ctx) + is Rex.Op.Path.Key -> visitRexOpPathKey(node, ctx) + is Rex.Op.Path.Symbol -> visitRexOpPathSymbol(node, ctx) } - override fun visitRexOpPathStepIndex(node: Rex.Op.Path.Step.Index, ctx: C): R = - defaultVisit(node, ctx) - - override fun visitRexOpPathStepKey(node: Rex.Op.Path.Step.Key, ctx: C): R = - defaultVisit(node, ctx) - - override fun visitRexOpPathStepSymbol(node: Rex.Op.Path.Step.Symbol, ctx: C): R = - defaultVisit(node, ctx) + override fun visitRexOpPathIndex(node: Rex.Op.Path.Index, ctx: C): R = defaultVisit( + node, + ctx + ) - override fun visitRexOpPathStepWildcard(node: Rex.Op.Path.Step.Wildcard, ctx: C): R = - defaultVisit(node, ctx) + override fun visitRexOpPathKey(node: Rex.Op.Path.Key, ctx: C): R = defaultVisit(node, ctx) - override fun visitRexOpPathStepUnpivot(node: Rex.Op.Path.Step.Unpivot, ctx: C): R = - defaultVisit(node, ctx) + override fun visitRexOpPathSymbol(node: Rex.Op.Path.Symbol, ctx: C): R = defaultVisit( + node, + ctx + ) override fun visitRexOpCall(node: Rex.Op.Call, ctx: C): R = when (node) { is Rex.Op.Call.Static -> visitRexOpCallStatic(node, ctx) @@ -236,27 +234,28 @@ internal abstract class PlanBaseVisitor : PlanVisitor { defaultVisit(node, ctx) override fun visitRelOpExcludeStep(node: Rel.Op.Exclude.Step, ctx: C): R = when (node) { - is Rel.Op.Exclude.Step.Attr -> visitRelOpExcludeStepAttr(node, ctx) - is Rel.Op.Exclude.Step.Pos -> visitRelOpExcludeStepPos(node, ctx) + is Rel.Op.Exclude.Step.StructField -> visitRelOpExcludeStepStructField(node, ctx) + is Rel.Op.Exclude.Step.CollIndex -> visitRelOpExcludeStepCollIndex(node, ctx) is Rel.Op.Exclude.Step.StructWildcard -> visitRelOpExcludeStepStructWildcard(node, ctx) - is Rel.Op.Exclude.Step.CollectionWildcard -> visitRelOpExcludeStepCollectionWildcard(node, ctx) + is Rel.Op.Exclude.Step.CollWildcard -> visitRelOpExcludeStepCollWildcard(node, ctx) } - override fun visitRelOpExcludeStepAttr(node: Rel.Op.Exclude.Step.Attr, ctx: C): R = - defaultVisit(node, ctx) + override fun visitRelOpExcludeStepStructField( + node: Rel.Op.Exclude.Step.StructField, + ctx: C + ): R = defaultVisit(node, ctx) - override fun visitRelOpExcludeStepPos(node: Rel.Op.Exclude.Step.Pos, ctx: C): R = + override fun visitRelOpExcludeStepCollIndex(node: Rel.Op.Exclude.Step.CollIndex, ctx: C): R = defaultVisit(node, ctx) override fun visitRelOpExcludeStepStructWildcard( node: Rel.Op.Exclude.Step.StructWildcard, - ctx: C, + ctx: C ): R = defaultVisit(node, ctx) - override - fun visitRelOpExcludeStepCollectionWildcard( - node: Rel.Op.Exclude.Step.CollectionWildcard, - ctx: C, + override fun visitRelOpExcludeStepCollWildcard( + node: Rel.Op.Exclude.Step.CollWildcard, + ctx: C ): R = defaultVisit(node, ctx) override fun visitRelOpErr(node: Rel.Op.Err, ctx: C): R = defaultVisit(node, ctx) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/visitor/PlanVisitor.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/visitor/PlanVisitor.kt index f3114e780c..c2fd04e7b0 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/visitor/PlanVisitor.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/ir/visitor/PlanVisitor.kt @@ -3,8 +3,8 @@ package org.partiql.planner.internal.ir.visitor import org.partiql.planner.internal.ir.Agg +import org.partiql.planner.internal.ir.Catalog import org.partiql.planner.internal.ir.Fn -import org.partiql.planner.internal.ir.Global import org.partiql.planner.internal.ir.Identifier import org.partiql.planner.internal.ir.PartiQLPlan import org.partiql.planner.internal.ir.PlanNode @@ -18,7 +18,11 @@ internal interface PlanVisitor { fun visitPartiQLPlan(node: PartiQLPlan, ctx: C): R - fun visitGlobal(node: Global, ctx: C): R + public fun visitCatalog(node: Catalog, ctx: C): R + + public fun visitCatalogSymbol(node: Catalog.Symbol, ctx: C): R + + public fun visitCatalogSymbolRef(node: Catalog.Symbol.Ref, ctx: C): R fun visitFn(node: Fn, ctx: C): R @@ -58,17 +62,11 @@ internal interface PlanVisitor { fun visitRexOpPath(node: Rex.Op.Path, ctx: C): R - fun visitRexOpPathStep(node: Rex.Op.Path.Step, ctx: C): R - - fun visitRexOpPathStepIndex(node: Rex.Op.Path.Step.Index, ctx: C): R + fun visitRexOpPathIndex(node: Rex.Op.Path.Index, ctx: C): R - fun visitRexOpPathStepKey(node: Rex.Op.Path.Step.Key, ctx: C): R + fun visitRexOpPathKey(node: Rex.Op.Path.Key, ctx: C): R - fun visitRexOpPathStepSymbol(node: Rex.Op.Path.Step.Symbol, ctx: C): R - - fun visitRexOpPathStepWildcard(node: Rex.Op.Path.Step.Wildcard, ctx: C): R - - fun visitRexOpPathStepUnpivot(node: Rex.Op.Path.Step.Unpivot, ctx: C): R + fun visitRexOpPathSymbol(node: Rex.Op.Path.Symbol, ctx: C): R fun visitRexOpCall(node: Rex.Op.Call, ctx: C): R @@ -142,19 +140,14 @@ internal interface PlanVisitor { fun visitRelOpExcludeStep(node: Rel.Op.Exclude.Step, ctx: C): R - fun visitRelOpExcludeStepAttr(node: Rel.Op.Exclude.Step.Attr, ctx: C): R + fun visitRelOpExcludeStepStructField(node: Rel.Op.Exclude.Step.StructField, ctx: C): R - fun visitRelOpExcludeStepPos(node: Rel.Op.Exclude.Step.Pos, ctx: C): R + fun visitRelOpExcludeStepCollIndex(node: Rel.Op.Exclude.Step.CollIndex, ctx: C): R fun visitRelOpExcludeStepStructWildcard(node: Rel.Op.Exclude.Step.StructWildcard, ctx: C): R - fun visitRelOpExcludeStepCollectionWildcard( - node: Rel.Op.Exclude.Step.CollectionWildcard, - ctx: C, - ): R - - fun visitRelOpErr(node: Rel.Op.Err, ctx: C): R + fun visitRelOpExcludeStepCollWildcard(node: Rel.Op.Exclude.Step.CollWildcard, ctx: C): R fun visitRelOpErr(node: Rel.Op.Err, ctx: C): R fun visitRelBinding(node: Rel.Binding, ctx: C): R } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt index b52698db6b..7bf954c873 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/PlanTransform.kt @@ -3,22 +3,16 @@ package org.partiql.planner.internal.transforms import org.partiql.errors.ProblemCallback import org.partiql.plan.PlanNode import org.partiql.plan.partiQLPlan -import org.partiql.plan.rex -import org.partiql.plan.rexOpLit -import org.partiql.plan.rexOpPathStepKey -import org.partiql.plan.rexOpPathStepSymbol import org.partiql.planner.internal.ir.Agg +import org.partiql.planner.internal.ir.Catalog import org.partiql.planner.internal.ir.Fn -import org.partiql.planner.internal.ir.Global import org.partiql.planner.internal.ir.Identifier import org.partiql.planner.internal.ir.PartiQLPlan import org.partiql.planner.internal.ir.Rel import org.partiql.planner.internal.ir.Rex import org.partiql.planner.internal.ir.Statement import org.partiql.planner.internal.ir.visitor.PlanBaseVisitor -import org.partiql.types.StaticType import org.partiql.value.PartiQLValueExperimental -import org.partiql.value.stringValue /** * This is an internal utility to translate from the internal unresolved plan used for typing to the public plan IR. @@ -35,15 +29,22 @@ internal object PlanTransform : PlanBaseVisitor() { } override fun visitPartiQLPlan(node: PartiQLPlan, ctx: ProblemCallback): org.partiql.plan.PartiQLPlan { - val globals = node.globals.map { visitGlobal(it, ctx) } + val catalogs = node.catalogs.map { visitCatalog(it, ctx) } val statement = visitStatement(node.statement, ctx) - return partiQLPlan(globals, statement) + return partiQLPlan(catalogs, statement) } - override fun visitGlobal(node: Global, ctx: ProblemCallback): org.partiql.plan.Global { - val path = visitIdentifierQualified(node.path, ctx) - val type = node.type - return org.partiql.plan.global(path, type) + override fun visitCatalog(node: Catalog, ctx: ProblemCallback): org.partiql.plan.Catalog { + val symbols = node.symbols.map { visitCatalogSymbol(it, ctx) } + return org.partiql.plan.Catalog(node.name, symbols) + } + + override fun visitCatalogSymbol(node: Catalog.Symbol, ctx: ProblemCallback): org.partiql.plan.Catalog.Symbol { + return org.partiql.plan.Catalog.Symbol(node.path, node.type) + } + + override fun visitCatalogSymbolRef(node: Catalog.Symbol.Ref, ctx: ProblemCallback): org.partiql.plan.Catalog.Symbol.Ref { + return org.partiql.plan.Catalog.Symbol.Ref(node.catalog, node.symbol) } override fun visitFnResolved(node: Fn.Resolved, ctx: ProblemCallback) = org.partiql.plan.fn(node.signature) @@ -108,37 +109,26 @@ internal object PlanTransform : PlanBaseVisitor() { override fun visitRexOpVarUnresolved(node: Rex.Op.Var.Unresolved, ctx: ProblemCallback) = org.partiql.plan.Rex.Op.Err("Unresolved variable $node") - override fun visitRexOpGlobal(node: Rex.Op.Global, ctx: ProblemCallback) = org.partiql.plan.Rex.Op.Global(node.ref) + override fun visitRexOpGlobal(node: Rex.Op.Global, ctx: ProblemCallback) = org.partiql.plan.Rex.Op.Global( + ref = visitCatalogSymbolRef(node.ref, ctx) + ) - override fun visitRexOpPath(node: Rex.Op.Path, ctx: ProblemCallback): org.partiql.plan.Rex.Op.Path { + override fun visitRexOpPathIndex(node: Rex.Op.Path.Index, ctx: ProblemCallback): PlanNode { val root = visitRex(node.root, ctx) - val steps = node.steps.map { visitRexOpPathStep(it, ctx) } - return org.partiql.plan.Rex.Op.Path(root, steps) + val key = visitRex(node.root, ctx) + return org.partiql.plan.Rex.Op.Path.Index(root, key) } - override fun visitRexOpPathStep(node: Rex.Op.Path.Step, ctx: ProblemCallback) = - super.visit(node, ctx) as org.partiql.plan.Rex.Op.Path.Step - - override fun visitRexOpPathStepIndex(node: Rex.Op.Path.Step.Index, ctx: ProblemCallback) = - org.partiql.plan.Rex.Op.Path.Step.Index( - key = visitRex(node.key, ctx), - ) - - @OptIn(PartiQLValueExperimental::class) - override fun visitRexOpPathStepSymbol(node: Rex.Op.Path.Step.Symbol, ctx: ProblemCallback) = when (node.identifier.caseSensitivity) { - Identifier.CaseSensitivity.SENSITIVE -> rexOpPathStepKey(rex(StaticType.STRING, rexOpLit(stringValue(node.identifier.symbol)))) - Identifier.CaseSensitivity.INSENSITIVE -> rexOpPathStepSymbol(node.identifier.symbol) + override fun visitRexOpPathKey(node: Rex.Op.Path.Key, ctx: ProblemCallback): PlanNode { + val root = visitRex(node.root, ctx) + val key = visitRex(node.root, ctx) + return org.partiql.plan.Rex.Op.Path.Key(root, key) } - override fun visitRexOpPathStepKey(node: Rex.Op.Path.Step.Key, ctx: ProblemCallback): PlanNode = rexOpPathStepKey( - key = visitRex(node.key, ctx) - ) - - override fun visitRexOpPathStepWildcard(node: Rex.Op.Path.Step.Wildcard, ctx: ProblemCallback) = - org.partiql.plan.Rex.Op.Path.Step.Wildcard() - - override fun visitRexOpPathStepUnpivot(node: Rex.Op.Path.Step.Unpivot, ctx: ProblemCallback) = - org.partiql.plan.Rex.Op.Path.Step.Unpivot() + override fun visitRexOpPathSymbol(node: Rex.Op.Path.Symbol, ctx: ProblemCallback): PlanNode { + val root = visitRex(node.root, ctx) + return org.partiql.plan.Rex.Op.Path.Symbol(root, node.key) + } override fun visitRexOpCall(node: Rex.Op.Call, ctx: ProblemCallback) = super.visitRexOpCall(node, ctx) as org.partiql.plan.Rex.Op @@ -352,22 +342,27 @@ internal object PlanTransform : PlanBaseVisitor() { items = node.items.map { visitRelOpExcludeItem(it, ctx) }, ) - override fun visitRelOpExcludeItem(node: Rel.Op.Exclude.Item, ctx: ProblemCallback) = - org.partiql.plan.Rel.Op.Exclude.Item( - root = visitIdentifierSymbol(node.root, ctx), + override fun visitRelOpExcludeItem(node: Rel.Op.Exclude.Item, ctx: ProblemCallback): org.partiql.plan.Rel.Op.Exclude.Item { + val root = when (node.root) { + is Rex.Op.Var.Resolved -> visitRexOpVar(node.root, ctx) as org.partiql.plan.Rex.Op.Var + is Rex.Op.Var.Unresolved -> org.partiql.plan.Rex.Op.Var(-1) // unresolved in `PlanTyper` results in error + } + return org.partiql.plan.Rel.Op.Exclude.Item( + root = root, steps = node.steps.map { visitRelOpExcludeStep(it, ctx) }, ) + } override fun visitRelOpExcludeStep(node: Rel.Op.Exclude.Step, ctx: ProblemCallback) = super.visit(node, ctx) as org.partiql.plan.Rel.Op.Exclude.Step - override fun visitRelOpExcludeStepAttr(node: Rel.Op.Exclude.Step.Attr, ctx: ProblemCallback) = - org.partiql.plan.Rel.Op.Exclude.Step.Attr( + override fun visitRelOpExcludeStepStructField(node: Rel.Op.Exclude.Step.StructField, ctx: ProblemCallback) = + org.partiql.plan.Rel.Op.Exclude.Step.StructField( symbol = visitIdentifierSymbol(node.symbol, ctx), ) - override fun visitRelOpExcludeStepPos(node: Rel.Op.Exclude.Step.Pos, ctx: ProblemCallback) = - org.partiql.plan.Rel.Op.Exclude.Step.Pos( + override fun visitRelOpExcludeStepCollIndex(node: Rel.Op.Exclude.Step.CollIndex, ctx: ProblemCallback) = + org.partiql.plan.Rel.Op.Exclude.Step.CollIndex( index = node.index, ) @@ -376,10 +371,10 @@ internal object PlanTransform : PlanBaseVisitor() { ctx: ProblemCallback, ) = org.partiql.plan.Rel.Op.Exclude.Step.StructWildcard() - override fun visitRelOpExcludeStepCollectionWildcard( - node: Rel.Op.Exclude.Step.CollectionWildcard, + override fun visitRelOpExcludeStepCollWildcard( + node: Rel.Op.Exclude.Step.CollWildcard, ctx: ProblemCallback, - ) = org.partiql.plan.Rel.Op.Exclude.Step.CollectionWildcard() + ) = org.partiql.plan.Rel.Op.Exclude.Step.CollWildcard() override fun visitRelOpErr(node: Rel.Op.Err, ctx: ProblemCallback) = org.partiql.plan.Rel.Op.Err(node.message) diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt index 09772acbfc..73cba006fc 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RelConverter.kt @@ -43,9 +43,9 @@ import org.partiql.planner.internal.ir.relOpErr import org.partiql.planner.internal.ir.relOpExcept import org.partiql.planner.internal.ir.relOpExclude import org.partiql.planner.internal.ir.relOpExcludeItem -import org.partiql.planner.internal.ir.relOpExcludeStepAttr -import org.partiql.planner.internal.ir.relOpExcludeStepCollectionWildcard -import org.partiql.planner.internal.ir.relOpExcludeStepPos +import org.partiql.planner.internal.ir.relOpExcludeStepCollIndex +import org.partiql.planner.internal.ir.relOpExcludeStepCollWildcard +import org.partiql.planner.internal.ir.relOpExcludeStepStructField import org.partiql.planner.internal.ir.relOpExcludeStepStructWildcard import org.partiql.planner.internal.ir.relOpFilter import org.partiql.planner.internal.ir.relOpIntersect @@ -483,22 +483,22 @@ internal object RelConverter { return input } val type = input.type // PlanTyper handles typing the exclusion - val items = exclude.exprs.map { convertExcludeItem(it) } + val items = exclude.items.map { convertExcludeItem(it) } val op = relOpExclude(input, items) return rel(type, op) } - private fun convertExcludeItem(expr: Exclude.ExcludeExpr): Rel.Op.Exclude.Item { - val root = AstToPlan.convert(expr.root) + private fun convertExcludeItem(expr: Exclude.Item): Rel.Op.Exclude.Item { + val root = (expr.root.toRex(env)).op as Rex.Op.Var val steps = expr.steps.map { convertExcludeStep(it) } return relOpExcludeItem(root, steps) } private fun convertExcludeStep(step: Exclude.Step): Rel.Op.Exclude.Step = when (step) { - is Exclude.Step.ExcludeTupleAttr -> relOpExcludeStepAttr(AstToPlan.convert(step.symbol)) - is Exclude.Step.ExcludeCollectionIndex -> relOpExcludeStepPos(step.index) - is Exclude.Step.ExcludeCollectionWildcard -> relOpExcludeStepCollectionWildcard() - is Exclude.Step.ExcludeTupleWildcard -> relOpExcludeStepStructWildcard() + is Exclude.Step.StructField -> relOpExcludeStepStructField(AstToPlan.convert(step.symbol)) + is Exclude.Step.CollIndex -> relOpExcludeStepCollIndex(step.index) + is Exclude.Step.StructWildcard -> relOpExcludeStepStructWildcard() + is Exclude.Step.CollWildcard -> relOpExcludeStepCollWildcard() } // /** diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt index c1264b1fc4..498dea8a2c 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/transforms/RexConverter.kt @@ -26,17 +26,15 @@ import org.partiql.planner.internal.ir.Identifier import org.partiql.planner.internal.ir.Rex import org.partiql.planner.internal.ir.builder.plan import org.partiql.planner.internal.ir.fnUnresolved +import org.partiql.planner.internal.ir.identifierQualified import org.partiql.planner.internal.ir.identifierSymbol import org.partiql.planner.internal.ir.rex import org.partiql.planner.internal.ir.rexOpCallStatic import org.partiql.planner.internal.ir.rexOpCollection import org.partiql.planner.internal.ir.rexOpLit -import org.partiql.planner.internal.ir.rexOpPath -import org.partiql.planner.internal.ir.rexOpPathStepIndex -import org.partiql.planner.internal.ir.rexOpPathStepKey -import org.partiql.planner.internal.ir.rexOpPathStepSymbol -import org.partiql.planner.internal.ir.rexOpPathStepUnpivot -import org.partiql.planner.internal.ir.rexOpPathStepWildcard +import org.partiql.planner.internal.ir.rexOpPathIndex +import org.partiql.planner.internal.ir.rexOpPathKey +import org.partiql.planner.internal.ir.rexOpPathSymbol import org.partiql.planner.internal.ir.rexOpStruct import org.partiql.planner.internal.ir.rexOpStructField import org.partiql.planner.internal.ir.rexOpSubquery @@ -53,6 +51,7 @@ import org.partiql.value.int32Value import org.partiql.value.int64Value import org.partiql.value.io.PartiQLValueIonReaderBuilder import org.partiql.value.nullValue +import org.partiql.value.stringValue /** * Converts an AST expression node to a Plan Rex node; ignoring any typing. @@ -158,39 +157,87 @@ internal object RexConverter { } } + private fun mergeIdentifiers(root: Identifier, steps: List): Identifier { + if (steps.isEmpty()) { + return root + } + val (newRoot, firstSteps) = when (root) { + is Identifier.Symbol -> root to emptyList() + is Identifier.Qualified -> root.root to root.steps + } + val followingSteps = steps.flatMap { step -> + when (step) { + is Identifier.Symbol -> listOf(step) + is Identifier.Qualified -> listOf(step.root) + step.steps + } + } + return identifierQualified(newRoot, firstSteps + followingSteps) + } + override fun visitExprPath(node: Expr.Path, context: Env): Rex { - val type = (StaticType.ANY) // Args val root = visitExprCoerce(node.root, context) - val steps = node.steps.map { - when (it) { - is Expr.Path.Step.Index -> { - val key = visitExprCoerce(it.key, context) - when (val astKey = it.key) { - is Expr.Lit -> when (astKey.value) { - is StringValue -> rexOpPathStepKey(key) - else -> rexOpPathStepIndex(key) - } - is Expr.Cast -> when (astKey.asType is Type.String) { - true -> rexOpPathStepKey(key) - false -> rexOpPathStepIndex(key) + + // Attempt to create qualified identifier + val (newRoot, newSteps) = when (val op = root.op) { + is Rex.Op.Var.Unresolved -> { + val identifierSteps = mutableListOf() + run { + node.steps.forEach { step -> + if (step !is Expr.Path.Step.Symbol) { + return@run } - else -> rexOpPathStepIndex(key) + identifierSteps.add(AstToPlan.convert(step.symbol)) } } - is Expr.Path.Step.Symbol -> { - val identifier = AstToPlan.convert(it.symbol) - rexOpPathStepSymbol(identifier) + when (identifierSteps.size) { + 0 -> root to node.steps + else -> { + val newRoot = rex(StaticType.ANY, rexOpVarUnresolved(mergeIdentifiers(op.identifier, identifierSteps), op.scope)) + val newSteps = node.steps.subList(identifierSteps.size, node.steps.size) + newRoot to newSteps + } } - is Expr.Path.Step.Unpivot -> rexOpPathStepUnpivot() - is Expr.Path.Step.Wildcard -> rexOpPathStepWildcard() + } + else -> root to node.steps + } + + // Return wrapped path + return when (newSteps.isEmpty()) { + true -> newRoot + false -> newSteps.fold(newRoot) { current, step -> + val path = when (step) { + is Expr.Path.Step.Index -> { + val key = visitExprCoerce(step.key, context) + when (val astKey = step.key) { + is Expr.Lit -> when (astKey.value) { + is StringValue -> rexOpPathKey(current, key) + else -> rexOpPathIndex(current, key) + } + is Expr.Cast -> when (astKey.asType is Type.String) { + true -> rexOpPathKey(current, key) + false -> rexOpPathIndex(current, key) + } + else -> rexOpPathIndex(current, key) + } + } + is Expr.Path.Step.Symbol -> { + val identifier = AstToPlan.convert(step.symbol) + when (identifier.caseSensitivity) { + Identifier.CaseSensitivity.SENSITIVE -> rexOpPathKey(current, rexString(identifier.symbol)) + Identifier.CaseSensitivity.INSENSITIVE -> rexOpPathSymbol(current, identifier.symbol) + } + } + is Expr.Path.Step.Unpivot -> error("Unpivot path not supported yet") + is Expr.Path.Step.Wildcard -> error("Wildcard path not supported yet") + } + rex(StaticType.ANY, path) } } - // Rex - val op = rexOpPath(root, steps) - return rex(type, op) } + private fun rexString(str: String) = rex(StaticType.STRING, rexOpLit(stringValue(str))) + override fun visitExprCall(node: Expr.Call, context: Env): Rex { val type = (StaticType.ANY) // Fn diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt index c6280900f0..f9cfd4a9f4 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/PlanTyper.kt @@ -32,6 +32,7 @@ import org.partiql.planner.internal.ir.Rel import org.partiql.planner.internal.ir.Rex import org.partiql.planner.internal.ir.Statement import org.partiql.planner.internal.ir.aggResolved +import org.partiql.planner.internal.ir.catalogSymbolRef import org.partiql.planner.internal.ir.fnResolved import org.partiql.planner.internal.ir.identifierSymbol import org.partiql.planner.internal.ir.rel @@ -40,6 +41,8 @@ import org.partiql.planner.internal.ir.relOpAggregate import org.partiql.planner.internal.ir.relOpAggregateCall import org.partiql.planner.internal.ir.relOpDistinct import org.partiql.planner.internal.ir.relOpErr +import org.partiql.planner.internal.ir.relOpExclude +import org.partiql.planner.internal.ir.relOpExcludeItem import org.partiql.planner.internal.ir.relOpFilter import org.partiql.planner.internal.ir.relOpJoin import org.partiql.planner.internal.ir.relOpLimit @@ -59,8 +62,9 @@ import org.partiql.planner.internal.ir.rexOpCollection import org.partiql.planner.internal.ir.rexOpErr import org.partiql.planner.internal.ir.rexOpGlobal import org.partiql.planner.internal.ir.rexOpLit -import org.partiql.planner.internal.ir.rexOpPath -import org.partiql.planner.internal.ir.rexOpPathStepSymbol +import org.partiql.planner.internal.ir.rexOpPathIndex +import org.partiql.planner.internal.ir.rexOpPathKey +import org.partiql.planner.internal.ir.rexOpPathSymbol import org.partiql.planner.internal.ir.rexOpPivot import org.partiql.planner.internal.ir.rexOpSelect import org.partiql.planner.internal.ir.rexOpStruct @@ -88,6 +92,7 @@ import org.partiql.types.StaticType.Companion.BOOL import org.partiql.types.StaticType.Companion.MISSING import org.partiql.types.StaticType.Companion.NULL import org.partiql.types.StaticType.Companion.STRING +import org.partiql.types.StaticType.Companion.unionOf import org.partiql.types.StringType import org.partiql.types.StructType import org.partiql.types.TupleConstraint @@ -96,6 +101,8 @@ import org.partiql.value.BoolValue import org.partiql.value.PartiQLValueExperimental import org.partiql.value.TextValue import org.partiql.value.boolValue +import org.partiql.value.missingValue +import org.partiql.value.stringValue /** * Rewrites an untyped algebraic translation of the query to be both typed and have resolved variables. @@ -112,7 +119,7 @@ internal class PlanTyper( /** * Rewrite the statement with inferred types and resolved variables */ - public fun resolve(statement: Statement): Statement { + fun resolve(statement: Statement): Statement { if (statement !is Statement.Query) { throw IllegalArgumentException("PartiQLPlanner only supports Query statements") } @@ -182,15 +189,15 @@ internal class PlanTyper( } // compute element type - val t = rex.type as StructType + val t = rex.type val e = if (t.contentClosed) { StaticType.unionOf(t.fields.map { it.value }.toSet()).flatten() } else { - StaticType.ANY + ANY } // compute rel type - val kType = StaticType.STRING + val kType = STRING val vType = e val type = ctx!!.copyWithSchema(listOf(kType, vType)) @@ -353,7 +360,29 @@ internal class PlanTyper( // rewrite val type = ctx!!.copy(schema) - return rel(type, node) + + // resolve exclude path roots + val newItems = node.items.map { item -> + val resolvedRoot = when (val root = item.root) { + is Rex.Op.Var.Unresolved -> { + // resolve `root` to local binding + val bindingPath = root.identifier.toBindingPath() + when (val resolved = env.resolveLocalBind(bindingPath, init)) { + null -> { + handleUnresolvedExcludeRoot(root.identifier) + root + } + else -> rexOpVarResolved(resolved.ordinal) + } + } + is Rex.Op.Var.Resolved -> root + } + val steps = item.steps + relOpExcludeItem(resolvedRoot, steps) + } + + val op = relOpExclude(input, newItems) + return rel(type, op) } override fun visitRelOpAggregate(node: Rel.Op.Aggregate, ctx: Rel.Type?): Rel { @@ -419,104 +448,134 @@ internal class PlanTyper( if (resolvedVar == null) { handleUndefinedVariable(path.steps.last()) - return rex(StaticType.ANY, rexOpErr("Undefined variable ${node.identifier}")) + return rex(ANY, rexOpErr("Undefined variable ${node.identifier}")) } val type = resolvedVar.type val op = when (resolvedVar) { - is ResolvedVar.Global -> rexOpGlobal(resolvedVar.ordinal) - is ResolvedVar.Local -> resolvedLocalPath(resolvedVar) + is ResolvedVar.Global -> rexOpGlobal(catalogSymbolRef(resolvedVar.ordinal, resolvedVar.position)) + is ResolvedVar.Local -> rexOpVarResolved(resolvedVar.ordinal) // resolvedLocalPath(resolvedVar) + } + val variable = rex(type, op) + return when (resolvedVar.depth) { + path.steps.size -> variable + else -> { + val foldedPath = path.steps.subList(resolvedVar.depth, path.steps.size).fold(variable) { current, step -> + when (step.bindingCase) { + BindingCase.SENSITIVE -> rex(ANY, rexOpPathKey(current, rex(STRING, rexOpLit(stringValue(step.name))))) + BindingCase.INSENSITIVE -> rex(ANY, rexOpPathSymbol(current, step.name)) + } + } + visitRex(foldedPath, ctx) + } } - return rex(type, op) } override fun visitRexOpGlobal(node: Rex.Op.Global, ctx: StaticType?): Rex { - val global = env.globals[node.ref] - val type = global.type + val catalog = env.catalogs[node.ref.catalog] + val type = catalog.symbols[node.ref.symbol].type return rex(type, node) } - /** - * Match path as far as possible (rewriting the steps), then infer based on resolved root and rewritten steps. - */ - override fun visitRexOpPath(node: Rex.Op.Path, ctx: StaticType?): Rex { - val visitedSteps = node.steps.map { visitRexOpPathStep(it, null) as Rex.Op.Path.Step } - // 1. Resolve path prefix - val (root, steps) = when (val rootOp = node.root.op) { - is Rex.Op.Var.Unresolved -> { - // Rewrite the root - val path = rexPathToBindingPath(rootOp, visitedSteps) - val resolvedVar = env.resolve(path, locals, rootOp.scope) - if (resolvedVar == null) { - handleUndefinedVariable(path.steps.last()) - return rex(StaticType.ANY, node) - } - val type = resolvedVar.type - val (op, steps) = when (resolvedVar) { - // Root (and some steps) was a local. Replace the matched nodes with disambiguated steps - // and return the remaining steps to continue typing. - is ResolvedVar.Local -> { - val amountRemaining = (visitedSteps.size + 1) - resolvedVar.depth - val remainingSteps = visitedSteps.takeLast(amountRemaining) - resolvedLocalPath(resolvedVar) to remainingSteps - } - is ResolvedVar.Global -> { - // Root (and some steps) was a global; replace root and re-calculate remaining steps. - val remainingFirstIndex = resolvedVar.depth - 1 - val remaining = when (remainingFirstIndex > visitedSteps.lastIndex) { - true -> emptyList() - false -> visitedSteps.subList(remainingFirstIndex, visitedSteps.size) - } - rexOpGlobal(resolvedVar.ordinal) to remaining - } - } - // rewrite root - rex(type, op) to steps - } - else -> visitRex(node.root, node.root.type) to visitedSteps + override fun visitRexOpPathIndex(node: Rex.Op.Path.Index, ctx: StaticType?): Rex { + val root = visitRex(node.root, node.root.type) + val key = visitRex(node.key, node.key.type) + if (key.type !is IntType) { + handleAlwaysMissing() + return rex(MISSING, rexOpErr("Collections must be indexed with integers, found ${key.type}")) } + val elementTypes = root.type.allTypes.map { type -> + val rootType = type as? CollectionType ?: return@map MISSING + if (rootType !is ListType && rootType !is SexpType) { + return@map MISSING + } + rootType.elementType + }.toSet() + val finalType = unionOf(elementTypes).flatten() + return rex(finalType.swallowAny(), rexOpPathIndex(root, key)) + } + + override fun visitRexOpPathKey(node: Rex.Op.Path.Key, ctx: StaticType?): Rex { + val root = visitRex(node.root, node.root.type) + val key = visitRex(node.key, node.key.type) - // short-circuit if whole path was matched - if (steps.isEmpty()) { - return root + // Check Key Type + val toAddTypes = key.type.allTypes.mapNotNull { keyType -> + when (keyType) { + is StringType -> null + is NullType -> NULL + else -> MISSING + } + } + if (toAddTypes.size == key.type.allTypes.size && toAddTypes.all { it is MissingType }) { + handleAlwaysMissing() + return rex(MISSING, rexOpErr("Expected string but found: ${key.type}")) } - // 2. TODO rewrite and type the steps containing expressions - // val typedSteps = steps.map { - // if (it is Rex.Op.Path.Step.Index) { - // val key = visitRex(it.key, null) - // rexOpPathStepIndex(key) - // } else it - // } + val pathTypes = root.type.allTypes.map { type -> + val struct = type as? StructType ?: return@map MISSING + + if (key.op is Rex.Op.Lit) { + val lit = key.op.value + if (lit is TextValue<*> && !lit.isNull) { + val id = identifierSymbol(lit.string!!, Identifier.CaseSensitivity.SENSITIVE) + inferStructLookup(struct, id).first + } else { + error("Expected text literal, but got $lit") + } + } else { + // cannot infer type of non-literal path step because we don't know its value + // we might improve upon this with some constant folding prior to typing + ANY + } + }.toSet() + val finalType = unionOf(pathTypes + toAddTypes).flatten() + return rex(finalType.swallowAny(), rexOpPathKey(root, key)) + } - // 3. Walk the steps, determine the path type, and replace each step with the disambiguated equivalent - // (AKA made sensitive, if possible) - var type = root.type - val newSteps = steps.map { step -> - val (stepType, replacementStep) = inferPathStep(type, step) - type = stepType - replacementStep + override fun visitRexOpPathSymbol(node: Rex.Op.Path.Symbol, ctx: StaticType?): Rex { + val root = visitRex(node.root, node.root.type) + + val paths = root.type.allTypes.map { type -> + val struct = type as? StructType ?: return@map rex(MISSING, rexOpLit(missingValue())) + val (pathType, replacementId) = inferStructLookup(struct, identifierSymbol(node.key, Identifier.CaseSensitivity.INSENSITIVE)) + when (replacementId.caseSensitivity) { + Identifier.CaseSensitivity.INSENSITIVE -> rex(pathType, rexOpPathSymbol(root, replacementId.symbol)) + Identifier.CaseSensitivity.SENSITIVE -> rex(pathType, rexOpPathKey(root, rexString(replacementId.symbol))) + } } + val type = unionOf(paths.map { it.type }.toSet()).flatten() - // 4. Invalid path reference; always MISSING - if (type == StaticType.MISSING) { - handleAlwaysMissing() - return rexErr("Unknown identifier $node") + // replace step only if all are disambiguated + val firstPathOp = paths.first().op + val replacementOp = when (paths.map { it.op }.all { it == firstPathOp }) { + true -> firstPathOp + false -> rexOpPathSymbol(root, node.key) } + return rex(type.swallowAny(), replacementOp) + } - // 5. Non-missing, root is resolved - return rex(type, rexOpPath(root, newSteps)) + /** + * "Swallows" ANY. If ANY is one of the types in the UNION type, we return ANY. If not, we flatten and return + * the [type]. + */ + private fun StaticType.swallowAny(): StaticType { + val flattened = this.flatten() + return when (flattened.allTypes.any { it is AnyType }) { + true -> ANY + false -> flattened + } } - // Default returns the original node, in some case we need the resolved node. - // i.e., the path step is a call node - override fun visitRexOpPathStep(node: Rex.Op.Path.Step, ctx: StaticType?): Rex.Op.Path.Step = - when (node) { - is Rex.Op.Path.Step.Index -> Rex.Op.Path.Step.Index(visitRex(node.key, ctx)) - is Rex.Op.Path.Step.Key -> Rex.Op.Path.Step.Key(visitRex(node.key, ctx)) - is Rex.Op.Path.Step.Symbol -> Rex.Op.Path.Step.Symbol(node.identifier) - is Rex.Op.Path.Step.Unpivot -> Rex.Op.Path.Step.Unpivot() - is Rex.Op.Path.Step.Wildcard -> Rex.Op.Path.Step.Wildcard() + private fun rexString(str: String) = rex(STRING, rexOpLit(stringValue(str))) + + override fun visitRexOpPath(node: Rex.Op.Path, ctx: StaticType?): Rex { + val path = super.visitRexOpPath(node, ctx) as Rex + if (path.type == MISSING) { + handleAlwaysMissing() + return rexErr("Path always returns missing $node") } + return path + } /** * Resolve and type scalar function calls. @@ -540,7 +599,7 @@ internal class PlanTyper( is FnMatch.Dynamic -> { val types = mutableSetOf() if (match.isMissable && !isNotMissable) { - types.add(StaticType.MISSING) + types.add(MISSING) } val candidates = match.candidates.map { candidate -> val rex = toRexCall(candidate, args, isNotMissable) @@ -574,7 +633,7 @@ internal class PlanTyper( newArgs.forEach { if (it.type == MissingType && !isNotMissable) { handleAlwaysMissing() - return rex(StaticType.MISSING, rexOpCallStatic(newFn, newArgs)) + return rex(MISSING, rexOpCallStatic(newFn, newArgs)) } } @@ -615,14 +674,14 @@ internal class PlanTyper( // Return type with calculated nullability var type = when { - isNull -> StaticType.NULL + isNull -> NULL isNullable -> returns.toStaticType() else -> returns.toNonNullStaticType() } // Some operators can return MISSING during runtime if (match.isMissable && !isNotMissable) { - type = StaticType.unionOf(type, StaticType.MISSING) + type = StaticType.unionOf(type, MISSING) } // Finally, rewrite this node @@ -740,8 +799,8 @@ internal class PlanTyper( } val ref = call.args.getOrNull(0) ?: error("IS STRUCT requires an argument.") val simplifiedCondition = when { - ref.type.allTypes.all { it is StructType } -> rex(StaticType.BOOL, rexOpLit(boolValue(true))) - ref.type.allTypes.none { it is StructType } -> rex(StaticType.BOOL, rexOpLit(boolValue(false))) + ref.type.allTypes.all { it is StructType } -> rex(BOOL, rexOpLit(boolValue(true))) + ref.type.allTypes.none { it is StructType } -> rex(BOOL, rexOpLit(boolValue(false))) else -> condition } @@ -789,9 +848,9 @@ internal class PlanTyper( when (field.k.op) { is Rex.Op.Lit -> { // A field is only included in the StructType if its key is a text literal - val key = field.k.op as Rex.Op.Lit + val key = field.k.op if (key.value is TextValue<*>) { - val name = (key.value as TextValue<*>).string!! + val name = key.value.string!! val type = field.v.type structKeysSeent.add(name) structTypeFields.add(StructType.Field(name, type)) @@ -919,7 +978,7 @@ internal class PlanTyper( } override fun visitRexOpErr(node: Rex.Op.Err, ctx: StaticType?): PlanNode { - val type = ctx ?: StaticType.ANY + val type = ctx ?: ANY return rex(type, node) } @@ -968,13 +1027,13 @@ internal class PlanTyper( PlanningProblemDetails.CompileError("TupleUnion wasn't normalized to exclude union types.") ) ) - possibleOutputTypes.add(StaticType.MISSING) + possibleOutputTypes.add(MISSING) } is NullType -> { - return StaticType.NULL + return NULL } else -> { - return StaticType.MISSING + return MISSING } } } @@ -1051,88 +1110,6 @@ internal class PlanTyper( // Helpers - /** - * @return a [Pair] where the [Pair.first] represents the type of the [step] and the [Pair.second] represents - * the disambiguated [step]. - */ - private fun inferPathStep(type: StaticType, step: Rex.Op.Path.Step): Pair = - when (type) { - is AnyType -> StaticType.ANY to step - is StructType -> inferPathStep(type, step) - is ListType, is SexpType -> inferPathStep(type as CollectionType, step) to step - is AnyOfType -> { - when (type.types.size) { - 0 -> throw IllegalStateException("Cannot path on an empty StaticType union") - else -> { - val prevTypes = type.allTypes - if (prevTypes.any { it is AnyType }) { - StaticType.ANY to step - } else { - val results = prevTypes.map { inferPathStep(it, step) } - val types = results.map { it.first } - val firstResultStep = results.first().second - // replace step only if all are disambiguated - val replacementStep = when (results.map { it.second }.all { it == firstResultStep }) { - true -> firstResultStep - false -> step - } - AnyOfType(types.toSet()).flatten() to replacementStep - } - } - } - } - else -> StaticType.MISSING to step - } - - /** - * @return a [Pair] where the [Pair.first] represents the type of the [step] and the [Pair.second] represents - * the disambiguated [step]. - */ - private fun inferPathStep(struct: StructType, step: Rex.Op.Path.Step): Pair = when (step) { - // { 'a': 1 }[0] should always return missing since tuples cannot be navigated via integer indexes - is Rex.Op.Path.Step.Index -> { - handleAlwaysMissing() - MISSING to step - } - is Rex.Op.Path.Step.Symbol -> { - val (type, replacementId) = inferStructLookup(struct, step.identifier) - type to replacementId.toPathStep() - } - is Rex.Op.Path.Step.Key -> { - if (step.key.type !is StringType) { - error("Expected string but found: ${step.key.type}") - } - if (step.key.op is Rex.Op.Lit) { - val lit = step.key.op.value - if (lit is TextValue<*> && !lit.isNull) { - val id = identifierSymbol(lit.string!!, Identifier.CaseSensitivity.SENSITIVE) - val (type, replacementId) = inferStructLookup(struct, id) - type to replacementId.toPathStep() - } else { - error("Expected text literal, but got $lit") - } - } else { - // cannot infer type of non-literal path step because we don't know its value - // we might improve upon this with some constant folding prior to typing - ANY to step - } - } - is Rex.Op.Path.Step.Unpivot -> error("Unpivot not supported") - is Rex.Op.Path.Step.Wildcard -> error("Wildcard not supported") - } - - private fun Identifier.Symbol.toPathStep() = rexOpPathStepSymbol(this) - - private fun inferPathStep(collection: CollectionType, step: Rex.Op.Path.Step): StaticType { - if (step !is Rex.Op.Path.Step.Index) { - error("Path step on a collection must be an expression") - } - if (step.key.type !is IntType) { - error("Collections must be indexed with integers, found ${step.key.type}") - } - return collection.elementType - } - /** * Logic is as follows: * 1. If [struct] is closed and ordered: @@ -1156,13 +1133,13 @@ internal class PlanTyper( isClosed && isOrdered -> { struct.fields.firstOrNull { entry -> binding.isEquivalentTo(entry.key) }?.let { (sensitive(it.key) to it.value) - } ?: (key to StaticType.MISSING) + } ?: (key to MISSING) } // 2. Struct is closed isClosed -> { val matches = struct.fields.filter { entry -> binding.isEquivalentTo(entry.key) } when (matches.size) { - 0 -> (key to StaticType.MISSING) + 0 -> (key to MISSING) 1 -> matches.first().let { (sensitive(it.key) to it.value) } else -> { val firstKey = matches.first().key @@ -1175,7 +1152,7 @@ internal class PlanTyper( } } // 3. Struct is open - else -> (key to StaticType.ANY) + else -> (key to ANY) } return type to name } @@ -1197,7 +1174,7 @@ internal class PlanTyper( * Let TX be the single-column table that is the result of applying the * to each row of T and eliminating null values <--- all NULL values are eliminated as inputs */ - public fun resolveAgg(agg: Agg.Unresolved, arguments: List): Pair { + fun resolveAgg(agg: Agg.Unresolved, arguments: List): Pair { var missingArg = false val args = arguments.map { val arg = visitRex(it, null) @@ -1227,7 +1204,7 @@ internal class PlanTyper( // Some operators can return MISSING during runtime if (match.isMissable) { - type = StaticType.unionOf(type, StaticType.MISSING).flatten() + type = StaticType.unionOf(type, MISSING).flatten() } // Finally, rewrite this node @@ -1248,7 +1225,7 @@ internal class PlanTyper( private fun Rex.type(typeEnv: TypeEnv) = RexTyper(typeEnv).visitRex(this, this.type) - private fun rexErr(message: String) = rex(StaticType.MISSING, rexOpErr(message)) + private fun rexErr(message: String) = rex(MISSING, rexOpErr(message)) /** * I found decorating the tree with the binding names (for resolution) was easier than associating introduced @@ -1272,7 +1249,7 @@ internal class PlanTyper( is Identifier.Symbol -> BindingPath(listOf(this.toBindingName())) } - private fun Identifier.Qualified.toBindingPath() = BindingPath(steps = steps.map { it.toBindingName() }) + private fun Identifier.Qualified.toBindingPath() = BindingPath(steps = listOf(this.root.toBindingName()) + steps.map { it.toBindingName() }) private fun Identifier.Symbol.toBindingName() = BindingName( name = symbol, @@ -1289,33 +1266,10 @@ internal class PlanTyper( */ private fun List.toUnionType(): StaticType = AnyOfType(map { it.type }.toSet()).flatten() - /** - * Helper function which returns the literal string/symbol steps of a path expression as a [BindingPath]. - * - * TODO this does not handle constant expressions in `[]`, only literals - */ - @OptIn(PartiQLValueExperimental::class) - private fun rexPathToBindingPath(rootOp: Rex.Op.Var.Unresolved, steps: List): BindingPath { - if (rootOp.identifier !is Identifier.Symbol) { - throw IllegalArgumentException("Expected identifier symbol") - } - val bindingRoot = rootOp.identifier.toBindingName() - val bindingSteps = mutableListOf(bindingRoot) - for (step in steps) { - when (step) { - is Rex.Op.Path.Step.Index -> break - is Rex.Op.Path.Step.Symbol -> bindingSteps.add(step.identifier.toBindingName()) - is Rex.Op.Path.Step.Key -> break - else -> break // short-circuit - } - } - return BindingPath(bindingSteps) - } - private fun getElementTypeForFromSource(fromSourceType: StaticType): StaticType = when (fromSourceType) { is BagType -> fromSourceType.elementType is ListType -> fromSourceType.elementType - is AnyType -> StaticType.ANY + is AnyType -> ANY is AnyOfType -> AnyOfType(fromSourceType.types.map { getElementTypeForFromSource(it) }.toSet()) // All the other types coerce into a bag of themselves (including null/missing/sexp). else -> fromSourceType @@ -1349,25 +1303,6 @@ internal class PlanTyper( } } - /** - * Constructs a Rex.Op.Path from a resolved local - */ - private fun resolvedLocalPath(local: ResolvedVar.Local): Rex.Op { - val root = rex(local.rootType, rexOpVarResolved(local.ordinal)) - val steps = local.replacementSteps.map { - val case = when (it.bindingCase) { - BindingCase.SENSITIVE -> Identifier.CaseSensitivity.SENSITIVE - BindingCase.INSENSITIVE -> Identifier.CaseSensitivity.INSENSITIVE - } - val symbol = identifierSymbol(it.name, case) - rexOpPathStepSymbol(symbol) - } - return when (steps.isEmpty()) { - true -> root.op - false -> rexOpPath(root, steps) - } - } - // ERRORS private fun handleUndefinedVariable(name: BindingName) { @@ -1409,11 +1344,16 @@ internal class PlanTyper( ) } - private fun handleUnresolvedExcludeRoot(root: String) { + private fun handleUnresolvedExcludeRoot(root: Identifier) { onProblem( Problem( sourceLocation = UNKNOWN_PROBLEM_LOCATION, - details = PlanningProblemDetails.UnresolvedExcludeExprRoot(root) + details = PlanningProblemDetails.UnresolvedExcludeExprRoot( + when (root) { + is Identifier.Symbol -> root.symbol + is Identifier.Qualified -> root.toString() + } + ) ) ) } @@ -1437,7 +1377,7 @@ internal class PlanTyper( private fun Fn.Unresolved.isNotMissable(): Boolean { return when (identifier) { is Identifier.Qualified -> false - is Identifier.Symbol -> when ((identifier as Identifier.Symbol).symbol) { + is Identifier.Symbol -> when (identifier.symbol) { "and" -> true "or" -> true "not" -> true @@ -1450,7 +1390,7 @@ internal class PlanTyper( } private fun Fn.Unresolved.isTypeAssertion(): Boolean { - return (identifier is Identifier.Symbol && (identifier as Identifier.Symbol).symbol.startsWith("is")) + return (identifier is Identifier.Symbol && identifier.symbol.startsWith("is")) } /** @@ -1473,16 +1413,26 @@ internal class PlanTyper( private fun excludeBindings(input: List, item: Rel.Op.Exclude.Item): List { var matchedRoot = false val output = input.map { - if (item.root.isEquivalentTo(it.name)) { - matchedRoot = true - // recompute the StaticType of this binding after apply the exclusions - val type = it.type.exclude(item.steps, false) - it.copy(type = type) - } else { - it + when (val root = item.root) { + is Rex.Op.Var.Unresolved -> { + when (val id = root.identifier) { + is Identifier.Symbol -> { + if (id.isEquivalentTo(it.name)) { + matchedRoot = true + // recompute the StaticType of this binding after apply the exclusions + val type = it.type.exclude(item.steps, false) + it.copy(type = type) + } else { + it + } + } + is Identifier.Qualified -> it + } + } + is Rex.Op.Var.Resolved -> it } } - if (!matchedRoot) handleUnresolvedExcludeRoot(item.root.symbol) + if (!matchedRoot && item.root is Rex.Op.Var.Unresolved) handleUnresolvedExcludeRoot(item.root.identifier) return output } diff --git a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt index 4363116820..081f2a0cb6 100644 --- a/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt +++ b/partiql-planner/src/main/kotlin/org/partiql/planner/internal/typer/TypeUtils.kt @@ -178,7 +178,7 @@ internal fun StructType.exclude(steps: List, lastStepOption StructType.Field(k, v) } when (step) { - is Rel.Op.Exclude.Step.Attr -> { + is Rel.Op.Exclude.Step.StructField -> { if (step.symbol.isEquivalentTo(field.key)) { newField } else { @@ -202,12 +202,12 @@ internal fun StructType.exclude(steps: List, lastStepOption internal fun CollectionType.exclude(steps: List, lastStepOptional: Boolean = true): StaticType { var e = this.elementType when (steps.first()) { - is Rel.Op.Exclude.Step.Pos -> { + is Rel.Op.Exclude.Step.CollIndex -> { if (steps.size > 1) { e = e.exclude(steps.drop(1), true) } } - is Rel.Op.Exclude.Step.CollectionWildcard -> { + is Rel.Op.Exclude.Step.CollWildcard -> { if (steps.size > 1) { e = e.exclude(steps.drop(1), lastStepOptional) } diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/EnvTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/EnvTest.kt index 725f0f481f..d406eda7b7 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/EnvTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/EnvTest.kt @@ -1,25 +1,18 @@ package org.partiql.planner.internal -import com.amazon.ionelement.api.field -import com.amazon.ionelement.api.ionString -import com.amazon.ionelement.api.ionStructOf import org.junit.jupiter.api.Assertions.assertNull import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.partiql.planner.PartiQLHeader import org.partiql.planner.PartiQLPlanner -import org.partiql.planner.internal.ir.Global -import org.partiql.planner.internal.ir.Identifier +import org.partiql.planner.internal.ir.Catalog import org.partiql.planner.internal.ir.Rex -import org.partiql.planner.internal.ir.identifierQualified -import org.partiql.planner.internal.ir.identifierSymbol -import org.partiql.plugins.local.LocalPlugin +import org.partiql.plugins.local.LocalConnector import org.partiql.spi.BindingCase import org.partiql.spi.BindingName import org.partiql.spi.BindingPath import org.partiql.types.StaticType import java.util.Random -import kotlin.io.path.pathString import kotlin.io.path.toPath import kotlin.test.assertEquals import kotlin.test.assertNotNull @@ -27,26 +20,16 @@ import kotlin.test.assertNotNull class EnvTest { companion object { - private val root = this::class.java.getResource("/catalogs/default")!!.toURI().toPath().pathString - val catalogConfig = mapOf( - "pql" to ionStructOf( - field("connector_name", ionString("local")), - field("root", ionString("$root/pql")), - ) - ) + private val root = this::class.java.getResource("/catalogs/default/pql")!!.toURI().toPath() private val EMPTY_TYPE_ENV = TypeEnv(schema = emptyList(), ResolutionStrategy.GLOBAL) - private val GLOBAL_OS = Global( - path = identifierQualified( - root = identifierSymbol("pql", Identifier.CaseSensitivity.SENSITIVE), - steps = listOf( - identifierSymbol("main", Identifier.CaseSensitivity.SENSITIVE), - identifierSymbol("os", Identifier.CaseSensitivity.SENSITIVE) - ) - ), - type = StaticType.STRING + private val GLOBAL_OS = Catalog( + name = "pql", + symbols = listOf( + Catalog.Symbol(path = listOf("main", "os"), type = StaticType.STRING) + ) ) } @@ -56,13 +39,14 @@ class EnvTest { fun init() { env = Env( listOf(PartiQLHeader), - listOf(LocalPlugin()), + mapOf( + "pql" to LocalConnector.Metadata(root) + ), PartiQLPlanner.Session( queryId = Random().nextInt().toString(), userId = "test-user", currentCatalog = "pql", currentDirectory = listOf("main"), - catalogConfig = catalogConfig ) ) } @@ -71,29 +55,29 @@ class EnvTest { fun testGlobalMatchingSensitiveName() { val path = BindingPath(listOf(BindingName("os", BindingCase.SENSITIVE))) assertNotNull(env.resolve(path, EMPTY_TYPE_ENV, Rex.Op.Var.Scope.DEFAULT)) - assertEquals(1, env.globals.size) - assert(env.globals.contains(GLOBAL_OS)) + assertEquals(1, env.catalogs.size) + assert(env.catalogs.contains(GLOBAL_OS)) } @Test fun testGlobalMatchingInsensitiveName() { val path = BindingPath(listOf(BindingName("oS", BindingCase.INSENSITIVE))) assertNotNull(env.resolve(path, EMPTY_TYPE_ENV, Rex.Op.Var.Scope.DEFAULT)) - assertEquals(1, env.globals.size) - assert(env.globals.contains(GLOBAL_OS)) + assertEquals(1, env.catalogs.size) + assert(env.catalogs.contains(GLOBAL_OS)) } @Test fun testGlobalNotMatchingSensitiveName() { val path = BindingPath(listOf(BindingName("oS", BindingCase.SENSITIVE))) assertNull(env.resolve(path, EMPTY_TYPE_ENV, Rex.Op.Var.Scope.DEFAULT)) - assert(env.globals.isEmpty()) + assert(env.catalogs.isEmpty()) } @Test fun testGlobalNotMatchingInsensitiveName() { val path = BindingPath(listOf(BindingName("nonexistent", BindingCase.INSENSITIVE))) assertNull(env.resolve(path, EMPTY_TYPE_ENV, Rex.Op.Var.Scope.DEFAULT)) - assert(env.globals.isEmpty()) + assert(env.catalogs.isEmpty()) } } diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt index ebe1dd2779..897c8069d1 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PartiQLTyperTestBase.kt @@ -1,21 +1,18 @@ package org.partiql.planner.internal.typer -import com.amazon.ionelement.api.ionString -import com.amazon.ionelement.api.ionStructOf import org.junit.jupiter.api.DynamicContainer import org.junit.jupiter.api.DynamicTest -import org.partiql.errors.Problem import org.partiql.errors.ProblemCallback -import org.partiql.errors.ProblemSeverity -import org.partiql.parser.PartiQLParserBuilder +import org.partiql.parser.PartiQLParser import org.partiql.plan.Statement import org.partiql.plan.debug.PlanPrinter import org.partiql.planner.PartiQLPlanner import org.partiql.planner.PartiQLPlannerBuilder import org.partiql.planner.test.PartiQLTest import org.partiql.planner.test.PartiQLTestProvider -import org.partiql.plugins.memory.MemoryCatalog -import org.partiql.plugins.memory.MemoryPlugin +import org.partiql.planner.util.ProblemCollector +import org.partiql.plugins.memory.MemoryConnector +import org.partiql.spi.connector.ConnectorMetadata import org.partiql.types.StaticType import java.util.Random import java.util.stream.Stream @@ -31,43 +28,23 @@ abstract class PartiQLTyperTestBase { } } - internal class ProblemCollector : ProblemCallback { - private val problemList = mutableListOf() - - val problems: List - get() = problemList - - val hasErrors: Boolean - get() = problemList.any { it.details.severity == ProblemSeverity.ERROR } - - val hasWarnings: Boolean - get() = problemList.any { it.details.severity == ProblemSeverity.WARNING } - - override fun invoke(problem: Problem) { - problemList.add(problem) - } - } - companion object { internal val session: ((String) -> PartiQLPlanner.Session) = { catalog -> PartiQLPlanner.Session( queryId = Random().nextInt().toString(), userId = "test-user", currentCatalog = catalog, - catalogConfig = mapOf( - catalog to ionStructOf( - "connector_name" to ionString("memory") - ) - ) ) } } val inputs = PartiQLTestProvider().apply { load() } - val testingPipeline: ((String, String, MemoryCatalog.Provider, ProblemCallback) -> PartiQLPlanner.Result) = { query, catalog, catalogProvider, collector -> - val ast = PartiQLParserBuilder.standard().build().parse(query).root - val planner = PartiQLPlannerBuilder().plugins(listOf(MemoryPlugin(catalogProvider))).build() + val testingPipeline: ((String, String, ConnectorMetadata, ProblemCallback) -> PartiQLPlanner.Result) = { query, catalog, metadata, collector -> + val ast = PartiQLParser.default().parse(query).root + val planner = PartiQLPlannerBuilder() + .addCatalog(catalog, metadata) + .build() planner.plan(ast, session(catalog), collector) } @@ -76,14 +53,13 @@ abstract class PartiQLTyperTestBase { tests: List, argsMap: Map>>, ): Stream { - val catalogProvider = MemoryCatalog.Provider() return tests.map { test -> val group = test.statement val children = argsMap.flatMap { (key, value) -> value.mapIndexed { index: Int, types: List -> val testName = "${testCategory}_${key}_$index" - catalogProvider[testName] = MemoryCatalog.of( + val metadata = MemoryConnector.Metadata.of( *( types.mapIndexed { i, t -> "t${i + 1}" to t @@ -96,7 +72,7 @@ abstract class PartiQLTyperTestBase { DynamicTest.dynamicTest(displayName) { val pc = ProblemCollector() if (key is TestResult.Success) { - val result = testingPipeline(statement, testName, catalogProvider, pc) + val result = testingPipeline(statement, testName, metadata, pc) val root = (result.plan.statement as Statement.Query).root val actualType = root.type assert(actualType == key.expectedType) { @@ -117,7 +93,7 @@ abstract class PartiQLTyperTestBase { } } } else { - val result = testingPipeline(statement, testName, catalogProvider, pc) + val result = testingPipeline(statement, testName, metadata, pc) val root = (result.plan.statement as Statement.Query).root val actualType = root.type assert(actualType == StaticType.MISSING) { diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTest.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTest.kt index 661e32802b..a2cc46dced 100644 --- a/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTest.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTest.kt @@ -1,52 +1,116 @@ package org.partiql.planner.internal.typer -import com.amazon.ionelement.api.field -import com.amazon.ionelement.api.ionString -import com.amazon.ionelement.api.ionStructOf import org.junit.jupiter.api.Test -import org.partiql.errors.Problem -import org.partiql.errors.ProblemCallback -import org.partiql.errors.ProblemHandler -import org.partiql.errors.ProblemSeverity import org.partiql.planner.PartiQLHeader import org.partiql.planner.PartiQLPlanner import org.partiql.planner.internal.Env import org.partiql.planner.internal.ir.Identifier import org.partiql.planner.internal.ir.Rex +import org.partiql.planner.internal.ir.catalogSymbolRef import org.partiql.planner.internal.ir.identifierSymbol import org.partiql.planner.internal.ir.rex import org.partiql.planner.internal.ir.rexOpGlobal import org.partiql.planner.internal.ir.rexOpLit -import org.partiql.planner.internal.ir.rexOpPath -import org.partiql.planner.internal.ir.rexOpPathStepSymbol +import org.partiql.planner.internal.ir.rexOpPathKey +import org.partiql.planner.internal.ir.rexOpPathSymbol import org.partiql.planner.internal.ir.rexOpStruct import org.partiql.planner.internal.ir.rexOpStructField import org.partiql.planner.internal.ir.rexOpVarUnresolved import org.partiql.planner.internal.ir.statementQuery -import org.partiql.plugins.local.LocalPlugin +import org.partiql.planner.util.ProblemCollector +import org.partiql.plugins.local.LocalConnector import org.partiql.types.StaticType +import org.partiql.types.StaticType.Companion.ANY +import org.partiql.types.StaticType.Companion.DECIMAL +import org.partiql.types.StaticType.Companion.FLOAT +import org.partiql.types.StaticType.Companion.INT2 +import org.partiql.types.StaticType.Companion.INT4 +import org.partiql.types.StaticType.Companion.STRING import org.partiql.types.StructType import org.partiql.types.TupleConstraint import org.partiql.value.PartiQLValueExperimental import org.partiql.value.int32Value import org.partiql.value.stringValue import java.util.Random -import kotlin.io.path.pathString import kotlin.io.path.toPath import kotlin.test.assertEquals class PlanTyperTest { companion object { - private val root = this::class.java.getResource("/catalogs/default")!!.toURI().toPath().pathString - private val catalogConfig = mapOf( - "pql" to ionStructOf( - field("connector_name", ionString("local")), - field("root", ionString("$root/pql")), + private val root = this::class.java.getResource("/catalogs/default/pql")!!.toURI().toPath() + + @OptIn(PartiQLValueExperimental::class) + private val LITERAL_STRUCT_1 = rex( + ANY, + rexOpStruct( + fields = listOf( + rexOpStructField( + k = rex(STRING, rexOpLit(stringValue("FiRsT_KeY"))), + v = rex( + ANY, + rexOpStruct( + fields = listOf( + rexOpStructField( + k = rex(STRING, rexOpLit(stringValue("sEcoNd_KEY"))), + v = rex(INT4, rexOpLit(int32Value(5))) + ) + ) + ) + ) + ) + ) ) ) + private val LITERAL_STRUCT_1_FIRST_KEY_TYPE = StructType( + fields = mapOf( + "sEcoNd_KEY" to INT4 + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Open(false) + ) + ) + + @OptIn(PartiQLValueExperimental::class) + private val LITERAL_STRUCT_1_TYPED: Rex + get() { + val topLevelStruct = StructType( + fields = mapOf( + "FiRsT_KeY" to LITERAL_STRUCT_1_FIRST_KEY_TYPE + ), + contentClosed = true, + constraints = setOf( + TupleConstraint.UniqueAttrs(true), + TupleConstraint.Open(false) + ) + ) + return rex( + type = topLevelStruct, + rexOpStruct( + fields = listOf( + rexOpStructField( + k = rex(STRING, rexOpLit(stringValue("FiRsT_KeY"))), + v = rex( + type = LITERAL_STRUCT_1_FIRST_KEY_TYPE, + rexOpStruct( + fields = listOf( + rexOpStructField( + k = rex(STRING, rexOpLit(stringValue("sEcoNd_KEY"))), + v = rex(INT4, rexOpLit(int32Value(5))) + ) + ) + ) + ) + ) + ) + ) + ) + } + private val ORDERED_DUPLICATES_STRUCT = StructType( fields = listOf( StructType.Field("definition", StaticType.STRING), @@ -111,13 +175,14 @@ class PlanTyperTest { val collector = ProblemCollector() val env = Env( listOf(PartiQLHeader), - listOf(LocalPlugin()), + mapOf( + "pql" to LocalConnector.Metadata(root) + ), PartiQLPlanner.Session( queryId = Random().nextInt().toString(), userId = "test-user", currentCatalog = "pql", currentDirectory = listOf("main"), - catalogConfig = catalogConfig ) ) return PlanTyperWrapper(PlanTyper(env, collector), collector) @@ -142,94 +207,12 @@ class PlanTyperTest { * It also checks that we type it all correctly as well. */ @Test - @OptIn(PartiQLValueExperimental::class) fun testReplacingStructs() { val wrapper = getTyper() val typer = wrapper.typer - val input = statementQuery( - root = rex( - type = StaticType.ANY, - op = rexOpPath( - root = rex( - StaticType.ANY, - rexOpStruct( - fields = listOf( - rexOpStructField( - k = rex(StaticType.STRING, rexOpLit(stringValue("FiRsT_KeY"))), - v = rex( - StaticType.ANY, - rexOpStruct( - fields = listOf( - rexOpStructField( - k = rex(StaticType.STRING, rexOpLit(stringValue("sEcoNd_KEY"))), - v = rex(StaticType.INT4, rexOpLit(int32Value(5))) - ) - ) - ) - ) - ) - ) - ) - ), - steps = listOf( - rexOpPathStepSymbol(identifierSymbol("first_key", Identifier.CaseSensitivity.INSENSITIVE)), - rexOpPathStepSymbol(identifierSymbol("sEcoNd_KEY", Identifier.CaseSensitivity.SENSITIVE)), - ) - ) - ) - ) - val firstKeyStruct = StructType( - fields = mapOf( - "sEcoNd_KEY" to StaticType.INT4 - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.UniqueAttrs(true), - TupleConstraint.Open(false) - ) - ) - val topLevelStruct = StructType( - fields = mapOf( - "FiRsT_KeY" to firstKeyStruct - ), - contentClosed = true, - constraints = setOf( - TupleConstraint.UniqueAttrs(true), - TupleConstraint.Open(false) - ) - ) - val expected = statementQuery( - root = rex( - type = StaticType.INT4, - op = rexOpPath( - root = rex( - type = topLevelStruct, - rexOpStruct( - fields = listOf( - rexOpStructField( - k = rex(StaticType.STRING, rexOpLit(stringValue("FiRsT_KeY"))), - v = rex( - type = firstKeyStruct, - rexOpStruct( - fields = listOf( - rexOpStructField( - k = rex(StaticType.STRING, rexOpLit(stringValue("sEcoNd_KEY"))), - v = rex(StaticType.INT4, rexOpLit(int32Value(5))) - ) - ) - ) - ) - ) - ) - ) - ), - steps = listOf( - rexOpPathStepSymbol(identifierSymbol("FiRsT_KeY", Identifier.CaseSensitivity.SENSITIVE)), - rexOpPathStepSymbol(identifierSymbol("sEcoNd_KEY", Identifier.CaseSensitivity.SENSITIVE)), - ) - ) - ) - ) + val input = statementQuery(LITERAL_STRUCT_1.pathSymbol("first_key").pathKey("sEcoNd_KEY")) + val expected = statementQuery(LITERAL_STRUCT_1_TYPED.pathKey("FiRsT_KeY", LITERAL_STRUCT_1_FIRST_KEY_TYPE).pathKey("sEcoNd_KEY", INT4)) + val actual = typer.resolve(input) assertEquals(expected, actual) } @@ -239,36 +222,10 @@ class PlanTyperTest { val wrapper = getTyper() val typer = wrapper.typer val input = statementQuery( - root = rex( - type = StaticType.ANY, - op = rexOpPath( - root = rex( - StaticType.ANY, - rexOpVarUnresolved( - identifierSymbol("closed_ordered_duplicates_struct", Identifier.CaseSensitivity.SENSITIVE), - Rex.Op.Var.Scope.DEFAULT - ) - ), - steps = listOf( - rexOpPathStepSymbol(identifierSymbol("DEFINITION", Identifier.CaseSensitivity.INSENSITIVE)), - ) - ) - ) - ) - val expected = statementQuery( - root = rex( - type = StaticType.STRING, - op = rexOpPath( - root = rex( - ORDERED_DUPLICATES_STRUCT, - rexOpGlobal(0) - ), - steps = listOf( - rexOpPathStepSymbol(identifierSymbol("definition", Identifier.CaseSensitivity.SENSITIVE)), - ) - ) - ) + unresolvedSensitiveVar("closed_ordered_duplicates_struct").pathSymbol("DEFINITION") ) + val expected = statementQuery(global(ORDERED_DUPLICATES_STRUCT).pathKey("definition", STRING)) + val actual = typer.resolve(input) assertEquals(expected, actual) } @@ -277,37 +234,9 @@ class PlanTyperTest { fun testOrderedDuplicatesWithSensitivity() { val wrapper = getTyper() val typer = wrapper.typer - val input = statementQuery( - root = rex( - type = StaticType.ANY, - op = rexOpPath( - root = rex( - StaticType.ANY, - rexOpVarUnresolved( - identifierSymbol("closed_ordered_duplicates_struct", Identifier.CaseSensitivity.SENSITIVE), - Rex.Op.Var.Scope.DEFAULT - ) - ), - steps = listOf( - rexOpPathStepSymbol(identifierSymbol("DEFINITION", Identifier.CaseSensitivity.SENSITIVE)), - ) - ) - ) - ) - val expected = statementQuery( - root = rex( - type = StaticType.DECIMAL, - op = rexOpPath( - root = rex( - ORDERED_DUPLICATES_STRUCT, - rexOpGlobal(0) - ), - steps = listOf( - rexOpPathStepSymbol(identifierSymbol("DEFINITION", Identifier.CaseSensitivity.SENSITIVE)), - ) - ) - ) - ) + val input = statementQuery(unresolvedSensitiveVar("closed_ordered_duplicates_struct").pathKey("DEFINITION")) + val expected = statementQuery(global(ORDERED_DUPLICATES_STRUCT).pathKey("DEFINITION", DECIMAL)) + val actual = typer.resolve(input) assertEquals(expected, actual) } @@ -316,37 +245,9 @@ class PlanTyperTest { fun testUnorderedDuplicates() { val wrapper = getTyper() val typer = wrapper.typer - val input = statementQuery( - root = rex( - type = StaticType.ANY, - op = rexOpPath( - root = rex( - StaticType.ANY, - rexOpVarUnresolved( - identifierSymbol("closed_duplicates_struct", Identifier.CaseSensitivity.SENSITIVE), - Rex.Op.Var.Scope.DEFAULT - ) - ), - steps = listOf( - rexOpPathStepSymbol(identifierSymbol("DEFINITION", Identifier.CaseSensitivity.INSENSITIVE)), - ) - ) - ) - ) - val expected = statementQuery( - root = rex( - type = StaticType.unionOf(StaticType.STRING, StaticType.FLOAT, StaticType.DECIMAL), - op = rexOpPath( - root = rex( - DUPLICATES_STRUCT, - rexOpGlobal(0) - ), - steps = listOf( - rexOpPathStepSymbol(identifierSymbol("DEFINITION", Identifier.CaseSensitivity.INSENSITIVE)), - ) - ) - ) - ) + val input = statementQuery(unresolvedSensitiveVar("closed_duplicates_struct").pathSymbol("DEFINITION")) + val expected = statementQuery(global(DUPLICATES_STRUCT).pathSymbol("DEFINITION", StaticType.unionOf(STRING, FLOAT, DECIMAL))) + val actual = typer.resolve(input) assertEquals(expected, actual) } @@ -355,37 +256,9 @@ class PlanTyperTest { fun testUnorderedDuplicatesWithSensitivity() { val wrapper = getTyper() val typer = wrapper.typer - val input = statementQuery( - root = rex( - type = StaticType.ANY, - op = rexOpPath( - root = rex( - StaticType.ANY, - rexOpVarUnresolved( - identifierSymbol("closed_duplicates_struct", Identifier.CaseSensitivity.SENSITIVE), - Rex.Op.Var.Scope.DEFAULT - ) - ), - steps = listOf( - rexOpPathStepSymbol(identifierSymbol("DEFINITION", Identifier.CaseSensitivity.SENSITIVE)), - ) - ) - ) - ) - val expected = statementQuery( - root = rex( - type = StaticType.DECIMAL, - op = rexOpPath( - root = rex( - DUPLICATES_STRUCT, - rexOpGlobal(0) - ), - steps = listOf( - rexOpPathStepSymbol(identifierSymbol("DEFINITION", Identifier.CaseSensitivity.SENSITIVE)), - ) - ) - ) - ) + val input = statementQuery(unresolvedSensitiveVar("closed_duplicates_struct").pathKey("DEFINITION")) + val expected = statementQuery(global(DUPLICATES_STRUCT).pathKey("DEFINITION", DECIMAL)) + val actual = typer.resolve(input) assertEquals(expected, actual) } @@ -394,37 +267,9 @@ class PlanTyperTest { fun testUnorderedDuplicatesWithSensitivityAndDuplicateResults() { val wrapper = getTyper() val typer = wrapper.typer - val input = statementQuery( - root = rex( - type = StaticType.ANY, - op = rexOpPath( - root = rex( - StaticType.ANY, - rexOpVarUnresolved( - identifierSymbol("closed_duplicates_struct", Identifier.CaseSensitivity.SENSITIVE), - Rex.Op.Var.Scope.DEFAULT - ) - ), - steps = listOf( - rexOpPathStepSymbol(identifierSymbol("definition", Identifier.CaseSensitivity.SENSITIVE)), - ) - ) - ) - ) - val expected = statementQuery( - root = rex( - type = StaticType.unionOf(StaticType.STRING, StaticType.FLOAT), - op = rexOpPath( - root = rex( - DUPLICATES_STRUCT, - rexOpGlobal(0) - ), - steps = listOf( - rexOpPathStepSymbol(identifierSymbol("definition", Identifier.CaseSensitivity.SENSITIVE)), - ) - ) - ) - ) + val input = statementQuery(unresolvedSensitiveVar("closed_duplicates_struct").pathKey("definition")) + val expected = statementQuery(global(DUPLICATES_STRUCT).pathKey("definition", StaticType.unionOf(StaticType.STRING, StaticType.FLOAT))) + val actual = typer.resolve(input) assertEquals(expected, actual) } @@ -433,37 +278,9 @@ class PlanTyperTest { fun testOpenDuplicates() { val wrapper = getTyper() val typer = wrapper.typer - val input = statementQuery( - root = rex( - type = StaticType.ANY, - op = rexOpPath( - root = rex( - StaticType.ANY, - rexOpVarUnresolved( - identifierSymbol("open_duplicates_struct", Identifier.CaseSensitivity.SENSITIVE), - Rex.Op.Var.Scope.DEFAULT - ) - ), - steps = listOf( - rexOpPathStepSymbol(identifierSymbol("definition", Identifier.CaseSensitivity.SENSITIVE)), - ) - ) - ) - ) - val expected = statementQuery( - root = rex( - type = StaticType.ANY, - op = rexOpPath( - root = rex( - OPEN_DUPLICATES_STRUCT, - rexOpGlobal(0) - ), - steps = listOf( - rexOpPathStepSymbol(identifierSymbol("definition", Identifier.CaseSensitivity.SENSITIVE)), - ) - ) - ) - ) + val input = statementQuery(unresolvedSensitiveVar("open_duplicates_struct").pathKey("definition")) + val expected = statementQuery(global(OPEN_DUPLICATES_STRUCT).pathKey("definition")) + val actual = typer.resolve(input) assertEquals(expected, actual) } @@ -472,37 +289,9 @@ class PlanTyperTest { fun testUnionClosedDuplicates() { val wrapper = getTyper() val typer = wrapper.typer - val input = statementQuery( - root = rex( - type = StaticType.ANY, - op = rexOpPath( - root = rex( - StaticType.ANY, - rexOpVarUnresolved( - identifierSymbol("closed_union_duplicates_struct", Identifier.CaseSensitivity.SENSITIVE), - Rex.Op.Var.Scope.DEFAULT - ) - ), - steps = listOf( - rexOpPathStepSymbol(identifierSymbol("definition", Identifier.CaseSensitivity.INSENSITIVE)), - ) - ) - ) - ) - val expected = statementQuery( - root = rex( - type = StaticType.unionOf(StaticType.STRING, StaticType.FLOAT, StaticType.DECIMAL, StaticType.INT2), - op = rexOpPath( - root = rex( - CLOSED_UNION_DUPLICATES_STRUCT, - rexOpGlobal(0) - ), - steps = listOf( - rexOpPathStepSymbol(identifierSymbol("definition", Identifier.CaseSensitivity.INSENSITIVE)), - ) - ) - ) - ) + val input = statementQuery(unresolvedSensitiveVar("closed_union_duplicates_struct").pathSymbol("definition")) + val expected = statementQuery(global(CLOSED_UNION_DUPLICATES_STRUCT).pathSymbol("definition", StaticType.unionOf(STRING, FLOAT, DECIMAL, INT2))) + val actual = typer.resolve(input) assertEquals(expected, actual) } @@ -511,62 +300,34 @@ class PlanTyperTest { fun testUnionClosedDuplicatesWithSensitivity() { val wrapper = getTyper() val typer = wrapper.typer - val input = statementQuery( - root = rex( - type = StaticType.ANY, - op = rexOpPath( - root = rex( - StaticType.ANY, - rexOpVarUnresolved( - identifierSymbol("closed_union_duplicates_struct", Identifier.CaseSensitivity.SENSITIVE), - Rex.Op.Var.Scope.DEFAULT - ) - ), - steps = listOf( - rexOpPathStepSymbol(identifierSymbol("definition", Identifier.CaseSensitivity.SENSITIVE)), - ) - ) - ) - ) - val expected = statementQuery( - root = rex( - type = StaticType.unionOf(StaticType.STRING, StaticType.FLOAT, StaticType.INT2), - op = rexOpPath( - root = rex( - CLOSED_UNION_DUPLICATES_STRUCT, - rexOpGlobal(0) - ), - steps = listOf( - rexOpPathStepSymbol(identifierSymbol("definition", Identifier.CaseSensitivity.SENSITIVE)), - ) - ) - ) - ) + val input = statementQuery(unresolvedSensitiveVar("closed_union_duplicates_struct").pathKey("definition")) + val expected = statementQuery(global(CLOSED_UNION_DUPLICATES_STRUCT).pathKey("definition", StaticType.unionOf(STRING, FLOAT, INT2))) + val actual = typer.resolve(input) assertEquals(expected, actual) } - /** - * A [ProblemHandler] that collects all the encountered [Problem]s without throwing. - * - * This is intended to be used when wanting to collect multiple problems that may be encountered (e.g. a static type - * inference pass that can result in multiple errors and/or warnings). This handler does not collect other exceptions - * that may be thrown. - */ - internal class ProblemCollector : ProblemCallback { - private val problemList = mutableListOf() + @OptIn(PartiQLValueExperimental::class) + private fun rexString(str: String) = rex(STRING, rexOpLit(stringValue(str))) - val problems: List - get() = problemList + private fun Rex.pathKey(key: String, type: StaticType = ANY): Rex = Rex(type, rexOpPathKey(this, rexString(key))) - val hasErrors: Boolean - get() = problemList.any { it.details.severity == ProblemSeverity.ERROR } + private fun Rex.pathSymbol(key: String, type: StaticType = ANY): Rex = Rex(type, rexOpPathSymbol(this, key)) - val hasWarnings: Boolean - get() = problemList.any { it.details.severity == ProblemSeverity.WARNING } + private fun unresolvedSensitiveVar(name: String, type: StaticType = ANY): Rex { + return rex( + type, + rexOpVarUnresolved( + identifierSymbol(name, Identifier.CaseSensitivity.SENSITIVE), + Rex.Op.Var.Scope.DEFAULT + ) + ) + } - override fun invoke(problem: Problem) { - problemList.add(problem) - } + private fun global(type: StaticType, catalogIndex: Int = 0, symbolIndex: Int = 0): Rex { + return rex( + type, + rexOpGlobal(catalogSymbolRef(catalogIndex, symbolIndex)) + ) } } diff --git a/partiql-eval/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt similarity index 89% rename from partiql-eval/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt rename to partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt index 19ed6a2ff2..b7601c561d 100644 --- a/partiql-eval/src/test/kotlin/org/partiql/lang/planner/transforms/PartiQLSchemaInferencerTests.kt +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/internal/typer/PlanTyperTestsPorted.kt @@ -1,8 +1,5 @@ -package org.partiql.lang.planner.transforms +package org.partiql.planner.internal.typer -import com.amazon.ionelement.api.field -import com.amazon.ionelement.api.ionString -import com.amazon.ionelement.api.ionStructOf import com.amazon.ionelement.api.loadSingleElement import org.junit.jupiter.api.assertThrows import org.junit.jupiter.api.extension.ExtensionContext @@ -13,150 +10,94 @@ import org.junit.jupiter.params.provider.Arguments import org.junit.jupiter.params.provider.ArgumentsProvider import org.junit.jupiter.params.provider.ArgumentsSource import org.junit.jupiter.params.provider.MethodSource -import org.partiql.annotations.ExperimentalPartiQLSchemaInferencer import org.partiql.errors.Problem import org.partiql.errors.UNKNOWN_PROBLEM_LOCATION -import org.partiql.lang.eval.internal.ProblemCollector -import org.partiql.lang.planner.SchemaLoader.toStaticType -import org.partiql.lang.planner.transforms.PartiQLSchemaInferencerTests.ProblemHandler -import org.partiql.lang.planner.transforms.PartiQLSchemaInferencerTests.TestCase.ErrorTestCase -import org.partiql.lang.planner.transforms.PartiQLSchemaInferencerTests.TestCase.SuccessTestCase -import org.partiql.lang.planner.transforms.PartiQLSchemaInferencerTests.TestCase.ThrowingExceptionTestCase +import org.partiql.parser.PartiQLParser +import org.partiql.plan.PartiQLPlan +import org.partiql.plan.Statement import org.partiql.plan.debug.PlanPrinter import org.partiql.planner.PartiQLPlanner +import org.partiql.planner.PartiQLPlannerBuilder import org.partiql.planner.PlanningProblemDetails +import org.partiql.planner.internal.typer.PlanTyperTestsPorted.TestCase.ErrorTestCase +import org.partiql.planner.internal.typer.PlanTyperTestsPorted.TestCase.SuccessTestCase +import org.partiql.planner.internal.typer.PlanTyperTestsPorted.TestCase.ThrowingExceptionTestCase import org.partiql.planner.test.PartiQLTest import org.partiql.planner.test.PartiQLTestProvider -import org.partiql.plugins.memory.MemoryCatalog -import org.partiql.plugins.memory.MemoryPlugin +import org.partiql.planner.util.ProblemCollector +import org.partiql.plugins.local.toStaticType +import org.partiql.plugins.memory.MemoryConnector +import org.partiql.spi.connector.ConnectorMetadata import org.partiql.types.AnyOfType import org.partiql.types.AnyType import org.partiql.types.BagType import org.partiql.types.ListType import org.partiql.types.SexpType import org.partiql.types.StaticType -import org.partiql.types.StaticType.Companion.ANY -import org.partiql.types.StaticType.Companion.BAG -import org.partiql.types.StaticType.Companion.BOOL -import org.partiql.types.StaticType.Companion.DATE -import org.partiql.types.StaticType.Companion.DECIMAL -import org.partiql.types.StaticType.Companion.INT -import org.partiql.types.StaticType.Companion.INT4 -import org.partiql.types.StaticType.Companion.INT8 -import org.partiql.types.StaticType.Companion.MISSING -import org.partiql.types.StaticType.Companion.NULL -import org.partiql.types.StaticType.Companion.STRING -import org.partiql.types.StaticType.Companion.unionOf import org.partiql.types.StructType import org.partiql.types.TupleConstraint -import java.time.Instant import java.util.stream.Stream import kotlin.reflect.KClass import kotlin.test.assertEquals import kotlin.test.assertNotNull import kotlin.test.assertTrue -class PartiQLSchemaInferencerTests { - private val testProvider = PartiQLTestProvider() - - init { - // load test inputs - testProvider.load() - } - - @ParameterizedTest - @ArgumentsSource(TestProvider::class) - fun test(tc: TestCase) = runTest(tc) - - @ParameterizedTest - @MethodSource("collections") - @Execution(ExecutionMode.CONCURRENT) - fun testCollections(tc: TestCase) = runTest(tc) - - @ParameterizedTest - @MethodSource("selectStar") - @Execution(ExecutionMode.CONCURRENT) - fun testSelectStar(tc: TestCase) = runTest(tc) - - @ParameterizedTest - @MethodSource("scanCases") - @Execution(ExecutionMode.CONCURRENT) - fun testScan(tc: TestCase) = runTest(tc) - - @ParameterizedTest - @MethodSource("pivotCases") - @Execution(ExecutionMode.CONCURRENT) - fun testPivot(tc: TestCase) = runTest(tc) - - @ParameterizedTest - @MethodSource("sessionVariables") - @Execution(ExecutionMode.CONCURRENT) - fun testSessionVariables(tc: TestCase) = runTest(tc) +class PlanTyperTestsPorted { - @ParameterizedTest - @MethodSource("bitwiseAnd") - @Execution(ExecutionMode.CONCURRENT) - fun testBitwiseAnd(tc: TestCase) = runTest(tc) - - @ParameterizedTest - @MethodSource("unpivotCases") - @Execution(ExecutionMode.CONCURRENT) - fun testUnpivot(tc: TestCase) = runTest(tc) - - @ParameterizedTest - @MethodSource("joinCases") - @Execution(ExecutionMode.CONCURRENT) - fun testJoins(tc: TestCase) = runTest(tc) - - @ParameterizedTest - @MethodSource("excludeCases") - @Execution(ExecutionMode.CONCURRENT) - fun testExclude(tc: TestCase) = runTest(tc) - - @ParameterizedTest - @MethodSource("orderByCases") - @Execution(ExecutionMode.CONCURRENT) - fun testOrderBy(tc: TestCase) = runTest(tc) - - @ParameterizedTest - @MethodSource("tupleUnionCases") - @Execution(ExecutionMode.CONCURRENT) - fun testTupleUnion(tc: TestCase) = runTest(tc) - - @ParameterizedTest - @MethodSource("aggregationCases") - @Execution(ExecutionMode.CONCURRENT) - fun testAggregations(tc: TestCase) = runTest(tc) - - @ParameterizedTest - @MethodSource("scalarFunctions") - @Execution(ExecutionMode.CONCURRENT) - fun testScalarFunctions(tc: TestCase) = runTest(tc) - - @ParameterizedTest - @MethodSource("pathExpressions") - @Execution(ExecutionMode.CONCURRENT) - fun testPathExpressions(tc: TestCase) = runTest(tc) - - @ParameterizedTest - @MethodSource("caseWhens") - @Execution(ExecutionMode.CONCURRENT) - fun testCaseWhens(tc: TestCase) = runTest(tc) + sealed class TestCase { + class SuccessTestCase( + val name: String, + val key: PartiQLTest.Key? = null, + val query: String? = null, + val catalog: String? = null, + val catalogPath: List = emptyList(), + val expected: StaticType, + val warnings: ProblemHandler? = null, + ) : TestCase() { + override fun toString(): String = "$name : $query" + } - @ParameterizedTest - @MethodSource("subqueryCases") - @Execution(ExecutionMode.CONCURRENT) - fun testSubqueries(tc: TestCase) = runTest(tc) + class ErrorTestCase( + val name: String, + val key: PartiQLTest.Key? = null, + val query: String? = null, + val catalog: String? = null, + val catalogPath: List = emptyList(), + val note: String? = null, + val expected: StaticType? = null, + val problemHandler: ProblemHandler? = null, + ) : TestCase() { + override fun toString(): String = "$name : $query" + } - @ParameterizedTest - @MethodSource("dynamicCalls") - @Execution(ExecutionMode.CONCURRENT) - fun testDynamicCalls(tc: TestCase) = runTest(tc) + class ThrowingExceptionTestCase( + val name: String, + val query: String, + val catalog: String? = null, + val catalogPath: List = emptyList(), + val note: String? = null, + val expectedThrowable: KClass, + ) : TestCase() { + override fun toString(): String { + return "$name : $query" + } + } + } companion object { - val inputStream = this::class.java.getResourceAsStream("/resource_path.txt")!! - val catalogProvider = MemoryCatalog.Provider().also { + private fun assertProblemExists(problem: () -> Problem) = ProblemHandler { problems, ignoreSourceLocation -> + when (ignoreSourceLocation) { + true -> assertTrue("Expected to find ${problem.invoke()} in $problems") { problems.any { it.details == problem.invoke().details } } + false -> assertTrue("Expected to find ${problem.invoke()} in $problems") { problems.any { it == problem.invoke() } } + } + } + + /** + * MemoryConnector.Factory from reading the resources in /resource_path.txt for Github CI/CD. + */ + val catalogs: List> by lazy { + val inputStream = this::class.java.getResourceAsStream("/resource_path.txt")!! val map = mutableMapOf>>() inputStream.reader().readLines().forEach { path -> if (path.startsWith("catalogs/default")) { @@ -178,41 +119,26 @@ class PartiQLSchemaInferencerTests { } } } - map.forEach { (k: String, v: MutableList>) -> - it[k] = MemoryCatalog.of(*v.toTypedArray()) + map.entries.map { + it.key to MemoryConnector.Metadata.of(*it.value.toTypedArray()) } } - private val PLUGINS = listOf(MemoryPlugin(catalogProvider)) - private const val USER_ID = "TEST_USER" - private val catalogConfig = mapOf( - "aws" to ionStructOf( - field("connector_name", ionString("memory")), - ), - "b" to ionStructOf( - field("connector_name", ionString("memory")), - ), - "db" to ionStructOf( - field("connector_name", ionString("memory")), - ), - "pql" to ionStructOf( - field("connector_name", ionString("memory")), - ), - "subqueries" to ionStructOf( - field("connector_name", ionString("memory")), - ), - ) + private fun key(name: String) = PartiQLTest.Key("schema_inferencer", name) + // + // testing result utility + // const val CATALOG_AWS = "aws" const val CATALOG_B = "b" const val CATALOG_DB = "db" val DB_SCHEMA_MARKETS = listOf("markets") val TYPE_BOOL = StaticType.BOOL - private val TYPE_AWS_DDB_PETS_ID = INT4 - private val TYPE_AWS_DDB_PETS_BREED = STRING + private val TYPE_AWS_DDB_PETS_ID = StaticType.INT4 + private val TYPE_AWS_DDB_PETS_BREED = StaticType.STRING val TABLE_AWS_DDB_PETS = BagType( elementType = StructType( fields = mapOf( @@ -243,7 +169,7 @@ class PartiQLSchemaInferencerTests { ) val TABLE_AWS_DDB_B = BagType( StructType( - fields = mapOf("identifier" to STRING), + fields = mapOf("identifier" to StaticType.STRING), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), @@ -254,7 +180,7 @@ class PartiQLSchemaInferencerTests { ) val TABLE_AWS_B_B = BagType( StructType( - fields = mapOf("identifier" to INT4), + fields = mapOf("identifier" to StaticType.INT4), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), @@ -263,14 +189,14 @@ class PartiQLSchemaInferencerTests { ) ) ) - val TYPE_B_B_B_B_B = INT4 + val TYPE_B_B_B_B_B = StaticType.INT4 private val TYPE_B_B_B_B = StructType( mapOf("b" to TYPE_B_B_B_B_B), contentClosed = true, constraints = setOf(TupleConstraint.Open(false), TupleConstraint.UniqueAttrs(true), TupleConstraint.Ordered) ) - val TYPE_B_B_B_C = INT4 - val TYPE_B_B_C = INT4 + val TYPE_B_B_B_C = StaticType.INT4 + val TYPE_B_B_C = StaticType.INT4 val TYPE_B_B_B = StructType( fields = mapOf( @@ -285,50 +211,42 @@ class PartiQLSchemaInferencerTests { ) ) - private fun assertProblemExists(problem: () -> Problem) = ProblemHandler { problems, ignoreSourceLocation -> - when (ignoreSourceLocation) { - true -> assertTrue("Expected to find ${problem.invoke()} in $problems") { problems.any { it.details == problem.invoke().details } } - false -> assertTrue("Expected to find ${problem.invoke()} in $problems") { problems.any { it == problem.invoke() } } - } - } - - // Tests - - private fun key(name: String) = PartiQLTest.Key("schema_inferencer", name) - + // + // Parameterized Test Source + // @JvmStatic fun collections() = listOf( SuccessTestCase( name = "Collection BAG", key = key("collections-01"), - expected = BagType(INT4), + expected = BagType(StaticType.INT4), ), SuccessTestCase( name = "Collection LIST", key = key("collections-02"), - expected = ListType(INT4), + expected = ListType(StaticType.INT4), ), SuccessTestCase( name = "Collection LIST", key = key("collections-03"), - expected = ListType(INT4), + expected = ListType(StaticType.INT4), ), SuccessTestCase( name = "Collection SEXP", key = key("collections-04"), - expected = SexpType(INT4), + expected = SexpType(StaticType.INT4), ), SuccessTestCase( name = "SELECT from array", key = key("collections-05"), - expected = BagType(INT4), + expected = BagType(StaticType.INT4), ), SuccessTestCase( name = "SELECT from array", key = key("collections-06"), expected = BagType( StructType( - fields = listOf(StructType.Field("x", INT4)), + fields = listOf(StructType.Field("x", StaticType.INT4)), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), @@ -374,8 +292,8 @@ class PartiQLSchemaInferencerTests { "name", StructType( fields = listOf( - StructType.Field("first", STRING), - StructType.Field("last", STRING), + StructType.Field("first", StaticType.STRING), + StructType.Field("last", StaticType.STRING), ), contentClosed = true, constraints = setOf( @@ -385,16 +303,16 @@ class PartiQLSchemaInferencerTests { ), ) ), - StructType.Field("ssn", STRING), - StructType.Field("employer", STRING.asNullable()), - StructType.Field("name", STRING), - StructType.Field("tax_id", INT8), + StructType.Field("ssn", StaticType.STRING), + StructType.Field("employer", StaticType.STRING.asNullable()), + StructType.Field("name", StaticType.STRING), + StructType.Field("tax_id", StaticType.INT8), StructType.Field( "address", StructType( fields = listOf( - StructType.Field("street", STRING), - StructType.Field("zip", INT4), + StructType.Field("street", StaticType.STRING), + StructType.Field("zip", StaticType.INT4), ), contentClosed = true, constraints = setOf( @@ -415,9 +333,9 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = listOf( - StructType.Field("first", STRING), - StructType.Field("last", STRING), - StructType.Field("full_name", STRING), + StructType.Field("first", StaticType.STRING), + StructType.Field("last", StaticType.STRING), + StructType.Field("full_name", StaticType.STRING), ), contentClosed = true, constraints = setOf( @@ -439,8 +357,8 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = listOf( - StructType.Field("first", STRING), - StructType.Field("i", INT8), + StructType.Field("first", StaticType.STRING), + StructType.Field("i", StaticType.INT8), ), contentClosed = true, constraints = setOf( @@ -473,22 +391,22 @@ class PartiQLSchemaInferencerTests { SuccessTestCase( name = "Current User", query = "CURRENT_USER", - expected = unionOf(STRING, StaticType.NULL) + expected = StaticType.unionOf(StaticType.STRING, StaticType.NULL) ), SuccessTestCase( name = "Current User Concat", query = "CURRENT_USER || 'hello'", - expected = unionOf(STRING, StaticType.NULL) + expected = StaticType.unionOf(StaticType.STRING, StaticType.NULL) ), SuccessTestCase( name = "Current User in WHERE", query = "SELECT VALUE a FROM [ 0 ] AS a WHERE CURRENT_USER = 'hello'", - expected = BagType(INT4) + expected = BagType(StaticType.INT4) ), SuccessTestCase( name = "Current User in WHERE", query = "SELECT VALUE a FROM [ 0 ] AS a WHERE CURRENT_USER = 5", - expected = BagType(INT4), + expected = BagType(StaticType.INT4), ), SuccessTestCase( name = "Testing CURRENT_USER and CURRENT_DATE Binders", @@ -504,11 +422,11 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = listOf( - StructType.Field("CURRENT_USER", STRING.asNullable()), - StructType.Field("CURRENT_DATE", DATE), - StructType.Field("curr_user", STRING.asNullable()), - StructType.Field("curr_date", DATE), - StructType.Field("name_desc", STRING.asNullable()), + StructType.Field("CURRENT_USER", StaticType.STRING.asNullable()), + StructType.Field("CURRENT_DATE", StaticType.DATE), + StructType.Field("curr_user", StaticType.STRING.asNullable()), + StructType.Field("curr_date", StaticType.DATE), + StructType.Field("name_desc", StaticType.STRING.asNullable()), ), contentClosed = true, constraints = setOf( @@ -529,8 +447,8 @@ class PartiQLSchemaInferencerTests { PlanningProblemDetails.UnknownFunction( "plus", listOf( - unionOf(STRING, StaticType.NULL), - STRING, + StaticType.unionOf(StaticType.STRING, StaticType.NULL), + StaticType.STRING, ), ) ) @@ -543,17 +461,17 @@ class PartiQLSchemaInferencerTests { SuccessTestCase( name = "BITWISE_AND_1", query = "1 & 2", - expected = INT4 + expected = StaticType.INT4 ), SuccessTestCase( name = "BITWISE_AND_2", query = "CAST(1 AS INT2) & CAST(2 AS INT2)", - expected = StaticType.unionOf(StaticType.INT2, MISSING) + expected = StaticType.unionOf(StaticType.INT2, StaticType.MISSING) ), SuccessTestCase( name = "BITWISE_AND_3", query = "1 & 2", - expected = INT4 + expected = StaticType.INT4 ), SuccessTestCase( name = "BITWISE_AND_4", @@ -563,17 +481,17 @@ class PartiQLSchemaInferencerTests { SuccessTestCase( name = "BITWISE_AND_5", query = "CAST(1 AS INT2) & 2", - expected = StaticType.unionOf(StaticType.INT4, MISSING) + expected = StaticType.unionOf(StaticType.INT4, StaticType.MISSING) ), SuccessTestCase( name = "BITWISE_AND_6", query = "CAST(1 AS INT2) & CAST(2 AS INT8)", - expected = StaticType.unionOf(StaticType.INT8, MISSING) + expected = StaticType.unionOf(StaticType.INT8, StaticType.MISSING) ), SuccessTestCase( name = "BITWISE_AND_7", query = "CAST(1 AS INT2) & 2", - expected = StaticType.unionOf(INT4, MISSING) + expected = StaticType.unionOf(StaticType.INT4, StaticType.MISSING) ), SuccessTestCase( name = "BITWISE_AND_8", @@ -588,7 +506,7 @@ class PartiQLSchemaInferencerTests { SuccessTestCase( name = "BITWISE_AND_10", query = "CAST(1 AS INT8) & 2", - expected = INT8 + expected = StaticType.INT8 ), SuccessTestCase( name = "BITWISE_AND_NULL_OPERAND", @@ -604,7 +522,7 @@ class PartiQLSchemaInferencerTests { sourceLocation = UNKNOWN_PROBLEM_LOCATION, details = PlanningProblemDetails.UnknownFunction( "bitwise_and", - listOf(INT4, MISSING) + listOf(StaticType.INT4, StaticType.MISSING) ) ) } @@ -616,7 +534,7 @@ class PartiQLSchemaInferencerTests { problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnknownFunction("bitwise_and", listOf(INT4, STRING)) + PlanningProblemDetails.UnknownFunction("bitwise_and", listOf(StaticType.INT4, StaticType.STRING)) ) } ), @@ -627,7 +545,7 @@ class PartiQLSchemaInferencerTests { SuccessTestCase( name = "UNPIVOT", query = "SELECT VALUE v FROM UNPIVOT { 'a': 2 } AS v AT attr WHERE attr = 'a'", - expected = BagType(INT4) + expected = BagType(StaticType.INT4) ), ) @@ -639,7 +557,7 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "a" to INT4, + "a" to StaticType.INT4, "b" to StaticType.DECIMAL, ), contentClosed = true, @@ -657,8 +575,8 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "a" to INT4, - "b" to unionOf(NULL, DECIMAL), + "a" to StaticType.INT4, + "b" to StaticType.unionOf(StaticType.NULL, StaticType.DECIMAL), ), contentClosed = true, constraints = setOf( @@ -675,8 +593,8 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = listOf( - StructType.Field("b", unionOf(NULL, DECIMAL)), - StructType.Field("a", INT4), + StructType.Field("b", StaticType.unionOf(StaticType.NULL, StaticType.DECIMAL)), + StructType.Field("a", StaticType.INT4), ), contentClosed = true, constraints = setOf( @@ -693,8 +611,8 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = listOf( - StructType.Field("a", INT4), - StructType.Field("a", unionOf(NULL, DECIMAL)), + StructType.Field("a", StaticType.INT4), + StructType.Field("a", StaticType.unionOf(StaticType.NULL, StaticType.DECIMAL)), ), contentClosed = true, constraints = setOf( @@ -711,8 +629,8 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = listOf( - StructType.Field("a", INT4), - StructType.Field("a", unionOf(NULL, DECIMAL)), + StructType.Field("a", StaticType.INT4), + StructType.Field("a", StaticType.unionOf(StaticType.NULL, StaticType.DECIMAL)), ), contentClosed = true, constraints = setOf( @@ -739,9 +657,9 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = listOf( - StructType.Field("a", INT4), - StructType.Field("a", unionOf(DECIMAL, NULL)), - StructType.Field("a", unionOf(STRING, NULL)), + StructType.Field("a", StaticType.INT4), + StructType.Field("a", StaticType.unionOf(StaticType.DECIMAL, StaticType.NULL)), + StructType.Field("a", StaticType.unionOf(StaticType.STRING, StaticType.NULL)), ), contentClosed = true, constraints = setOf( @@ -758,8 +676,8 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = listOf( - StructType.Field("a", INT4), - StructType.Field("a", unionOf(DECIMAL, NULL)), + StructType.Field("a", StaticType.INT4), + StructType.Field("a", StaticType.unionOf(StaticType.DECIMAL, StaticType.NULL)), ), contentClosed = true, constraints = setOf( @@ -848,8 +766,8 @@ class PartiQLSchemaInferencerTests { fields = mapOf( "field" to AnyOfType( setOf( - INT4, - MISSING // c[1]'s `field` was excluded + StaticType.INT4, + StaticType.MISSING // c[1]'s `field` was excluded ) ) ), @@ -2030,8 +1948,8 @@ class PartiQLSchemaInferencerTests { query = "TUPLEUNION({ 'a': 1, 'a': 'hello' })", expected = StructType( fields = listOf( - StructType.Field("a", INT4), - StructType.Field("a", STRING), + StructType.Field("a", StaticType.INT4), + StructType.Field("a", StaticType.STRING), ), contentClosed = true, constraints = setOf( @@ -2052,7 +1970,7 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = listOf( - StructType.Field("b", INT4), + StructType.Field("b", StaticType.INT4), ), contentClosed = true, // TODO: This shouldn't be ordered. However, this doesn't come from the TUPLEUNION. It is @@ -2076,11 +1994,11 @@ class PartiQLSchemaInferencerTests { >> AS t """, expected = BagType( - unionOf( - MISSING, + StaticType.unionOf( + StaticType.MISSING, StructType( fields = listOf( - StructType.Field("b", INT4), + StructType.Field("b", StaticType.INT4), ), contentClosed = true, constraints = setOf( @@ -2105,12 +2023,12 @@ class PartiQLSchemaInferencerTests { >> AS t """, expected = BagType( - unionOf( - NULL, - MISSING, + StaticType.unionOf( + StaticType.NULL, + StaticType.MISSING, StructType( fields = listOf( - StructType.Field("b", INT4), + StructType.Field("b", StaticType.INT4), ), contentClosed = true, constraints = setOf( @@ -2120,7 +2038,7 @@ class PartiQLSchemaInferencerTests { ), StructType( fields = listOf( - StructType.Field("b", STRING), + StructType.Field("b", StaticType.STRING), ), contentClosed = true, constraints = setOf( @@ -2139,12 +2057,12 @@ class PartiQLSchemaInferencerTests { ) FROM aws.ddb.persons AS p """, expected = BagType( - unionOf( - MISSING, + StaticType.unionOf( + StaticType.MISSING, StructType( fields = listOf( - StructType.Field("first", STRING), - StructType.Field("last", STRING), + StructType.Field("first", StaticType.STRING), + StructType.Field("last", StaticType.STRING), ), contentClosed = false, constraints = setOf( @@ -2154,7 +2072,7 @@ class PartiQLSchemaInferencerTests { ), StructType( fields = listOf( - StructType.Field("full_name", STRING), + StructType.Field("full_name", StaticType.STRING), ), contentClosed = true, constraints = setOf( @@ -2175,14 +2093,14 @@ class PartiQLSchemaInferencerTests { ) FROM aws.ddb.persons AS p """, expected = BagType( - unionOf( - MISSING, + StaticType.unionOf( + StaticType.MISSING, StructType( fields = listOf( - StructType.Field("first", STRING), - StructType.Field("last", STRING), - StructType.Field("first", STRING), - StructType.Field("last", STRING), + StructType.Field("first", StaticType.STRING), + StructType.Field("last", StaticType.STRING), + StructType.Field("first", StaticType.STRING), + StructType.Field("last", StaticType.STRING), ), contentClosed = false, constraints = setOf( @@ -2192,9 +2110,9 @@ class PartiQLSchemaInferencerTests { ), StructType( fields = listOf( - StructType.Field("first", STRING), - StructType.Field("last", STRING), - StructType.Field("full_name", STRING), + StructType.Field("first", StaticType.STRING), + StructType.Field("last", StaticType.STRING), + StructType.Field("full_name", StaticType.STRING), ), contentClosed = false, constraints = setOf( @@ -2204,9 +2122,9 @@ class PartiQLSchemaInferencerTests { ), StructType( fields = listOf( - StructType.Field("full_name", STRING), - StructType.Field("first", STRING), - StructType.Field("last", STRING), + StructType.Field("full_name", StaticType.STRING), + StructType.Field("first", StaticType.STRING), + StructType.Field("last", StaticType.STRING), ), contentClosed = false, constraints = setOf( @@ -2216,8 +2134,8 @@ class PartiQLSchemaInferencerTests { ), StructType( fields = listOf( - StructType.Field("full_name", STRING), - StructType.Field("full_name", STRING), + StructType.Field("full_name", StaticType.STRING), + StructType.Field("full_name", StaticType.STRING), ), contentClosed = true, constraints = setOf( @@ -2242,7 +2160,7 @@ class PartiQLSchemaInferencerTests { ELSE 2 END; """, - expected = INT4 + expected = StaticType.INT4 ), SuccessTestCase( name = "Folded case when to grab the true", @@ -2252,7 +2170,7 @@ class PartiQLSchemaInferencerTests { WHEN TRUE THEN 'hello' END; """, - expected = STRING + expected = StaticType.STRING ), SuccessTestCase( name = "Boolean case when", @@ -2262,7 +2180,7 @@ class PartiQLSchemaInferencerTests { ELSE FALSE END; """, - expected = BOOL + expected = StaticType.BOOL ), SuccessTestCase( name = "Folded out false", @@ -2272,7 +2190,7 @@ class PartiQLSchemaInferencerTests { ELSE TRUE END; """, - expected = BOOL + expected = StaticType.BOOL ), SuccessTestCase( name = "Folded out false without default", @@ -2281,7 +2199,7 @@ class PartiQLSchemaInferencerTests { WHEN FALSE THEN 'IMPOSSIBLE TO GET' END; """, - expected = NULL + expected = StaticType.NULL ), SuccessTestCase( name = "Not folded gives us a nullable without default", @@ -2291,7 +2209,7 @@ class PartiQLSchemaInferencerTests { WHEN 2 THEN FALSE END; """, - expected = BOOL.asNullable() + expected = StaticType.BOOL.asNullable() ), SuccessTestCase( name = "Not folded gives us a nullable without default for query", @@ -2308,7 +2226,7 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "breed_descriptor" to STRING.asNullable(), + "breed_descriptor" to StaticType.STRING.asNullable(), ), contentClosed = true, constraints = setOf( @@ -2335,7 +2253,7 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "breed_descriptor" to STRING, + "breed_descriptor" to StaticType.STRING, ), contentClosed = true, constraints = setOf( @@ -2362,7 +2280,11 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "breed_descriptor" to unionOf(STRING, INT4, DECIMAL), + "breed_descriptor" to StaticType.unionOf( + StaticType.STRING, + StaticType.INT4, + StaticType.DECIMAL + ), ), contentClosed = true, constraints = setOf( @@ -2382,7 +2304,7 @@ class PartiQLSchemaInferencerTests { query = """ [0, 1, 2, 3][0] """, - expected = INT4 + expected = StaticType.INT4 ), SuccessTestCase( name = "Index on global list", @@ -2391,7 +2313,7 @@ class PartiQLSchemaInferencerTests { """, catalog = "pql", catalogPath = listOf("main"), - expected = STRING + expected = StaticType.STRING ), SuccessTestCase( name = "Index on list attribute of global table", @@ -2403,7 +2325,7 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "main_allergy" to STRING, + "main_allergy" to StaticType.STRING, ), contentClosed = true, constraints = setOf( @@ -2424,7 +2346,7 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "s" to STRING, + "s" to StaticType.STRING, ), contentClosed = true, constraints = setOf( @@ -2445,7 +2367,7 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "s" to STRING, + "s" to StaticType.STRING, ), contentClosed = true, constraints = setOf( @@ -2466,7 +2388,7 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "s" to STRING, + "s" to StaticType.STRING, ), contentClosed = true, constraints = setOf( @@ -2485,7 +2407,7 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "s" to STRING, + "s" to StaticType.STRING, ), contentClosed = true, constraints = setOf( @@ -2501,7 +2423,7 @@ class PartiQLSchemaInferencerTests { query = """ SELECT VALUE 1 FROM "pql"."main"['employer'] AS e; """, - expected = BagType(INT4), + expected = BagType(StaticType.INT4), problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, @@ -2514,7 +2436,7 @@ class PartiQLSchemaInferencerTests { query = """ SELECT VALUE 1 FROM "pql"['main']."employer" AS e; """, - expected = BagType(INT4), + expected = BagType(StaticType.INT4), problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, @@ -2527,7 +2449,7 @@ class PartiQLSchemaInferencerTests { query = """ { 'aBc': 1, 'AbC': 2.0 }['AbC']; """, - expected = DECIMAL + expected = StaticType.DECIMAL ), // This should fail because the Spec says tuple indexing MUST use a literal string or explicit cast. ErrorTestCase( @@ -2535,7 +2457,7 @@ class PartiQLSchemaInferencerTests { query = """ { 'aBc': 1, 'AbC': 2.0 }['Ab' || 'C']; """, - expected = MISSING, + expected = StaticType.MISSING, problemHandler = assertProblemExists { Problem( sourceLocation = UNKNOWN_PROBLEM_LOCATION, @@ -2551,7 +2473,7 @@ class PartiQLSchemaInferencerTests { query = """ { 'aBc': 1, 'AbC': 2.0 }[CAST('Ab' || 'C' AS STRING)]; """, - expected = ANY + expected = StaticType.ANY ), ) @@ -2569,7 +2491,7 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "upper_str" to STRING, + "upper_str" to StaticType.STRING, ), contentClosed = true, constraints = setOf( @@ -2585,7 +2507,7 @@ class PartiQLSchemaInferencerTests { query = """ UPPER('hello world') """, - expected = STRING + expected = StaticType.STRING ), SuccessTestCase( name = "UPPER on global string", @@ -2594,7 +2516,7 @@ class PartiQLSchemaInferencerTests { """, catalog = "pql", catalogPath = listOf("main"), - expected = STRING + expected = StaticType.STRING ), SuccessTestCase( name = "UPPER on global string", @@ -2603,7 +2525,7 @@ class PartiQLSchemaInferencerTests { """, catalog = "pql", catalogPath = listOf("main"), - expected = STRING + expected = StaticType.STRING ), SuccessTestCase( name = "UPPER on global struct", @@ -2612,7 +2534,7 @@ class PartiQLSchemaInferencerTests { """, catalog = "pql", catalogPath = listOf("main"), - expected = STRING + expected = StaticType.STRING ), SuccessTestCase( name = "UPPER on global nested struct", @@ -2621,7 +2543,7 @@ class PartiQLSchemaInferencerTests { """, catalog = "pql", catalogPath = listOf("main"), - expected = STRING + expected = StaticType.STRING ), SuccessTestCase( name = "UPPER on global table", @@ -2634,7 +2556,7 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "upper_breed" to STRING, + "upper_breed" to StaticType.STRING, ), contentClosed = true, constraints = setOf( @@ -2655,10 +2577,10 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "a" to INT4, - "_1" to INT4, - "_2" to INT4.asNullable(), - "_3" to INT4.asNullable(), + "a" to StaticType.INT4, + "_1" to StaticType.INT4, + "_2" to StaticType.INT4.asNullable(), + "_3" to StaticType.INT4.asNullable(), ), contentClosed = true, constraints = setOf( @@ -2675,10 +2597,10 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "a" to INT4, - "c" to INT4, - "s" to INT4.asNullable(), - "m" to INT4.asNullable(), + "a" to StaticType.INT4, + "c" to StaticType.INT4, + "s" to StaticType.INT4.asNullable(), + "m" to StaticType.INT4.asNullable(), ), contentClosed = true, constraints = setOf( @@ -2696,7 +2618,7 @@ class PartiQLSchemaInferencerTests { StructType( fields = mapOf( "a" to StaticType.DECIMAL, - "c" to INT4, + "c" to StaticType.INT4, "s" to StaticType.DECIMAL.asNullable(), "m" to StaticType.DECIMAL.asNullable(), ), @@ -2725,7 +2647,7 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "a" to unionOf(INT4, INT8), + "a" to StaticType.unionOf(StaticType.INT4, StaticType.INT8), ), contentClosed = true, constraints = setOf( @@ -2749,7 +2671,7 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "a" to unionOf(INT4, INT8, MISSING), + "a" to StaticType.unionOf(StaticType.INT4, StaticType.INT8, StaticType.MISSING), ), contentClosed = true, constraints = setOf( @@ -2773,7 +2695,7 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "a" to unionOf(INT4, INT8, MISSING), + "a" to StaticType.unionOf(StaticType.INT4, StaticType.INT8, StaticType.MISSING), ), contentClosed = true, constraints = setOf( @@ -2797,7 +2719,7 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "c" to unionOf(MISSING, DECIMAL), + "c" to StaticType.unionOf(StaticType.MISSING, StaticType.DECIMAL), ), contentClosed = true, constraints = setOf( @@ -2819,13 +2741,13 @@ class PartiQLSchemaInferencerTests { { 'a': 'hello world!' } >> AS t """.trimIndent(), - expected = BagType(MISSING), + expected = BagType(StaticType.MISSING), problemHandler = assertProblemExists { Problem( sourceLocation = UNKNOWN_PROBLEM_LOCATION, details = PlanningProblemDetails.UnknownFunction( "pos", - listOf(STRING) + listOf(StaticType.STRING) ) ) } @@ -2842,13 +2764,13 @@ class PartiQLSchemaInferencerTests { { 'a': <<>> } >> AS t """.trimIndent(), - expected = BagType(MISSING), + expected = BagType(StaticType.MISSING), problemHandler = assertProblemExists { Problem( sourceLocation = UNKNOWN_PROBLEM_LOCATION, details = PlanningProblemDetails.UnknownFunction( "pos", - listOf(unionOf(STRING, BAG)) + listOf(StaticType.unionOf(StaticType.STRING, StaticType.BAG)) ) ) } @@ -2864,13 +2786,13 @@ class PartiQLSchemaInferencerTests { { 'NOT_A': 1 } >> AS t """.trimIndent(), - expected = BagType(MISSING), + expected = BagType(StaticType.MISSING), problemHandler = assertProblemExists { Problem( sourceLocation = UNKNOWN_PROBLEM_LOCATION, details = PlanningProblemDetails.UnknownFunction( "pos", - listOf(MISSING) + listOf(StaticType.MISSING) ) ) } @@ -2886,7 +2808,7 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "x" to INT4, + "x" to StaticType.INT4, ), contentClosed = true, constraints = setOf( @@ -2904,7 +2826,7 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "x" to INT4, + "x" to StaticType.INT4, ), contentClosed = true, constraints = setOf( @@ -2922,12 +2844,12 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = mapOf( - "x" to INT4, - "y" to INT4, - "z" to INT4, - "a" to INT4, - "b" to INT4, - "c" to INT4, + "x" to StaticType.INT4, + "y" to StaticType.INT4, + "z" to StaticType.INT4, + "a" to StaticType.INT4, + "b" to StaticType.INT4, + "c" to StaticType.INT4, ), contentClosed = true, constraints = setOf( @@ -2942,64 +2864,241 @@ class PartiQLSchemaInferencerTests { name = "Subquery scalar coercion", catalog = "subqueries", key = PartiQLTest.Key("subquery", "subquery-03"), - expected = BOOL, + expected = StaticType.BOOL, ), ) + + // --------- Parameterized Test Source Finished ------------ } - sealed class TestCase { - fun toIgnored(reason: String) = - when (this) { - is IgnoredTestCase -> this - else -> IgnoredTestCase(this, reason) - } + private val testProvider = PartiQLTestProvider() - class SuccessTestCase( - val name: String, - val key: PartiQLTest.Key? = null, - val query: String? = null, - val catalog: String? = null, - val catalogPath: List = emptyList(), - val expected: StaticType, - val warnings: ProblemHandler? = null, - ) : TestCase() { - override fun toString(): String = "$name : $query" + init { + // load test inputs + testProvider.load() + } + + // + // Parameterized Tests + // + @ParameterizedTest + @ArgumentsSource(TestProvider::class) + fun test(tc: TestCase) = runTest(tc) + + @ParameterizedTest + @MethodSource("collections") + @Execution(ExecutionMode.CONCURRENT) + fun testCollections(tc: TestCase) = runTest(tc) + + @ParameterizedTest + @MethodSource("selectStar") + @Execution(ExecutionMode.CONCURRENT) + fun testSelectStar(tc: TestCase) = runTest(tc) + + @ParameterizedTest + @MethodSource("sessionVariables") + @Execution(ExecutionMode.CONCURRENT) + fun testSessionVariables(tc: TestCase) = runTest(tc) + + @ParameterizedTest + @MethodSource("bitwiseAnd") + @Execution(ExecutionMode.CONCURRENT) + fun testBitwiseAnd(tc: TestCase) = runTest(tc) + + @ParameterizedTest + @MethodSource("unpivotCases") + @Execution(ExecutionMode.CONCURRENT) + fun testUnpivot(tc: TestCase) = runTest(tc) + + @ParameterizedTest + @MethodSource("joinCases") + @Execution(ExecutionMode.CONCURRENT) + fun testJoins(tc: TestCase) = runTest(tc) + + @ParameterizedTest + @MethodSource("excludeCases") + @Execution(ExecutionMode.CONCURRENT) + fun testExclude(tc: TestCase) = runTest(tc) + + @ParameterizedTest + @MethodSource("orderByCases") + @Execution(ExecutionMode.CONCURRENT) + fun testOrderBy(tc: TestCase) = runTest(tc) + + @ParameterizedTest + @MethodSource("tupleUnionCases") + @Execution(ExecutionMode.CONCURRENT) + fun testTupleUnion(tc: TestCase) = runTest(tc) + + @ParameterizedTest + @MethodSource("aggregationCases") + @Execution(ExecutionMode.CONCURRENT) + fun testAggregations(tc: TestCase) = runTest(tc) + + @ParameterizedTest + @MethodSource("scalarFunctions") + @Execution(ExecutionMode.CONCURRENT) + fun testScalarFunctions(tc: TestCase) = runTest(tc) + + @ParameterizedTest + @MethodSource("pathExpressions") + @Execution(ExecutionMode.CONCURRENT) + fun testPathExpressions(tc: TestCase) = runTest(tc) + + @ParameterizedTest + @MethodSource("caseWhens") + @Execution(ExecutionMode.CONCURRENT) + fun testCaseWhens(tc: TestCase) = runTest(tc) + + @ParameterizedTest + @MethodSource("subqueryCases") + @Execution(ExecutionMode.CONCURRENT) + fun testSubqueries(tc: TestCase) = runTest(tc) + + @ParameterizedTest + @MethodSource("dynamicCalls") + @Execution(ExecutionMode.CONCURRENT) + fun testDynamicCalls(tc: TestCase) = runTest(tc) + + @ParameterizedTest + @MethodSource("scanCases") + @Execution(ExecutionMode.CONCURRENT) + fun testScan(tc: TestCase) = runTest(tc) + + @ParameterizedTest + @MethodSource("pivotCases") + @Execution(ExecutionMode.CONCURRENT) + fun testPivot(tc: TestCase) = runTest(tc) + + // --------- Finish Parameterized Tests ------ + + // + // Testing Utility + // + private fun infer( + query: String, + session: PartiQLPlanner.Session, + problemCollector: ProblemCollector + ): PartiQLPlan { + val parser = PartiQLParser.default() + val planner = PartiQLPlannerBuilder() + .catalogs(*catalogs.toTypedArray()) + .build() + val ast = parser.parse(query).root + return planner.plan(ast, session, problemCollector).plan + } + + private fun runTest(tc: TestCase) = when (tc) { + is SuccessTestCase -> runTest(tc) + is ErrorTestCase -> runTest(tc) + is ThrowingExceptionTestCase -> runTest(tc) + } + + private fun runTest(tc: SuccessTestCase) { + val session = PartiQLPlanner.Session( + tc.query.hashCode().toString(), + USER_ID, + tc.catalog, + tc.catalogPath, + ) + + val hasQuery = tc.query != null + val hasKey = tc.key != null + if (hasQuery == hasKey) { + error("Test must have one of either `query` or `key`") } + val input = tc.query ?: testProvider[tc.key!!]!!.statement - class ErrorTestCase( - val name: String, - val key: PartiQLTest.Key? = null, - val query: String? = null, - val catalog: String? = null, - val catalogPath: List = emptyList(), - val note: String? = null, - val expected: StaticType? = null, - val problemHandler: ProblemHandler? = null, - ) : TestCase() { - override fun toString(): String = "$name : $query" + val collector = ProblemCollector() + val plan = infer(input, session, collector) + when (val statement = plan.statement) { + is Statement.Query -> { + assert(collector.problems.isEmpty()) { + buildString { + appendLine(collector.problems.toString()) + appendLine() + PlanPrinter.append(this, statement) + } + } + val actual = statement.root.type + assert(tc.expected == actual) { + buildString { + appendLine() + appendLine("Expect: ${tc.expected}") + appendLine("Actual: $actual") + appendLine() + PlanPrinter.append(this, statement) + } + } + } } + } - class ThrowingExceptionTestCase( - val name: String, - val query: String, - val catalog: String? = null, - val catalogPath: List = emptyList(), - val note: String? = null, - val expectedThrowable: KClass, - ) : TestCase() { - override fun toString(): String { - return "$name : $query" + private fun runTest(tc: ErrorTestCase) { + val session = PartiQLPlanner.Session( + tc.query.hashCode().toString(), + USER_ID, + tc.catalog, + tc.catalogPath, + ) + val collector = ProblemCollector() + + val hasQuery = tc.query != null + val hasKey = tc.key != null + if (hasQuery == hasKey) { + error("Test must have one of either `query` or `key`") + } + val input = tc.query ?: testProvider[tc.key!!]!!.statement + val plan = infer(input, session, collector) + + when (val statement = plan.statement) { + is Statement.Query -> { + assert(collector.problems.isNotEmpty()) { + buildString { + appendLine("Expected to find problems, but none were found.") + appendLine() + PlanPrinter.append(this, plan) + } + } + if (tc.expected != null) { + assert(tc.expected == statement.root.type) { + buildString { + appendLine() + appendLine("Expect: ${tc.expected}") + appendLine("Actual: ${statement.root.type}") + appendLine() + PlanPrinter.append(this, plan) + } + } + } + assert(collector.problems.isNotEmpty()) { + "Expected to find problems, but none were found." + } + tc.problemHandler?.handle(collector.problems, true) } } + } - class IgnoredTestCase( - val shouldBe: TestCase, - reason: String, - ) : TestCase() { - override fun toString(): String = "Disabled - $shouldBe" + private fun runTest(tc: ThrowingExceptionTestCase) { + val session = PartiQLPlanner.Session( + tc.query.hashCode().toString(), + USER_ID, + tc.catalog, + tc.catalogPath, + ) + val collector = ProblemCollector() + val exception = assertThrows { + infer(tc.query, session, collector) + Unit } + val cause = exception.cause + assertNotNull(cause) + assertEquals(tc.expectedThrowable, cause::class) } + // + // Additional Test + // class TestProvider : ArgumentsProvider { override fun provideArguments(context: ExtensionContext?): Stream { return parameters.map { Arguments.of(it) }.stream() @@ -3010,7 +3109,7 @@ class PartiQLSchemaInferencerTests { name = "Pets should not be accessible #1", query = "SELECT * FROM pets", expected = BagType( - unionOf( + StaticType.unionOf( StructType( fields = emptyMap(), contentClosed = false, @@ -3038,12 +3137,12 @@ class PartiQLSchemaInferencerTests { ) } ), - ErrorTestCase( + TestCase.ErrorTestCase( name = "Pets should not be accessible #2", catalog = CATALOG_AWS, query = "SELECT * FROM pets", expected = BagType( - unionOf( + StaticType.unionOf( StructType( fields = emptyMap(), contentClosed = false, @@ -3071,46 +3170,46 @@ class PartiQLSchemaInferencerTests { ) } ), - SuccessTestCase( + TestCase.SuccessTestCase( name = "Project all explicitly", catalog = CATALOG_AWS, catalogPath = listOf("ddb"), query = "SELECT * FROM pets", expected = TABLE_AWS_DDB_PETS ), - SuccessTestCase( + TestCase.SuccessTestCase( name = "Project all implicitly", catalog = CATALOG_AWS, catalogPath = listOf("ddb"), query = "SELECT id, breed FROM pets", expected = TABLE_AWS_DDB_PETS ), - SuccessTestCase( + TestCase.SuccessTestCase( name = "Test #4", catalog = CATALOG_B, catalogPath = listOf("b"), query = "b", expected = TYPE_B_B_B ), - SuccessTestCase( + TestCase.SuccessTestCase( name = "Test #5", catalog = CATALOG_AWS, catalogPath = listOf("ddb"), query = "SELECT * FROM b", expected = TABLE_AWS_DDB_B ), - SuccessTestCase( + TestCase.SuccessTestCase( name = "Test #6", catalog = CATALOG_AWS, catalogPath = listOf("b"), query = "SELECT * FROM b", expected = TABLE_AWS_B_B ), - ErrorTestCase( + TestCase.ErrorTestCase( name = "Test #7", query = "SELECT * FROM ddb.pets", expected = BagType( - unionOf( + StaticType.unionOf( StructType( fields = emptyMap(), contentClosed = false, @@ -3138,64 +3237,64 @@ class PartiQLSchemaInferencerTests { ) } ), - SuccessTestCase( + TestCase.SuccessTestCase( name = "Test #10", catalog = CATALOG_B, query = "b.b", expected = TYPE_B_B_B ), - SuccessTestCase( + TestCase.SuccessTestCase( name = "Test #11", catalog = CATALOG_B, catalogPath = listOf("b"), query = "b.b", expected = TYPE_B_B_B ), - SuccessTestCase( + TestCase.SuccessTestCase( name = "Test #12", catalog = CATALOG_AWS, catalogPath = listOf("ddb"), query = "SELECT * FROM b.b", expected = TABLE_AWS_B_B ), - SuccessTestCase( + TestCase.SuccessTestCase( name = "Test #13", catalog = CATALOG_AWS, catalogPath = listOf("ddb"), query = "SELECT * FROM ddb.b", expected = TABLE_AWS_DDB_B ), - SuccessTestCase( + TestCase.SuccessTestCase( name = "Test #14", query = "SELECT * FROM aws.ddb.pets", expected = TABLE_AWS_DDB_PETS ), - SuccessTestCase( + TestCase.SuccessTestCase( name = "Test #15", catalog = CATALOG_AWS, query = "SELECT * FROM aws.b.b", expected = TABLE_AWS_B_B ), - SuccessTestCase( + TestCase.SuccessTestCase( name = "Test #16", catalog = CATALOG_B, query = "b.b.b", expected = TYPE_B_B_B ), - SuccessTestCase( + TestCase.SuccessTestCase( name = "Test #17", catalog = CATALOG_B, query = "b.b.c", expected = TYPE_B_B_C ), - SuccessTestCase( + TestCase.SuccessTestCase( name = "Test #18", catalog = CATALOG_B, catalogPath = listOf("b"), query = "b.b.b", expected = TYPE_B_B_B ), - SuccessTestCase( + TestCase.SuccessTestCase( name = "Test #19", query = "b.b.b.c", expected = TYPE_B_B_B_C @@ -3289,13 +3388,13 @@ class PartiQLSchemaInferencerTests { catalog = CATALOG_DB, catalogPath = DB_SCHEMA_MARKETS, query = "order_info.customer_id IN 'hello'", - expected = MISSING, + expected = StaticType.MISSING, problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, PlanningProblemDetails.UnknownFunction( "in_collection", - listOf(INT4, STRING), + listOf(StaticType.INT4, StaticType.STRING), ) ) } @@ -3312,16 +3411,16 @@ class PartiQLSchemaInferencerTests { catalog = CATALOG_DB, catalogPath = DB_SCHEMA_MARKETS, query = "order_info.customer_id BETWEEN 1 AND 'a'", - expected = MISSING, + expected = StaticType.MISSING, problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, PlanningProblemDetails.UnknownFunction( "between", listOf( - INT4, - INT4, - STRING + StaticType.INT4, + StaticType.INT4, + StaticType.STRING ), ) ) @@ -3339,13 +3438,13 @@ class PartiQLSchemaInferencerTests { catalog = CATALOG_DB, catalogPath = DB_SCHEMA_MARKETS, query = "order_info.ship_option LIKE 3", - expected = MISSING, + expected = StaticType.MISSING, problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, PlanningProblemDetails.UnknownFunction( "like", - listOf(STRING, INT4), + listOf(StaticType.STRING, StaticType.INT4), ) ) } @@ -3363,7 +3462,7 @@ class PartiQLSchemaInferencerTests { catalog = CATALOG_DB, catalogPath = DB_SCHEMA_MARKETS, query = "order_info.\"CUSTOMER_ID\" = 1", - expected = NULL + expected = StaticType.NULL ), SuccessTestCase( name = "Case Sensitive success", @@ -3377,14 +3476,14 @@ class PartiQLSchemaInferencerTests { catalog = CATALOG_DB, catalogPath = DB_SCHEMA_MARKETS, query = "(order_info.customer_id = 1) AND (order_info.marketplace_id = 2)", - expected = StaticType.unionOf(BOOL, NULL) + expected = StaticType.unionOf(StaticType.BOOL, StaticType.NULL) ), SuccessTestCase( name = "2-Level Junction", catalog = CATALOG_DB, catalogPath = DB_SCHEMA_MARKETS, query = "(order_info.customer_id = 1) AND (order_info.marketplace_id = 2) OR (order_info.customer_id = 3) AND (order_info.marketplace_id = 4)", - expected = StaticType.unionOf(BOOL, NULL) + expected = StaticType.unionOf(StaticType.BOOL, StaticType.NULL) ), SuccessTestCase( name = "INT and STR Comparison", @@ -3400,7 +3499,7 @@ class PartiQLSchemaInferencerTests { query = "non_existing_column = 1", // Function resolves to EQ__ANY_ANY__BOOL // Which can return BOOL Or NULL - expected = StaticType.unionOf(BOOL, NULL), + expected = StaticType.unionOf(StaticType.BOOL, StaticType.NULL), problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, @@ -3413,13 +3512,13 @@ class PartiQLSchemaInferencerTests { catalog = CATALOG_DB, catalogPath = DB_SCHEMA_MARKETS, query = "order_info.customer_id = 1 AND 1", - expected = MISSING, + expected = StaticType.MISSING, problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, PlanningProblemDetails.UnknownFunction( "and", - listOf(StaticType.BOOL, INT4), + listOf(StaticType.BOOL, StaticType.INT4), ) ) } @@ -3429,13 +3528,13 @@ class PartiQLSchemaInferencerTests { catalog = CATALOG_DB, catalogPath = DB_SCHEMA_MARKETS, query = "1 AND order_info.customer_id = 1", - expected = MISSING, + expected = StaticType.MISSING, problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, PlanningProblemDetails.UnknownFunction( "and", - listOf(INT4, StaticType.BOOL), + listOf(StaticType.INT4, StaticType.BOOL), ) ) } @@ -3479,7 +3578,7 @@ class PartiQLSchemaInferencerTests { problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnexpectedType(STRING, setOf(INT)) + PlanningProblemDetails.UnexpectedType(StaticType.STRING, setOf(StaticType.INT)) ) } ), @@ -3499,7 +3598,7 @@ class PartiQLSchemaInferencerTests { problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, - PlanningProblemDetails.UnexpectedType(STRING, setOf(INT)) + PlanningProblemDetails.UnexpectedType(StaticType.STRING, setOf(StaticType.INT)) ) } ), @@ -3510,7 +3609,7 @@ class PartiQLSchemaInferencerTests { query = "SELECT CAST(breed AS INT) AS cast_breed FROM pets", expected = BagType( StructType( - fields = mapOf("cast_breed" to unionOf(INT, MISSING)), + fields = mapOf("cast_breed" to StaticType.unionOf(StaticType.INT, StaticType.MISSING)), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), @@ -3527,7 +3626,7 @@ class PartiQLSchemaInferencerTests { query = "SELECT UPPER(breed) AS upper_breed FROM pets", expected = BagType( StructType( - fields = mapOf("upper_breed" to STRING), + fields = mapOf("upper_breed" to StaticType.STRING), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), @@ -3542,7 +3641,7 @@ class PartiQLSchemaInferencerTests { query = "SELECT a FROM << [ 1, 1.0 ] >> AS a", expected = BagType( StructType( - fields = mapOf("a" to ListType(unionOf(INT4, StaticType.DECIMAL))), + fields = mapOf("a" to ListType(StaticType.unionOf(StaticType.INT4, StaticType.DECIMAL))), contentClosed = true, constraints = setOf( TupleConstraint.Open(false), @@ -3556,13 +3655,13 @@ class PartiQLSchemaInferencerTests { name = "Non-tuples in SELECT VALUE", query = "SELECT VALUE a FROM << [ 1, 1.0 ] >> AS a", expected = - BagType(ListType(unionOf(INT4, StaticType.DECIMAL))) + BagType(ListType(StaticType.unionOf(StaticType.INT4, StaticType.DECIMAL))) ), SuccessTestCase( name = "SELECT VALUE", query = "SELECT VALUE [1, 1.0] FROM <<>>", expected = - BagType(ListType(unionOf(INT4, StaticType.DECIMAL))) + BagType(ListType(StaticType.unionOf(StaticType.INT4, StaticType.DECIMAL))) ), SuccessTestCase( name = "Duplicate fields in struct", @@ -3575,7 +3674,7 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = listOf( - StructType.Field("a", unionOf(INT4, STRING)) + StructType.Field("a", StaticType.unionOf(StaticType.INT4, StaticType.STRING)) ), contentClosed = true, constraints = setOf( @@ -3595,7 +3694,7 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = listOf( - StructType.Field("e", INT4) + StructType.Field("e", StaticType.INT4) ), contentClosed = true, constraints = setOf( @@ -3617,7 +3716,7 @@ class PartiQLSchemaInferencerTests { expected = BagType( StructType( fields = listOf( - StructType.Field("a", unionOf(INT4, STRING)) + StructType.Field("a", StaticType.unionOf(StaticType.INT4, StaticType.STRING)) ), contentClosed = true, constraints = setOf( @@ -3631,175 +3730,56 @@ class PartiQLSchemaInferencerTests { SuccessTestCase( name = "Current User", query = "CURRENT_USER", - expected = unionOf(STRING, NULL) + expected = StaticType.unionOf(StaticType.STRING, StaticType.NULL) ), SuccessTestCase( name = "Trim", query = "trim(' ')", - expected = STRING + expected = StaticType.STRING ), SuccessTestCase( name = "Current User Concat", query = "CURRENT_USER || 'hello'", - expected = unionOf(STRING, NULL) + expected = StaticType.unionOf(StaticType.STRING, StaticType.NULL) ), SuccessTestCase( name = "Current User Concat in WHERE", query = "SELECT VALUE a FROM [ 0 ] AS a WHERE CURRENT_USER = 'hello'", - expected = BagType(INT4) + expected = BagType(StaticType.INT4) ), SuccessTestCase( name = "TRIM_2", query = "trim(' ' FROM ' Hello, World! ')", - expected = STRING + expected = StaticType.STRING ), SuccessTestCase( name = "TRIM_1", query = "trim(' Hello, World! ')", - expected = STRING + expected = StaticType.STRING ), SuccessTestCase( name = "TRIM_3", query = "trim(LEADING ' ' FROM ' Hello, World! ')", - expected = STRING + expected = StaticType.STRING ), ErrorTestCase( name = "TRIM_2_error", query = "trim(2 FROM ' Hello, World! ')", - expected = MISSING, + expected = StaticType.MISSING, problemHandler = assertProblemExists { Problem( UNKNOWN_PROBLEM_LOCATION, PlanningProblemDetails.UnknownFunction( "trim_chars", - args = listOf(STRING, INT4) + args = listOf(StaticType.STRING, StaticType.INT4) ) ) } ), ) } +} - private fun runTest(tc: TestCase) = when (tc) { - is SuccessTestCase -> runTest(tc) - is ErrorTestCase -> runTest(tc) - is ThrowingExceptionTestCase -> runTest(tc) - is TestCase.IgnoredTestCase -> runTest(tc) - } - - @OptIn(ExperimentalPartiQLSchemaInferencer::class) - private fun runTest(tc: ThrowingExceptionTestCase) { - val session = PartiQLPlanner.Session( - tc.query.hashCode().toString(), - USER_ID, - tc.catalog, - tc.catalogPath, - catalogConfig, - Instant.now() - ) - val collector = ProblemCollector() - val ctx = PartiQLSchemaInferencer.Context(session, PLUGINS, collector) - val exception = assertThrows { - PartiQLSchemaInferencer.infer(tc.query, ctx) - Unit - } - val cause = exception.cause - assertNotNull(cause) - assertEquals(tc.expectedThrowable, cause::class) - } - - @OptIn(ExperimentalPartiQLSchemaInferencer::class) - private fun runTest(tc: SuccessTestCase) { - val session = PartiQLPlanner.Session( - tc.query.hashCode().toString(), - USER_ID, - tc.catalog, - tc.catalogPath, - catalogConfig, - Instant.now() - ) - val collector = ProblemCollector() - val ctx = PartiQLSchemaInferencer.Context(session, PLUGINS, collector) - - val hasQuery = tc.query != null - val hasKey = tc.key != null - if (hasQuery == hasKey) { - error("Test must have one of either `query` or `key`") - } - val input = tc.query ?: testProvider[tc.key!!]!!.statement - - val result = PartiQLSchemaInferencer.inferInternal(input, ctx) - assert(collector.problems.isEmpty()) { - buildString { - appendLine(collector.problems.toString()) - appendLine() - PlanPrinter.append(this, result.first) - } - } - val actual = result.second - assert(tc.expected == actual) { - buildString { - appendLine() - appendLine("Expect: ${tc.expected}") - appendLine("Actual: $actual") - appendLine() - PlanPrinter.append(this, result.first) - } - } - } - - @OptIn(ExperimentalPartiQLSchemaInferencer::class) - private fun runTest(tc: ErrorTestCase) { - val session = PartiQLPlanner.Session( - tc.query.hashCode().toString(), - USER_ID, - tc.catalog, - tc.catalogPath, - catalogConfig, - Instant.now() - ) - val collector = ProblemCollector() - val ctx = PartiQLSchemaInferencer.Context(session, PLUGINS, collector) - - val hasQuery = tc.query != null - val hasKey = tc.key != null - if (hasQuery == hasKey) { - error("Test must have one of either `query` or `key`") - } - val input = tc.query ?: testProvider[tc.key!!]!!.statement - val result = PartiQLSchemaInferencer.inferInternal(input, ctx) - - assert(collector.problems.isNotEmpty()) { - buildString { - appendLine("Expected to find problems, but none were found.") - appendLine() - PlanPrinter.append(this, result.first) - } - } - if (tc.expected != null) { - assert(tc.expected == result.second) { - buildString { - appendLine() - appendLine("Expect: ${tc.expected}") - appendLine("Actual: ${result.second}") - appendLine() - PlanPrinter.append(this, result.first) - } - } - } - assert(collector.problems.isNotEmpty()) { - "Expected to find problems, but none were found." - } - tc.problemHandler?.handle(collector.problems, true) - } - - private fun runTest(tc: TestCase.IgnoredTestCase) { - assertThrows { - runTest(tc.shouldBe) - } - } - - fun interface ProblemHandler { - fun handle(problems: List, ignoreSourceLocation: Boolean) - } +fun interface ProblemHandler { + fun handle(problems: List, ignoreSourceLocation: Boolean) } diff --git a/partiql-planner/src/test/kotlin/org/partiql/planner/util/ProblemCollector.kt b/partiql-planner/src/test/kotlin/org/partiql/planner/util/ProblemCollector.kt new file mode 100644 index 0000000000..23069adc35 --- /dev/null +++ b/partiql-planner/src/test/kotlin/org/partiql/planner/util/ProblemCollector.kt @@ -0,0 +1,30 @@ +package org.partiql.planner.util + +import org.partiql.errors.Problem +import org.partiql.errors.ProblemCallback +import org.partiql.errors.ProblemHandler +import org.partiql.errors.ProblemSeverity + +/** + * A [ProblemHandler] that collects all the encountered [Problem]s without throwing. + * + * This is intended to be used when wanting to collect multiple problems that may be encountered (e.g. a static type + * inference pass that can result in multiple errors and/or warnings). This handler does not collect other exceptions + * that may be thrown. + */ +internal class ProblemCollector : ProblemCallback { + private val problemList = mutableListOf() + + val problems: List + get() = problemList + + val hasErrors: Boolean + get() = problemList.any { it.details.severity == ProblemSeverity.ERROR } + + val hasWarnings: Boolean + get() = problemList.any { it.details.severity == ProblemSeverity.WARNING } + + override fun invoke(problem: Problem) { + problemList.add(problem) + } +} diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/Plugin.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/Plugin.kt index 007bb8a66a..4df4c5dda9 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/Plugin.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/Plugin.kt @@ -22,11 +22,15 @@ import org.partiql.spi.function.PartiQLFunctionExperimental * A singular unit of external logic. */ public interface Plugin { - public fun getConnectorFactories(): List /** - * Represents custom built-in functions to be accessed during execution. - **/ - @OptIn(PartiQLFunctionExperimental::class) - public fun getFunctions(): List + * A [Connector.Factory] is used to instantiate a connector. + */ + public val factory: Connector.Factory + + /** + * Functions defined by this plugin. + */ + @PartiQLFunctionExperimental + public val functions: List } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/Connector.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/Connector.kt index 56914ebe8d..01ac354974 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/connector/Connector.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/connector/Connector.kt @@ -20,10 +20,33 @@ import com.amazon.ionelement.api.StructElement * A mechanism by which PartiQL can access a Catalog. */ public interface Connector { + + /** + * Returns a [ConnectorMetadata] for the given [ConnectorSession]. The [ConnectorMetadata] is responsible + * for accessing catalog metadata. + * + * @param session + * @return + */ public fun getMetadata(session: ConnectorSession): ConnectorMetadata + /** + * A Plugin leverages a [Factory] to produce a [Connector] which is used for catalog metadata and data access. + */ public interface Factory { - public fun getName(): String - public fun create(catalogName: String, config: StructElement): Connector + + /** + * The connector name used to register the factory. + */ + public val name: String + + /** + * The connector factory method. + * + * @param catalogName + * @param config + * @return + */ + public fun create(catalogName: String, config: StructElement? = null): Connector } } diff --git a/partiql-spi/src/main/kotlin/org/partiql/spi/function/PartiQLFunction.kt b/partiql-spi/src/main/kotlin/org/partiql/spi/function/PartiQLFunction.kt index 52e79a13d8..b95f460613 100644 --- a/partiql-spi/src/main/kotlin/org/partiql/spi/function/PartiQLFunction.kt +++ b/partiql-spi/src/main/kotlin/org/partiql/spi/function/PartiQLFunction.kt @@ -8,7 +8,7 @@ import org.partiql.value.PartiQLValueExperimental * The [PartiQLFunction] interface is used to implement user-defined-functions (UDFs). * UDFs can be registered to a plugin for use in the query planner and evaluator. */ -@OptIn(PartiQLValueExperimental::class) +@PartiQLFunctionExperimental public sealed interface PartiQLFunction { /** @@ -32,6 +32,7 @@ public sealed interface PartiQLFunction { * @param args * @return */ + @OptIn(PartiQLValueExperimental::class) public fun invoke(args: Array): PartiQLValue } @@ -45,12 +46,23 @@ public sealed interface PartiQLFunction { */ override val signature: FunctionSignature.Aggregation + /** + * Instantiates an accumulator for this aggregation function. + * + * @return + */ + public fun accumulator(): Accumulator + } + + public interface Accumulator { + /** * Apply args to the accumulator. * * @param args * @return */ + @OptIn(PartiQLValueExperimental::class) public fun next(args: Array): PartiQLValue /** @@ -58,6 +70,7 @@ public sealed interface PartiQLFunction { * * @return */ + @OptIn(PartiQLValueExperimental::class) public fun value(): PartiQLValue } } diff --git a/partiql-types/src/main/kotlin/org/partiql/value/PartiQL.kt b/partiql-types/src/main/kotlin/org/partiql/value/PartiQL.kt index adc7af49fe..2f951db0ff 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/PartiQL.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/PartiQL.kt @@ -38,12 +38,12 @@ import org.partiql.value.impl.Int64ValueImpl import org.partiql.value.impl.Int8ValueImpl import org.partiql.value.impl.IntValueImpl import org.partiql.value.impl.IntervalValueImpl +import org.partiql.value.impl.IterableStructValueImpl import org.partiql.value.impl.ListValueImpl import org.partiql.value.impl.MapStructValueImpl import org.partiql.value.impl.MissingValueImpl import org.partiql.value.impl.MultiMapStructValueImpl import org.partiql.value.impl.NullValueImpl -import org.partiql.value.impl.SequenceStructValueImpl import org.partiql.value.impl.SexpValueImpl import org.partiql.value.impl.StringValueImpl import org.partiql.value.impl.SymbolValueImpl @@ -344,10 +344,25 @@ public fun intervalValue( @JvmOverloads @PartiQLValueExperimental public fun bagValue( - elements: Sequence?, + elements: Iterable?, annotations: Annotations = emptyList(), ): BagValue = BagValueImpl(elements, annotations.toPersistentList()) +/** + * BAG type value. + * + * @param T + * @param elements + * @param annotations + * @return + */ +@JvmOverloads +@PartiQLValueExperimental +public fun bagValue( + vararg elements: T, + annotations: Annotations = emptyList(), +): BagValue = BagValueImpl(elements.asIterable(), annotations.toPersistentList()) + /** * LIST type value. * @@ -359,10 +374,25 @@ public fun bagValue( @JvmOverloads @PartiQLValueExperimental public fun listValue( - elements: Sequence?, + elements: Iterable?, annotations: Annotations = emptyList(), ): ListValue = ListValueImpl(elements, annotations.toPersistentList()) +/** + * LIST type value. + * + * @param T + * @param elements + * @param annotations + * @return + */ +@JvmOverloads +@PartiQLValueExperimental +public fun listValue( + vararg elements: T, + annotations: Annotations = emptyList(), +): ListValue = ListValueImpl(elements.asIterable(), annotations.toPersistentList()) + /** * SEXP type value. * @@ -374,12 +404,42 @@ public fun listValue( @JvmOverloads @PartiQLValueExperimental public fun sexpValue( - elements: Sequence?, + elements: Iterable?, annotations: Annotations = emptyList(), ): SexpValue = SexpValueImpl(elements, annotations.toPersistentList()) /** - * STRUCT type value. + * SEXP type value. + * + * @param T + * @param elements + * @param annotations + * @return + */ +@JvmOverloads +@PartiQLValueExperimental +public fun sexpValue( + vararg elements: T, + annotations: Annotations = emptyList(), +): SexpValue = SexpValueImpl(elements.asIterable(), annotations.toPersistentList()) + +/** + * Create a PartiQL struct value backed by an iterable of key-value field pairs. + * + * @param T + * @param fields + * @param annotations + * @return + */ +@JvmOverloads +@PartiQLValueExperimental +public fun structValue( + fields: Iterable>?, + annotations: Annotations = emptyList(), +): StructValue = IterableStructValueImpl(fields, annotations.toPersistentList()) + +/** + * Create a PartiQL struct value backed by an iterable of key-value field pairs. * * @param T * @param fields @@ -389,12 +449,13 @@ public fun sexpValue( @JvmOverloads @PartiQLValueExperimental public fun structValue( - fields: Sequence>?, + vararg fields: Pair, annotations: Annotations = emptyList(), -): StructValue = SequenceStructValueImpl(fields, annotations.toPersistentList()) +): StructValue = IterableStructValueImpl(fields.toList(), annotations.toPersistentList()) /** - * STRUCT type value. + * Create a PartiQL struct value backed by a multimap of keys with a list of values. This supports having multiple + * values per key, while improving lookup performance compared to using an iterable. * * @param T * @param fields @@ -403,13 +464,14 @@ public fun structValue( */ @JvmOverloads @PartiQLValueExperimental -public fun structValueWithDuplicates( +public fun structValueMultiMap( fields: Map>?, annotations: Annotations = emptyList(), ): StructValue = MultiMapStructValueImpl(fields, annotations.toPersistentList()) /** - * STRUCT type value. + * Create a PartiQL struct value backed by a map of keys with a list of values. This does not support having multiple + * values per key, but uses a Java HashMap for quicker lookup than an iterable backed StructValue. * * @param T * @param fields @@ -418,7 +480,7 @@ public fun structValueWithDuplicates( */ @JvmOverloads @PartiQLValueExperimental -public fun structValueNoDuplicates( +public fun structValueMap( fields: Map?, annotations: Annotations = emptyList(), ): StructValue = MapStructValueImpl(fields, annotations.toPersistentList()) diff --git a/partiql-types/src/main/kotlin/org/partiql/value/PartiQLValue.kt b/partiql-types/src/main/kotlin/org/partiql/value/PartiQLValue.kt index 5c0eb7a207..bca12407cb 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/PartiQLValue.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/PartiQLValue.kt @@ -66,14 +66,11 @@ public sealed interface ScalarValue : PartiQLValue { } @PartiQLValueExperimental -public sealed interface CollectionValue : PartiQLValue, Sequence { - - public val elements: Sequence? +public sealed interface CollectionValue : PartiQLValue, Iterable { override val isNull: Boolean - get() = elements == null - override fun iterator(): Iterator = elements!!.iterator() + override fun iterator(): Iterator override fun copy(annotations: Annotations): CollectionValue @@ -389,8 +386,8 @@ public abstract class BagValue : CollectionValue { if (this.isNull || other.isNull) return this.isNull == other.isNull // both not null, compare values - val lhs = this.elements!!.groupingBy { it }.eachCount() - val rhs = other.elements!!.groupingBy { it }.eachCount() + val lhs = this.toList() + val rhs = other.toList() // this is incorrect as it assumes ordered-ness, but we don't have a sort or hash yet return lhs == rhs } @@ -422,8 +419,8 @@ public abstract class ListValue : CollectionValue { if (this.isNull || other.isNull) return this.isNull == other.isNull // both not null, compare values - val lhs = this.elements!!.toList() - val rhs = other.elements!!.toList() + val lhs = this.toList() + val rhs = other.toList() return lhs == rhs } @@ -453,8 +450,8 @@ public abstract class SexpValue : CollectionValue { if (this.isNull || other.isNull) return this.isNull == other.isNull // both not null, compare values - val lhs = this.elements!!.toList() - val rhs = other.elements!!.toList() + val lhs = this.toList() + val rhs = other.toList() return lhs == rhs } @@ -465,19 +462,15 @@ public abstract class SexpValue : CollectionValue { } @PartiQLValueExperimental -public abstract class StructValue : PartiQLValue, Sequence> { +public abstract class StructValue : PartiQLValue { override val type: PartiQLValueType = PartiQLValueType.STRUCT - public abstract val fields: Sequence>? - - // TODO: This is a temporary solution to not exhaust the underlying fields upon evaluation - private lateinit var _fields: List> + public abstract val fields: Iterable - override val isNull: Boolean - get() = fields == null + public abstract val values: Iterable - override fun iterator(): Iterator> = getFields()!!.iterator() + public abstract val entries: Iterable> public abstract operator fun get(key: String): T? @@ -490,9 +483,7 @@ public abstract class StructValue : PartiQLValue, Sequence /** - * See equality of IonElement StructElementImpl - * - * https://github.com/amazon-ion/ion-element-kotlin/blob/master/src/com/amazon/ionelement/impl/StructElementImpl.kt + * Checks equality of struct entries, ignoring ordering. * * @param other * @return @@ -506,15 +497,15 @@ public abstract class StructValue : PartiQLValue, Sequence + lhs.entries.forEach { (key, values) -> val lGroup: Map = values.groupingBy { it }.eachCount() val rGroup: Map = rhs[key]!!.groupingBy { it }.eachCount() if (lGroup != rGroup) return false @@ -524,17 +515,14 @@ public abstract class StructValue : PartiQLValue, Sequence>? { - if (fields == null) { - return null - } - if (this::_fields.isInitialized.not()) { - _fields = fields?.toList() ?: emptyList() + override fun toString(): String { + if (isNull) { + return "null" } - return _fields + return super.toString() } } diff --git a/partiql-types/src/main/kotlin/org/partiql/value/helpers/ToIon.kt b/partiql-types/src/main/kotlin/org/partiql/value/helpers/ToIon.kt index d682b716ee..e46330a1e4 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/helpers/ToIon.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/helpers/ToIon.kt @@ -258,31 +258,31 @@ internal object ToIon : PartiQLValueBaseVisitor() { override fun visitInterval(v: IntervalValue, ctx: Unit): IonElement = TODO("Not Yet supported") override fun visitBag(v: BagValue<*>, ctx: Unit): IonElement = v.annotate { - when (val elements = v.elements) { - null -> ionNull(ElementType.LIST) - else -> ionListOf(elements.map { it.accept(ToIon, Unit) }.toList()) + when (v.isNull) { + true -> ionNull(ElementType.LIST) + else -> ionListOf(v.map { it.accept(ToIon, Unit) }.toList()) } }.withAnnotations(BAG_ANNOTATION) override fun visitList(v: ListValue<*>, ctx: Unit): IonElement = v.annotate { - when (val elements = v.elements) { - null -> ionNull(ElementType.LIST) - else -> ionListOf(elements.map { it.accept(ToIon, Unit) }.toList()) + when (v.isNull) { + true -> ionNull(ElementType.LIST) + else -> ionListOf(v.map { it.accept(ToIon, Unit) }.toList()) } } override fun visitSexp(v: SexpValue<*>, ctx: Unit): IonElement = v.annotate { - when (val elements = v.elements) { - null -> ionNull(ElementType.SEXP) - else -> ionSexpOf(elements.map { it.accept(ToIon, Unit) }.toList()) + when (v.isNull) { + true -> ionNull(ElementType.SEXP) + else -> ionSexpOf(v.map { it.accept(ToIon, Unit) }.toList()) } } override fun visitStruct(v: StructValue<*>, ctx: Unit): IonElement = v.annotate { - when (val fields = v.fields) { - null -> ionNull(ElementType.STRUCT) + when (v.isNull) { + true -> ionNull(ElementType.STRUCT) else -> { - val ionFields = fields.map { + val ionFields = entries.map { val fk = it.first val fv = it.second.accept(ToIon, ctx) field(fk, fv) diff --git a/partiql-types/src/main/kotlin/org/partiql/value/impl/BagValueImpl.kt b/partiql-types/src/main/kotlin/org/partiql/value/impl/BagValueImpl.kt index 98c8fbcd33..02c2d61e97 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/impl/BagValueImpl.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/impl/BagValueImpl.kt @@ -24,11 +24,13 @@ import org.partiql.value.util.PartiQLValueVisitor @OptIn(PartiQLValueExperimental::class) internal class BagValueImpl( - private val delegate: Sequence?, + private val delegate: Iterable?, override val annotations: PersistentList, ) : BagValue() { - override val elements: Sequence? = delegate + override val isNull: Boolean = delegate == null + + override fun iterator(): Iterator = delegate!!.iterator() override fun copy(annotations: Annotations) = BagValueImpl(delegate, annotations.toPersistentList()) diff --git a/partiql-types/src/main/kotlin/org/partiql/value/impl/ListValueImpl.kt b/partiql-types/src/main/kotlin/org/partiql/value/impl/ListValueImpl.kt index d7780b50b2..304dadae3c 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/impl/ListValueImpl.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/impl/ListValueImpl.kt @@ -24,11 +24,13 @@ import org.partiql.value.util.PartiQLValueVisitor @OptIn(PartiQLValueExperimental::class) internal class ListValueImpl( - private val delegate: Sequence?, + private val delegate: Iterable?, override val annotations: PersistentList, ) : ListValue() { - override val elements: Sequence? = delegate + override val isNull: Boolean = delegate == null + + override fun iterator(): Iterator = delegate!!.iterator() override fun copy(annotations: Annotations) = ListValueImpl(delegate, annotations.toPersistentList()) diff --git a/partiql-types/src/main/kotlin/org/partiql/value/impl/SexpValueImpl.kt b/partiql-types/src/main/kotlin/org/partiql/value/impl/SexpValueImpl.kt index 8e6838e263..c477dca3fd 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/impl/SexpValueImpl.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/impl/SexpValueImpl.kt @@ -24,11 +24,13 @@ import org.partiql.value.util.PartiQLValueVisitor @OptIn(PartiQLValueExperimental::class) internal class SexpValueImpl( - private val delegate: Sequence?, + private val delegate: Iterable?, override val annotations: PersistentList, ) : SexpValue() { - override val elements: Sequence? = delegate + override val isNull: Boolean = delegate == null + + override fun iterator(): Iterator = delegate!!.iterator() override fun copy(annotations: Annotations) = SexpValueImpl(delegate, annotations.toPersistentList()) diff --git a/partiql-types/src/main/kotlin/org/partiql/value/impl/StructValueImpl.kt b/partiql-types/src/main/kotlin/org/partiql/value/impl/StructValueImpl.kt index a59480c7dd..a70623e20a 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/impl/StructValueImpl.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/impl/StructValueImpl.kt @@ -23,19 +23,28 @@ import org.partiql.value.StructValue import org.partiql.value.util.PartiQLValueVisitor /** - * Implementation of a [StructValue] backed by a Sequence. + * Implementation of a [StructValue] backed by an iterator. * * @param T * @property delegate * @property annotations */ @OptIn(PartiQLValueExperimental::class) -internal class SequenceStructValueImpl( - private val delegate: Sequence>?, +internal class IterableStructValueImpl( + private val delegate: Iterable>?, override val annotations: PersistentList, ) : StructValue() { - override val fields: Sequence>? = delegate + override val isNull: Boolean = delegate == null + + override val fields: Iterable + get() = delegate!!.map { it.first } + + override val values: Iterable + get() = delegate!!.map { it.second } + + override val entries: Iterable> + get() = delegate!! override operator fun get(key: String): T? { if (delegate == null) { @@ -51,7 +60,7 @@ internal class SequenceStructValueImpl( return delegate.filter { it.first == key }.map { it.second }.asIterable() } - override fun copy(annotations: Annotations) = SequenceStructValueImpl(delegate, annotations.toPersistentList()) + override fun copy(annotations: Annotations) = IterableStructValueImpl(delegate, annotations.toPersistentList()) override fun withAnnotations(annotations: Annotations): StructValue = _withAnnotations(annotations) @@ -73,13 +82,14 @@ internal class MultiMapStructValueImpl( override val annotations: PersistentList, ) : StructValue() { - override val fields: Sequence>? - get() { - if (delegate == null) { - return null - } - return delegate.asSequence().map { f -> f.value.map { v -> f.key to v } }.flatten() - } + override val isNull: Boolean = delegate == null + + override val fields: Iterable = delegate!!.map { it.key } + + override val values: Iterable = delegate!!.flatMap { it.value } + + override val entries: Iterable> = + delegate!!.entries.map { f -> f.value.map { v -> f.key to v } }.flatten() override operator fun get(key: String): T? = getAll(key).firstOrNull() @@ -112,13 +122,13 @@ internal class MapStructValueImpl( override val annotations: PersistentList, ) : StructValue() { - override val fields: Sequence>? - get() { - if (delegate == null) { - return null - } - return delegate.asSequence().map { f -> f.key to f.value } - } + override val isNull: Boolean = delegate == null + + override val fields: Iterable = delegate!!.map { it.key } + + override val values: Iterable = delegate!!.map { it.value } + + override val entries: Iterable> = delegate!!.entries.map { it.key to it.value } override operator fun get(key: String): T? { if (delegate == null) { diff --git a/partiql-types/src/main/kotlin/org/partiql/value/io/PartiQLValueIonReader.kt b/partiql-types/src/main/kotlin/org/partiql/value/io/PartiQLValueIonReader.kt index 37cdf54c97..732738ffec 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/io/PartiQLValueIonReader.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/io/PartiQLValueIonReader.kt @@ -167,7 +167,7 @@ internal class PartiQLValueIonReader( } } reader.stepOut() - listValue(elements.asSequence(), annotations) + listValue(elements, annotations) } IonType.SEXP -> { @@ -179,7 +179,7 @@ internal class PartiQLValueIonReader( } } reader.stepOut() - sexpValue(elements.asSequence(), annotation) + sexpValue(elements, annotation) } IonType.STRUCT -> { @@ -192,7 +192,7 @@ internal class PartiQLValueIonReader( } } reader.stepOut() - structValue(elements.asSequence(), annotations) + structValue(elements, annotations) } IonType.DATAGRAM -> throw IllegalArgumentException("Datagram not supported") @@ -394,7 +394,7 @@ internal class PartiQLValueIonReader( } } reader.stepOut() - bagValue(elements.asSequence(), annotations.dropLast(1)) + bagValue(elements, annotations.dropLast(1)) } } PARTIQL_ANNOTATION.DATE_ANNOTATION -> throw IllegalArgumentException("DATE_ANNOTATION with List Value") @@ -412,7 +412,7 @@ internal class PartiQLValueIonReader( } } reader.stepOut() - listValue(elements.asSequence(), annotations) + listValue(elements, annotations) } } } @@ -437,7 +437,7 @@ internal class PartiQLValueIonReader( } } reader.stepOut() - sexpValue(elements.asSequence(), annotations) + sexpValue(elements, annotations) } } } @@ -525,8 +525,7 @@ internal class PartiQLValueIonReader( PARTIQL_ANNOTATION.GRAPH_ANNOTATION -> TODO("Not yet implemented") null -> { if (reader.isNullValue) { - val nullSequence: Sequence>? = null - structValue(nullSequence, annotations) + structValue(null, annotations) } else { reader.stepIn() val elements = mutableListOf>().also { elements -> @@ -536,7 +535,7 @@ internal class PartiQLValueIonReader( } } reader.stepOut() - structValue(elements.asSequence(), annotations) + structValue(elements, annotations) } } } diff --git a/partiql-types/src/main/kotlin/org/partiql/value/io/PartiQLValueTextWriter.kt b/partiql-types/src/main/kotlin/org/partiql/value/io/PartiQLValueTextWriter.kt index 75d59c06f0..e0d1ca63bb 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/io/PartiQLValueTextWriter.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/io/PartiQLValueTextWriter.kt @@ -199,14 +199,17 @@ public class PartiQLValueTextWriter( override fun visitSexp(v: SexpValue<*>, format: Format?) = collection(v, format, "(" to ")", " ") override fun visitStruct(v: StructValue<*>, format: Format?): String = buildString { + if (v.isNull) { + return "null" + } // null.struct - val fields = v.fields?.toList() ?: return "null" + val entries = v.entries.toList() // {} - if (fields.isEmpty() || format == null) { + if (entries.isEmpty() || format == null) { format?.let { append(it.prefix) } annotate(v, this) append("{") - val items = fields.map { + val items = entries.map { val fk = it.first val fv = it.second.accept(ToString, null) // it.toString() "$fk:$fv" @@ -219,10 +222,10 @@ public class PartiQLValueTextWriter( append(format.prefix) annotate(v, this) appendLine("{") - fields.forEachIndexed { i, e -> + entries.forEachIndexed { i, e -> val fk = e.first val fv = e.second.accept(ToString, format.nest()).trimStart() // e.toString(format) - val suffix = if (i == fields.size - 1) "" else "," + val suffix = if (i == entries.size - 1) "" else "," append(format.prefix + format.indent) append("$fk: $fv") appendLine(suffix) @@ -238,7 +241,10 @@ public class PartiQLValueTextWriter( separator: CharSequence = ",", ) = buildString { // null.bag, null.list, null.sexp - val elements = v.elements?.toList() ?: return "null" + if (v.isNull) { + return "null" + } + val elements = v.toList() // skip empty if (elements.isEmpty() || format == null) { format?.let { append(it.prefix) } diff --git a/partiql-types/src/main/kotlin/org/partiql/value/util/PartiQLValueBaseVisitor.kt b/partiql-types/src/main/kotlin/org/partiql/value/util/PartiQLValueBaseVisitor.kt index 9db4b742b1..41e4bdba19 100644 --- a/partiql-types/src/main/kotlin/org/partiql/value/util/PartiQLValueBaseVisitor.kt +++ b/partiql-types/src/main/kotlin/org/partiql/value/util/PartiQLValueBaseVisitor.kt @@ -53,10 +53,14 @@ public abstract class PartiQLValueBaseVisitor : PartiQLValueVisitor public open fun defaultVisit(v: PartiQLValue, ctx: C): R { when (v) { is CollectionValue<*> -> { - v.elements?.forEach { it?.accept(this, ctx) } + if (!v.isNull) { + v.forEach { it.accept(this, ctx) } + } } is StructValue<*> -> { - v.fields?.forEach { it.second.accept(this, ctx) } + if (!v.isNull) { + v.entries.forEach { it.second.accept(this, ctx) } + } } else -> {} } diff --git a/partiql-types/src/test/kotlin/org/partiql/value/io/PartiQLValueIonSerdeTest.kt b/partiql-types/src/test/kotlin/org/partiql/value/io/PartiQLValueIonSerdeTest.kt index 6213446663..a2a7f0dcc0 100644 --- a/partiql-types/src/test/kotlin/org/partiql/value/io/PartiQLValueIonSerdeTest.kt +++ b/partiql-types/src/test/kotlin/org/partiql/value/io/PartiQLValueIonSerdeTest.kt @@ -360,24 +360,22 @@ class PartiQLValueIonSerdeTest { @JvmStatic fun collections() = listOf( roundTrip( - bagValue(emptySequence()), + bagValue(), ION.newEmptyList().apply { addTypeAnnotation("\$bag") }, ), roundTrip( - listValue(emptySequence()), + listValue(), ION.newEmptyList() ), roundTrip( - sexpValue(emptySequence()), + sexpValue(), ION.newEmptySexp() ), oneWayTrip( bagValue( - sequenceOf( - int32Value(1), - int32Value(2), - int32Value(3), - ) + int32Value(1), + int32Value(2), + int32Value(3), ), ION.newList( ION.newInt(1), @@ -385,20 +383,16 @@ class PartiQLValueIonSerdeTest { ION.newInt(3) ).apply { addTypeAnnotation("\$bag") }, bagValue( - sequenceOf( - intValue(BigInteger.ONE), - intValue(BigInteger.valueOf(2L)), - intValue(BigInteger.valueOf(3L)), - ) + intValue(BigInteger.ONE), + intValue(BigInteger.valueOf(2L)), + intValue(BigInteger.valueOf(3L)), ) ), roundTrip( listValue( - sequenceOf( - stringValue("a"), - stringValue("b"), - stringValue("c"), - ) + stringValue("a"), + stringValue("b"), + stringValue("c"), ), ION.newList( ION.newString("a"), @@ -408,11 +402,9 @@ class PartiQLValueIonSerdeTest { ), oneWayTrip( sexpValue( - sequenceOf( - int32Value(1), - int32Value(2), - int32Value(3), - ) + int32Value(1), + int32Value(2), + int32Value(3), ), ION.newSexp( ION.newInt(1), @@ -420,19 +412,15 @@ class PartiQLValueIonSerdeTest { ION.newInt(3) ), sexpValue( - sequenceOf( - intValue(BigInteger.ONE), - intValue(BigInteger.valueOf(2L)), - intValue(BigInteger.valueOf(3L)), - ) + intValue(BigInteger.ONE), + intValue(BigInteger.valueOf(2L)), + intValue(BigInteger.valueOf(3L)), ) ), oneWayTrip( structValue( - sequenceOf( - "a" to int32Value(1), - "b" to stringValue("x"), - ) + "a" to int32Value(1), + "b" to stringValue("x"), ), ION.newEmptyStruct() .apply { @@ -440,10 +428,8 @@ class PartiQLValueIonSerdeTest { add("b", ION.newString("x")) }, structValue( - sequenceOf( - "a" to intValue(BigInteger.ONE), - "b" to stringValue("x"), - ) + "a" to intValue(BigInteger.ONE), + "b" to stringValue("x"), ), ) ) diff --git a/partiql-types/src/test/kotlin/org/partiql/value/io/PartiQLValueTextWriterTest.kt b/partiql-types/src/test/kotlin/org/partiql/value/io/PartiQLValueTextWriterTest.kt index bf4531e6f8..309d832a64 100644 --- a/partiql-types/src/test/kotlin/org/partiql/value/io/PartiQLValueTextWriterTest.kt +++ b/partiql-types/src/test/kotlin/org/partiql/value/io/PartiQLValueTextWriterTest.kt @@ -322,44 +322,38 @@ class PartiQLValueTextWriterTest { @JvmStatic fun collections() = listOf( case( - value = bagValue(emptySequence()), + value = bagValue(), expected = "<<>>", ), case( - value = listValue(emptySequence()), + value = listValue(), expected = "[]", ), case( - value = sexpValue(emptySequence()), + value = sexpValue(), expected = "()", ), case( value = bagValue( - sequenceOf( - int32Value(1), - int32Value(2), - int32Value(3), - ) + int32Value(1), + int32Value(2), + int32Value(3), ), expected = "<<1,2,3>>", ), case( value = listValue( - sequenceOf( - stringValue("a"), - stringValue("b"), - stringValue("c"), - ) + stringValue("a"), + stringValue("b"), + stringValue("c"), ), expected = "['a','b','c']", ), case( value = sexpValue( - sequenceOf( - int32Value(1), - int32Value(2), - int32Value(3), - ) + int32Value(1), + int32Value(2), + int32Value(3), ), expected = "(1 2 3)", ), @@ -385,15 +379,13 @@ class PartiQLValueTextWriterTest { @JvmStatic fun struct() = listOf( case( - value = structValue(emptySequence()), + value = structValue(), expected = "{}", ), case( value = structValue( - sequenceOf( - "a" to int32Value(1), - "b" to stringValue("x"), - ) + "a" to int32Value(1), + "b" to stringValue("x"), ), expected = "{a:1,b:'x'}", ), @@ -403,11 +395,9 @@ class PartiQLValueTextWriterTest { fun collectionsFormatted() = listOf( formatted( value = bagValue( - sequenceOf( - int32Value(1), - int32Value(2), - int32Value(3), - ) + int32Value(1), + int32Value(2), + int32Value(3), ), expected = """ |<< @@ -419,11 +409,9 @@ class PartiQLValueTextWriterTest { ), formatted( value = listValue( - sequenceOf( - stringValue("a"), - stringValue("b"), - stringValue("c"), - ) + stringValue("a"), + stringValue("b"), + stringValue("c"), ), expected = """ |[ @@ -435,11 +423,9 @@ class PartiQLValueTextWriterTest { ), formatted( value = sexpValue( - sequenceOf( - int32Value(1), - int32Value(2), - int32Value(3), - ) + int32Value(1), + int32Value(2), + int32Value(3), ), expected = """ |( @@ -454,15 +440,13 @@ class PartiQLValueTextWriterTest { @JvmStatic fun structFormatted() = listOf( formatted( - value = structValue(emptySequence()), + value = structValue(), expected = "{}", ), formatted( value = structValue( - sequenceOf( - "a" to int32Value(1), - "b" to stringValue("x"), - ) + "a" to int32Value(1), + "b" to stringValue("x"), ), expected = """ |{ @@ -477,29 +461,21 @@ class PartiQLValueTextWriterTest { fun nestedCollectionsFormatted() = listOf( formatted( value = structValue( - sequenceOf( - "bag" to bagValue( - sequenceOf( - int32Value(1), - int32Value(2), - int32Value(3), - ) - ), - "list" to listValue( - sequenceOf( - stringValue("a"), - stringValue("b"), - stringValue("c"), - ) - ), - "sexp" to sexpValue( - sequenceOf( - int32Value(1), - int32Value(2), - int32Value(3), - ) - ), - ) + "bag" to bagValue( + int32Value(1), + int32Value(2), + int32Value(3), + ), + "list" to listValue( + stringValue("a"), + stringValue("b"), + stringValue("c"), + ), + "sexp" to sexpValue( + int32Value(1), + int32Value(2), + int32Value(3), + ), ), expected = """ |{ @@ -523,28 +499,20 @@ class PartiQLValueTextWriterTest { ), formatted( value = bagValue( - sequenceOf( - listValue( - sequenceOf( - stringValue("a"), - stringValue("b"), - stringValue("c"), - ) - ), - sexpValue( - sequenceOf( - int32Value(1), - int32Value(2), - int32Value(3), - ) - ), - structValue( - sequenceOf( - "a" to int32Value(1), - "b" to stringValue("x"), - ) - ), - ) + listValue( + stringValue("a"), + stringValue("b"), + stringValue("c"), + ), + sexpValue( + int32Value(1), + int32Value(2), + int32Value(3), + ), + structValue( + "a" to int32Value(1), + "b" to stringValue("x"), + ), ), expected = """ |<< @@ -567,11 +535,9 @@ class PartiQLValueTextWriterTest { ), formatted( value = structValue( - sequenceOf( - "bag" to bagValue(emptySequence()), - "list" to listValue(emptySequence()), - "sexp" to sexpValue(emptySequence()), - ) + "bag" to bagValue(), + "list" to listValue(), + "sexp" to sexpValue(), ), expected = """ |{ @@ -583,11 +549,9 @@ class PartiQLValueTextWriterTest { ), formatted( value = bagValue( - sequenceOf( - listValue(emptySequence()), - sexpValue(emptySequence()), - structValue(emptySequence()), - ) + listValue(), + sexpValue(), + structValue(), ), expected = """ |<< @@ -690,39 +654,31 @@ class PartiQLValueTextWriterTest { // TODO TIMESTAMP // TODO INTERVAL case( - value = bagValue(emptySequence(), annotations), + value = bagValue(annotations = annotations), expected = "x::y::<<>>", ), case( - value = listValue(emptySequence(), annotations), + value = listValue(annotations = annotations), expected = "x::y::[]", ), case( - value = sexpValue(emptySequence(), annotations), + value = sexpValue(annotations = annotations), expected = "x::y::()", ), formatted( value = bagValue( - sequenceOf( - listValue( - sequenceOf( - stringValue("a", listOf("x")), - ), - listOf("list") - ), - sexpValue( - sequenceOf( - int32Value(1, listOf("y")), - ), - listOf("sexp") - ), - structValue( - sequenceOf( - "a" to int32Value(1, listOf("z")), - ), - listOf("struct") - ), - ) + listValue( + stringValue("a", listOf("x")), + annotations = listOf("list") + ), + sexpValue( + int32Value(1, listOf("y")), + annotations = listOf("sexp") + ), + structValue( + "a" to int32Value(1, listOf("z")), + annotations = listOf("struct") + ), ), expected = """ |<< diff --git a/plugins/partiql-local/src/main/kotlin/org/partiql/plugins/local/LocalConnector.kt b/plugins/partiql-local/src/main/kotlin/org/partiql/plugins/local/LocalConnector.kt index 649520f300..5ec5808ec1 100644 --- a/plugins/partiql-local/src/main/kotlin/org/partiql/plugins/local/LocalConnector.kt +++ b/plugins/partiql-local/src/main/kotlin/org/partiql/plugins/local/LocalConnector.kt @@ -64,10 +64,11 @@ class LocalConnector( private val default: Path = Paths.get(System.getProperty("user.home")).resolve(".partiql/local") - override fun getName(): String = CONNECTOR_NAME + override val name: String = CONNECTOR_NAME - override fun create(catalogName: String, config: StructElement): Connector { - val root = config.getOptional(ROOT_KEY)?.stringValueOrNull?.let { Paths.get(it) } + override fun create(catalogName: String, config: StructElement?): Connector { + assert(config != null) { "Local plugin requires non-null config" } + val root = config!!.getOptional(ROOT_KEY)?.stringValueOrNull?.let { Paths.get(it) } val catalogRoot = root ?: default if (catalogRoot.notExists()) { error("Invalid catalog `$catalogRoot`") diff --git a/plugins/partiql-local/src/main/kotlin/org/partiql/plugins/local/LocalPlugin.kt b/plugins/partiql-local/src/main/kotlin/org/partiql/plugins/local/LocalPlugin.kt index 2bec07bf2d..94ad833a39 100644 --- a/plugins/partiql-local/src/main/kotlin/org/partiql/plugins/local/LocalPlugin.kt +++ b/plugins/partiql-local/src/main/kotlin/org/partiql/plugins/local/LocalPlugin.kt @@ -20,14 +20,14 @@ import org.partiql.spi.function.PartiQLFunction import org.partiql.spi.function.PartiQLFunctionExperimental /** - * FsPlugin is a PartiQL plugin that provides schemas written in PartiQL Value Schema. + * LocalPlugin is a PartiQL plugin that provides schemas written in PartiQL Value Schema. * * Backed by a memoized catalog tree from the given root dir; global bindings are files. */ class LocalPlugin : Plugin { - override fun getConnectorFactories(): List = listOf(LocalConnector.Factory()) + override val factory: Connector.Factory = LocalConnector.Factory() - @PartiQLFunctionExperimental - override fun getFunctions(): List = listOf() + @OptIn(PartiQLFunctionExperimental::class) + override val functions: List = listOf() } diff --git a/plugins/partiql-memory/src/main/kotlin/org/partiql/plugins/memory/MemoryCatalog.kt b/plugins/partiql-memory/src/main/kotlin/org/partiql/plugins/memory/MemoryCatalog.kt deleted file mode 100644 index cc6a578422..0000000000 --- a/plugins/partiql-memory/src/main/kotlin/org/partiql/plugins/memory/MemoryCatalog.kt +++ /dev/null @@ -1,76 +0,0 @@ -package org.partiql.plugins.memory - -import org.partiql.spi.BindingCase -import org.partiql.spi.BindingPath -import org.partiql.spi.connector.ConnectorObjectPath -import org.partiql.types.StaticType - -class MemoryCatalog( - private val map: Map -) { - operator fun get(key: String): StaticType? = map[key] - - public fun lookup(path: BindingPath): MemoryObject? { - val kPath = ConnectorObjectPath( - path.steps.map { - when (it.bindingCase) { - BindingCase.SENSITIVE -> it.name - BindingCase.INSENSITIVE -> it.loweredName - } - } - ) - val k = kPath.steps.joinToString(".") - if (this[k] != null) { - return this[k]?.let { MemoryObject(kPath.steps, it) } - } else { - val candidatePath = this.map.keys.map { it.split(".") } - val kPathIter = kPath.steps.listIterator() - while (kPathIter.hasNext()) { - val currKPath = kPathIter.next() - candidatePath.forEach { - val match = mutableListOf() - val candidateIterator = it.iterator() - while (candidateIterator.hasNext()) { - if (candidateIterator.next() == currKPath) { - match.add(currKPath) - val pathIteratorCopy = kPath.steps.listIterator(kPathIter.nextIndex()) - candidateIterator.forEachRemaining { - val nextPath = pathIteratorCopy.next() - if (it != nextPath) { - match.clear() - return@forEachRemaining - } - match.add(it) - } - } else { - return@forEach - } - } - if (match.isNotEmpty()) { - return this[match.joinToString(".")]?.let { it1 -> - MemoryObject( - match, - it1 - ) - } - } - } - } - return null - } - } - - companion object { - fun of(vararg entities: Pair) = MemoryCatalog(mapOf(*entities)) - } - - class Provider { - private val catalogs = mutableMapOf() - - operator fun get(path: String): MemoryCatalog = catalogs[path] ?: error("invalid catalog path") - - operator fun set(path: String, catalog: MemoryCatalog) { - catalogs[path] = catalog - } - } -} diff --git a/plugins/partiql-memory/src/main/kotlin/org/partiql/plugins/memory/MemoryConnector.kt b/plugins/partiql-memory/src/main/kotlin/org/partiql/plugins/memory/MemoryConnector.kt index 9fce2172ea..830e1b0d02 100644 --- a/plugins/partiql-memory/src/main/kotlin/org/partiql/plugins/memory/MemoryConnector.kt +++ b/plugins/partiql-memory/src/main/kotlin/org/partiql/plugins/memory/MemoryConnector.kt @@ -1,6 +1,7 @@ package org.partiql.plugins.memory import com.amazon.ionelement.api.StructElement +import org.partiql.spi.BindingCase import org.partiql.spi.BindingPath import org.partiql.spi.connector.Connector import org.partiql.spi.connector.ConnectorMetadata @@ -9,38 +10,101 @@ import org.partiql.spi.connector.ConnectorObjectPath import org.partiql.spi.connector.ConnectorSession import org.partiql.types.StaticType -class MemoryConnector( - val catalog: MemoryCatalog -) : Connector { +/** + * This is a plugin used for testing and is not a versioned API per semver. + */ +public class MemoryConnector(private val metadata: ConnectorMetadata) : Connector { companion object { const val CONNECTOR_NAME = "memory" } - override fun getMetadata(session: ConnectorSession): ConnectorMetadata = Metadata() + override fun getMetadata(session: ConnectorSession): ConnectorMetadata = metadata - class Factory(private val provider: MemoryCatalog.Provider) : Connector.Factory { - override fun getName(): String = CONNECTOR_NAME + class Factory(private val catalogs: Map) : Connector.Factory { - override fun create(catalogName: String, config: StructElement): Connector { - val catalog = provider[catalogName] - return MemoryConnector(catalog) + override val name: String = CONNECTOR_NAME + + override fun create(catalogName: String, config: StructElement?): Connector { + return catalogs[catalogName] ?: error("Catalog $catalogName is not registered in the MemoryPlugin") } } - inner class Metadata : ConnectorMetadata { + /** + * Connector metadata uses dot-delimited identifiers and StaticType for catalog metadata. + * + * @property map + */ + class Metadata(private val map: Map) : ConnectorMetadata { + + companion object { + @JvmStatic + fun of(vararg entities: Pair) = Metadata(mapOf(*entities)) + } - override fun getObjectType(session: ConnectorSession, handle: ConnectorObjectHandle): StaticType? { + override fun getObjectType(session: ConnectorSession, handle: ConnectorObjectHandle): StaticType { val obj = handle.value as MemoryObject return obj.type } override fun getObjectHandle(session: ConnectorSession, path: BindingPath): ConnectorObjectHandle? { - val value = catalog.lookup(path) ?: return null + val value = lookup(path) ?: return null return ConnectorObjectHandle( absolutePath = ConnectorObjectPath(value.path), value = value, ) } + + operator fun get(key: String): StaticType? = map[key] + + public fun lookup(path: BindingPath): MemoryObject? { + val kPath = ConnectorObjectPath( + path.steps.map { + when (it.bindingCase) { + BindingCase.SENSITIVE -> it.name + BindingCase.INSENSITIVE -> it.loweredName + } + } + ) + val k = kPath.steps.joinToString(".") + if (this[k] != null) { + return this[k]?.let { MemoryObject(kPath.steps, it) } + } else { + val candidatePath = this.map.keys.map { it.split(".") } + val kPathIter = kPath.steps.listIterator() + while (kPathIter.hasNext()) { + val currKPath = kPathIter.next() + candidatePath.forEach { + val match = mutableListOf() + val candidateIterator = it.iterator() + while (candidateIterator.hasNext()) { + if (candidateIterator.next() == currKPath) { + match.add(currKPath) + val pathIteratorCopy = kPath.steps.listIterator(kPathIter.nextIndex()) + candidateIterator.forEachRemaining { + val nextPath = pathIteratorCopy.next() + if (it != nextPath) { + match.clear() + return@forEachRemaining + } + match.add(it) + } + } else { + return@forEach + } + } + if (match.isNotEmpty()) { + return this[match.joinToString(".")]?.let { it1 -> + MemoryObject( + match, + it1 + ) + } + } + } + } + return null + } + } } } diff --git a/plugins/partiql-memory/src/main/kotlin/org/partiql/plugins/memory/MemoryPlugin.kt b/plugins/partiql-memory/src/main/kotlin/org/partiql/plugins/memory/MemoryPlugin.kt index 808f8e72ae..3b269ff21b 100644 --- a/plugins/partiql-memory/src/main/kotlin/org/partiql/plugins/memory/MemoryPlugin.kt +++ b/plugins/partiql-memory/src/main/kotlin/org/partiql/plugins/memory/MemoryPlugin.kt @@ -5,9 +5,10 @@ import org.partiql.spi.connector.Connector import org.partiql.spi.function.PartiQLFunction import org.partiql.spi.function.PartiQLFunctionExperimental -class MemoryPlugin(val provider: MemoryCatalog.Provider) : Plugin { - override fun getConnectorFactories(): List = listOf(MemoryConnector.Factory(provider)) +class MemoryPlugin(private val catalogs: Map) : Plugin { - @PartiQLFunctionExperimental - override fun getFunctions(): List = emptyList() + override val factory: Connector.Factory = MemoryConnector.Factory(catalogs) + + @OptIn(PartiQLFunctionExperimental::class) + override val functions: List = emptyList() } diff --git a/plugins/partiql-memory/src/test/kotlin/org/partiql/plugins/memory/InMemoryPluginTest.kt b/plugins/partiql-memory/src/test/kotlin/org/partiql/plugins/memory/InMemoryPluginTest.kt index d2fab9eae5..5085ed3f99 100644 --- a/plugins/partiql-memory/src/test/kotlin/org/partiql/plugins/memory/InMemoryPluginTest.kt +++ b/plugins/partiql-memory/src/test/kotlin/org/partiql/plugins/memory/InMemoryPluginTest.kt @@ -12,25 +12,24 @@ import org.partiql.types.StructType class InMemoryPluginTest { - private val session = object : ConnectorSession { - override fun getQueryId(): String = "mock_query_id" - override fun getUserId(): String = "mock_user" - } - companion object { - val provider = MemoryCatalog.Provider().also { - it["test"] = MemoryCatalog.of( - "a" to StaticType.INT2, - "struct" to StructType( + + private val session = object : ConnectorSession { + override fun getQueryId(): String = "mock_query_id" + override fun getUserId(): String = "mock_user" + } + + private val metadata = MemoryConnector.Metadata.of( + "a" to StaticType.INT2, + "struct" to StructType( + fields = listOf(StructType.Field("a", StaticType.INT2)) + ), + "schema.tbl" to BagType( + StructType( fields = listOf(StructType.Field("a", StaticType.INT2)) - ), - "schema.tbl" to BagType( - StructType( - fields = listOf(StructType.Field("a", StaticType.INT2)) - ) ) ) - } + ) } @Test @@ -41,13 +40,7 @@ class InMemoryPluginTest { ) ) val expected = StaticType.INT2 - - val connector = MemoryConnector(provider["test"]) - - val metadata = connector.Metadata() - val handle = metadata.getObjectHandle(session, requested) - val descriptor = metadata.getObjectType(session, handle!!) assert(requested.isEquivalentTo(handle.absolutePath)) @@ -61,11 +54,6 @@ class InMemoryPluginTest { BindingName("A", BindingCase.SENSITIVE) ) ) - - val connector = MemoryConnector(provider["test"]) - - val metadata = connector.Metadata() - val handle = metadata.getObjectHandle(session, requested) assert(null == handle) @@ -79,17 +67,9 @@ class InMemoryPluginTest { BindingName("a", BindingCase.INSENSITIVE) ) ) - - val connector = MemoryConnector(provider["test"]) - - val metadata = connector.Metadata() - val handle = metadata.getObjectHandle(session, requested) - val descriptor = metadata.getObjectType(session, handle!!) - val expectConnectorPath = ConnectorObjectPath(listOf("struct")) - val expectedObjectType = StructType(fields = listOf(StructType.Field("a", StaticType.INT2))) assert(expectConnectorPath == handle.absolutePath) @@ -104,15 +84,8 @@ class InMemoryPluginTest { BindingName("tbl", BindingCase.INSENSITIVE) ) ) - - val connector = MemoryConnector(provider["test"]) - - val metadata = connector.Metadata() - val handle = metadata.getObjectHandle(session, requested) - val descriptor = metadata.getObjectType(session, handle!!) - val expectedObjectType = BagType(StructType(fields = listOf(StructType.Field("a", StaticType.INT2)))) assert(requested.isEquivalentTo(handle.absolutePath)) @@ -128,17 +101,9 @@ class InMemoryPluginTest { BindingName("a", BindingCase.INSENSITIVE) ) ) - - val connector = MemoryConnector(provider["test"]) - - val metadata = connector.Metadata() - val handle = metadata.getObjectHandle(session, requested) - val descriptor = metadata.getObjectType(session, handle!!) - val expectedObjectType = BagType(StructType(fields = listOf(StructType.Field("a", StaticType.INT2)))) - val expectConnectorPath = ConnectorObjectPath(listOf("schema", "tbl")) assert(expectConnectorPath == handle.absolutePath)