From 603af2dde9b0c93dc68bf8e9b7b55923a30bc554 Mon Sep 17 00:00:00 2001 From: "R. C. Howell" Date: Wed, 3 Jan 2024 12:53:30 -0800 Subject: [PATCH] Adds eq and not --- .../partiql/plugin/internal/fn/scalar/FnEq.kt | 243 ++++++++++++++++-- .../plugin/internal/fn/scalar/FnNot.kt | 14 +- .../plugin/internal/fn/scalar/FnScalarTest.kt | 2 +- 3 files changed, 228 insertions(+), 31 deletions(-) diff --git a/plugins/partiql-plugin/src/main/kotlin/org/partiql/plugin/internal/fn/scalar/FnEq.kt b/plugins/partiql-plugin/src/main/kotlin/org/partiql/plugin/internal/fn/scalar/FnEq.kt index 86c81a4d8e..2a7933cec3 100644 --- a/plugins/partiql-plugin/src/main/kotlin/org/partiql/plugin/internal/fn/scalar/FnEq.kt +++ b/plugins/partiql-plugin/src/main/kotlin/org/partiql/plugin/internal/fn/scalar/FnEq.kt @@ -7,6 +7,24 @@ import org.partiql.spi.function.PartiQLFunction import org.partiql.spi.function.PartiQLFunctionExperimental import org.partiql.types.function.FunctionParameter import org.partiql.types.function.FunctionSignature +import org.partiql.value.BagValue +import org.partiql.value.BinaryValue +import org.partiql.value.BlobValue +import org.partiql.value.BoolValue +import org.partiql.value.ByteValue +import org.partiql.value.CharValue +import org.partiql.value.ClobValue +import org.partiql.value.DateValue +import org.partiql.value.DecimalValue +import org.partiql.value.Float32Value +import org.partiql.value.Float64Value +import org.partiql.value.Int16Value +import org.partiql.value.Int32Value +import org.partiql.value.Int64Value +import org.partiql.value.Int8Value +import org.partiql.value.IntValue +import org.partiql.value.IntervalValue +import org.partiql.value.ListValue import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental import org.partiql.value.PartiQLValueType.ANY @@ -37,6 +55,14 @@ import org.partiql.value.PartiQLValueType.STRUCT import org.partiql.value.PartiQLValueType.SYMBOL import org.partiql.value.PartiQLValueType.TIME import org.partiql.value.PartiQLValueType.TIMESTAMP +import org.partiql.value.SexpValue +import org.partiql.value.StringValue +import org.partiql.value.StructValue +import org.partiql.value.SymbolValue +import org.partiql.value.TimeValue +import org.partiql.value.TimestampValue +import org.partiql.value.boolValue +import org.partiql.value.check @OptIn(PartiQLValueExperimental::class, PartiQLFunctionExperimental::class) internal object Fn_EQ__ANY_ANY__BOOL : PartiQLFunction.Scalar { @@ -52,8 +78,15 @@ internal object Fn_EQ__ANY_ANY__BOOL : PartiQLFunction.Scalar { isNullable = false, ) + // TODO ANY, ANY equals not clearly defined at the moment. override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0] + val rhs = args[1] + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -72,7 +105,13 @@ internal object Fn_EQ__BOOL_BOOL__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -91,7 +130,13 @@ internal object Fn_EQ__INT8_INT8__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -110,7 +155,13 @@ internal object Fn_EQ__INT16_INT16__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -129,7 +180,13 @@ internal object Fn_EQ__INT32_INT32__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -148,7 +205,13 @@ internal object Fn_EQ__INT64_INT64__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -167,7 +230,13 @@ internal object Fn_EQ__INT_INT__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -186,7 +255,13 @@ internal object Fn_EQ__DECIMAL_DECIMAL__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -205,7 +280,13 @@ internal object Fn_EQ__DECIMAL_ARBITRARY_DECIMAL_ARBITRARY__BOOL : PartiQLFuncti ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -224,7 +305,13 @@ internal object Fn_EQ__FLOAT32_FLOAT32__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -243,7 +330,13 @@ internal object Fn_EQ__FLOAT64_FLOAT64__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -262,7 +355,13 @@ internal object Fn_EQ__CHAR_CHAR__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -281,7 +380,13 @@ internal object Fn_EQ__STRING_STRING__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -300,7 +405,13 @@ internal object Fn_EQ__SYMBOL_SYMBOL__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -319,7 +430,13 @@ internal object Fn_EQ__BINARY_BINARY__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -338,7 +455,13 @@ internal object Fn_EQ__BYTE_BYTE__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -357,7 +480,13 @@ internal object Fn_EQ__BLOB_BLOB__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -376,7 +505,13 @@ internal object Fn_EQ__CLOB_CLOB__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -395,7 +530,13 @@ internal object Fn_EQ__DATE_DATE__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -414,7 +555,13 @@ internal object Fn_EQ__TIME_TIME__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -433,7 +580,13 @@ internal object Fn_EQ__TIMESTAMP_TIMESTAMP__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -452,7 +605,13 @@ internal object Fn_EQ__INTERVAL_INTERVAL__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check() + val rhs = args[1].check() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -471,7 +630,13 @@ internal object Fn_EQ__BAG_BAG__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check>() + val rhs = args[1].check>() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -490,7 +655,13 @@ internal object Fn_EQ__LIST_LIST__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check>() + val rhs = args[1].check>() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -509,7 +680,13 @@ internal object Fn_EQ__SEXP_SEXP__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check>() + val rhs = args[1].check>() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -528,7 +705,13 @@ internal object Fn_EQ__STRUCT_STRUCT__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0].check>() + val rhs = args[1].check>() + return if (lhs.isNull || rhs.isNull) { + boolValue(null) + } else { + boolValue(lhs == rhs) + } } } @@ -546,8 +729,11 @@ internal object Fn_EQ__NULL_NULL__BOOL : PartiQLFunction.Scalar { isNullable = false, ) + // TODO how does null comparison work? ie null.null == null.null or int8.null == null.null ?? override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + val lhs = args[0] + val rhs = args[1] + return boolValue(lhs.isNull == rhs.isNull) } } @@ -565,7 +751,8 @@ internal object Fn_EQ__MISSING_MISSING__BOOL : PartiQLFunction.Scalar { isNullable = false, ) + // TODO how does `=` work with MISSING? As of now, always false. override fun invoke(args: Array): PartiQLValue { - TODO("Function eq not implemented") + return boolValue(false) } } diff --git a/plugins/partiql-plugin/src/main/kotlin/org/partiql/plugin/internal/fn/scalar/FnNot.kt b/plugins/partiql-plugin/src/main/kotlin/org/partiql/plugin/internal/fn/scalar/FnNot.kt index 7ae8b1d919..9615880a50 100644 --- a/plugins/partiql-plugin/src/main/kotlin/org/partiql/plugin/internal/fn/scalar/FnNot.kt +++ b/plugins/partiql-plugin/src/main/kotlin/org/partiql/plugin/internal/fn/scalar/FnNot.kt @@ -3,14 +3,18 @@ package org.partiql.plugin.internal.fn.scalar +import org.partiql.errors.TypeCheckException import org.partiql.spi.function.PartiQLFunction import org.partiql.spi.function.PartiQLFunctionExperimental import org.partiql.types.function.FunctionParameter import org.partiql.types.function.FunctionSignature +import org.partiql.value.BoolValue import org.partiql.value.PartiQLValue import org.partiql.value.PartiQLValueExperimental import org.partiql.value.PartiQLValueType.BOOL import org.partiql.value.PartiQLValueType.MISSING +import org.partiql.value.boolValue +import org.partiql.value.check @OptIn(PartiQLValueExperimental::class, PartiQLFunctionExperimental::class) internal object Fn_NOT__BOOL__BOOL : PartiQLFunction.Scalar { @@ -24,7 +28,12 @@ internal object Fn_NOT__BOOL__BOOL : PartiQLFunction.Scalar { ) override fun invoke(args: Array): PartiQLValue { - TODO("Function not not implemented") + val value = args[0].check().value + return if (value == null) { + boolValue(null) + } else { + boolValue(!value) + } } } @@ -39,7 +48,8 @@ internal object Fn_NOT__MISSING__BOOL : PartiQLFunction.Scalar { isNullable = false, ) + // TODO determine what this behavior should be override fun invoke(args: Array): PartiQLValue { - TODO("Function not not implemented") + throw TypeCheckException() } } diff --git a/plugins/partiql-plugin/src/test/kotlin/org/partiql/plugin/internal/fn/scalar/FnScalarTest.kt b/plugins/partiql-plugin/src/test/kotlin/org/partiql/plugin/internal/fn/scalar/FnScalarTest.kt index d4c94e8f0d..0bdd678c46 100644 --- a/plugins/partiql-plugin/src/test/kotlin/org/partiql/plugin/internal/fn/scalar/FnScalarTest.kt +++ b/plugins/partiql-plugin/src/test/kotlin/org/partiql/plugin/internal/fn/scalar/FnScalarTest.kt @@ -24,7 +24,7 @@ class FnScalarTest { @JvmStatic fun fnGt() = listOf( - FnGt0.tests { + Fn_GT__INT8_INT8__BOOL.tests { case(arrayOf(int8Value(1), int8Value(0)), result = boolValue(true)) case(arrayOf(int8Value(0), int8Value(1)), result = boolValue(false)) case(arrayOf(int8Value(0), int8Value(0)), result = boolValue(false))