Skip to content

Commit

Permalink
Adds support for COLL_AGGs
Browse files Browse the repository at this point in the history
  • Loading branch information
johnedquinn committed Feb 15, 2024
1 parent cf3cf5d commit 40a065f
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ internal object SqlBuiltins {
Fn_CHAR_LENGTH__STRING__INT,
Fn_CHAR_LENGTH__SYMBOL__INT,
Fn_CHAR_LENGTH__CLOB__INT,
Fn_COLL_AGG__BAG__ANY.ANY,
Fn_COLL_AGG__BAG__ANY.AVG,
Fn_COLL_AGG__BAG__ANY.COUNT,
Fn_COLL_AGG__BAG__ANY.EVERY,
Fn_COLL_AGG__BAG__ANY.GROUP_AS,
Fn_COLL_AGG__BAG__ANY.MAX,
Fn_COLL_AGG__BAG__ANY.MIN,
Fn_COLL_AGG__BAG__ANY.SOME,
Fn_COLL_AGG__BAG__ANY.SUM,
Fn_POS__INT8__INT8,
Fn_POS__INT16__INT16,
Fn_POS__INT32__INT32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

package org.partiql.spi.connector.sql.builtins

import org.partiql.spi.connector.sql.builtins.internal.AccumulatorCountStar
import org.partiql.spi.connector.sql.builtins.internal.AccumulatorCount
import org.partiql.spi.fn.Agg
import org.partiql.spi.fn.AggSignature
import org.partiql.spi.fn.FnExperimental
Expand All @@ -25,5 +25,5 @@ public object Agg_COUNT__ANY__INT32 : Agg {
isDecomposable = true
)

override fun accumulator(): Agg.Accumulator = AccumulatorCountStar()
override fun accumulator(): Agg.Accumulator = AccumulatorCount()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// ktlint-disable filename
@file:Suppress("ClassName")

package org.partiql.spi.connector.sql.builtins

import org.partiql.spi.connector.sql.builtins.internal.Accumulator
import org.partiql.spi.connector.sql.builtins.internal.AccumulatorAnySome
import org.partiql.spi.connector.sql.builtins.internal.AccumulatorAvg
import org.partiql.spi.connector.sql.builtins.internal.AccumulatorCount
import org.partiql.spi.connector.sql.builtins.internal.AccumulatorEvery
import org.partiql.spi.connector.sql.builtins.internal.AccumulatorGroupAs
import org.partiql.spi.connector.sql.builtins.internal.AccumulatorMax
import org.partiql.spi.connector.sql.builtins.internal.AccumulatorMin
import org.partiql.spi.connector.sql.builtins.internal.AccumulatorSum
import org.partiql.spi.fn.Agg
import org.partiql.spi.fn.Fn
import org.partiql.spi.fn.FnExperimental
import org.partiql.spi.fn.FnParameter
import org.partiql.spi.fn.FnSignature
import org.partiql.value.BagValue
import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.PartiQLValueType
import org.partiql.value.check

@OptIn(PartiQLValueExperimental::class, FnExperimental::class)
internal abstract class Fn_COLL_AGG__BAG__ANY : Fn {

abstract fun getAccumulator(): Agg.Accumulator

companion object {
@JvmStatic
internal fun createSignature(name: String) = FnSignature(
name = name,
returns = PartiQLValueType.ANY,
parameters = listOf(
FnParameter("value", PartiQLValueType.BAG),
),
isNullCall = true,
isNullable = true
)
}

override fun invoke(args: Array<PartiQLValue>): PartiQLValue {
val bag = args[0].check<BagValue<*>>()
val accumulator = getAccumulator()
bag.forEach { element -> accumulator.next(arrayOf(element)) }
return accumulator.value()
}

object SUM : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_sum")
override fun getAccumulator(): Accumulator = AccumulatorSum()
}

object AVG : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_avg")
override fun getAccumulator(): Accumulator = AccumulatorAvg()
}

object MIN : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_min")
override fun getAccumulator(): Accumulator = AccumulatorMin()
}

object MAX : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_max")
override fun getAccumulator(): Accumulator = AccumulatorMax()
}

object COUNT : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_count")
override fun getAccumulator(): Accumulator = AccumulatorCount()
}

object EVERY : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_every")
override fun getAccumulator(): Accumulator = AccumulatorEvery()
}

// TODO: Should we allow this?
object GROUP_AS : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_group_as")
override fun getAccumulator(): Accumulator = AccumulatorGroupAs()
}

object ANY : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_any")
override fun getAccumulator(): Accumulator = AccumulatorAnySome()
}

object SOME : Fn_COLL_AGG__BAG__ANY() {
override val signature = createSignature("coll_some")
override fun getAccumulator(): Accumulator = AccumulatorAnySome()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.partiql.spi.connector.sql.builtins.internal

import com.amazon.ion.Decimal
import org.partiql.errors.TypeCheckException
import org.partiql.spi.fn.Agg
import org.partiql.spi.fn.FnExperimental
import org.partiql.value.BoolValue
Expand Down Expand Up @@ -70,7 +71,7 @@ internal fun comparisonAccumulator(comparator: Comparator<PartiQLValue>): (Parti
@OptIn(PartiQLValueExperimental::class)
internal fun checkIsNumberType(funcName: String, value: PartiQLValue) {
if (!value.type.isNumber()) {
TODO("NEED TO HANDLE NUMBER TYPE CHECK FOR $value")
throw TypeCheckException("Expected NUMBER but received ${value.type}.")
}
}

Expand Down Expand Up @@ -123,7 +124,7 @@ private fun Long.checkOverflowPlus(other: Long): Number {
@OptIn(PartiQLValueExperimental::class)
internal fun checkIsBooleanType(funcName: String, value: PartiQLValue) {
if (value.type != PartiQLValueType.BOOL) {
TODO("NEED TO HANDLE")
throw TypeCheckException("Expected ${PartiQLValueType.BOOL} but received ${value.type}.")
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package org.partiql.spi.connector.sql.builtins.internal

import org.partiql.value.PartiQLValue
import org.partiql.value.PartiQLValueExperimental
import org.partiql.value.int64Value

@OptIn(PartiQLValueExperimental::class)
internal class AccumulatorCount : Accumulator() {

var count: Long = 0L

override fun nextValue(value: PartiQLValue) {
this.count += 1L
}

override fun value(): PartiQLValue = int64Value(count)
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ private val CONVERTERS = mapOf<Class<*>, (Number) -> Number>(
is Float -> bigDecimalOf(num)
is Double -> bigDecimalOf(num)
is BigDecimal -> bigDecimalOf(num)
is BigInteger -> bigDecimalOf(num)
else -> throw IllegalArgumentException(
"Unsupported number for decimal conversion: $num"
"Unsupported number for decimal conversion: $num (${num.javaClass.simpleName})"
)
}
}
Expand Down

0 comments on commit 40a065f

Please sign in to comment.