diff --git a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKStream.java b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKStream.java index 84ee7386dbdc..4ceaede1bdbc 100644 --- a/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKStream.java +++ b/ksql-engine/src/main/java/io/confluent/ksql/structured/SchemaKStream.java @@ -23,6 +23,8 @@ import com.google.common.collect.ImmutableList; import io.confluent.ksql.engine.rewrite.StatementRewriteForRowtime; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.codegen.CodeGenRunner; +import io.confluent.ksql.execution.codegen.ExpressionMetadata; import io.confluent.ksql.execution.context.QueryContext; import io.confluent.ksql.execution.context.QueryLoggerUtil; import io.confluent.ksql.execution.ddl.commands.KsqlTopic; @@ -75,7 +77,6 @@ import java.util.List; import java.util.Objects; import java.util.Optional; -import java.util.OptionalInt; import java.util.Set; import java.util.stream.Collectors; import org.apache.kafka.connect.data.Struct; @@ -690,22 +691,20 @@ public SchemaKStream flatMap( final List tableFunctions, final QueryContext.Stacker contextStacker ) { - final List tableFunctionAppliers = new ArrayList<>(); + final List tableFunctionAppliers = new ArrayList<>(tableFunctions.size()); + final CodeGenRunner codeGenRunner = + new CodeGenRunner(getSchema(), ksqlConfig, functionRegistry); for (FunctionCall functionCall: tableFunctions) { - final ColumnReferenceExp exp = (ColumnReferenceExp)functionCall.getArguments().get(0); - final ColumnName columnName = exp.getReference().name(); - final ColumnRef ref = ColumnRef.withoutSource(columnName); - final OptionalInt indexInInput = getSchema().valueColumnIndex(ref); - if (!indexInInput.isPresent()) { - throw new IllegalArgumentException("Can't find input column " + columnName); - } + final Expression expression = functionCall.getArguments().get(0); + final ExpressionMetadata expressionMetadata = + codeGenRunner.buildCodeGenFromParseTree(expression, "Table function"); final KsqlTableFunction tableFunction = UdtfUtil.resolveTableFunction( functionRegistry, functionCall, getSchema() ); final TableFunctionApplier tableFunctionApplier = - new TableFunctionApplier(tableFunction, indexInInput.getAsInt()); + new TableFunctionApplier(tableFunction, expressionMetadata); tableFunctionAppliers.add(tableFunctionApplier); } final StreamFlatMap step = ExecutionStepFactory.streamFlatMap( diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udtf/KudtfFlatMapper.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udtf/KudtfFlatMapper.java index 65a470af0c58..af4c9812a7f9 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udtf/KudtfFlatMapper.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udtf/KudtfFlatMapper.java @@ -37,7 +37,7 @@ public KudtfFlatMapper(final List tableFunctionAppliers) { /* This function zips results from multiple table functions together as described in KLIP-9 - in the design-proposals directory + in the design-proposals directory. */ @Override public Iterable apply(final GenericRow row) { diff --git a/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udtf/TableFunctionApplier.java b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udtf/TableFunctionApplier.java index 30e78afa981c..c872db939081 100644 --- a/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udtf/TableFunctionApplier.java +++ b/ksql-execution/src/main/java/io/confluent/ksql/execution/function/udtf/TableFunctionApplier.java @@ -16,6 +16,7 @@ import com.google.errorprone.annotations.Immutable; import io.confluent.ksql.GenericRow; +import io.confluent.ksql.execution.codegen.ExpressionMetadata; import io.confluent.ksql.function.KsqlTableFunction; import java.util.List; import java.util.Objects; @@ -26,16 +27,17 @@ @Immutable public class TableFunctionApplier { private final KsqlTableFunction tableFunction; - private final int argColumnIndex; + private final ExpressionMetadata expressionMetadata; - public TableFunctionApplier(final KsqlTableFunction tableFunction, final int argColumnIndex) { + public TableFunctionApplier(final KsqlTableFunction tableFunction, + final ExpressionMetadata expressionMetadata) { this.tableFunction = Objects.requireNonNull(tableFunction); - this.argColumnIndex = argColumnIndex; + this.expressionMetadata = Objects.requireNonNull(expressionMetadata); } @SuppressWarnings("unchecked") List apply(final GenericRow row) { - final List unexplodedValue = row.getColumnValue(argColumnIndex); - return tableFunction.flatMap(unexplodedValue); + final Object unexplodedVal = expressionMetadata.evaluate(row); + return tableFunction.flatMap(unexplodedVal); } } diff --git a/ksql-functional-tests/src/test/resources/query-validation-tests/table-functions.json b/ksql-functional-tests/src/test/resources/query-validation-tests/table-functions.json index 0f3b4c784c2d..84636c2fd495 100644 --- a/ksql-functional-tests/src/test/resources/query-validation-tests/table-functions.json +++ b/ksql-functional-tests/src/test/resources/query-validation-tests/table-functions.json @@ -115,6 +115,21 @@ {"topic": "OUTPUT", "key": "1", "value": {"KSQL_COL_0": 3, "KSQL_COL_1": 20}}, {"topic": "OUTPUT", "key": "1", "value": {"KSQL_COL_0": 4, "KSQL_COL_1": null}} ] + }, + { + "name": "table functions with complex expressions", + "statements": [ + "CREATE STREAM TEST (F0 INT, F1 INT, F2 INT, F3 INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT F0, EXPLODE(AS_ARRAY(ABS(F1 + F2), ABS(F2 + F3), ABS(F3 + F1))) FROM TEST;" + ], + "inputs": [ + {"topic": "test_topic", "key": 0, "value": {"ID": 0, "F0": 1, "F1": 10, "F2": 11, "F3": 12}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": "0", "value": {"F0": 1, "KSQL_COL_1": 21.0}}, + {"topic": "OUTPUT", "key": "0", "value": {"F0": 1, "KSQL_COL_1": 23.0}}, + {"topic": "OUTPUT", "key": "0", "value": {"F0": 1, "KSQL_COL_1": 22.0}} + ] } ] } \ No newline at end of file