Skip to content

Commit

Permalink
feat: allow for decimals to be used as input types for UDFs (#3217)
Browse files Browse the repository at this point in the history
  • Loading branch information
agavra authored Aug 15, 2019
1 parent 11618d3 commit 4a2e4b9
Show file tree
Hide file tree
Showing 13 changed files with 171 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.confluent.ksql.schema.connect.SqlSchemaFormatter;
import io.confluent.ksql.util.DecimalUtil;
import io.confluent.ksql.util.KsqlException;
import java.util.ArrayList;
import java.util.Collection;
Expand Down Expand Up @@ -262,6 +263,7 @@ static final class Parameter {
.put(Type.MAP, Parameter::mapEquals)
.put(Type.ARRAY, Parameter::arrayEquals)
.put(Type.STRUCT, Parameter::structEquals)
.put(Type.BYTES, Parameter::bytesEquals)
.build();

private final Schema schema;
Expand Down Expand Up @@ -316,7 +318,6 @@ boolean accepts(final Schema argument, final Map<Schema, Schema> reservedGeneric
return Objects.equals(type, argument.type())
&& CUSTOM_SCHEMA_EQ.getOrDefault(type, (a, b) -> true).test(schema, argument)
&& Objects.equals(schema.version(), argument.version())
&& Objects.equals(schema.parameters(), argument.parameters())
&& Objects.deepEquals(schema.defaultValue(), argument.defaultValue());
}
// CHECKSTYLE_RULES.ON: BooleanExpressionComplexity
Expand Down Expand Up @@ -356,6 +357,13 @@ private static boolean structEquals(final Schema structA, final Schema structB)
|| Objects.equals(structA.fields(), structB.fields());
}

private static boolean bytesEquals(final Schema bytesA, final Schema bytesB) {
// from a UDF parameter perspective, all decimals are the same
// since they can all be cast to BigDecimal - other bytes types
// are not supported in UDFs
return DecimalUtil.isDecimal(bytesA) && DecimalUtil.isDecimal(bytesB);
}

@Override
public String toString() {
return FORMATTER.format(schema) + (isVararg ? "(VARARG)" : "");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ protected String visitBytes(final Schema schema) {
+ DecimalUtil.scale(schema) + ")";
}

throw new KsqlException("Cannot format bytes type: " + schema);
return "BYTES";
}

private final class Converter implements SchemaWalker.Visitor<String, String> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import com.google.common.collect.ImmutableList;
import io.confluent.ksql.function.udf.Kudf;
import io.confluent.ksql.util.DecimalUtil;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.KsqlException;
import java.util.Arrays;
Expand All @@ -30,6 +31,8 @@ public class UdfIndexTest {
private static final Schema STRUCT3_PERMUTE = SchemaBuilder.struct().field("d", INT).field("c", INT).build();
private static final Schema MAP1 = SchemaBuilder.map(STRING, STRING).build();
private static final Schema MAP2 = SchemaBuilder.map(STRING, INT).build();
private static final Schema DECIMAL1 = DecimalUtil.builder(2, 1).build();
private static final Schema DECIMAL2 = DecimalUtil.builder(3, 1).build();

private static final Schema GENERIC_LIST = GenericsUtil.array("T").build();
private static final Schema STRING_LIST = SchemaBuilder.array(STRING).build();
Expand Down Expand Up @@ -156,6 +159,32 @@ public void shouldChooseCorrectMap() {
assertThat(fun.getFunctionName(), equalTo(EXPECTED));
}

@Test
public void shouldChooseCorrectDecimal() {
// Given:
final KsqlFunction[] functions = new KsqlFunction[]{function(EXPECTED, false, DECIMAL1)};
Arrays.stream(functions).forEach(udfIndex::addFunction);

// When:
final KsqlFunction fun = udfIndex.getFunction(ImmutableList.of(DECIMAL1));

// Then:
assertThat(fun.getFunctionName(), equalTo(EXPECTED));
}

@Test
public void shouldAllowAnyDecimal() {
// Given:
final KsqlFunction[] functions = new KsqlFunction[]{function(EXPECTED, false, DECIMAL1)};
Arrays.stream(functions).forEach(udfIndex::addFunction);

// When:
final KsqlFunction fun = udfIndex.getFunction(ImmutableList.of(DECIMAL2));

// Then:
assertThat(fun.getFunctionName(), equalTo(EXPECTED));
}

@Test
public void shouldChooseCorrectPermutedStruct() {
// Given:
Expand Down Expand Up @@ -683,6 +712,20 @@ public void shouldNotMatchNestedGenericMethodWithAlreadyReservedTypes() {
udfIndex.getFunction(ImmutableList.of(INT_LIST, STRING_LIST));
}

@Test
public void shouldNotFindArbitraryBytesTypes() {
// Given:
final KsqlFunction[] functions = new KsqlFunction[]{function(EXPECTED, false, DECIMAL1)};
Arrays.stream(functions).forEach(udfIndex::addFunction);

// Expect:
expectedException.expect(KsqlException.class);
expectedException.expectMessage(is("Function 'name' does not accept parameters of types:"
+ "[BYTES]"));

// When:
udfIndex.getFunction(ImmutableList.of(SchemaBuilder.bytes().build()));
}

private static KsqlFunction function(
final String name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,19 @@ public void shouldFormatDecimal() {
assertThat(STRICT.format(DecimalUtil.builder(2, 1).build()), is("DECIMAL(2, 1)"));
}

@Test
public void shouldFormatOptionalBytes() {
assertThat(DEFAULT.format(Schema.OPTIONAL_BYTES_SCHEMA), is("BYTES"));
assertThat(STRICT.format(Schema.OPTIONAL_BYTES_SCHEMA), is("BYTES"));
}


@Test
public void shouldFormatBytes() {
assertThat(DEFAULT.format(Schema.BYTES_SCHEMA), is("BYTES"));
assertThat(STRICT.format(Schema.BYTES_SCHEMA), is("BYTES NOT NULL"));
}

@Test
public void shouldFormatArray() {
// Given:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import io.confluent.ksql.function.udf.json.JsonExtractStringKudf;
import io.confluent.ksql.function.udf.math.AbsKudf;
import io.confluent.ksql.function.udf.math.CeilKudf;
import io.confluent.ksql.function.udf.math.FloorKudf;
import io.confluent.ksql.function.udf.math.RandomKudf;
import io.confluent.ksql.function.udf.math.RoundKudf;
import io.confluent.ksql.function.udf.string.ConcatKudf;
Expand Down Expand Up @@ -238,12 +237,6 @@ private void addMathFunctions() {
"CEIL",
CeilKudf.class));

addBuiltInFunction(KsqlFunction.createLegacyBuiltIn(
Schema.OPTIONAL_FLOAT64_SCHEMA,
Collections.singletonList(Schema.OPTIONAL_FLOAT64_SCHEMA),
"FLOOR",
FloorKudf.class));

addBuiltInFunction(KsqlFunction.createLegacyBuiltIn(
Schema.OPTIONAL_INT64_SCHEMA,
Collections.singletonList(Schema.OPTIONAL_FLOAT64_SCHEMA),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
package io.confluent.ksql.function;

import com.google.common.collect.ImmutableMap;
import io.confluent.ksql.util.DecimalUtil;
import io.confluent.ksql.util.KsqlException;
import java.lang.reflect.GenericArrayType;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.math.BigDecimal;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
Expand All @@ -40,6 +42,10 @@ public final class UdfUtil {
.put(long.class, SchemaBuilder::int64)
.put(Double.class, () -> SchemaBuilder.float64().optional())
.put(double.class, SchemaBuilder::float64)
// from the UDF perspective, all Decimal schemas are the same (BigDecimal) in Java
// so we arbitrarily choose DECIMAL(1,0). if we migrate to use a type system dedicated
// for UDFs, we can update this to be a "generic decimal"
.put(BigDecimal.class, () -> DecimalUtil.builder(1, 0).optional())
.build();

private UdfUtil() {
Expand Down Expand Up @@ -93,7 +99,9 @@ static Schema getSchemaFromType(final Type type, final String name, final String
schema = GenericsUtil.generic(((TypeVariable) type).getName());
} else {
schema = typeToSchema.getOrDefault(type, () -> handleParametrizedType(type)).get();
schema.name(name);
if (schema.name() == null) {
schema.name(name);
}
}

schema.doc(doc);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright 2019 Confluent Inc.
*
* Licensed under the Confluent Community License (the "License"); you may not use
* this file except in compliance with the License. You may obtain a copy of the
* License at
*
* http://www.confluent.io/confluent-community-license
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*/

package io.confluent.ksql.function.udf.math;

import io.confluent.ksql.function.udf.Udf;
import io.confluent.ksql.function.udf.UdfDescription;
import io.confluent.ksql.function.udf.UdfParameter;
import java.math.BigDecimal;

@UdfDescription(name = "Floor", description = Floor.DESCRIPTION)
public class Floor {

static final String DESCRIPTION = "Returns the largest integer less than or equal to the "
+ "specified numeric expression. NOTE: for backwards compatibility, this returns a DOUBLE "
+ "that has a mantissa of zero.";


@Udf
public Double floor(@UdfParameter final Integer val) {
return (val == null) ? null : Math.floor(val);
}

@Udf
public Double floor(@UdfParameter final Long val) {
return (val == null) ? null : Math.floor(val);
}

@Udf
public Double floor(@UdfParameter final Double val) {
return (val == null) ? null : Math.floor(val);
}

@Udf
public Double floor(@UdfParameter final BigDecimal val) {
return (val == null) ? null : Math.floor(val.doubleValue());
}

}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ public void shouldHaveBuiltInUDFRegistered() {
// String UDF
"LCASE", "UCASE", "CONCAT", "TRIM", "IFNULL", "LEN",
// Math UDF
"ABS", "CEIL", "FLOOR", "ROUND", "RANDOM",
"ABS", "CEIL", "ROUND", "RANDOM",
// JSON UDF
"EXTRACTJSONFIELD", "ARRAYCONTAINS",
// Struct UDF
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalToIgnoringCase;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
Expand All @@ -34,6 +35,7 @@
import io.confluent.ksql.function.udf.Udf;
import io.confluent.ksql.function.udf.UdfDescription;
import io.confluent.ksql.function.udf.UdfParameter;
import io.confluent.ksql.util.DecimalUtil;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.KsqlException;
import java.io.File;
Expand Down Expand Up @@ -133,6 +135,19 @@ public void shouldLoadStructUdafs() {
equalTo(new Struct(schema).put("A", 1).put("B", 2)));
}

@Test
public void shouldLoadDecimalUdfs() {
// Given:
final Schema schema = DecimalUtil.builder(2, 1).optional().build();

// When:
final KsqlFunction fun = FUNC_REG.getUdfFactory("floor")
.getFunction(ImmutableList.of(schema));

// Then:
assertThat(fun.getFunctionName(), equalToIgnoringCase("floor"));
}

@Test
public void shouldLoadFunctionsFromJarsInPluginDir() {
final UdfFactory toString = FUNC_REG.getUdfFactory("tostring");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;

import io.confluent.ksql.util.DecimalUtil;
import io.confluent.ksql.util.KsqlException;
import java.lang.reflect.Type;
import java.math.BigDecimal;
import java.util.List;
import java.util.Map;
import org.apache.kafka.connect.data.Schema;
Expand Down Expand Up @@ -113,6 +115,12 @@ public void shouldGetFloatSchemaForDoublePrimitiveClass() {
equalTo(Schema.FLOAT64_SCHEMA));
}

@Test
public void shouldGetDecimalSchemaForBigDecimalClass() {
assertThat(UdfUtil.getSchemaFromType(BigDecimal.class).name(),
is(DecimalUtil.builder(2, 1).name()));
}

@Test
public void shouldGetMapSchemaFromMapClass() throws NoSuchMethodException {
final Type type = getClass().getDeclaredMethod("mapType", Map.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import static org.hamcrest.Matchers.is;

import io.confluent.ksql.function.InternalFunctionRegistry;
import io.confluent.ksql.function.TestFunctionRegistry;
import io.confluent.ksql.metastore.MetaStore;
import io.confluent.ksql.metastore.model.DataSource;
import io.confluent.ksql.metastore.model.MetaStoreMatchers.OptionalMatchers;
Expand Down Expand Up @@ -50,7 +51,7 @@ public class LogicalPlannerTest {

@Before
public void init() {
metaStore = MetaStoreFixture.getNewMetaStore(new InternalFunctionRegistry());
metaStore = MetaStoreFixture.getNewMetaStore(TestFunctionRegistry.INSTANCE.get());
ksqlConfig = new KsqlConfig(Collections.emptyMap());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,25 @@
{"topic": "OUTPUT", "value": {"I": 0, "L": 0, "D": 0}},
{"topic": "OUTPUT", "value": {"I": 1, "L": 1, "D": 1}}
]
},
{
"name": "floor",
"statements": [
"CREATE STREAM INPUT (i INT, l BIGINT, d DOUBLE, b DECIMAL(2,1)) WITH (kafka_topic='input', value_format='AVRO');",
"CREATE STREAM OUTPUT AS SELECT floor(i) i, floor(l) l, floor(d) d, floor(b) b FROM INPUT;"
],
"inputs": [
{"topic": "input", "value": {"i": null, "l": null, "d": null}},
{"topic": "input", "value": {"i": -1, "l": -2, "d": -3.1, "b": "-3.1"}},
{"topic": "input", "value": {"i": 0, "l": 0, "d": 0.0, "b": "0.0"}},
{"topic": "input", "value": {"i": 1, "l": 2, "d": 3.1, "b": "3.1"}}
],
"outputs": [
{"topic": "OUTPUT", "value": {"I": null, "L": null, "D": null, "B": null}},
{"topic": "OUTPUT", "value": {"I": -1.0, "L": -2.0, "D": -4.0, "B": -4.0}},
{"topic": "OUTPUT", "value": {"I": 0.0, "L": 0.0, "D": 0.0, "B": 0.0}},
{"topic": "OUTPUT", "value": {"I": 1.0, "L": 2.0, "D": 3.0, "B": 3.0}}
]
}
]
}

0 comments on commit 4a2e4b9

Please sign in to comment.