Skip to content

Commit

Permalink
Merge pull request #1388 from partiql/partiql-eval-perf-dynamic
Browse files Browse the repository at this point in the history
Adds performance optimizations for ExprCallDynamic
  • Loading branch information
johnedquinn authored Apr 2, 2024
2 parents 609f8b8 + 25d56dd commit 810035c
Show file tree
Hide file tree
Showing 31 changed files with 553 additions and 280 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,11 @@ internal class Compiler(
}
}

@OptIn(FnExperimental::class, PartiQLValueExperimental::class)
@OptIn(FnExperimental::class)
override fun visitRexOpCallDynamic(node: Rex.Op.Call.Dynamic, ctx: StaticType?): Operator {
val args = node.args.map { visitRex(it, ctx).modeHandled() }.toTypedArray()
val candidates = node.candidates.map { candidate ->
val candidates = Array(node.candidates.size) {
val candidate = node.candidates[it]
val fn = symbols.getFn(candidate.fn)
val coercions = candidate.coercions.toTypedArray()
ExprCallDynamic.Candidate(fn, coercions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,21 @@ import org.partiql.value.PartiQLValueType
*/
@OptIn(PartiQLValueExperimental::class, FnExperimental::class)
internal class ExprCallDynamic(
private val candidates: List<Candidate>,
candidates: Array<Candidate>,
private val args: Array<Operator.Expr>
) : Operator.Expr {

private val candidateIndex = CandidateIndex.All(candidates)

override fun eval(env: Environment): PartiQLValue {
val actualArgs = args.map { it.eval(env) }.toTypedArray()
candidates.forEach { candidate ->
if (candidate.matches(actualArgs)) {
return candidate.eval(actualArgs, env)
}
val actualTypes = actualArgs.map { it.type }
candidateIndex.get(actualTypes)?.let {
return it.eval(actualArgs, env)
}
val errorString = buildString {
val argString = actualArgs.joinToString(", ")
append("Could not dynamically find function for arguments $argString in $candidates.")
append("Could not dynamically find function (${candidateIndex.name}) for arguments $argString.")
}
throw TypeCheckException(errorString)
}
Expand All @@ -47,13 +48,11 @@ internal class ExprCallDynamic(
*
* @see ExprCallDynamic
*/
internal class Candidate(
data class Candidate(
val fn: Fn,
val coercions: Array<Ref.Cast?>
) {

private val signatureParameters = fn.signature.parameters.map { it.type }.toTypedArray()

fun eval(originalArgs: Array<PartiQLValue>, env: Environment): PartiQLValue {
val args = originalArgs.mapIndexed { i, arg ->
when (val c = coercions[i]) {
Expand All @@ -63,32 +62,156 @@ internal class ExprCallDynamic(
}.toTypedArray()
return fn.invoke(args)
}
}

private sealed interface CandidateIndex {

public fun get(args: List<PartiQLValueType>): Candidate?

/**
* Preserves the original ordering of the passed-in candidates while making it faster to lookup matching
* functions. Utilizes both [Direct] and [Indirect].
*
* Say a user passes in the following ordered candidates:
* [
* foo(int16, int16) -> int16,
* foo(int32, int32) -> int32,
* foo(int64, int64) -> int64,
* foo(string, string) -> string,
* foo(struct, struct) -> struct,
* foo(numeric, numeric) -> numeric,
* foo(int64, dynamic) -> dynamic,
* foo(struct, dynamic) -> dynamic,
* foo(bool, bool) -> bool
* ]
*
* With the above candidates, the [CandidateIndex.All] will maintain the original ordering by utilizing:
* - [CandidateIndex.Direct] to match hashable runtime types
* - [CandidateIndex.Indirect] to match the dynamic type
*
* For the above example, the internal representation of [CandidateIndex.All] is a list of
* [CandidateIndex.Direct] and [CandidateIndex.Indirect] that looks like:
* ALL listOf(
* DIRECT hashMap(
* [int16, int16] --> foo(int16, int16) -> int16,
* [int32, int32] --> foo(int32, int32) -> int32,
* [int64, int64] --> foo(int64, int64) -> int64
* [string, string] --> foo(string, string) -> string,
* [struct, struct] --> foo(struct, struct) -> struct,
* [numeric, numeric] --> foo(numeric, numeric) -> numeric
* ),
* INDIRECT listOf(
* foo(int64, dynamic) -> dynamic,
* foo(struct, dynamic) -> dynamic
* ),
* DIRECT hashMap(
* [bool, bool] --> foo(bool, bool) -> bool
* )
* )
*
* @param candidates
*/
class All(
candidates: Array<Candidate>,
) : CandidateIndex {

private val lookups: List<CandidateIndex>
internal val name: String = candidates.first().fn.signature.name

internal fun matches(inputs: Array<PartiQLValue>): Boolean {
for (i in inputs.indices) {
val inputType = inputs[i].type
val parameterType = signatureParameters[i]
val c = coercions[i]
when (c) {
// coercion might be null if one of the following is true
// Function parameter is ANY,
// Input type is null
// input type is the same as function parameter
null -> {
if (!(inputType == parameterType || inputType == PartiQLValueType.NULL || parameterType == PartiQLValueType.ANY)) {
return false
init {
val lookupsMutable = mutableListOf<CandidateIndex>()
val accumulator = mutableListOf<Pair<List<PartiQLValueType>, Candidate>>()

// Indicates that we are currently processing dynamic candidates that accept ANY.
var activelyProcessingAny = true

candidates.forEach { candidate ->
// Gather the input types to the dynamic invocation
val lookupTypes = candidate.coercions.mapIndexed { index, cast ->
when (cast) {
null -> candidate.fn.signature.parameters[index].type
else -> cast.input
}
}
else -> {
// checking the input type is expected by the coercion
if (inputType != c.input) return false
// checking the result is expected by the function signature
// this should branch should never be reached, but leave it here for clarity
if (c.target != parameterType) error("Internal Error: Cast Target does not match Function Parameter")
val parametersIncludeAny = lookupTypes.any { it == PartiQLValueType.ANY }
// A way to simplify logic further below. If it's empty, add something and set the processing type.
if (accumulator.isEmpty()) {
activelyProcessingAny = parametersIncludeAny
accumulator.add(lookupTypes to candidate)
return@forEach
}
when (parametersIncludeAny) {
true -> when (activelyProcessingAny) {
true -> accumulator.add(lookupTypes to candidate)
false -> {
activelyProcessingAny = true
lookupsMutable.add(Direct.of(accumulator.toList()))
accumulator.clear()
accumulator.add(lookupTypes to candidate)
}
}
false -> when (activelyProcessingAny) {
false -> accumulator.add(lookupTypes to candidate)
true -> {
activelyProcessingAny = false
lookupsMutable.add(Indirect(accumulator.toList()))
accumulator.clear()
accumulator.add(lookupTypes to candidate)
}
}
}
}
// Add any remaining candidates (that we didn't submit due to not ending while switching)
when (accumulator.isEmpty()) {
true -> { /* Do nothing! */ }
false -> when (activelyProcessingAny) {
true -> lookupsMutable.add(Indirect(accumulator.toList()))
false -> lookupsMutable.add(Direct.of(accumulator.toList()))
}
}
this.lookups = lookupsMutable
}

override fun get(args: List<PartiQLValueType>): Candidate? {
return this.lookups.firstNotNullOfOrNull { it.get(args) }
}
}

/**
* An O(1) structure to quickly find directly matching dynamic candidates. This is specifically used for runtime
* types that can be matched directly. AKA int32, int64, etc. This does NOT include [PartiQLValueType.ANY].
*/
data class Direct private constructor(val directCandidates: HashMap<List<PartiQLValueType>, Candidate>) : CandidateIndex {

companion object {
internal fun of(candidates: List<Pair<List<PartiQLValueType>, Candidate>>): Direct {
val candidateMap = java.util.HashMap<List<PartiQLValueType>, Candidate>()
candidateMap.putAll(candidates)
return Direct(candidateMap)
}
}

override fun get(args: List<PartiQLValueType>): Candidate? {
return directCandidates[args]
}
}

/**
* Holds all candidates that expect a [PartiQLValueType.ANY] on input. This maintains the original
* precedence order.
*/
data class Indirect(private val candidates: List<Pair<List<PartiQLValueType>, Candidate>>) : CandidateIndex {
override fun get(args: List<PartiQLValueType>): Candidate? {
candidates.forEach { (types, candidate) ->
for (i in args.indices) {
if (args[i] != types[i] && types[i] != PartiQLValueType.ANY) {
return@forEach
}
}
return candidate
}
return null
}
return true
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.partiql.value.Int64Value
import org.partiql.value.Int8Value
import org.partiql.value.IntValue
import org.partiql.value.ListValue
import org.partiql.value.NullValue
import org.partiql.value.NumericValue
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
Expand All @@ -30,7 +31,13 @@ import org.partiql.value.StringValue
import org.partiql.value.SymbolValue
import org.partiql.value.TextValue
import org.partiql.value.bagValue
import org.partiql.value.binaryValue
import org.partiql.value.blobValue
import org.partiql.value.boolValue
import org.partiql.value.byteValue
import org.partiql.value.charValue
import org.partiql.value.clobValue
import org.partiql.value.dateValue
import org.partiql.value.decimalValue
import org.partiql.value.float32Value
import org.partiql.value.float64Value
Expand All @@ -40,9 +47,13 @@ import org.partiql.value.int64Value
import org.partiql.value.int8Value
import org.partiql.value.intValue
import org.partiql.value.listValue
import org.partiql.value.missingValue
import org.partiql.value.sexpValue
import org.partiql.value.stringValue
import org.partiql.value.structValue
import org.partiql.value.symbolValue
import org.partiql.value.timeValue
import org.partiql.value.timestampValue
import java.math.BigDecimal
import java.math.BigInteger

Expand Down Expand Up @@ -79,14 +90,48 @@ internal class ExprCast(val arg: Operator.Expr, val cast: Ref.Cast) : Operator.E
PartiQLValueType.LIST -> castFromCollection(arg as ListValue<*>, cast.target)
PartiQLValueType.SEXP -> castFromCollection(arg as SexpValue<*>, cast.target)
PartiQLValueType.STRUCT -> TODO("CAST FROM STRUCT not yet implemented")
PartiQLValueType.NULL -> error("cast from NULL should be handled by Typer")
PartiQLValueType.NULL -> castFromNull(arg as NullValue, cast.target)
PartiQLValueType.MISSING -> error("cast from MISSING should be handled by Typer")
}
} catch (e: DataException) {
throw TypeCheckException()
}
}

@OptIn(PartiQLValueExperimental::class)
private fun castFromNull(value: NullValue, t: PartiQLValueType): PartiQLValue {
return when (t) {
PartiQLValueType.ANY -> value
PartiQLValueType.BOOL -> boolValue(null)
PartiQLValueType.CHAR -> charValue(null)
PartiQLValueType.STRING -> stringValue(null)
PartiQLValueType.SYMBOL -> symbolValue(null)
PartiQLValueType.BINARY -> binaryValue(null)
PartiQLValueType.BYTE -> byteValue(null)
PartiQLValueType.BLOB -> blobValue(null)
PartiQLValueType.CLOB -> clobValue(null)
PartiQLValueType.DATE -> dateValue(null)
PartiQLValueType.TIME -> timeValue(null)
PartiQLValueType.TIMESTAMP -> timestampValue(null)
PartiQLValueType.INTERVAL -> TODO("Not yet supported")
PartiQLValueType.BAG -> bagValue<PartiQLValue>(null)
PartiQLValueType.LIST -> listValue<PartiQLValue>(null)
PartiQLValueType.SEXP -> sexpValue<PartiQLValue>(null)
PartiQLValueType.STRUCT -> structValue<PartiQLValue>(null)
PartiQLValueType.NULL -> value
PartiQLValueType.MISSING -> missingValue() // TODO: Os this allowed
PartiQLValueType.INT8 -> int8Value(null)
PartiQLValueType.INT16 -> int16Value(null)
PartiQLValueType.INT32 -> int32Value(null)
PartiQLValueType.INT64 -> int64Value(null)
PartiQLValueType.INT -> intValue(null)
PartiQLValueType.DECIMAL -> decimalValue(null)
PartiQLValueType.DECIMAL_ARBITRARY -> decimalValue(null)
PartiQLValueType.FLOAT32 -> float32Value(null)
PartiQLValueType.FLOAT64 -> float64Value(null)
}
}

@OptIn(PartiQLValueExperimental::class)
private fun castFromBool(value: BoolValue, t: PartiQLValueType): PartiQLValue {
val v = value.value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import org.partiql.value.structValue
import java.io.ByteArrayOutputStream
import java.math.BigDecimal
import java.math.BigInteger
import kotlin.test.assertEquals
import kotlin.test.assertNotNull

/**
Expand Down Expand Up @@ -1253,10 +1252,12 @@ class PartiQLEngineDefaultTest {

internal fun assert() {
val permissiveResult = run(mode = PartiQLEngine.Mode.PERMISSIVE)
assertEquals(expectedPermissive, permissiveResult, comparisonString(expectedPermissive, permissiveResult))
assert(expectedPermissive == permissiveResult.first) {
comparisonString(expectedPermissive, permissiveResult.first, permissiveResult.second)
}
var error: Throwable? = null
try {
when (val result = run(mode = PartiQLEngine.Mode.STRICT)) {
when (val result = run(mode = PartiQLEngine.Mode.STRICT).first) {
is CollectionValue<*> -> result.toList()
else -> result
}
Expand All @@ -1266,7 +1267,7 @@ class PartiQLEngineDefaultTest {
assertNotNull(error)
}

private fun run(mode: PartiQLEngine.Mode): PartiQLValue {
private fun run(mode: PartiQLEngine.Mode): Pair<PartiQLValue, PartiQLPlan> {
val statement = parser.parse(input).root
val catalog = MemoryCatalog.PartiQL().name("memory").build()
val connector = MemoryConnector(catalog)
Expand All @@ -1283,17 +1284,18 @@ class PartiQLEngineDefaultTest {
val plan = planner.plan(statement, session)
val prepared = engine.prepare(plan.plan, PartiQLEngine.Session(mapOf("memory" to connector), mode = mode))
when (val result = engine.execute(prepared)) {
is PartiQLResult.Value -> return result.value
is PartiQLResult.Value -> return result.value to plan.plan
is PartiQLResult.Error -> throw result.cause
}
}

@OptIn(PartiQLValueExperimental::class)
private fun comparisonString(expected: PartiQLValue, actual: PartiQLValue): String {
private fun comparisonString(expected: PartiQLValue, actual: PartiQLValue, plan: PartiQLPlan): String {
val expectedBuffer = ByteArrayOutputStream()
val expectedWriter = PartiQLValueIonWriterBuilder.standardIonTextBuilder().build(expectedBuffer)
expectedWriter.append(expected)
return buildString {
PlanPrinter.append(this, plan)
appendLine("Expected : $expectedBuffer")
expectedBuffer.reset()
expectedWriter.append(actual)
Expand Down Expand Up @@ -1444,6 +1446,7 @@ class PartiQLEngineDefaultTest {
).assert()

@Test
@Disabled("This broke in its introduction to the codebase on merge. See 5fb9a1ccbc7e630b0df62aa8b161d319c763c1f6.")
// TODO: Add to conformance tests
fun wildCard() =
SuccessTestCase(
Expand Down Expand Up @@ -1487,6 +1490,7 @@ class PartiQLEngineDefaultTest {
).assert()

@Test
@Disabled("This broke in its introduction to the codebase on merge. See 5fb9a1ccbc7e630b0df62aa8b161d319c763c1f6.")
// TODO: add to conformance tests
// Note that the existing pipeline produced identical result when supplying with
// SELECT VALUE v2.name FROM e as v0, v0.books as v1, unpivot v1.authors as v2;
Expand Down
Loading

0 comments on commit 810035c

Please sign in to comment.