From 1a90eeb71ae8292410fc99d49c24fdafadcf9de7 Mon Sep 17 00:00:00 2001 From: Almog Gavra Date: Mon, 27 Apr 2020 12:01:22 -0700 Subject: [PATCH] fix: use schema in annotation as schema provider if present --- .../schema/ksql/SchemaConvertersTest.java | 63 +++++++++++++++++++ .../ksql/function/FunctionLoaderUtils.java | 25 +++++--- .../io/confluent/ksql/function/UdfLoader.java | 2 + .../confluent/ksql/function/UdtfLoader.java | 2 + .../ksql/function/UdfLoaderTest.java | 45 +++++++++++-- .../ksql/function/UdtfLoaderTest.java | 2 +- .../io/confluent/ksql/function/udf/Udf.java | 7 ++- 7 files changed, 128 insertions(+), 18 deletions(-) diff --git a/ksqldb-common/src/test/java/io/confluent/ksql/schema/ksql/SchemaConvertersTest.java b/ksqldb-common/src/test/java/io/confluent/ksql/schema/ksql/SchemaConvertersTest.java index 08280e2f7fd9..7c285ad3c47f 100644 --- a/ksqldb-common/src/test/java/io/confluent/ksql/schema/ksql/SchemaConvertersTest.java +++ b/ksqldb-common/src/test/java/io/confluent/ksql/schema/ksql/SchemaConvertersTest.java @@ -27,6 +27,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import io.confluent.ksql.function.types.ArrayType; +import io.confluent.ksql.function.types.MapType; +import io.confluent.ksql.function.types.ParamType; +import io.confluent.ksql.function.types.ParamTypes; +import io.confluent.ksql.function.types.StructType; import io.confluent.ksql.schema.ksql.types.SqlArray; import io.confluent.ksql.schema.ksql.types.SqlDecimal; import io.confluent.ksql.schema.ksql.types.SqlMap; @@ -96,6 +101,28 @@ public class SchemaConvertersTest { .put(SqlBaseType.STRUCT, Struct.class) .build(); + private static final BiMap SQL_TO_FUNCTION = ImmutableBiMap + .builder() + .put(SqlTypes.BOOLEAN, ParamTypes.BOOLEAN) + .put(SqlTypes.INTEGER, ParamTypes.INTEGER) + .put(SqlTypes.BIGINT, ParamTypes.LONG) + .put(SqlTypes.DOUBLE, ParamTypes.DOUBLE) + .put(SqlTypes.STRING, ParamTypes.STRING) + .put(SqlArray.of(SqlTypes.INTEGER), ArrayType.of(ParamTypes.INTEGER)) + .put(SqlDecimal.of(2, 1), ParamTypes.DECIMAL) + .put(SqlMap.of(SqlTypes.INTEGER), MapType.of(ParamTypes.INTEGER)) + .put(SqlStruct.builder() + .field("f0", SqlTypes.INTEGER) + .build(), + StructType.builder() + .field("f0", ParamTypes.INTEGER) + .build()) + .build(); + + private static final Set REQUIRES_SCHEMA_SPEC = ImmutableSet.of( + ParamTypes.DECIMAL + ); + private static final Schema STRUCT_LOGICAL_TYPE = SchemaBuilder.struct() .field("F0", SchemaBuilder.int32().optional().build()) .optional() @@ -277,4 +304,40 @@ public void shouldThrowOnUnknownJavaType() { // Then: assertThat(e.getMessage(), containsString("Unexpected java type: " + double.class)); } + + @Test + public void shouldCoverAllSqlToFunction() { + final Set tested = SQL_TO_FUNCTION.keySet().stream() + .map(SqlType::baseType) + .collect(Collectors.toSet()); + + final ImmutableSet allTypes = ImmutableSet.copyOf(SqlBaseType.values()); + + assertThat("If this test fails then there has been a new SQL type added and this test " + + "file needs updating to cover that new type", tested, is(allTypes)); + } + + @Test + public void shouldGetParamTypesForAllSqlTypes() { + for (final Entry entry : SQL_TO_FUNCTION.entrySet()) { + final SqlType sqlType = entry.getKey(); + final ParamType javaType = entry.getValue(); + final ParamType result = SchemaConverters.sqlToFunctionConverter().toFunctionType(sqlType); + assertThat(result, equalTo(javaType)); + } + } + + @Test + public void shouldGetSqlTypeForAllParamTypes() { + for (Entry entry : SQL_TO_FUNCTION.inverse().entrySet()) { + ParamType param = entry.getKey(); + if (REQUIRES_SCHEMA_SPEC.contains(param)) { + continue; + } + + SqlType sqlType = entry.getValue(); + assertThat(SchemaConverters.functionToSqlConverter().toSqlType(param), is(sqlType)); + } + } + } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java index d754ec8511ac..82189be1b20e 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/FunctionLoaderUtils.java @@ -17,9 +17,9 @@ import com.google.common.annotations.VisibleForTesting; import io.confluent.ksql.execution.function.UdfUtil; -import io.confluent.ksql.function.types.DecimalType; import io.confluent.ksql.function.types.GenericType; import io.confluent.ksql.function.types.ParamType; +import io.confluent.ksql.function.udf.Udf; import io.confluent.ksql.function.udf.UdfParameter; import io.confluent.ksql.function.udf.UdfSchemaProvider; import io.confluent.ksql.schema.ksql.SchemaConverters; @@ -180,17 +180,27 @@ static ParamType getReturnType( static SchemaProvider handleUdfReturnSchema( final Class theClass, final ParamType javaReturnSchema, + final String annotationSchema, + final SqlTypeParser parser, final String schemaProviderFunctionName, final String functionName, final boolean isVariadic ) { final Function, SqlType> schemaProvider; - if (!schemaProviderFunctionName.equals("")) { + if (!Udf.NO_SCHEMA_PROVIDER.equals(schemaProviderFunctionName)) { schemaProvider = handleUdfSchemaProviderAnnotation( schemaProviderFunctionName, theClass, functionName); - } else if (javaReturnSchema instanceof DecimalType) { - throw new KsqlException(String.format("Cannot load UDF %s. BigDecimal return type " - + "is not supported without a schema provider method.", functionName)); + } else if (!Udf.NO_SCHEMA.equals(annotationSchema)) { + schemaProvider = args -> parser.parse(annotationSchema).getSqlType(); + } else if (!GenericsUtil.hasGenerics(javaReturnSchema)) { + final SqlType sqlType; + try { + sqlType = SchemaConverters.functionToSqlConverter().toSqlType(javaReturnSchema); + } catch (final Exception e) { + throw new KsqlException("Cannot load UDF " + functionName + ". " + + javaReturnSchema + " return type is not supported without a schema annotation."); + } + schemaProvider = args -> sqlType; } else { schemaProvider = null; } @@ -210,10 +220,6 @@ static SchemaProvider handleUdfReturnSchema( return returnType; } - if (!GenericsUtil.hasGenerics(javaReturnSchema)) { - return SchemaConverters.functionToSqlConverter().toSqlType(javaReturnSchema); - } - final Map genericMapping = new HashMap<>(); for (int i = 0; i < Math.min(parameters.size(), arguments.size()); i++) { final ParamType schema = parameters.get(i); @@ -283,5 +289,4 @@ private static SqlType invokeSchemaProviderMethod( ), e); } } - } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdfLoader.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdfLoader.java index 2eb6f4875fbc..c81fae6c6c47 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdfLoader.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdfLoader.java @@ -151,6 +151,8 @@ private KsqlScalarFunction createFunction( .handleUdfReturnSchema( theClass, javaReturnSchema, + udfAnnotation.schema(), + typeParser, udfAnnotation.schemaProvider(), udfDescriptionAnnotation.name(), method.isVarArgs() diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdtfLoader.java b/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdtfLoader.java index 3dc6bec4bfc8..e99c4ab6160c 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdtfLoader.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/function/UdtfLoader.java @@ -142,6 +142,8 @@ private KsqlTableFunction createTableFunction( .handleUdfReturnSchema( method.getDeclaringClass(), outputType, + udtfAnnotation.schema(), + typeParser, udtfAnnotation.schemaProvider(), functionName.name(), method.isVarArgs()); diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java index 522fee0caf89..180d1f8f67b0 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdfLoaderTest.java @@ -61,6 +61,7 @@ import java.io.File; import java.lang.reflect.Field; import java.math.BigDecimal; +import java.math.RoundingMode; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; @@ -74,6 +75,7 @@ import org.apache.kafka.common.metrics.KafkaMetric; import org.apache.kafka.common.metrics.Metrics; import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.connect.data.Decimal; import org.apache.kafka.connect.data.Schema; import org.apache.kafka.connect.data.SchemaBuilder; import org.apache.kafka.connect.data.Struct; @@ -254,6 +256,21 @@ public void shouldLoadFunctionWithSchemaProvider() { assertThat(function.getReturnType(args), equalTo(decimal)); } + @Test + public void shouldLoadFunctionWithNestedDecimalSchema() { + // Given: + final UdfFactory returnDecimal = FUNC_REG.getUdfFactory(FunctionName.of("decimalstruct")); + + // When: + final KsqlScalarFunction function = returnDecimal.getFunction(ImmutableList.of()); + + // Then: + assertThat( + function.getReturnType(ImmutableList.of()), + equalTo(SqlStruct.builder().field("VAL", SqlDecimal.of(64, 2)).build())); + } + + @Test public void shouldThrowOnReturnTypeMismatch() { // Given: @@ -299,8 +316,8 @@ public void shouldThrowOnMissingAnnotation() throws ClassNotFoundException { // Then: assertThat(e.getMessage(), containsString( - "Cannot load UDF MissingAnnotation. BigDecimal return type " - + "is not supported without a schema provider method.")); + "Cannot load UDF MissingAnnotation. DECIMAL return type is " + + "not supported without a schema annotation.")); } @@ -358,9 +375,8 @@ public void shouldThrowOnReturnDecimalWithoutSchemaProvider() throws ClassNotFou // Then: assertThat(e.getMessage(), containsString( - "Cannot load UDF ReturnDecimalWithoutSchemaProvider. " - + "BigDecimal return type is not supported without a " - + "schema provider method.")); + "Cannot load UDF ReturnDecimalWithoutSchemaProvider. DECIMAL return type is not " + + "supported without a schema annotation.")); } @Test @@ -1356,6 +1372,25 @@ public SqlType provideSchema(final List params) { } } + @UdfDescription( + name = "DecimalStruct", + description = "A test-only UDF for testing nested DECIMAL in schema annotation") + public static class DecimalStructUdf { + + @Udf(schema = "STRUCT") + public Struct getDecimalStruct() { + final Schema schema = SchemaBuilder.struct() + .optional() + .field("VAL", + Decimal.builder(2).optional().parameter("connect.decimal.precision", "64").build()) + .build(); + + Struct struct = new Struct(schema); + struct.put("VAL", BigDecimal.valueOf(123.45).setScale(2, RoundingMode.CEILING)); + return struct; + } + } + @SuppressWarnings({"unused", "MethodMayBeStatic"}) // Invoked via reflection in test. @UdfDescription( name = "ReturnIncompatible", diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdtfLoaderTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdtfLoaderTest.java index dcc6eace63ee..77dfaca992a5 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdtfLoaderTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/function/UdtfLoaderTest.java @@ -287,7 +287,7 @@ functionRegistry, empty(), typeParser, true // Then: assertThat(e.getMessage(), containsString( - "Cannot load UDF bigDecimalNoSchemaProvider. BigDecimal return type is not supported without a schema provider method.")); + "Cannot load UDF bigDecimalNoSchemaProvider. DECIMAL return type is not supported without a schema annotation")); } @UdtfDescription(name = "badReturnUdtf", description = "whatever") diff --git a/ksqldb-udf/src/main/java/io/confluent/ksql/function/udf/Udf.java b/ksqldb-udf/src/main/java/io/confluent/ksql/function/udf/Udf.java index d0d2488ac603..e383356a597d 100644 --- a/ksqldb-udf/src/main/java/io/confluent/ksql/function/udf/Udf.java +++ b/ksqldb-udf/src/main/java/io/confluent/ksql/function/udf/Udf.java @@ -29,6 +29,9 @@ @Target(ElementType.METHOD) public @interface Udf { + String NO_SCHEMA = ""; + String NO_SCHEMA_PROVIDER = ""; + /** * The function description. * @@ -45,11 +48,11 @@ * the return value itself. For complex return types (e.g. {@code Struct} types), * this is required and will fail if not supplied. */ - String schema() default ""; + String schema() default NO_SCHEMA; /** * The name of the method that provides the return type of the UDF. * @return the name of the other method */ - String schemaProvider() default ""; + String schemaProvider() default NO_SCHEMA_PROVIDER; }