Skip to content

Commit

Permalink
Fixes list/bag ExprValue creation in plan evaluator (#969)
Browse files Browse the repository at this point in the history
* Adds IsOrderedMeta to determine if the sequence represents a bag or list

* Changes AggregationFinder to not be static
  • Loading branch information
rchowell authored and yliuuuu committed Jan 20, 2023
1 parent bddbe04 commit 97b3028
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 26 deletions.
9 changes: 9 additions & 0 deletions lang/src/main/kotlin/org/partiql/lang/ast/IsOrderedMeta.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package org.partiql.lang.ast

/**
* To reduce any extraneous passes over data, this [Meta] indicates whether the associated BindingsToValues Physical
* expression should be an ordered list or a bag.
*/
object IsOrderedMeta : Meta {
override val tag = "\$is_ordered"
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package org.partiql.lang.eval.physical
import com.amazon.ion.IonString
import com.amazon.ion.IonValue
import com.amazon.ion.Timestamp
import com.amazon.ion.system.IonSystemBuilder
import com.amazon.ionelement.api.MetaContainer
import com.amazon.ionelement.api.emptyMetaContainer
import com.amazon.ionelement.api.toIonValue
Expand Down Expand Up @@ -132,7 +133,6 @@ import java.util.regex.Pattern
* [1]: https://www.complang.tuwien.ac.at/anton/lvas/sem06w/fest.pdf
*/
internal class PhysicalPlanCompilerImpl(
private val valueFactory: ExprValueFactory,
private val functions: Map<String, ExprFunction>,
private val customTypedOpParameters: Map<String, TypedOpParameter>,
private val procedures: Map<String, StoredProcedure>,
Expand Down Expand Up @@ -275,34 +275,25 @@ internal class PhysicalPlanCompilerImpl(
val mapThunk = compileAstExpr(expr.exp)
val bexprThunk: RelationThunkEnv = bexperConverter.convert(expr.query)

fun createOutputSequence(relationType: RelationType?, elements: Sequence<ExprValue>) = when (relationType) {
RelationType.LIST -> valueFactory.newList(elements)
RelationType.BAG -> valueFactory.newBag(elements)
null -> throw EvaluationException(
message = "Unable to recover the output Relation Type",
errorCode = ErrorCode.EVALUATOR_GENERIC_EXCEPTION,
internal = false
)
val relationType = when (expr.metas.containsKey(IsOrderedMeta.tag)) {
true -> RelationType.LIST
false -> RelationType.BAG
}

return thunkFactory.thunkEnv(expr.metas) { env ->
var relationType: RelationType? = null
// we create a snapshot for currentRegister to use during the evaluation
// this is to avoid issue when iterator planner result
val currentRegister = env.registers.clone()
val elements = sequence {
env.load(currentRegister)
val relItr = bexprThunk(env)
relationType = relItr.relType
while (relItr.nextRow()) {
yield(mapThunk(env))
}
}

// Trick the compiler here to always initialize `relationType`
when (elements.firstOrNull()) {
null -> createOutputSequence(relationType, emptySequence())
else -> createOutputSequence(relationType, elements)
when (relationType) {
RelationType.LIST -> ExprValue.newList(elements)
RelationType.BAG -> ExprValue.newBag(elements)
}
}
}
Expand Down Expand Up @@ -1609,7 +1600,7 @@ internal class PhysicalPlanCompilerImpl(

fun matchRegexPattern(value: ExprValue, likePattern: (() -> Pattern)?): ExprValue {
return when {
likePattern == null || value.type.isUnknown -> valueFactory.nullValue
likePattern == null || value.type.isUnknown -> ExprValue.nullValue
!value.type.isText -> err(
"LIKE expression must be given non-null strings as input",
ErrorCode.EVALUATOR_LIKE_INVALID_INPUTS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ internal class AggregationVisitorTransform(

override fun transformExprSelect_group(node: PartiqlAst.Expr.Select): PartiqlAst.GroupBy? {
// Return with Empty Context if without Group
val containsAggregations = AggregationFinder.containsAggregations(node.project)
val containsAggregations = AggregationFinder().containsAggregations(node.project)
if (node.group == null) {
val context = VisitorContext(emptyList(), null, containsAggregations)
contextStack.add(context)
Expand Down Expand Up @@ -202,7 +202,8 @@ internal class AggregationVisitorTransform(
* Recursively searches through a [PartiqlAst.Projection] to find [PartiqlAst.Expr.CallAgg]'s, but does NOT recurse
* into [PartiqlAst.Expr.Select]. Designed to be called directly using [containsAggregations].
*/
private object AggregationFinder : PartiqlAst.Visitor() {
private class AggregationFinder : PartiqlAst.Visitor() {

var hasAggregations: Boolean = false

fun containsAggregations(node: PartiqlAst.Projection): Boolean {
Expand Down Expand Up @@ -276,7 +277,7 @@ internal class AggregationVisitorTransform(
}

/**
* IDs outside of aggregation functions should always be replaced with the Group Key uniue aliases. If no
* IDs outside of aggregation functions should always be replaced with the Group Key unique aliases. If no
* replacement is found, we throw an EvaluationException.
*/
private fun getReplacementForIdOutsideOfAggregationFunction(node: PartiqlAst.Expr.Id): PartiqlAst.Expr {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ package org.partiql.lang.planner.transforms
import com.amazon.ionelement.api.emptyMetaContainer
import com.amazon.ionelement.api.ionString
import com.amazon.ionelement.api.ionSymbol
import org.partiql.lang.ast.IsOrderedMeta
import org.partiql.lang.domains.PartiqlAst
import org.partiql.lang.domains.PartiqlAstToPartiqlLogicalVisitorTransform
import org.partiql.lang.domains.PartiqlLogical
import org.partiql.lang.domains.metaContainerOf
import org.partiql.lang.errors.Problem
import org.partiql.lang.errors.ProblemHandler
import org.partiql.lang.eval.builtins.CollectionAggregationFunction
Expand Down Expand Up @@ -263,20 +265,25 @@ internal class AstToLogicalVisitorTransform(
}

private fun transformProjection(node: PartiqlAst.Expr.Select, algebra: PartiqlLogical.Bexpr): PartiqlLogical.Expr {
val project = node.project
val metas = when (node.order) {
null -> project.metas
else -> project.metas + metaContainerOf(IsOrderedMeta)
}
return PartiqlLogical.build {
when (val project = node.project) {
when (project) {
is PartiqlAst.Projection.ProjectValue -> {
bindingsToValues(
exp = transformExpr(project.value),
query = algebra,
metas = project.metas
metas = metas
)
}
is PartiqlAst.Projection.ProjectList -> {
bindingsToValues(
exp = transformProjectList(project),
query = algebra,
metas = project.metas
metas = metas
)
}
is PartiqlAst.Projection.ProjectStar -> {
Expand All @@ -289,7 +296,7 @@ internal class AstToLogicalVisitorTransform(
input = algebra,
key = transformExpr(project.key),
value = transformExpr(project.value),
metas = project.metas
metas = metas
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,17 @@ class GetByKeyProjectRelationalOperatorFactory : ProjectRelationalOperatorFactor
// Parse the tableId so we don't have to at evaluation-time
val tableId = UUID.fromString(impl.staticArgs.single().textValue)

var exhausted = false

// Finally, return a RelationExpression which evaluates the key value expression and returns a
// RelationIterator containing a single row corresponding to the key (or no rows if nothing matches)
return RelationExpression { state ->
// this code runs at evaluation-time.

if (exhausted) {
throw IllegalStateException("Exhausted result set")
}

// Get the current database from the EvaluationSession context.
// Please note that the state.session.context map is immutable, therefore it is not possible
// for custom operators or functions to put stuff in there. (Hopefully that will reduce the
Expand All @@ -74,6 +80,8 @@ class GetByKeyProjectRelationalOperatorFactory : ProjectRelationalOperatorFactor
// get the record requested.
val record = db.getRecordByKey(tableId, keyValue)

exhausted = true

// if the record was not found, return an empty relation:
if (record == null)
relation(RelationType.BAG) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ internal class EvaluatingCompilerCollectionAggregationsTest : EvaluatorTestBase(
""",
expectedResult = """
<<
{'k': [2, 4], 'coll_sum_a': 6, 'coll_sum_inner': <<6>>, 'sum_b': 30},
{'k': [6, 7], 'coll_sum_a': 13, 'coll_sum_inner': <<13>>, 'sum_b': 20}
{'k': [2, 4], 'coll_sum_a': 6, 'sum_b': 30, 'coll_sum_inner': <<6>>},
{'k': [6, 7], 'coll_sum_a': 13, 'sum_b': 20, 'coll_sum_inner': <<13>>}
>>
"""
),
Expand Down

0 comments on commit 97b3028

Please sign in to comment.