Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1/2] Resolve functions from the catalog #1584

Merged
merged 2 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ import org.partiql.plan.v1.operator.rex.RexTable
import org.partiql.plan.v1.operator.rex.RexVar
import org.partiql.plan.v1.operator.rex.RexVisitor
import org.partiql.spi.catalog.Session
import org.partiql.spi.fn.Agg
import org.partiql.spi.fn.Aggregation
import org.partiql.spi.value.Datum
import org.partiql.types.PType
import org.partiql.plan.Rel as IRel
Expand Down Expand Up @@ -166,7 +166,7 @@ internal class SqlCompiler(
else -> Operator.Aggregation.SetQuantifier.ALL
}
object : Operator.Aggregation {
override val delegate: Agg = agg
override val delegate: Aggregation = agg
override val args: List<Operator.Expr> = args
override val setQuantifier: Operator.Aggregation.SetQuantifier = setq
}
Expand Down Expand Up @@ -382,7 +382,7 @@ internal class SqlCompiler(
val fn = rex.getFunction()
val args = rex.getArgs().map { compile(it, ctx) }
val fnTakesInMissing = fn.signature.parameters.any {
it.type.kind == PType.Kind.DYNAMIC // TODO: Is this needed?
it.getType().kind == PType.Kind.DYNAMIC // TODO: Is this needed?
}
return when (fnTakesInMissing) {
true -> ExprCallStatic(fn, args.map { it.catch() }.toTypedArray())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package org.partiql.eval.internal.operator

import org.partiql.eval.internal.Environment
import org.partiql.eval.internal.Record
import org.partiql.spi.fn.Agg
import org.partiql.spi.value.Datum

internal sealed interface Operator {
Expand All @@ -27,7 +26,7 @@ internal sealed interface Operator {

interface Aggregation : Operator {

val delegate: Agg
val delegate: org.partiql.spi.fn.Aggregation

val args: List<Expr>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package org.partiql.eval.internal.operator.rel
import org.partiql.eval.internal.Environment
import org.partiql.eval.internal.Record
import org.partiql.eval.internal.operator.Operator
import org.partiql.spi.fn.Agg
import org.partiql.spi.fn.Aggregation
import org.partiql.spi.value.Datum
import java.util.TreeMap
import java.util.TreeSet
Expand All @@ -19,12 +19,12 @@ internal class RelOpAggregate(
private val aggregationMap = TreeMap<Array<Datum>, List<AccumulatorWrapper>>(DatumArrayComparator)

/**
* Wraps an [Agg.Accumulator] to help with filtering distinct values.
* Wraps an [Aggregation.Accumulator] to help with filtering distinct values.
*
* @property seen maintains which values have already been seen. If null, we accumulate all values coming through.
*/
class AccumulatorWrapper(
val delegate: Agg.Accumulator,
val delegate: Aggregation.Accumulator,
val args: List<Operator.Expr>,
val seen: TreeSet<Array<Datum>>?
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package org.partiql.eval.internal.operator.rex
import org.partiql.errors.TypeCheckException
import org.partiql.eval.internal.Environment
import org.partiql.eval.internal.operator.Operator
import org.partiql.spi.fn.Fn
import org.partiql.spi.fn.Function
import org.partiql.spi.value.Datum
import org.partiql.types.PType
import org.partiql.value.PartiQLValue
Expand All @@ -30,14 +30,14 @@ import org.partiql.value.PartiQLValue
*/
internal class ExprCallDynamic(
private val name: String,
candidateFns: Array<Fn>,
candidateFns: Array<Function>,
private val args: Array<Operator.Expr>
) : Operator.Expr {

private val candidates = Array(candidateFns.size) { Candidate(candidateFns[it]) }
private val paramIndices: IntRange = args.indices
private val paramTypes: List<List<PType>> = this.candidates.map { candidate -> candidate.fn.signature.parameters.map { it.type } }
private val paramFamilies: List<List<CoercionFamily>> = this.candidates.map { candidate -> candidate.fn.signature.parameters.map { family(it.type.kind) } }
private val paramTypes: List<List<PType>> = this.candidates.map { candidate -> candidate.fn.signature.parameters.map { it.getType() } }
private val paramFamilies: List<List<CoercionFamily>> = this.candidates.map { candidate -> candidate.fn.signature.parameters.map { family(it.getType().kind) } }
private val cachedMatches: MutableMap<List<PType>, Int> = mutableMapOf()

override fun eval(env: Environment): Datum {
Expand Down Expand Up @@ -154,7 +154,7 @@ internal class ExprCallDynamic(
* @see ExprCallDynamic
*/
private class Candidate(
val fn: Fn,
val fn: Function,
) {

/**
Expand All @@ -168,7 +168,7 @@ internal class ExprCallDynamic(
return nil.invoke()
}
val argType = arg.type
val paramType = fn.signature.parameters[i].type
val paramType = fn.signature.parameters[i].getType()
when (paramType == argType) {
true -> arg
false -> CastTable.cast(arg, paramType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@ package org.partiql.eval.internal.operator.rex

import org.partiql.eval.internal.Environment
import org.partiql.eval.internal.operator.Operator
import org.partiql.spi.fn.Fn
import org.partiql.spi.fn.Function
import org.partiql.spi.value.Datum
import org.partiql.value.PartiQLValueExperimental

@OptIn(PartiQLValueExperimental::class)
internal class ExprCallStatic(
private val fn: Fn,
private val fn: Function,
private val inputs: Array<Operator.Expr>,
) : Operator.Expr {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.MethodSource
import org.partiql.eval.internal.Environment
import org.partiql.eval.internal.helpers.ValueUtility.check
import org.partiql.spi.fn.Fn
import org.partiql.spi.fn.FnParameter
import org.partiql.spi.fn.FnSignature
import org.partiql.spi.fn.Function
import org.partiql.spi.fn.Parameter
import org.partiql.spi.value.Datum
import org.partiql.spi.value.Datum.bag
import org.partiql.spi.value.Datum.bool
Expand Down Expand Up @@ -64,14 +64,14 @@ class ExprCallDynamicTest {
)

@OptIn(PartiQLValueExperimental::class)
internal val candidates: Array<Fn> = params.mapIndexed { index, it ->
object : Fn {
internal val candidates: Array<Function> = params.mapIndexed { index, it ->
object : Function {
override val signature: FnSignature = FnSignature(
name = "example_function",
returns = PType.integer(),
parameters = listOf(
FnParameter("first", type = it.first.toPType()),
FnParameter("second", type = it.second.toPType()),
Parameter("first", type = it.first.toPType()),
Parameter("second", type = it.second.toPType()),
)
)

Expand Down
14 changes: 7 additions & 7 deletions partiql-plan/api/partiql-plan.api
Original file line number Diff line number Diff line change
Expand Up @@ -2530,7 +2530,7 @@ public abstract interface class org/partiql/plan/v1/builder/PlanFactory {
public static final field Companion Lorg/partiql/plan/v1/builder/PlanFactory$Companion;
public static fun getSTANDARD ()Lorg/partiql/plan/v1/builder/PlanFactory;
public abstract fun relAggregate (Lorg/partiql/plan/v1/operator/rel/Rel;Ljava/util/List;Ljava/util/List;)Lorg/partiql/plan/v1/operator/rel/RelAggregate;
public abstract fun relAggregateCall (Lorg/partiql/spi/fn/Agg;Ljava/util/List;Z)Lorg/partiql/plan/v1/operator/rel/RelAggregateCall;
public abstract fun relAggregateCall (Lorg/partiql/spi/fn/Aggregation;Ljava/util/List;Z)Lorg/partiql/plan/v1/operator/rel/RelAggregateCall;
public abstract fun relCorrelate (Lorg/partiql/plan/v1/operator/rel/Rel;Lorg/partiql/plan/v1/operator/rel/Rel;)Lorg/partiql/plan/v1/operator/rel/RelCorrelate;
public abstract fun relCorrelate (Lorg/partiql/plan/v1/operator/rel/Rel;Lorg/partiql/plan/v1/operator/rel/Rel;Lorg/partiql/plan/v1/operator/rel/RelJoinType;)Lorg/partiql/plan/v1/operator/rel/RelCorrelate;
public abstract fun relDistinct (Lorg/partiql/plan/v1/operator/rel/Rel;)Lorg/partiql/plan/v1/operator/rel/RelDistinct;
Expand All @@ -2554,7 +2554,7 @@ public abstract interface class org/partiql/plan/v1/builder/PlanFactory {
public abstract fun rexArray (Ljava/util/Collection;)Lorg/partiql/plan/v1/operator/rex/RexArray;
public abstract fun rexBag (Ljava/util/Collection;)Lorg/partiql/plan/v1/operator/rex/RexBag;
public abstract fun rexCall (Ljava/util/List;Ljava/util/List;)Lorg/partiql/plan/v1/operator/rex/RexCallDynamic;
public abstract fun rexCall (Lorg/partiql/spi/fn/Fn;Ljava/util/List;)Lorg/partiql/plan/v1/operator/rex/RexCallStatic;
public abstract fun rexCall (Lorg/partiql/spi/fn/Function;Ljava/util/List;)Lorg/partiql/plan/v1/operator/rex/RexCallStatic;
public abstract fun rexCase (Ljava/util/List;Lorg/partiql/plan/v1/operator/rex/Rex;)Lorg/partiql/plan/v1/operator/rex/RexCase;
public abstract fun rexCase (Lorg/partiql/plan/v1/operator/rex/Rex;Ljava/util/List;Lorg/partiql/plan/v1/operator/rex/Rex;)Lorg/partiql/plan/v1/operator/rex/RexCase;
public abstract fun rexCast (Lorg/partiql/plan/v1/operator/rex/Rex;Lorg/partiql/types/PType;)Lorg/partiql/plan/v1/operator/rex/RexCast;
Expand Down Expand Up @@ -2586,8 +2586,8 @@ public final class org/partiql/plan/v1/builder/PlanFactory$Companion {

public final class org/partiql/plan/v1/builder/PlanFactory$DefaultImpls {
public static fun relAggregate (Lorg/partiql/plan/v1/builder/PlanFactory;Lorg/partiql/plan/v1/operator/rel/Rel;Ljava/util/List;Ljava/util/List;)Lorg/partiql/plan/v1/operator/rel/RelAggregate;
public static fun relAggregateCall (Lorg/partiql/plan/v1/builder/PlanFactory;Lorg/partiql/spi/fn/Agg;Ljava/util/List;Z)Lorg/partiql/plan/v1/operator/rel/RelAggregateCall;
public static synthetic fun relAggregateCall$default (Lorg/partiql/plan/v1/builder/PlanFactory;Lorg/partiql/spi/fn/Agg;Ljava/util/List;ZILjava/lang/Object;)Lorg/partiql/plan/v1/operator/rel/RelAggregateCall;
public static fun relAggregateCall (Lorg/partiql/plan/v1/builder/PlanFactory;Lorg/partiql/spi/fn/Aggregation;Ljava/util/List;Z)Lorg/partiql/plan/v1/operator/rel/RelAggregateCall;
public static synthetic fun relAggregateCall$default (Lorg/partiql/plan/v1/builder/PlanFactory;Lorg/partiql/spi/fn/Aggregation;Ljava/util/List;ZILjava/lang/Object;)Lorg/partiql/plan/v1/operator/rel/RelAggregateCall;
public static fun relCorrelate (Lorg/partiql/plan/v1/builder/PlanFactory;Lorg/partiql/plan/v1/operator/rel/Rel;Lorg/partiql/plan/v1/operator/rel/Rel;)Lorg/partiql/plan/v1/operator/rel/RelCorrelate;
public static fun relCorrelate (Lorg/partiql/plan/v1/builder/PlanFactory;Lorg/partiql/plan/v1/operator/rel/Rel;Lorg/partiql/plan/v1/operator/rel/Rel;Lorg/partiql/plan/v1/operator/rel/RelJoinType;)Lorg/partiql/plan/v1/operator/rel/RelCorrelate;
public static fun relDistinct (Lorg/partiql/plan/v1/builder/PlanFactory;Lorg/partiql/plan/v1/operator/rel/Rel;)Lorg/partiql/plan/v1/operator/rel/RelDistinct;
Expand All @@ -2612,7 +2612,7 @@ public final class org/partiql/plan/v1/builder/PlanFactory$DefaultImpls {
public static fun rexArray (Lorg/partiql/plan/v1/builder/PlanFactory;Ljava/util/Collection;)Lorg/partiql/plan/v1/operator/rex/RexArray;
public static fun rexBag (Lorg/partiql/plan/v1/builder/PlanFactory;Ljava/util/Collection;)Lorg/partiql/plan/v1/operator/rex/RexBag;
public static fun rexCall (Lorg/partiql/plan/v1/builder/PlanFactory;Ljava/util/List;Ljava/util/List;)Lorg/partiql/plan/v1/operator/rex/RexCallDynamic;
public static fun rexCall (Lorg/partiql/plan/v1/builder/PlanFactory;Lorg/partiql/spi/fn/Fn;Ljava/util/List;)Lorg/partiql/plan/v1/operator/rex/RexCallStatic;
public static fun rexCall (Lorg/partiql/plan/v1/builder/PlanFactory;Lorg/partiql/spi/fn/Function;Ljava/util/List;)Lorg/partiql/plan/v1/operator/rex/RexCallStatic;
public static fun rexCase (Lorg/partiql/plan/v1/builder/PlanFactory;Ljava/util/List;Lorg/partiql/plan/v1/operator/rex/Rex;)Lorg/partiql/plan/v1/operator/rex/RexCase;
public static fun rexCase (Lorg/partiql/plan/v1/builder/PlanFactory;Lorg/partiql/plan/v1/operator/rex/Rex;Ljava/util/List;Lorg/partiql/plan/v1/operator/rex/Rex;)Lorg/partiql/plan/v1/operator/rex/RexCase;
public static fun rexCast (Lorg/partiql/plan/v1/builder/PlanFactory;Lorg/partiql/plan/v1/operator/rex/Rex;Lorg/partiql/types/PType;)Lorg/partiql/plan/v1/operator/rex/RexCast;
Expand Down Expand Up @@ -2740,7 +2740,7 @@ public final class org/partiql/plan/v1/operator/rel/RelAggregate$DefaultImpls {
}

public abstract interface class org/partiql/plan/v1/operator/rel/RelAggregateCall {
public abstract fun getAgg ()Lorg/partiql/spi/fn/Agg;
public abstract fun getAgg ()Lorg/partiql/spi/fn/Aggregation;
public abstract fun getArgs ()Ljava/util/List;
public abstract fun isDistinct ()Z
}
Expand Down Expand Up @@ -3165,7 +3165,7 @@ public abstract interface class org/partiql/plan/v1/operator/rex/RexCallStatic :
public abstract fun accept (Lorg/partiql/plan/v1/operator/rex/RexVisitor;Ljava/lang/Object;)Ljava/lang/Object;
public abstract fun getArgs ()Ljava/util/List;
public abstract fun getChildren ()Ljava/util/Collection;
public abstract fun getFunction ()Lorg/partiql/spi/fn/Fn;
public abstract fun getFunction ()Lorg/partiql/spi/fn/Function;
}

public final class org/partiql/plan/v1/operator/rex/RexCallStatic$DefaultImpls {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ import org.partiql.plan.v1.operator.rex.RexTableImpl
import org.partiql.plan.v1.operator.rex.RexVar
import org.partiql.plan.v1.operator.rex.RexVarImpl
import org.partiql.spi.catalog.Table
import org.partiql.spi.fn.Agg
import org.partiql.spi.fn.Fn
import org.partiql.spi.fn.Aggregation
import org.partiql.spi.fn.Function
import org.partiql.spi.value.Datum
import org.partiql.types.PType

Expand Down Expand Up @@ -134,7 +134,7 @@ public interface PlanFactory {
* @param isDistinct
* @return
*/
public fun relAggregateCall(aggregation: Agg, args: List<Rex>, isDistinct: Boolean = false): RelAggregateCall =
public fun relAggregateCall(aggregation: Aggregation, args: List<Rex>, isDistinct: Boolean = false): RelAggregateCall =
RelAggregateCallImpl(aggregation, args, isDistinct)

/**
Expand Down Expand Up @@ -347,7 +347,7 @@ public interface PlanFactory {
* @param args
* @return
*/
public fun rexCall(function: Fn, args: List<Rex>): RexCallStatic = RexCallStaticImpl(function, args)
public fun rexCall(function: Function, args: List<Rex>): RexCallStatic = RexCallStaticImpl(function, args)

/**
* Create a [RexCallDynamic] instance.
Expand All @@ -356,7 +356,7 @@ public interface PlanFactory {
* @param args
* @return
*/
public fun rexCall(functions: List<Fn>, args: List<Rex>): RexCallDynamic = RexCallDynamicImpl(functions, args)
public fun rexCall(functions: List<Function>, args: List<Rex>): RexCallDynamic = RexCallDynamicImpl(functions, args)

/**
* Create a [RexCase] instance for a searched case-when.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package org.partiql.plan.v1.operator.rel

import org.partiql.plan.v1.operator.rex.Rex
import org.partiql.spi.fn.Agg
import org.partiql.spi.fn.Aggregation

/**
* TODO DOCUMENTATION
*/
public interface RelAggregateCall {

public fun getAgg(): Agg
public fun getAgg(): Aggregation

public fun getArgs(): List<Rex>

Expand All @@ -25,11 +25,11 @@ public interface RelAggregateCall {
* @property isDistinct
*/
internal class RelAggregateCallImpl(
private var agg: Agg,
private var agg: Aggregation,
private var args: List<Rex>,
private var isDistinct: Boolean,
) : RelAggregateCall {
override fun getAgg(): Agg = agg
override fun getAgg(): Aggregation = agg
override fun getArgs(): List<Rex> = args
override fun isDistinct(): Boolean = isDistinct
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package org.partiql.plan.v1.operator.rex

import org.partiql.spi.fn.Fn
import org.partiql.spi.fn.Function
import org.partiql.types.PType

/**
Expand All @@ -13,7 +13,7 @@ public interface RexCallDynamic : Rex {
*
* @return
*/
public fun getFunctions(): List<Fn>
public fun getFunctions(): List<Function>

/**
* Returns the list of function arguments.
Expand All @@ -28,13 +28,13 @@ public interface RexCallDynamic : Rex {
/**
* Default [RexCallDynamic] implementation meant for extension.
*/
internal class RexCallDynamicImpl(functions: List<Fn>, args: List<Rex>) : RexCallDynamic {
internal class RexCallDynamicImpl(functions: List<Function>, args: List<Rex>) : RexCallDynamic {

// DO NOT USE FINAL
private var _functions = functions
private var _args = args

override fun getFunctions(): List<Fn> = _functions
override fun getFunctions(): List<Function> = _functions

override fun getArgs(): List<Rex> = _args

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package org.partiql.plan.v1.operator.rex

import org.partiql.spi.fn.Fn
import org.partiql.spi.fn.Function
import org.partiql.types.PType

/**
Expand All @@ -13,7 +13,7 @@ public interface RexCallStatic : Rex {
*
* @return
*/
public fun getFunction(): Fn
public fun getFunction(): Function

/**
* Returns the list of function arguments.
Expand All @@ -28,13 +28,13 @@ public interface RexCallStatic : Rex {
/**
* Default [RexCallStatic] implementation meant for extension.
*/
internal class RexCallStaticImpl(function: Fn, args: List<Rex>) : RexCallStatic {
internal class RexCallStaticImpl(function: Function, args: List<Rex>) : RexCallStatic {

// DO NOT USE FINAL
private var _function = function
private var _args = args

override fun getFunction(): Fn = _function
override fun getFunction(): Function = _function

override fun getArgs(): List<Rex> = _args

Expand Down
Loading
Loading