Skip to content

Commit

Permalink
Adds Datum comparator
Browse files Browse the repository at this point in the history
  • Loading branch information
johnedquinn committed Aug 13, 2024
1 parent 8717275 commit 0bbc039
Show file tree
Hide file tree
Showing 21 changed files with 1,378 additions and 209 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package org.partiql.eval.internal

import org.partiql.eval.value.Datum

internal data class Record(val values: Array<Datum>) {
internal class Record(val values: Array<Datum>) {

companion object {
val empty = Record(emptyArray())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package org.partiql.eval.internal.operator.rel

import org.partiql.eval.value.Datum

internal object DatumArrayComparator : Comparator<Array<Datum>> {
private val delegate = Datum.comparator(false)
override fun compare(o1: Array<Datum>, o2: Array<Datum>): Int {
if (o1.size < o2.size) {
return -1
}
if (o1.size > o2.size) {
return 1
}
for (index in 0..o2.lastIndex) {
val element1 = o1[index]
val element2 = o2[index]
val compared = delegate.compare(element1, element2)
if (compared != 0) {
return compared
}
}
return 0
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,66 +5,39 @@ import org.partiql.eval.internal.Record
import org.partiql.eval.internal.operator.Operator
import org.partiql.eval.value.Datum
import org.partiql.spi.fn.Agg
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.nullValue
import java.util.TreeMap
import java.util.TreeSet

internal class RelAggregate(
val input: Operator.Relation,
val keys: List<Operator.Expr>,
val functions: List<Operator.Aggregation>
private val keys: List<Operator.Expr>,
private val functions: List<Operator.Aggregation>
) : Operator.Relation {

lateinit var records: Iterator<Record>
private lateinit var records: Iterator<Record>

@OptIn(PartiQLValueExperimental::class)
val aggregationMap = TreeMap<List<PartiQLValue>, List<AccumulatorWrapper>>(PartiQLValueListComparator)

@OptIn(PartiQLValueExperimental::class)
object PartiQLValueListComparator : Comparator<List<PartiQLValue>> {
private val delegate = PartiQLValue.comparator(nullsFirst = false)
override fun compare(o1: List<PartiQLValue>, o2: List<PartiQLValue>): Int {
if (o1.size < o2.size) {
return -1
}
if (o1.size > o2.size) {
return 1
}
for (index in 0..o2.lastIndex) {
val element1 = o1[index]
val element2 = o2[index]
val compared = delegate.compare(element1, element2)
if (compared != 0) {
return compared
}
}
return 0
}
}
private val aggregationMap = TreeMap<Array<Datum>, List<AccumulatorWrapper>>(DatumArrayComparator)

/**
* Wraps an [Agg.Accumulator] to help with filtering distinct values.
*
* @property seen maintains which values have already been seen. If null, we accumulate all values coming through.
*/
class AccumulatorWrapper @OptIn(PartiQLValueExperimental::class) constructor(
class AccumulatorWrapper(
val delegate: Agg.Accumulator,
val args: List<Operator.Expr>,
val seen: TreeSet<List<PartiQLValue>>?
val seen: TreeSet<Array<Datum>>?
)

@OptIn(PartiQLValueExperimental::class)
override fun open(env: Environment) {
input.open(env)
for (inputRecord in input) {
// Initialize the AggregationMap
val evaluatedGroupByKeys = keys.map {
val key = it.eval(env.push(inputRecord))
val evaluatedGroupByKeys = Array(keys.size) { keyIndex ->
val key = keys[keyIndex].eval(env.push(inputRecord))
when (key.isMissing) {
true -> nullValue()
false -> key.toPartiQLValue()
true -> Datum.nullValue()
false -> key
}
}
val accumulators = aggregationMap.getOrPut(evaluatedGroupByKeys) {
Expand All @@ -73,7 +46,7 @@ internal class RelAggregate(
delegate = it.delegate.accumulator(),
args = it.args,
seen = when (it.setQuantifier) {
Operator.Aggregation.SetQuantifier.DISTINCT -> TreeSet(PartiQLValueListComparator)
Operator.Aggregation.SetQuantifier.DISTINCT -> TreeSet(DatumArrayComparator)
Operator.Aggregation.SetQuantifier.ALL -> null
}
)
Expand All @@ -82,19 +55,19 @@ internal class RelAggregate(

// Aggregate Values in Aggregation State
accumulators.forEachIndexed { index, function ->
// TODO: Add support for aggregating PQLValues directly
val arguments = function.args.map { it.eval(env.push(inputRecord)) }
// Skip over aggregation if NULL/MISSING
if (arguments.any { it.isMissing || it.isNull }) {
return@forEachIndexed
val arguments = Array(function.args.size) {
val argument = function.args[it].eval(env.push(inputRecord))
// Skip over aggregation if NULL/MISSING
if (argument.isNull || argument.isMissing) {
return@forEachIndexed
}
argument
}
// TODO: Add support for a Datum comparator. Currently, this conversion is inefficient.
val valuesToCompare = arguments.map { it.toPartiQLValue() }
// Skip over aggregation if DISTINCT and SEEN
if (function.seen != null && (function.seen.add(valuesToCompare).not())) {
if (function.seen != null && (function.seen.add(arguments).not())) {
return@forEachIndexed
}
accumulators[index].delegate.next(arguments.toTypedArray())
accumulators[index].delegate.next(arguments)
}
}

Expand All @@ -111,7 +84,7 @@ internal class RelAggregate(

records = iterator {
aggregationMap.forEach { (keysEvaluated, accumulators) ->
val recordValues = accumulators.map { acc -> acc.delegate.value() } + keysEvaluated.map { value -> Datum.of(value) }
val recordValues = accumulators.map { acc -> acc.delegate.value() } + keysEvaluated
yield(Record.of(*recordValues.toTypedArray()))
}
}
Expand All @@ -125,7 +98,6 @@ internal class RelAggregate(
return records.next()
}

@OptIn(PartiQLValueExperimental::class)
override fun close() {
aggregationMap.clear()
input.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,21 @@ package org.partiql.eval.internal.operator.rel
import org.partiql.eval.internal.Environment
import org.partiql.eval.internal.Record
import org.partiql.eval.internal.operator.Operator
import org.partiql.value.ListValue
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.listValue
import java.util.TreeSet

internal class RelDistinct(
val input: Operator.Relation
) : RelPeeking() {

// TODO: Add hashcode/equals support for Datum. Then we can use Record directly.
@OptIn(PartiQLValueExperimental::class)
private val seen = TreeSet<ListValue<PartiQLValue>>(PartiQLValue.comparator())
private val seen = TreeSet(DatumArrayComparator)

override fun openPeeking(env: Environment) {
input.open(env)
}

@OptIn(PartiQLValueExperimental::class)
override fun peek(): Record? {
for (next in input) {
val transformed = listValue(List(next.values.size) { next.values[it].toPartiQLValue() })
val transformed = Array(next.values.size) { next.values[it] }
if (seen.contains(transformed).not()) {
seen.add(transformed)
return next
Expand All @@ -33,7 +26,6 @@ internal class RelDistinct(
return null
}

@OptIn(PartiQLValueExperimental::class)
override fun closePeeking() {
seen.clear()
input.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ package org.partiql.eval.internal.operator.rel
import org.partiql.eval.internal.Environment
import org.partiql.eval.internal.Record
import org.partiql.eval.internal.operator.Operator
import org.partiql.eval.value.Datum
import org.partiql.plan.Rel
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import java.util.Collections

@OptIn(PartiQLValueExperimental::class)
internal class RelSort(
private val input: Operator.Relation,
private val specs: List<Pair<Operator.Expr, Rel.Op.Sort.Order>>
Expand All @@ -17,8 +15,8 @@ internal class RelSort(
private var records: Iterator<Record> = Collections.emptyIterator()
private var init: Boolean = false

private val nullsFirstComparator = PartiQLValue.comparator(nullsFirst = true)
private val nullsLastComparator = PartiQLValue.comparator(nullsFirst = false)
private val nullsFirstComparator = Datum.comparator(true)
private val nullsLastComparator = Datum.comparator(false)

private lateinit var env: Environment

Expand All @@ -32,9 +30,8 @@ internal class RelSort(
private val comparator = object : Comparator<Record> {
override fun compare(l: Record, r: Record): Int {
specs.forEach { spec ->
// TODO: Write comparator for PQLValue
val lVal = spec.first.eval(env.push(l)).toPartiQLValue()
val rVal = spec.first.eval(env.push(r)).toPartiQLValue()
val lVal = spec.first.eval(env.push(l))
val rVal = spec.first.eval(env.push(r))

// DESC_NULLS_FIRST(l, r) == ASC_NULLS_LAST(r, l)
// DESC_NULLS_LAST(l, r) == ASC_NULLS_FIRST(r, l)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,18 @@ package org.partiql.eval.internal.operator.rex
import org.partiql.eval.internal.Environment
import org.partiql.eval.internal.operator.Operator
import org.partiql.eval.value.Datum
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental

internal class ExprNullIf(
private val valueExpr: Operator.Expr,
private val nullifierExpr: Operator.Expr
) : Operator.Expr {

@OptIn(PartiQLValueExperimental::class)
private val comparator = PartiQLValue.comparator()
private val comparator = Datum.comparator()

@PartiQLValueExperimental
override fun eval(env: Environment): Datum {
val value = valueExpr.eval(env)
val nullifier = nullifierExpr.eval(env)
return when (comparator.compare(value.toPartiQLValue(), nullifier.toPartiQLValue())) {
return when (comparator.compare(value, nullifier)) {
0 -> Datum.nullValue()
else -> value
}
Expand Down
5 changes: 5 additions & 0 deletions partiql-spi/api/partiql-spi.api
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,13 @@ public final class org/partiql/errors/TypeCheckException : java/lang/RuntimeExce

public abstract interface class org/partiql/eval/value/Datum : java/lang/Iterable {
public static fun bag (Ljava/lang/Iterable;)Lorg/partiql/eval/value/Datum;
public static fun bag ([Lorg/partiql/eval/value/Datum;)Lorg/partiql/eval/value/Datum;
public static fun bigInt (J)Lorg/partiql/eval/value/Datum;
public static fun blob ([B)Lorg/partiql/eval/value/Datum;
public static fun bool (Z)Lorg/partiql/eval/value/Datum;
public static fun clob ([B)Lorg/partiql/eval/value/Datum;
public static fun comparator ()Ljava/util/Comparator;
public static fun comparator (Z)Ljava/util/Comparator;
public static fun date (Lorg/partiql/value/datetime/Date;)Lorg/partiql/eval/value/Datum;
public static fun decimal (Ljava/math/BigDecimal;II)Lorg/partiql/eval/value/Datum;
public static fun decimalArbitrary (Ljava/math/BigDecimal;)Lorg/partiql/eval/value/Datum;
Expand Down Expand Up @@ -306,13 +309,15 @@ public abstract interface class org/partiql/eval/value/Datum : java/lang/Iterabl
public fun isNull ()Z
public fun iterator ()Ljava/util/Iterator;
public static fun list (Ljava/lang/Iterable;)Lorg/partiql/eval/value/Datum;
public static fun list ([Lorg/partiql/eval/value/Datum;)Lorg/partiql/eval/value/Datum;
public static fun missing ()Lorg/partiql/eval/value/Datum;
public static fun missing (Lorg/partiql/types/PType;)Lorg/partiql/eval/value/Datum;
public static fun nullValue ()Lorg/partiql/eval/value/Datum;
public static fun nullValue (Lorg/partiql/types/PType;)Lorg/partiql/eval/value/Datum;
public static fun of (Lorg/partiql/value/PartiQLValue;)Lorg/partiql/eval/value/Datum;
public static fun real (F)Lorg/partiql/eval/value/Datum;
public static fun sexp (Ljava/lang/Iterable;)Lorg/partiql/eval/value/Datum;
public static fun sexp ([Lorg/partiql/eval/value/Datum;)Lorg/partiql/eval/value/Datum;
public static fun smallInt (S)Lorg/partiql/eval/value/Datum;
public static fun string (Ljava/lang/String;)Lorg/partiql/eval/value/Datum;
public static fun struct (Ljava/lang/Iterable;)Lorg/partiql/eval/value/Datum;
Expand Down
25 changes: 25 additions & 0 deletions partiql-spi/src/main/java/org/partiql/eval/value/Datum.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Comparator;
import java.util.Iterator;
import java.util.Objects;

Expand Down Expand Up @@ -291,6 +292,7 @@ default BigDecimal getBigDecimal() {
* @throws NullPointerException if this instance also returns true on {@link #isNull()}; callers should check that
* {@link #isNull()} returns false before attempting to invoke this method.
*/
@NotNull
@Override
default Iterator<Datum> iterator() {
throw new UnsupportedOperationException();
Expand Down Expand Up @@ -641,4 +643,27 @@ static Datum timestampWithoutTZ(@NotNull Timestamp value) {
static Datum date(@NotNull Date value) {
return new DatumDate(value);
}

/**
* Comparator for PartiQL's scalar comparison operator.
* @return the default comparator for {@link Datum}. The comparator orders null values first.
*/
@NotNull
static Comparator<Datum> comparator() {
return comparator(true);
}

/**
* Comparator for PartiQL's scalar comparison operator.
* @param nullsFirst if true, nulls are ordered before non-null values, otherwise after.
* @return the default comparator for {@link Datum}.
*/
@NotNull
static Comparator<Datum> comparator(boolean nullsFirst) {
if (nullsFirst) {
return new DatumComparator.NullsFirst();
} else {
return new DatumComparator.NullsLast();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class DatumCollection implements Datum {
_type = type;
}

@NotNull
@Override
public Iterator<Datum> iterator() {
return _value.iterator();
Expand Down
Loading

0 comments on commit 0bbc039

Please sign in to comment.