Skip to content

Commit

Permalink
fix: use schema in annotation as schema provider if present
Browse files Browse the repository at this point in the history
  • Loading branch information
agavra committed Apr 28, 2020
1 parent abd1392 commit 1a90eeb
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.confluent.ksql.function.types.ArrayType;
import io.confluent.ksql.function.types.MapType;
import io.confluent.ksql.function.types.ParamType;
import io.confluent.ksql.function.types.ParamTypes;
import io.confluent.ksql.function.types.StructType;
import io.confluent.ksql.schema.ksql.types.SqlArray;
import io.confluent.ksql.schema.ksql.types.SqlDecimal;
import io.confluent.ksql.schema.ksql.types.SqlMap;
Expand Down Expand Up @@ -96,6 +101,28 @@ public class SchemaConvertersTest {
.put(SqlBaseType.STRUCT, Struct.class)
.build();

private static final BiMap<SqlType, ParamType> SQL_TO_FUNCTION = ImmutableBiMap
.<SqlType, ParamType>builder()
.put(SqlTypes.BOOLEAN, ParamTypes.BOOLEAN)
.put(SqlTypes.INTEGER, ParamTypes.INTEGER)
.put(SqlTypes.BIGINT, ParamTypes.LONG)
.put(SqlTypes.DOUBLE, ParamTypes.DOUBLE)
.put(SqlTypes.STRING, ParamTypes.STRING)
.put(SqlArray.of(SqlTypes.INTEGER), ArrayType.of(ParamTypes.INTEGER))
.put(SqlDecimal.of(2, 1), ParamTypes.DECIMAL)
.put(SqlMap.of(SqlTypes.INTEGER), MapType.of(ParamTypes.INTEGER))
.put(SqlStruct.builder()
.field("f0", SqlTypes.INTEGER)
.build(),
StructType.builder()
.field("f0", ParamTypes.INTEGER)
.build())
.build();

private static final Set<ParamType> REQUIRES_SCHEMA_SPEC = ImmutableSet.of(
ParamTypes.DECIMAL
);

private static final Schema STRUCT_LOGICAL_TYPE = SchemaBuilder.struct()
.field("F0", SchemaBuilder.int32().optional().build())
.optional()
Expand Down Expand Up @@ -277,4 +304,40 @@ public void shouldThrowOnUnknownJavaType() {
// Then:
assertThat(e.getMessage(), containsString("Unexpected java type: " + double.class));
}

@Test
public void shouldCoverAllSqlToFunction() {
final Set<SqlBaseType> tested = SQL_TO_FUNCTION.keySet().stream()
.map(SqlType::baseType)
.collect(Collectors.toSet());

final ImmutableSet<SqlBaseType> allTypes = ImmutableSet.copyOf(SqlBaseType.values());

assertThat("If this test fails then there has been a new SQL type added and this test "
+ "file needs updating to cover that new type", tested, is(allTypes));
}

@Test
public void shouldGetParamTypesForAllSqlTypes() {
for (final Entry<SqlType, ParamType> entry : SQL_TO_FUNCTION.entrySet()) {
final SqlType sqlType = entry.getKey();
final ParamType javaType = entry.getValue();
final ParamType result = SchemaConverters.sqlToFunctionConverter().toFunctionType(sqlType);
assertThat(result, equalTo(javaType));
}
}

@Test
public void shouldGetSqlTypeForAllParamTypes() {
for (Entry<ParamType, SqlType> entry : SQL_TO_FUNCTION.inverse().entrySet()) {
ParamType param = entry.getKey();
if (REQUIRES_SCHEMA_SPEC.contains(param)) {
continue;
}

SqlType sqlType = entry.getValue();
assertThat(SchemaConverters.functionToSqlConverter().toSqlType(param), is(sqlType));
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

import com.google.common.annotations.VisibleForTesting;
import io.confluent.ksql.execution.function.UdfUtil;
import io.confluent.ksql.function.types.DecimalType;
import io.confluent.ksql.function.types.GenericType;
import io.confluent.ksql.function.types.ParamType;
import io.confluent.ksql.function.udf.Udf;
import io.confluent.ksql.function.udf.UdfParameter;
import io.confluent.ksql.function.udf.UdfSchemaProvider;
import io.confluent.ksql.schema.ksql.SchemaConverters;
Expand Down Expand Up @@ -180,17 +180,27 @@ static ParamType getReturnType(
static SchemaProvider handleUdfReturnSchema(
final Class theClass,
final ParamType javaReturnSchema,
final String annotationSchema,
final SqlTypeParser parser,
final String schemaProviderFunctionName,
final String functionName,
final boolean isVariadic
) {
final Function<List<SqlType>, SqlType> schemaProvider;
if (!schemaProviderFunctionName.equals("")) {
if (!Udf.NO_SCHEMA_PROVIDER.equals(schemaProviderFunctionName)) {
schemaProvider = handleUdfSchemaProviderAnnotation(
schemaProviderFunctionName, theClass, functionName);
} else if (javaReturnSchema instanceof DecimalType) {
throw new KsqlException(String.format("Cannot load UDF %s. BigDecimal return type "
+ "is not supported without a schema provider method.", functionName));
} else if (!Udf.NO_SCHEMA.equals(annotationSchema)) {
schemaProvider = args -> parser.parse(annotationSchema).getSqlType();
} else if (!GenericsUtil.hasGenerics(javaReturnSchema)) {
final SqlType sqlType;
try {
sqlType = SchemaConverters.functionToSqlConverter().toSqlType(javaReturnSchema);
} catch (final Exception e) {
throw new KsqlException("Cannot load UDF " + functionName + ". "
+ javaReturnSchema + " return type is not supported without a schema annotation.");
}
schemaProvider = args -> sqlType;
} else {
schemaProvider = null;
}
Expand All @@ -210,10 +220,6 @@ static SchemaProvider handleUdfReturnSchema(
return returnType;
}

if (!GenericsUtil.hasGenerics(javaReturnSchema)) {
return SchemaConverters.functionToSqlConverter().toSqlType(javaReturnSchema);
}

final Map<GenericType, SqlType> genericMapping = new HashMap<>();
for (int i = 0; i < Math.min(parameters.size(), arguments.size()); i++) {
final ParamType schema = parameters.get(i);
Expand Down Expand Up @@ -283,5 +289,4 @@ private static SqlType invokeSchemaProviderMethod(
), e);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ private KsqlScalarFunction createFunction(
.handleUdfReturnSchema(
theClass,
javaReturnSchema,
udfAnnotation.schema(),
typeParser,
udfAnnotation.schemaProvider(),
udfDescriptionAnnotation.name(),
method.isVarArgs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,8 @@ private KsqlTableFunction createTableFunction(
.handleUdfReturnSchema(
method.getDeclaringClass(),
outputType,
udtfAnnotation.schema(),
typeParser,
udtfAnnotation.schemaProvider(),
functionName.name(),
method.isVarArgs());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import java.io.File;
import java.lang.reflect.Field;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
Expand All @@ -74,6 +75,7 @@
import org.apache.kafka.common.metrics.KafkaMetric;
import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.metrics.Sensor;
import org.apache.kafka.connect.data.Decimal;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.SchemaBuilder;
import org.apache.kafka.connect.data.Struct;
Expand Down Expand Up @@ -254,6 +256,21 @@ public void shouldLoadFunctionWithSchemaProvider() {
assertThat(function.getReturnType(args), equalTo(decimal));
}

@Test
public void shouldLoadFunctionWithNestedDecimalSchema() {
// Given:
final UdfFactory returnDecimal = FUNC_REG.getUdfFactory(FunctionName.of("decimalstruct"));

// When:
final KsqlScalarFunction function = returnDecimal.getFunction(ImmutableList.of());

// Then:
assertThat(
function.getReturnType(ImmutableList.of()),
equalTo(SqlStruct.builder().field("VAL", SqlDecimal.of(64, 2)).build()));
}


@Test
public void shouldThrowOnReturnTypeMismatch() {
// Given:
Expand Down Expand Up @@ -299,8 +316,8 @@ public void shouldThrowOnMissingAnnotation() throws ClassNotFoundException {

// Then:
assertThat(e.getMessage(), containsString(
"Cannot load UDF MissingAnnotation. BigDecimal return type "
+ "is not supported without a schema provider method."));
"Cannot load UDF MissingAnnotation. DECIMAL return type is " +
"not supported without a schema annotation."));

}

Expand Down Expand Up @@ -358,9 +375,8 @@ public void shouldThrowOnReturnDecimalWithoutSchemaProvider() throws ClassNotFou

// Then:
assertThat(e.getMessage(), containsString(
"Cannot load UDF ReturnDecimalWithoutSchemaProvider. "
+ "BigDecimal return type is not supported without a "
+ "schema provider method."));
"Cannot load UDF ReturnDecimalWithoutSchemaProvider. DECIMAL return type is not " +
"supported without a schema annotation."));
}

@Test
Expand Down Expand Up @@ -1356,6 +1372,25 @@ public SqlType provideSchema(final List<SqlType> params) {
}
}

@UdfDescription(
name = "DecimalStruct",
description = "A test-only UDF for testing nested DECIMAL in schema annotation")
public static class DecimalStructUdf {

@Udf(schema = "STRUCT<VAL DECIMAL(64,2)>")
public Struct getDecimalStruct() {
final Schema schema = SchemaBuilder.struct()
.optional()
.field("VAL",
Decimal.builder(2).optional().parameter("connect.decimal.precision", "64").build())
.build();

Struct struct = new Struct(schema);
struct.put("VAL", BigDecimal.valueOf(123.45).setScale(2, RoundingMode.CEILING));
return struct;
}
}

@SuppressWarnings({"unused", "MethodMayBeStatic"}) // Invoked via reflection in test.
@UdfDescription(
name = "ReturnIncompatible",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ functionRegistry, empty(), typeParser, true

// Then:
assertThat(e.getMessage(), containsString(
"Cannot load UDF bigDecimalNoSchemaProvider. BigDecimal return type is not supported without a schema provider method."));
"Cannot load UDF bigDecimalNoSchemaProvider. DECIMAL return type is not supported without a schema annotation"));
}

@UdtfDescription(name = "badReturnUdtf", description = "whatever")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
@Target(ElementType.METHOD)
public @interface Udf {

String NO_SCHEMA = "";
String NO_SCHEMA_PROVIDER = "";

/**
* The function description.
*
Expand All @@ -45,11 +48,11 @@
* the return value itself. For complex return types (e.g. {@code Struct} types),
* this is required and will fail if not supplied.
*/
String schema() default "";
String schema() default NO_SCHEMA;

/**
* The name of the method that provides the return type of the UDF.
* @return the name of the other method
*/
String schemaProvider() default "";
String schemaProvider() default NO_SCHEMA_PROVIDER;
}

0 comments on commit 1a90eeb

Please sign in to comment.