From 8d17dde25147b08474fc662252b957f82d007cb5 Mon Sep 17 00:00:00 2001 From: Andy Coates Date: Mon, 16 Mar 2020 17:51:35 +0000 Subject: [PATCH] refactor: projection expression handling Prep for https://github.com/confluentinc/ksql/issues/4749. This commit changes the way the engine resolves '*' in a projection, e.g. `SELECT * FROM X;`. Previously, the `Analyzer` was responsible for expanding the `*` into the set of columns of each source. However, this code was getting complicated and would be much more complicated once the key column can have any name, (https://github.com/confluentinc/ksql/issues/3536). The complexity comes about because the `Analyzer` would need to determine the presence of joins, group bys, partition bys, etc, which can effect how `*` is resolved. This logic duplicates the logic in the `LogicalPlanner` and `PlanNode` sub-classes. With this commit sees the logical plan and planner being responsible for resolving any `*` in the projection. This is achieved by asking the parent of the projection node to resolve the `*` into the set of columns. Parent node types that do not know how to resolve the `*`, e.g. `FilterNode`, forward requests to their parents. In this way, the resolution request ripples up the logical plan until it reaches a `DataSourceNode`, which can resolve the `*` into a list of columns. `JoinNode` knows how forward `*`, `left.*` and `right.*` appropriately. Previously, the list of `SelectExpressions` was passed down from parent `PlanNode` to child, allowing some nodes to rewrite the expressions. For example, `FlatMapNode` would rewrite any expression involving a TableFunction to use the internal names like `KSQL_SYNTH_0`. With this commit this is no longer necessary. Instead, when building a projection node the planner asks it's parent node to resolve any selects, allowing the parent to perform any rewrite. At the moment, the planner is still responsible for much of this work. In the future, this logic may move into the plan itself. However, such a change would increase the complexity of this commit. --- .../ksql/analyzer/AggregateAnalyzer.java | 189 ++- .../io/confluent/ksql/analyzer/Analysis.java | 47 +- .../io/confluent/ksql/analyzer/Analyzer.java | 180 +-- .../ksql/analyzer/ImmutableAnalysis.java | 14 +- .../analyzer/MutableAggregateAnalysis.java | 5 +- .../ksql/analyzer/QueryAnalyzer.java | 192 --- .../ksql/analyzer/RewrittenAnalysis.java | 61 +- .../io/confluent/ksql/engine/QueryEngine.java | 8 +- .../ksql/planner/LogicalPlanner.java | 172 ++- .../ksql/planner/plan/AggregateNode.java | 74 +- .../ksql/planner/plan/DataSourceNode.java | 49 +- .../ksql/planner/plan/FilterNode.java | 2 +- .../ksql/planner/plan/FlatMapNode.java | 41 +- .../confluent/ksql/planner/plan/JoinNode.java | 16 +- .../plan/KsqlStructuredDataOutputNode.java | 20 +- .../ksql/planner/plan/OutputNode.java | 2 +- .../confluent/ksql/planner/plan/PlanNode.java | 60 +- .../ksql/planner/plan/ProjectNode.java | 56 +- .../ksql/planner/plan/RepartitionNode.java | 2 +- .../ksql/analyzer/AggregateAnalyzerTest.java | 293 +++-- .../ksql/analyzer/AnalyzerFunctionalTest.java | 306 +---- .../ColumnReferenceValidatorTest.java | 8 - .../analyzer/QueryAnalyzerFunctionalTest.java | 428 +------ .../ksql/analyzer/QueryAnalyzerTest.java | 4 - .../ksql/codegen/CodeGenRunnerTest.java | 1086 ----------------- .../ksql/function/udf/WhenCondition.java | 31 + .../ksql/function/udf/WhenResult.java | 31 + .../ksql/planner/plan/DataSourceNodeTest.java | 7 +- .../ksql/planner/plan/JoinNodeTest.java | 18 - .../KsqlStructuredDataOutputNodeTest.java | 18 +- .../ksql/testutils/AnalysisTestUtil.java | 27 +- .../ksql/execution/util/ComparisonUtil.java | 96 +- .../execution/util/ComparisonUtilTest.java | 27 +- .../query-validation-tests/array.json | 13 + .../binary-arithmetic.json | 23 + .../binary-comparison.json | 257 ++++ .../case-expression.json | 47 +- .../query-validation-tests/cast.json | 39 + .../query-validation-tests/concat.json | 30 +- .../query-validation-tests/group-by.json | 102 +- .../resources/query-validation-tests/map.json | 39 + .../query-validation-tests/null.json | 81 ++ .../query-validation-tests/string.json | 32 + .../table-functions.json | 51 + .../resources/query-validation-tests/udf.json | 92 ++ .../query-validation-tests/window-bounds.json | 12 +- ...eries-against-materialized-aggregates.json | 20 - .../push-queries.json | 4 +- .../server/execution/PullQueryExecutor.java | 41 +- 49 files changed, 1826 insertions(+), 2627 deletions(-) delete mode 100644 ksqldb-engine/src/test/java/io/confluent/ksql/codegen/CodeGenRunnerTest.java create mode 100644 ksqldb-engine/src/test/java/io/confluent/ksql/function/udf/WhenCondition.java create mode 100644 ksqldb-engine/src/test/java/io/confluent/ksql/function/udf/WhenResult.java create mode 100644 ksqldb-functional-tests/src/test/resources/query-validation-tests/binary-arithmetic.json create mode 100644 ksqldb-functional-tests/src/test/resources/query-validation-tests/binary-comparison.json create mode 100644 ksqldb-functional-tests/src/test/resources/query-validation-tests/map.json create mode 100644 ksqldb-functional-tests/src/test/resources/query-validation-tests/null.json create mode 100644 ksqldb-functional-tests/src/test/resources/query-validation-tests/udf.json diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/AggregateAnalyzer.java b/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/AggregateAnalyzer.java index 3ddf8ab90ca0..964f8d8ce7fc 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/AggregateAnalyzer.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/AggregateAnalyzer.java @@ -1,5 +1,5 @@ /* - * Copyright 2018 Confluent Inc. + * Copyright 2020 Confluent Inc. * * Licensed under the Confluent Community License (the "License"); you may not use * this file except in compliance with the License. You may obtain a copy of the @@ -15,47 +15,74 @@ package io.confluent.ksql.analyzer; +import static java.util.Objects.requireNonNull; + import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; +import com.google.common.collect.Sets.SetView; import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.expression.tree.FunctionCall; import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.TraversalExpressionVisitor; import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp; +import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.util.KsqlException; import io.confluent.ksql.util.SchemaUtil; import java.util.HashSet; -import java.util.Objects; +import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.function.BiConsumer; +import java.util.stream.Collectors; -class AggregateAnalyzer { +public class AggregateAnalyzer { - private final MutableAggregateAnalysis aggregateAnalysis; - private final QualifiedColumnReferenceExp defaultArgument; private final FunctionRegistry functionRegistry; - private final boolean hasWindowExpression; - AggregateAnalyzer( - final MutableAggregateAnalysis aggregateAnalysis, - final QualifiedColumnReferenceExp defaultArgument, - final boolean hasWindowExpression, + public AggregateAnalyzer( final FunctionRegistry functionRegistry ) { - this.aggregateAnalysis = Objects.requireNonNull(aggregateAnalysis, "aggregateAnalysis"); - this.defaultArgument = Objects.requireNonNull(defaultArgument, "defaultArgument"); - this.functionRegistry = Objects.requireNonNull(functionRegistry, "functionRegistry"); - this.hasWindowExpression = hasWindowExpression; + this.functionRegistry = requireNonNull(functionRegistry, "functionRegistry"); } - void processSelect(final Expression expression) { + public AggregateAnalysisResult analyze( + final ImmutableAnalysis analysis, + final List finalProjection + ) { + if (analysis.getGroupByExpressions().isEmpty()) { + throw new IllegalArgumentException("Not an aggregate query"); + } + + final Context context = new Context(analysis); + + finalProjection.stream() + .map(SelectExpression::getExpression) + .forEach(exp -> processSelect(exp, context)); + + analysis.getWhereExpression() + .ifPresent(exp -> processWhere(exp, context)); + + analysis.getGroupByExpressions() + .forEach(exp -> processGroupBy(exp, context)); + + analysis.getHavingExpression() + .ifPresent(exp -> processHaving(exp, context)); + + enforceAggregateRules(context); + + return context.aggregateAnalysis; + } + + private void processSelect(final Expression expression, final Context context) { final Set nonAggParams = new HashSet<>(); - final AggregateVisitor visitor = new AggregateVisitor((aggFuncName, node) -> { + final AggregateVisitor visitor = new AggregateVisitor(context, (aggFuncName, node) -> { if (aggFuncName.isPresent()) { - throwOnWindowBoundColumnIfWindowedAggregate(node); + throwOnWindowBoundColumnIfWindowedAggregate(node, context); } else { nonAggParams.add(node); } @@ -64,45 +91,51 @@ void processSelect(final Expression expression) { visitor.process(expression, null); if (visitor.visitedAggFunction) { - aggregateAnalysis.addAggregateSelectField(nonAggParams); + context.aggregateAnalysis.addAggregateSelectField(nonAggParams); } else { - aggregateAnalysis.addNonAggregateSelectExpression(expression, nonAggParams); + context.aggregateAnalysis.addNonAggregateSelectExpression(expression, nonAggParams); } + + context.aggregateAnalysis.addFinalSelectExpression(expression); } - void processGroupBy(final Expression expression) { - final AggregateVisitor visitor = new AggregateVisitor((aggFuncName, node) -> { + private void processGroupBy(final Expression expression, final Context context) { + final AggregateVisitor visitor = new AggregateVisitor(context, (aggFuncName, node) -> { if (aggFuncName.isPresent()) { throw new KsqlException("GROUP BY does not support aggregate functions: " + aggFuncName.get().text() + " is an aggregate function."); } - throwOnWindowBoundColumnIfWindowedAggregate(node); + throwOnWindowBoundColumnIfWindowedAggregate(node, context); }); visitor.process(expression, null); } - void processWhere(final Expression expression) { - final AggregateVisitor visitor = new AggregateVisitor((aggFuncName, node) -> { - throwOnWindowBoundColumnIfWindowedAggregate(node); - }); + private void processWhere(final Expression expression, final Context context) { + final AggregateVisitor visitor = new AggregateVisitor(context, (aggFuncName, node) -> + throwOnWindowBoundColumnIfWindowedAggregate(node, context)); visitor.process(expression, null); } - void processHaving(final Expression expression) { - final AggregateVisitor visitor = new AggregateVisitor((aggFuncName, node) -> { - throwOnWindowBoundColumnIfWindowedAggregate(node); + private void processHaving(final Expression expression, final Context context) { + final AggregateVisitor visitor = new AggregateVisitor(context, (aggFuncName, node) -> { + throwOnWindowBoundColumnIfWindowedAggregate(node, context); if (!aggFuncName.isPresent()) { - aggregateAnalysis.addNonAggregateHavingField(node); + context.aggregateAnalysis.addNonAggregateHavingField(node); } }); visitor.process(expression, null); + + context.aggregateAnalysis.setHavingExpression(expression); } - private void throwOnWindowBoundColumnIfWindowedAggregate(final ColumnReferenceExp node) { + private static void throwOnWindowBoundColumnIfWindowedAggregate( + final ColumnReferenceExp node, + final Context context + ) { // Window bounds are supported for operations on windowed sources - if (!hasWindowExpression) { + if (!context.analysis.getWindowExpression().isPresent()) { return; } @@ -117,18 +150,91 @@ private void throwOnWindowBoundColumnIfWindowedAggregate(final ColumnReferenceEx } } + private static void enforceAggregateRules( + final Context context + ) { + if (context.aggregateAnalysis.getAggregateFunctions().isEmpty()) { + throw new KsqlException( + "GROUP BY requires columns using aggregate functions in SELECT clause."); + } + + final Set groupByExprs = getGroupByExpressions(context.analysis); + + final List unmatchedSelects = context.aggregateAnalysis + .getNonAggregateSelectExpressions() + .entrySet() + .stream() + // Remove any that exactly match a group by expression: + .filter(e -> !groupByExprs.contains(e.getKey())) + // Remove any that are constants, + // or expressions where all params exactly match a group by expression: + .filter(e -> !Sets.difference(e.getValue(), groupByExprs).isEmpty()) + .map(Map.Entry::getKey) + .map(Expression::toString) + .sorted() + .collect(Collectors.toList()); + + if (!unmatchedSelects.isEmpty()) { + throw new KsqlException( + "Non-aggregate SELECT expression(s) not part of GROUP BY: " + unmatchedSelects); + } + + final SetView unmatchedSelectsAgg = Sets + .difference(context.aggregateAnalysis.getAggregateSelectFields(), groupByExprs); + if (!unmatchedSelectsAgg.isEmpty()) { + throw new KsqlException( + "Column used in aggregate SELECT expression(s) " + + "outside of aggregate functions not part of GROUP BY: " + unmatchedSelectsAgg); + } + + final Set havingColumns = context.aggregateAnalysis + .getNonAggregateHavingFields().stream() + .map(ref -> new UnqualifiedColumnReferenceExp(ref.getColumnName())) + .collect(Collectors.toSet()); + + final Set havingOnly = Sets.difference(havingColumns, groupByExprs); + if (!havingOnly.isEmpty()) { + throw new KsqlException( + "Non-aggregate HAVING expression not part of GROUP BY: " + havingOnly); + } + } + + private static Set getGroupByExpressions( + final ImmutableAnalysis analysis + ) { + if (!analysis.getWindowExpression().isPresent()) { + return ImmutableSet.copyOf(analysis.getGroupByExpressions()); + } + + // Add in window bounds columns as implicit group by columns: + final Set windowBoundColumnRefs = + SchemaUtil.windowBoundsColumnNames().stream() + .map(UnqualifiedColumnReferenceExp::new) + .collect(Collectors.toSet()); + + return ImmutableSet.builder() + .addAll(analysis.getGroupByExpressions()) + .addAll(windowBoundColumnRefs) + .build(); + } + private final class AggregateVisitor extends TraversalExpressionVisitor { private final BiConsumer, ColumnReferenceExp> dereferenceCollector; + private final ColumnReferenceExp defaultArgument; + private final MutableAggregateAnalysis aggregateAnalysis; + private Optional aggFunctionName = Optional.empty(); private boolean visitedAggFunction = false; private AggregateVisitor( + final Context context, final BiConsumer, ColumnReferenceExp> dereferenceCollector ) { - this.dereferenceCollector = - Objects.requireNonNull(dereferenceCollector, "dereferenceCollector"); + this.defaultArgument = context.analysis.getDefaultArgument(); + this.aggregateAnalysis = context.aggregateAnalysis; + this.dereferenceCollector = requireNonNull(dereferenceCollector, "dereferenceCollector"); } @Override @@ -180,12 +286,17 @@ public Void visitQualifiedColumnReference( final QualifiedColumnReferenceExp node, final Void context ) { - dereferenceCollector.accept(aggFunctionName, node); + throw new UnsupportedOperationException("Should of been converted to unqualified"); + } + } - if (!SchemaUtil.isWindowBound(node.getColumnName())) { - aggregateAnalysis.addRequiredColumn(node); - } - return null; + private static final class Context { + + final ImmutableAnalysis analysis; + final MutableAggregateAnalysis aggregateAnalysis = new MutableAggregateAnalysis(); + + Context(final ImmutableAnalysis analysis) { + this.analysis = requireNonNull(analysis, "analysis"); } } } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/Analysis.java b/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/Analysis.java index d8aa5c2ea8d9..42ec8e1babf0 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/Analysis.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/Analysis.java @@ -25,7 +25,6 @@ import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.expression.tree.FunctionCall; import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp; -import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.metastore.model.DataSource; import io.confluent.ksql.metastore.model.KsqlStream; import io.confluent.ksql.metastore.model.KsqlTable; @@ -33,6 +32,7 @@ import io.confluent.ksql.name.SourceName; import io.confluent.ksql.parser.properties.with.CreateSourceAsProperties; import io.confluent.ksql.parser.tree.ResultMaterialization; +import io.confluent.ksql.parser.tree.SelectItem; import io.confluent.ksql.parser.tree.WindowExpression; import io.confluent.ksql.parser.tree.WithinExpression; import io.confluent.ksql.planner.plan.JoinNode; @@ -61,12 +61,11 @@ public class Analysis implements ImmutableAnalysis { private final List fromDataSources = new ArrayList<>(); private Optional joinInfo = Optional.empty(); private Optional whereExpression = Optional.empty(); - private final List selectExpressions = new ArrayList<>(); + private final List selectItems = new ArrayList<>(); private final Set selectColumnNames = new HashSet<>(); private final List groupByExpressions = new ArrayList<>(); private Optional windowExpression = Optional.empty(); private Optional partitionBy = Optional.empty(); - private ImmutableSet serdeOptions = ImmutableSet.of(); private Optional havingExpression = Optional.empty(); private OptionalInt limitClause = OptionalInt.empty(); private CreateSourceAsProperties withProperties = CreateSourceAsProperties.none(); @@ -89,8 +88,8 @@ ResultMaterialization getResultMaterialization() { return resultMaterialization; } - void addSelectItem(final Expression expression, final ColumnName alias) { - selectExpressions.add(SelectExpression.of(alias, expression)); + void addSelectItem(final SelectItem selectItem) { + selectItems.add(Objects.requireNonNull(selectItem, "selectItem")); } void addSelectColumnRefs(final Collection columnNames) { @@ -116,12 +115,12 @@ void setWhereExpression(final Expression whereExpression) { } @Override - public List getSelectExpressions() { - return Collections.unmodifiableList(selectExpressions); + public List getSelectItems() { + return Collections.unmodifiableList(selectItems); } @Override - public Set getSelectColumnRefs() { + public Set getSelectColumnNames() { return Collections.unmodifiableSet(selectColumnNames); } @@ -143,7 +142,8 @@ void setWindowExpression(final WindowExpression windowExpression) { this.windowExpression = Optional.of(windowExpression); } - Optional getHavingExpression() { + @Override + public Optional getHavingExpression() { return havingExpression; } @@ -209,20 +209,12 @@ void addDataSource(final SourceName alias, final DataSource dataSource) { fromDataSources.add(new AliasedDataSource(alias, dataSource)); } - QualifiedColumnReferenceExp getDefaultArgument() { + @Override + public QualifiedColumnReferenceExp getDefaultArgument() { final SourceName alias = fromDataSources.get(0).getAlias(); return new QualifiedColumnReferenceExp(alias, SchemaUtil.ROWTIME_NAME); } - void setSerdeOptions(final Set serdeOptions) { - this.serdeOptions = ImmutableSet.copyOf(serdeOptions); - } - - @Override - public Set getSerdeOptions() { - return serdeOptions; - } - void setProperties(final CreateSourceAsProperties properties) { withProperties = requireNonNull(properties, "properties"); } @@ -265,23 +257,28 @@ public static final class Into { private final SourceName name; private final KsqlTopic topic; private final boolean create; + private final ImmutableSet defaultSerdeOptions; - public static Into of( + public static Into of( final SourceName name, final boolean create, - final KsqlTopic topic + final KsqlTopic topic, + final Set defaultSerdeOptions ) { - return new Into(name, create, topic); + return new Into(name, create, topic, defaultSerdeOptions); } private Into( final SourceName name, final boolean create, - final KsqlTopic topic + final KsqlTopic topic, + final Set defaultSerdeOptions ) { this.name = requireNonNull(name, "name"); this.create = create; this.topic = requireNonNull(topic, "topic"); + this.defaultSerdeOptions = ImmutableSet + .copyOf(requireNonNull(defaultSerdeOptions, "defaultSerdeOptions")); } public SourceName getName() { @@ -295,6 +292,10 @@ public boolean isCreate() { public KsqlTopic getKsqlTopic() { return topic; } + + public ImmutableSet getDefaultSerdeOptions() { + return defaultSerdeOptions; + } } @Immutable diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java b/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java index 7c04029dfa11..aaa71943c86f 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/Analyzer.java @@ -17,7 +17,6 @@ import static java.util.Objects.requireNonNull; -import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; @@ -31,7 +30,6 @@ import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.TraversalExpressionVisitor; import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp; -import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.execution.windows.KsqlWindowExpression; import io.confluent.ksql.metastore.MetaStore; import io.confluent.ksql.metastore.model.DataSource; @@ -40,7 +38,6 @@ import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.name.SourceName; import io.confluent.ksql.parser.DefaultTraversalVisitor; -import io.confluent.ksql.parser.NodeLocation; import io.confluent.ksql.parser.tree.AliasedRelation; import io.confluent.ksql.parser.tree.AllColumns; import io.confluent.ksql.parser.tree.AstNode; @@ -56,22 +53,18 @@ import io.confluent.ksql.parser.tree.Table; import io.confluent.ksql.parser.tree.WindowExpression; import io.confluent.ksql.planner.plan.JoinNode; -import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.FormatOptions; -import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.serde.Format; import io.confluent.ksql.serde.FormatFactory; import io.confluent.ksql.serde.FormatInfo; import io.confluent.ksql.serde.KeyFormat; import io.confluent.ksql.serde.SerdeOption; -import io.confluent.ksql.serde.SerdeOptions; import io.confluent.ksql.serde.ValueFormat; import io.confluent.ksql.serde.WindowInfo; import io.confluent.ksql.util.KsqlException; import io.confluent.ksql.util.SchemaUtil; import java.util.HashMap; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -99,7 +92,6 @@ class Analyzer { private final MetaStore metaStore; private final String topicPrefix; - private final SerdeOptionsSupplier serdeOptionsSupplier; private final Set defaultSerdeOptions; /** @@ -111,26 +103,11 @@ class Analyzer { final MetaStore metaStore, final String topicPrefix, final Set defaultSerdeOptions - ) { - this( - metaStore, - topicPrefix, - defaultSerdeOptions, - SerdeOptions::buildForCreateAsStatement); - } - - @VisibleForTesting - Analyzer( - final MetaStore metaStore, - final String topicPrefix, - final Set defaultSerdeOptions, - final SerdeOptionsSupplier serdeOptionsSupplier ) { this.metaStore = requireNonNull(metaStore, "metaStore"); this.topicPrefix = requireNonNull(topicPrefix, "topicPrefix"); this.defaultSerdeOptions = ImmutableSet .copyOf(requireNonNull(defaultSerdeOptions, "defaultSerdeOptions")); - this.serdeOptionsSupplier = requireNonNull(serdeOptionsSupplier, "serdeOptionsSupplier"); } /** @@ -161,12 +138,10 @@ private final class Visitor extends DefaultTraversalVisitor { private final Analysis analysis; private final boolean persistent; - private final boolean pullQuery; private boolean isJoin = false; private boolean isGroupBy = false; Visitor(final Query query, final boolean persistent) { - this.pullQuery = query.isPullQuery(); this.analysis = new Analysis(query.getResultMaterialization()); this.persistent = persistent; } @@ -174,8 +149,6 @@ private final class Visitor extends DefaultTraversalVisitor { private void analyzeNonStdOutSink(final Sink sink) { analysis.setProperties(sink.getProperties()); - setSerdeOptions(sink); - if (!sink.shouldCreateSink()) { final DataSource existing = metaStore.getSource(sink.getName()); if (existing == null) { @@ -186,7 +159,8 @@ private void analyzeNonStdOutSink(final Sink sink) { analysis.setInto(Into.of( sink.getName(), false, - existing.getKsqlTopic() + existing.getKsqlTopic(), + defaultSerdeOptions )); return; } @@ -224,7 +198,8 @@ private void analyzeNonStdOutSink(final Sink sink) { analysis.setInto(Into.of( sink.getName(), true, - intoKsqlTopic + intoKsqlTopic, + defaultSerdeOptions )); } @@ -243,27 +218,6 @@ private KeyFormat buildKeyFormat() { .getKeyFormat()); } - private void setSerdeOptions(final Sink sink) { - final List columnNames = getColumnNames(); - - final Format valueFormat = getValueFormat(sink); - - final Set serdeOptions = serdeOptionsSupplier.build( - columnNames, - valueFormat, - sink.getProperties().getWrapSingleValues(), - defaultSerdeOptions - ); - - analysis.setSerdeOptions(serdeOptions); - } - - private List getColumnNames() { - return analysis.getSelectExpressions().stream() - .map(SelectExpression::getAlias) - .collect(Collectors.toList()); - } - private Format getValueFormat(final Sink sink) { return sink.getProperties().getValueFormat() .orElseGet(() -> FormatFactory.of(getSourceInfo())); @@ -287,8 +241,6 @@ protected AstNode visitQuery( ) { process(node.getFrom(), context); - process(node.getSelect(), context); - node.getWhere().ifPresent(this::analyzeWhere); node.getGroupBy().ifPresent(this::analyzeGroupBy); node.getPartitionBy().ifPresent(this::analyzePartitionBy); @@ -296,6 +248,8 @@ protected AstNode visitQuery( node.getHaving().ifPresent(this::analyzeHaving); node.getLimit().ifPresent(analysis::setLimitClause); + process(node.getSelect(), context); + throwOnUnknownColumnReference(); return null; @@ -315,8 +269,10 @@ private void throwOnUnknownColumnReference() { analysis.getHavingExpression() .ifPresent(columnValidator::analyzeExpression); - analysis.getSelectExpressions().stream() - .map(SelectExpression::getExpression) + analysis.getSelectItems().stream() + .filter(si -> si instanceof SingleColumn) + .map(SingleColumn.class::cast) + .map(SingleColumn::getExpression) .forEach(columnValidator::analyzeExpression); } @@ -508,13 +464,14 @@ protected AstNode visitAliasedRelation(final AliasedRelation node, final Void co @Override protected AstNode visitSelect(final Select node, final Void context) { for (final SelectItem selectItem : node.getSelectItems()) { - if (selectItem instanceof AllColumns) { - visitSelectStar((AllColumns) selectItem); - } else if (selectItem instanceof SingleColumn) { + analysis.addSelectItem(selectItem); + + if (selectItem instanceof SingleColumn) { final SingleColumn column = (SingleColumn) selectItem; - addSelectItem(column.getExpression(), column.getAlias().get()); + validateSelect(column); + captureReferencedSourceColumns(column.getExpression()); visitTableFunctions(column.getExpression()); - } else { + } else if (!(selectItem instanceof AllColumns)) { throw new IllegalArgumentException( "Unsupported SelectItem type: " + selectItem.getClass().getName()); } @@ -552,63 +509,36 @@ private void analyzeHaving(final Expression node) { analysis.setHavingExpression(node); } - private void visitSelectStar(final AllColumns allColumns) { + private void validateSelect(final SingleColumn column) { + final ColumnName columnName = column.getAlias() + .orElseThrow(IllegalStateException::new); - final Optional location = allColumns.getLocation(); - - final Optional prefix = allColumns.getSource(); - - for (final AliasedDataSource source : analysis.getFromDataSources()) { - - if (prefix.isPresent() && !prefix.get().equals(source.getAlias())) { - continue; + if (persistent) { + if (SchemaUtil.isSystemColumn(columnName)) { + throw new KsqlException("Reserved column name in select: " + columnName + ". " + + "Please remove or alias the column."); } + } - final String aliasPrefix = analysis.isJoin() - ? source.getAlias().text() + "_" - : ""; - - final LogicalSchema schema = source.getDataSource().getSchema(); - final boolean windowed = source.getDataSource().getKsqlTopic().getKeyFormat().isWindowed(); - - // Non-join persistent queries only require value columns on SELECT * - // where as joins and transient queries require all columns in the select: - // See https://github.com/confluentinc/ksql/issues/3731 for more info - final List valueColumns = persistent && !analysis.isJoin() - ? schema.value() - : orderColumns(schema.withMetaAndKeyColsInValue(windowed).value(), schema); - - for (final Column column : valueColumns) { + if (analysis.getGroupByExpressions().isEmpty()) { + throwOnUdafs(column.getExpression()); + } + } - if (pullQuery && schema.isMetaColumn(column.name())) { - continue; + private void throwOnUdafs(final Expression expression) { + new TraversalExpressionVisitor() { + @Override + public Void visitFunctionCall(final FunctionCall functionCall, final Void context) { + final FunctionName functionName = functionCall.getName(); + if (metaStore.isAggregate(functionName)) { + throw new KsqlException("Use of aggregate function " + + functionName.text() + " requires a GROUP BY clause."); } - final QualifiedColumnReferenceExp selectItem = new QualifiedColumnReferenceExp( - location, - source.getAlias(), - column.name()); - - final String alias = aliasPrefix + column.name().text(); - - addSelectItem(selectItem, ColumnName.of(alias)); + super.visitFunctionCall(functionCall, context); + return null; } - } - } - - private List orderColumns( - final List columns, - final LogicalSchema schema - ) { - // When doing a `select *` system and key columns should be at the front of the column list - // but are added at the back during processing for performance reasons. - // Switch them around here: - final Map> partitioned = columns.stream().collect(Collectors - .groupingBy(c -> SchemaUtil.isSystemColumn(c.name()) || schema.isKeyColumn(c.name()))); - - final List all = partitioned.get(true); - all.addAll(partitioned.get(false)); - return all; + }.process(expression, null); } public void validate() { @@ -636,16 +566,10 @@ public void validate() { } } - private void addSelectItem(final Expression exp, final ColumnName columnName) { - if (persistent) { - if (SchemaUtil.isSystemColumn(columnName)) { - throw new KsqlException("Reserved column name in select: " + columnName + ". " - + "Please remove or alias the column."); - } - } - + private void captureReferencedSourceColumns(final Expression exp) { final Set columnNames = new HashSet<>(); - final TraversalExpressionVisitor visitor = new TraversalExpressionVisitor() { + + new TraversalExpressionVisitor() { @Override public Void visitColumnReference( final UnqualifiedColumnReferenceExp node, @@ -663,11 +587,8 @@ public Void visitQualifiedColumnReference( columnNames.add(node.getColumnName()); return null; } - }; - - visitor.process(exp, null); + }.process(exp, null); - analysis.addSelectItem(exp, columnName); analysis.addSelectColumnRefs(columnNames); } @@ -693,6 +614,10 @@ public Void visitFunctionCall(final FunctionCall functionCall, final Void contex tableFunctionName = Optional.of(functionName); + if (!analysis.getGroupByExpressions().isEmpty()) { + throw new KsqlException("Table functions cannot be used with aggregations."); + } + analysis.addTableFunction(functionCall); } @@ -706,15 +631,4 @@ public Void visitFunctionCall(final FunctionCall functionCall, final Void contex } } } - - @FunctionalInterface - interface SerdeOptionsSupplier { - - Set build( - List valueColumnNames, - Format valueFormat, - Optional wrapSingleValues, - Set singleFieldDefaults - ); - } } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/ImmutableAnalysis.java b/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/ImmutableAnalysis.java index 449a5a7909b1..2350110e01e2 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/ImmutableAnalysis.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/ImmutableAnalysis.java @@ -18,13 +18,13 @@ import io.confluent.ksql.analyzer.Analysis.AliasedDataSource; import io.confluent.ksql.analyzer.Analysis.Into; import io.confluent.ksql.analyzer.Analysis.JoinInfo; +import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.expression.tree.FunctionCall; -import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.parser.properties.with.CreateSourceAsProperties; +import io.confluent.ksql.parser.tree.SelectItem; import io.confluent.ksql.parser.tree.WindowExpression; -import io.confluent.ksql.serde.SerdeOption; import io.confluent.ksql.testing.EffectivelyImmutable; import java.util.List; import java.util.Optional; @@ -36,18 +36,22 @@ public interface ImmutableAnalysis { List getTableFunctions(); - List getSelectExpressions(); + List getSelectItems(); Optional getWhereExpression(); Optional getInto(); - Set getSelectColumnRefs(); + Set getSelectColumnNames(); List getGroupByExpressions(); + Optional getHavingExpression(); + Optional getWindowExpression(); + ColumnReferenceExp getDefaultArgument(); + Optional getPartitionBy(); OptionalInt getLimitClause(); @@ -56,8 +60,6 @@ public interface ImmutableAnalysis { List getFromDataSources(); - Set getSerdeOptions(); - CreateSourceAsProperties getProperties(); SourceSchemas getFromSourceSchemas(boolean postAggregate); diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/MutableAggregateAnalysis.java b/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/MutableAggregateAnalysis.java index 46ba0e41879e..46f4053bef23 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/MutableAggregateAnalysis.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/MutableAggregateAnalysis.java @@ -27,7 +27,7 @@ import java.util.Optional; import java.util.Set; -public class MutableAggregateAnalysis implements AggregateAnalysis { +public class MutableAggregateAnalysis implements AggregateAnalysisResult { private final List requiredColumns = new ArrayList<>(); private final Map> nonAggSelectExpressions @@ -50,17 +50,14 @@ public List getRequiredColumns() { return Collections.unmodifiableList(requiredColumns); } - @Override public Map> getNonAggregateSelectExpressions() { return Collections.unmodifiableMap(nonAggSelectExpressions); } - @Override public Set getAggregateSelectFields() { return Collections.unmodifiableSet(aggSelectFields); } - @Override public Set getNonAggregateHavingFields() { return Collections.unmodifiableSet(nonAggHavingFields); } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/QueryAnalyzer.java b/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/QueryAnalyzer.java index 185938d576d2..d3572814f5e9 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/QueryAnalyzer.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/QueryAnalyzer.java @@ -18,35 +18,19 @@ import static java.util.Objects.requireNonNull; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; -import com.google.common.collect.Sets; -import com.google.common.collect.Sets.SetView; import io.confluent.ksql.analyzer.Analysis.AliasedDataSource; -import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter; -import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; -import io.confluent.ksql.execution.expression.tree.Expression; -import io.confluent.ksql.execution.expression.tree.FunctionCall; -import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp; -import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.metastore.MetaStore; import io.confluent.ksql.metastore.model.DataSource.DataSourceType; -import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.parser.tree.Query; import io.confluent.ksql.parser.tree.Sink; import io.confluent.ksql.serde.SerdeOption; import io.confluent.ksql.util.KsqlException; -import io.confluent.ksql.util.SchemaUtil; -import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; public class QueryAnalyzer { private final Analyzer analyzer; - private final MetaStore metaStore; private final QueryValidator pullQueryValidator; private final QueryValidator pushQueryValidator; @@ -56,7 +40,6 @@ public QueryAnalyzer( final Set defaultSerdeOptions ) { this( - metaStore, new Analyzer(metaStore, outputTopicPrefix, defaultSerdeOptions), new PushQueryValidator(), new PullQueryValidator() @@ -65,12 +48,10 @@ public QueryAnalyzer( @VisibleForTesting QueryAnalyzer( - final MetaStore metaStore, final Analyzer analyzer, final QueryValidator pullQueryValidator, final QueryValidator pushQueryValidator ) { - this.metaStore = requireNonNull(metaStore, "metaStore"); this.analyzer = requireNonNull(analyzer, "analyzer"); this.pullQueryValidator = requireNonNull(pullQueryValidator, "pullQueryValidator"); this.pushQueryValidator = requireNonNull(pushQueryValidator, "pushQueryValidator"); @@ -97,177 +78,4 @@ public Analysis analyze( return analysis; } - - public AggregateAnalysis analyzeAggregate(final Query query, final Analysis analysis) { - final MutableAggregateAnalysis aggregateAnalysis = new MutableAggregateAnalysis(); - final QualifiedColumnReferenceExp defaultArgument = analysis.getDefaultArgument(); - - final AggregateAnalyzer aggregateAnalyzer = new AggregateAnalyzer( - aggregateAnalysis, - defaultArgument, - analysis.getWindowExpression().isPresent(), - metaStore - ); - - final AggregateExpressionRewriter aggregateExpressionRewriter = - new AggregateExpressionRewriter(metaStore); - - processSelectExpressions( - analysis, - aggregateAnalysis, - aggregateAnalyzer, - aggregateExpressionRewriter - ); - - if (!aggregateAnalysis.getAggregateFunctions().isEmpty() - && analysis.getGroupByExpressions().isEmpty()) { - final String aggFuncs = aggregateAnalysis.getAggregateFunctions().stream() - .map(FunctionCall::getName) - .map(FunctionName::text) - .collect(Collectors.joining(", ")); - throw new KsqlException("Use of aggregate functions requires a GROUP BY clause. " - + "Aggregate function(s): " + aggFuncs); - } - - processWhereExpression( - analysis, - aggregateAnalyzer - ); - - processGroupByExpression( - analysis, - aggregateAnalyzer - ); - - analysis.getHavingExpression().ifPresent(having -> - processHavingExpression( - having, - aggregateAnalysis, - aggregateAnalyzer, - aggregateExpressionRewriter - ) - ); - - enforceAggregateRules(query, analysis, aggregateAnalysis); - return aggregateAnalysis; - } - - private static void processHavingExpression( - final Expression having, - final MutableAggregateAnalysis aggregateAnalysis, - final AggregateAnalyzer aggregateAnalyzer, - final AggregateExpressionRewriter aggregateExpressionRewriter - ) { - aggregateAnalyzer.processHaving(having); - - aggregateAnalysis.setHavingExpression( - ExpressionTreeRewriter.rewriteWith(aggregateExpressionRewriter::process, having)); - } - - private static void processWhereExpression( - final Analysis analysis, - final AggregateAnalyzer aggregateAnalyzer - ) { - analysis.getWhereExpression() - .ifPresent(aggregateAnalyzer::processWhere); - } - - private static void processGroupByExpression( - final Analysis analysis, - final AggregateAnalyzer aggregateAnalyzer - ) { - for (final Expression exp : analysis.getGroupByExpressions()) { - aggregateAnalyzer.processGroupBy(exp); - } - } - - private static void processSelectExpressions( - final Analysis analysis, - final MutableAggregateAnalysis aggregateAnalysis, - final AggregateAnalyzer aggregateAnalyzer, - final AggregateExpressionRewriter aggregateExpressionRewriter - ) { - for (final SelectExpression select : analysis.getSelectExpressions()) { - final Expression exp = select.getExpression(); - aggregateAnalyzer.processSelect(exp); - - aggregateAnalysis.addFinalSelectExpression( - ExpressionTreeRewriter.rewriteWith(aggregateExpressionRewriter::process, exp)); - } - } - - private static void enforceAggregateRules( - final Query query, - final Analysis analysis, - final AggregateAnalysis aggregateAnalysis - ) { - if (!query.getGroupBy().isPresent()) { - return; - } - - if (!analysis.getTableFunctions().isEmpty()) { - throw new KsqlException("Table functions cannot be used with aggregations."); - } - - if (aggregateAnalysis.getAggregateFunctions().isEmpty()) { - throw new KsqlException( - "GROUP BY requires columns using aggregate functions in SELECT clause."); - } - - final Set groupByExprs = getGroupByExpressions(analysis); - - final List unmatchedSelects = aggregateAnalysis.getNonAggregateSelectExpressions() - .entrySet() - .stream() - // Remove any that exactly match a group by expression: - .filter(e -> !groupByExprs.contains(e.getKey())) - // Remove any that are constants, - // or expressions where all params exactly match a group by expression: - .filter(e -> !Sets.difference(e.getValue(), groupByExprs).isEmpty()) - .map(Map.Entry::getKey) - .map(Expression::toString) - .sorted() - .collect(Collectors.toList()); - - if (!unmatchedSelects.isEmpty()) { - throw new KsqlException( - "Non-aggregate SELECT expression(s) not part of GROUP BY: " + unmatchedSelects); - } - - final SetView unmatchedSelectsAgg = Sets - .difference(aggregateAnalysis.getAggregateSelectFields(), groupByExprs); - if (!unmatchedSelectsAgg.isEmpty()) { - throw new KsqlException( - "Field used in aggregate SELECT expression(s) " - + "outside of aggregate functions not part of GROUP BY: " + unmatchedSelectsAgg); - } - - final Set havingColumns = aggregateAnalysis - .getNonAggregateHavingFields(); - - final Set havingOnly = Sets.difference(havingColumns, groupByExprs); - if (!havingOnly.isEmpty()) { - throw new KsqlException( - "Non-aggregate HAVING expression not part of GROUP BY: " + havingOnly); - } - } - - private static Set getGroupByExpressions(final Analysis analysis) { - if (!analysis.getWindowExpression().isPresent()) { - return ImmutableSet.copyOf(analysis.getGroupByExpressions()); - } - - // Add in window bounds columns as implicit group by columns: - final AliasedDataSource source = Iterables.getOnlyElement(analysis.getFromDataSources()); - - final Set windowBoundColumnRefs = - SchemaUtil.windowBoundsColumnNames().stream() - .map(cn -> new QualifiedColumnReferenceExp(source.getAlias(), cn)) - .collect(Collectors.toSet()); - - return ImmutableSet.builder() - .addAll(analysis.getGroupByExpressions()) - .addAll(windowBoundColumnRefs) - .build(); - } } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/RewrittenAnalysis.java b/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/RewrittenAnalysis.java index 630279f0fc60..7551d03569bc 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/RewrittenAnalysis.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/analyzer/RewrittenAnalysis.java @@ -20,14 +20,15 @@ import io.confluent.ksql.analyzer.Analysis.JoinInfo; import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter; import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context; +import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.expression.tree.FunctionCall; import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp; -import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.parser.properties.with.CreateSourceAsProperties; +import io.confluent.ksql.parser.tree.SelectItem; +import io.confluent.ksql.parser.tree.SingleColumn; import io.confluent.ksql.parser.tree.WindowExpression; -import io.confluent.ksql.serde.SerdeOption; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -44,6 +45,7 @@ * transformations needed to execute the query. */ public class RewrittenAnalysis implements ImmutableAnalysis { + private final ImmutableAnalysis original; private final BiFunction, Optional> rewriter; @@ -52,7 +54,7 @@ public RewrittenAnalysis( final BiFunction, Optional> rewriter ) { this.original = Objects.requireNonNull(original, "original"); - this.rewriter = Objects.requireNonNull(rewriter ,"rewriter"); + this.rewriter = Objects.requireNonNull(rewriter, "rewriter"); } public ImmutableAnalysis getOriginal() { @@ -65,11 +67,21 @@ public List getTableFunctions() { } @Override - public List getSelectExpressions() { - return original.getSelectExpressions().stream() - .map(e -> SelectExpression.of( - e.getAlias(), - ExpressionTreeRewriter.rewriteWith(rewriter, e.getExpression()))) + public List getSelectItems() { + return original.getSelectItems().stream() + .map(si -> { + if (!(si instanceof SingleColumn)) { + return si; + } + + final SingleColumn singleColumn = (SingleColumn) si; + return new SingleColumn( + singleColumn.getLocation(), + rewrite(singleColumn.getExpression()), + singleColumn.getAlias() + ); + } + ) .collect(Collectors.toList()); } @@ -84,10 +96,10 @@ public Optional getInto() { } @Override - public Set getSelectColumnRefs() { - return original.getSelectColumnRefs().stream() + public Set getSelectColumnNames() { + return original.getSelectColumnNames().stream() .map(UnqualifiedColumnReferenceExp::new) - .map(r -> ExpressionTreeRewriter.rewriteWith(rewriter, r)) + .map(this::rewrite) .map(UnqualifiedColumnReferenceExp::getColumnName) .collect(Collectors.toSet()); } @@ -97,11 +109,21 @@ public List getGroupByExpressions() { return rewriteList(original.getGroupByExpressions()); } + @Override + public Optional getHavingExpression() { + return rewriteOptional(original.getHavingExpression()); + } + @Override public Optional getWindowExpression() { return original.getWindowExpression(); } + @Override + public ColumnReferenceExp getDefaultArgument() { + return rewrite(original.getDefaultArgument()); + } + @Override public Optional getPartitionBy() { return rewriteOptional(original.getPartitionBy()); @@ -116,8 +138,8 @@ public OptionalInt getLimitClause() { public Optional getJoin() { return original.getJoin().map( j -> new JoinInfo( - ExpressionTreeRewriter.rewriteWith(rewriter, j.getLeftJoinExpression()), - ExpressionTreeRewriter.rewriteWith(rewriter, j.getRightJoinExpression()), + rewrite(j.getLeftJoinExpression()), + rewrite(j.getRightJoinExpression()), j.getType(), j.getWithinExpression() ) @@ -129,11 +151,6 @@ public List getFromDataSources() { return original.getFromDataSources(); } - @Override - public Set getSerdeOptions() { - return original.getSerdeOptions(); - } - @Override public CreateSourceAsProperties getProperties() { return original.getProperties(); @@ -145,12 +162,16 @@ public SourceSchemas getFromSourceSchemas(final boolean postAggregate) { } private Optional rewriteOptional(final Optional expression) { - return expression.map(e -> ExpressionTreeRewriter.rewriteWith(rewriter, e)); + return expression.map(this::rewrite); } private List rewriteList(final List expressions) { return expressions.stream() - .map(e -> ExpressionTreeRewriter.rewriteWith(rewriter, e)) + .map(this::rewrite) .collect(Collectors.toList()); } + + private T rewrite(final T expression) { + return ExpressionTreeRewriter.rewriteWith(rewriter, expression); + } } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/engine/QueryEngine.java b/ksqldb-engine/src/main/java/io/confluent/ksql/engine/QueryEngine.java index 841e8125d1ab..b90087396e08 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/engine/QueryEngine.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/engine/QueryEngine.java @@ -15,7 +15,6 @@ package io.confluent.ksql.engine; -import io.confluent.ksql.analyzer.AggregateAnalysisResult; import io.confluent.ksql.analyzer.Analysis; import io.confluent.ksql.analyzer.QueryAnalyzer; import io.confluent.ksql.logging.processing.ProcessingLogContext; @@ -38,15 +37,11 @@ import java.util.Optional; import java.util.Set; import org.apache.kafka.streams.StreamsBuilder; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; // CHECKSTYLE_RULES.OFF: ClassDataAbstractionCoupling class QueryEngine { // CHECKSTYLE_RULES.ON: ClassDataAbstractionCoupling - private static final Logger LOG = LoggerFactory.getLogger(QueryEngine.class); - private final ServiceContext serviceContext; private final ProcessingLogContext processingLogContext; private final QueryIdGenerator queryIdGenerator; @@ -78,9 +73,8 @@ static OutputNode buildQueryLogicalPlan( new QueryAnalyzer(metaStore, outputPrefix, defaultSerdeOptions); final Analysis analysis = queryAnalyzer.analyze(query, sink); - final AggregateAnalysisResult aggAnalysis = queryAnalyzer.analyzeAggregate(query, analysis); - return new LogicalPlanner(config, analysis, aggAnalysis, metaStore).buildPlan(); + return new LogicalPlanner(config, analysis, metaStore).buildPlan(); } PhysicalPlan buildPhysicalPlan( diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/LogicalPlanner.java b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/LogicalPlanner.java index c3324e8d43d2..0ae527057123 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/LogicalPlanner.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/LogicalPlanner.java @@ -16,12 +16,12 @@ package io.confluent.ksql.planner; import io.confluent.ksql.analyzer.AggregateAnalysisResult; +import io.confluent.ksql.analyzer.AggregateAnalyzer; import io.confluent.ksql.analyzer.Analysis.AliasedDataSource; import io.confluent.ksql.analyzer.Analysis.Into; import io.confluent.ksql.analyzer.Analysis.JoinInfo; import io.confluent.ksql.analyzer.ImmutableAnalysis; import io.confluent.ksql.analyzer.RewrittenAnalysis; -import io.confluent.ksql.analyzer.SourceSchemas; import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter; import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context; import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; @@ -38,6 +38,9 @@ import io.confluent.ksql.metastore.model.KeyField; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.name.SourceName; +import io.confluent.ksql.parser.tree.AllColumns; +import io.confluent.ksql.parser.tree.SelectItem; +import io.confluent.ksql.parser.tree.SingleColumn; import io.confluent.ksql.planner.plan.AggregateNode; import io.confluent.ksql.planner.plan.DataSourceNode; import io.confluent.ksql.planner.plan.FilterNode; @@ -57,14 +60,20 @@ import io.confluent.ksql.schema.ksql.LogicalSchema.Builder; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.schema.ksql.types.SqlTypes; +import io.confluent.ksql.serde.Format; +import io.confluent.ksql.serde.SerdeOption; +import io.confluent.ksql.serde.SerdeOptions; import io.confluent.ksql.util.KsqlConfig; import io.confluent.ksql.util.KsqlException; import io.confluent.ksql.util.SchemaUtil; import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.Set; import java.util.function.BiFunction; import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; // CHECKSTYLE_RULES.OFF: ClassDataAbstractionCoupling public class LogicalPlanner { @@ -72,25 +81,21 @@ public class LogicalPlanner { private final KsqlConfig ksqlConfig; private final RewrittenAnalysis analysis; - private final AggregateAnalysisResult aggregateAnalysis; private final FunctionRegistry functionRegistry; + private final AggregateAnalyzer aggregateAnalyzer; + private final ColumnReferenceRewriter refRewriter; public LogicalPlanner( final KsqlConfig ksqlConfig, final ImmutableAnalysis analysis, - final AggregateAnalysisResult aggregateAnalysis, final FunctionRegistry functionRegistry ) { this.ksqlConfig = Objects.requireNonNull(ksqlConfig, "ksqlConfig"); - Objects.requireNonNull(analysis, "analysis"); - final ColumnReferenceRewriter refRewriter = - new ColumnReferenceRewriter(analysis.getFromSourceSchemas(false)); + this.refRewriter = + new ColumnReferenceRewriter(analysis.getFromSourceSchemas(false).isJoin()); this.analysis = new RewrittenAnalysis(analysis, refRewriter::process); - this.aggregateAnalysis = new RewrittenAggregateAnalysis( - Objects.requireNonNull(aggregateAnalysis, "aggregateAnalysis"), - refRewriter::process - ); this.functionRegistry = Objects.requireNonNull(functionRegistry, "functionRegistry"); + this.aggregateAnalyzer = new AggregateAnalyzer(functionRegistry); } public OutputNode buildPlan() { @@ -110,7 +115,7 @@ public OutputNode buildPlan() { } if (analysis.getGroupByExpressions().isEmpty()) { - currentNode = buildProjectNode(currentNode, "Project", currentNode.getSelectExpressions()); + currentNode = buildProjectNode(currentNode, "Project"); } else { currentNode = buildAggregateNode(currentNode); } @@ -143,11 +148,31 @@ private OutputNode buildOutputNode(final PlanNode sourcePlanNode) { intoDataSource.getKsqlTopic(), analysis.getLimitClause(), intoDataSource.isCreate(), - analysis.getSerdeOptions(), + getSerdeOptions(sourcePlanNode, intoDataSource), intoDataSource.getName() ); } + private Set getSerdeOptions( + final PlanNode sourcePlanNode, + final Into intoDataSource + ) { + final List columnNames = sourcePlanNode.getSchema().value().stream() + .map(Column::name) + .collect(Collectors.toList()); + + final Format valueFormat = intoDataSource.getKsqlTopic() + .getValueFormat() + .getFormat(); + + return SerdeOptions.buildForCreateAsStatement( + columnNames, + valueFormat, + analysis.getProperties().getWrapSingleValues(), + intoDataSource.getDefaultSerdeOptions() + ); + } + private Optional getTimestampColumn( final LogicalSchema inputSchema, final ImmutableAnalysis analysis @@ -168,7 +193,10 @@ private Optional getTimestampColumn( private AggregateNode buildAggregateNode(final PlanNode sourcePlanNode) { final List groupByExps = analysis.getGroupByExpressions(); - final LogicalSchema schema = buildAggregateSchema(sourcePlanNode, groupByExps); + final List projectionExpressions = buildSelectExpressions(sourcePlanNode); + + final LogicalSchema schema = + buildAggregateSchema(sourcePlanNode, groupByExps, projectionExpressions); final Expression groupBy = groupByExps.size() == 1 ? groupByExps.get(0) @@ -179,7 +207,12 @@ private AggregateNode buildAggregateNode(final PlanNode sourcePlanNode) { expression.equals(groupBy) && !SchemaUtil.isSystemColumn(alias) && !schema.isKeyColumn(alias), - sourcePlanNode.getSelectExpressions()); + projectionExpressions); + + final RewrittenAggregateAnalysis aggregateAnalysis = new RewrittenAggregateAnalysis( + aggregateAnalyzer.analyze(analysis, projectionExpressions), + refRewriter::process + ); return new AggregateNode( new PlanNodeId("Aggregate"), @@ -187,19 +220,25 @@ private AggregateNode buildAggregateNode(final PlanNode sourcePlanNode) { schema, keyFieldName, groupByExps, - analysis.getWindowExpression(), - aggregateAnalysis.getAggregateFunctionArguments(), - aggregateAnalysis.getAggregateFunctions(), - aggregateAnalysis.getRequiredColumns(), - aggregateAnalysis.getFinalSelectExpressions(), - aggregateAnalysis.getHavingExpression().orElse(null) + functionRegistry, + analysis, + aggregateAnalysis, + projectionExpressions ); } + private ProjectNode buildProjectNode( + final PlanNode parentNode, + final String id + ) { + return buildProjectNode(parentNode, id, buildSelectExpressions(parentNode)); + } + private ProjectNode buildProjectNode( final PlanNode sourcePlanNode, final String id, - final List projection) { + final List projection + ) { final ColumnName sourceKeyFieldName = sourcePlanNode .getKeyField() .ref() @@ -210,7 +249,7 @@ private ProjectNode buildProjectNode( final Optional keyFieldName = getSelectAliasMatching( (expression, alias) -> expression instanceof UnqualifiedColumnReferenceExp && ((UnqualifiedColumnReferenceExp) expression).getColumnName().equals( - sourceKeyFieldName), + sourceKeyFieldName), projection ); @@ -223,6 +262,46 @@ private ProjectNode buildProjectNode( ); } + private List buildSelectExpressions(final PlanNode parentNode) { + return IntStream.range(0, analysis.getSelectItems().size()) + .boxed() + .flatMap(idx -> resolveSelectItem(idx, parentNode)) + .collect(Collectors.toList()); + } + + private Stream resolveSelectItem( + final int idx, + final PlanNode parentNode + ) { + final SelectItem selectItem = analysis.getSelectItems().get(idx); + + if (selectItem instanceof SingleColumn) { + final SingleColumn column = (SingleColumn) selectItem; + final Expression expression = parentNode.resolveSelect(idx, column.getExpression()); + final ColumnName alias = column.getAlias() + .orElseThrow(() -> new IllegalStateException("Alias should be present by this point")); + + return Stream.of(SelectExpression.of(alias, expression)); + } + + if (selectItem instanceof AllColumns) { + final AllColumns allColumns = (AllColumns) selectItem; + + final Stream columns = parentNode + .resolveSelectStar(allColumns.getSource(), analysis.getInto().isPresent()); + + // Only need to take value columns as value schema includes key schema by this point + return columns + .map(name -> SelectExpression.of(name, new UnqualifiedColumnReferenceExp( + allColumns.getLocation(), + name + ))); + } + + throw new IllegalArgumentException( + "Unsupported SelectItem type: " + selectItem.getClass().getName()); + } + private static FilterNode buildFilterNode( final PlanNode sourcePlanNode, final Expression filterExpression @@ -280,13 +359,14 @@ private FlatMapNode buildFlatMapNode(final PlanNode sourcePlanNode) { private PlanNode buildSourceForJoin( final AliasedDataSource source, final String side, - final Expression joinExpression) { + final Expression joinExpression + ) { final DataSourceNode sourceNode = new DataSourceNode( new PlanNodeId("KafkaTopic_" + side), source.getDataSource(), - source.getAlias(), - analysis.getSelectExpressions() + source.getAlias() ); + // it is always safe to build the repartition node - this operation will be // a no-op if a repartition is not required. if the source is a table, and // a repartition is needed, then an exception will be thrown @@ -344,7 +424,6 @@ private PlanNode buildSourceNode() { return new JoinNode( new PlanNodeId("Join"), - analysis.getSelectExpressions(), joinInfo.get().getType(), leftSourceNode, rightSourceNode, @@ -361,8 +440,7 @@ private DataSourceNode buildNonJoinNode(final List sources) { return new DataSourceNode( new PlanNodeId("KsqlTopic"), dataSource.getDataSource(), - dataSource.getAlias(), - analysis.getSelectExpressions() + dataSource.getAlias() ); } @@ -370,9 +448,7 @@ private static Optional getSelectAliasMatching( final BiFunction matcher, final List projection ) { - for (int i = 0; i < projection.size(); i++) { - final SelectExpression select = projection.get(i); - + for (final SelectExpression select : projection) { if (matcher.apply(select.getExpression(), select.getAlias())) { return Optional.of(select.getAlias()); } @@ -395,9 +471,7 @@ private LogicalSchema buildProjectionSchema( builder.keyColumns(schema.key()); - for (int i = 0; i < projection.size(); i++) { - final SelectExpression select = projection.get(i); - + for (final SelectExpression select : projection) { final SqlType expressionType = expressionTypeManager .getExpressionSqlType(select.getExpression()); @@ -409,7 +483,8 @@ private LogicalSchema buildProjectionSchema( private LogicalSchema buildAggregateSchema( final PlanNode sourcePlanNode, - final List groupByExps + final List groupByExps, + final List projectionExpressions ) { final LogicalSchema sourceSchema = sourcePlanNode.getSchema(); @@ -434,7 +509,7 @@ private LogicalSchema buildAggregateSchema( final LogicalSchema projectionSchema = buildProjectionSchema( sourceSchema .withMetaAndKeyColsInValue(analysis.getWindowExpression().isPresent()), - sourcePlanNode.getSelectExpressions() + projectionExpressions ); return LogicalSchema.builder() @@ -504,27 +579,12 @@ private static List selectWithPrependAlias( private static final class ColumnReferenceRewriter extends VisitParentExpressionVisitor, Context> { - final SourceSchemas sourceSchemas; - ColumnReferenceRewriter(final SourceSchemas sourceSchemas) { - super(Optional.empty()); - this.sourceSchemas = Objects.requireNonNull(sourceSchemas, "sourceSchemas"); - } + private final boolean isJoin; - @Override - public Optional visitColumnReference( - final UnqualifiedColumnReferenceExp node, - final Context ctx - ) { - if (sourceSchemas.isJoin()) { - final SourceName sourceName = sourceSchemas - .sourcesWithField(Optional.empty(), node.getColumnName()).iterator().next(); - - return Optional.of(new UnqualifiedColumnReferenceExp( - ColumnName.generatedJoinColumnAlias(sourceName, node.getColumnName()) - )); - } - return Optional.empty(); + ColumnReferenceRewriter(final boolean isJoin) { + super(Optional.empty()); + this.isJoin = isJoin; } @Override @@ -532,7 +592,7 @@ public Optional visitQualifiedColumnReference( final QualifiedColumnReferenceExp node, final Context ctx ) { - if (sourceSchemas.isJoin()) { + if (isJoin) { return Optional.of(new UnqualifiedColumnReferenceExp( ColumnName.generatedJoinColumnAlias(node.getQualifier(), node.getColumnName()) )); diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java index 6a821b847df1..74a68683fcc0 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/AggregateNode.java @@ -19,6 +19,9 @@ import static java.util.Objects.requireNonNull; import com.google.common.collect.ImmutableList; +import io.confluent.ksql.analyzer.AggregateAnalysisResult; +import io.confluent.ksql.analyzer.AggregateExpressionRewriter; +import io.confluent.ksql.analyzer.ImmutableAnalysis; import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter; import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; @@ -30,6 +33,7 @@ import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor; import io.confluent.ksql.execution.plan.SelectExpression; +import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.metastore.model.KeyField; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.parser.tree.WindowExpression; @@ -71,41 +75,45 @@ public class AggregateNode extends PlanNode { private final ImmutableList aggregateFunctionArguments; private final ImmutableList functionList; private final ImmutableList requiredColumns; - private final Expression havingExpressions; + private final Optional havingExpressions; + private final ImmutableList finalSelectExpressions; - // CHECKSTYLE_RULES.OFF: ParameterNumberCheck public AggregateNode( final PlanNodeId id, final PlanNode source, final LogicalSchema schema, final Optional keyFieldName, final List groupByExpressions, - final Optional windowExpression, - final List aggregateFunctionArguments, - final List functionList, - final List requiredColumns, - final List finalSelectExpressions, - final Expression havingExpressions + final FunctionRegistry functionRegistry, + final ImmutableAnalysis analysis, + final AggregateAnalysisResult rewrittenAggregateAnalysis, + final List projectionExpressions ) { - // CHECKSTYLE_RULES.ON: ParameterNumberCheck - super( - id, - DataSourceType.KTABLE, - schema, - buildSelectExpressions(schema, finalSelectExpressions) - ); + super(id, DataSourceType.KTABLE, schema, Optional.empty()); this.source = requireNonNull(source, "source"); this.groupByExpressions = ImmutableList .copyOf(requireNonNull(groupByExpressions, "groupByExpressions")); - this.windowExpression = requireNonNull(windowExpression, "windowExpression"); + this.windowExpression = requireNonNull(analysis, "analysis").getWindowExpression(); + + final AggregateExpressionRewriter aggregateExpressionRewriter = + new AggregateExpressionRewriter(functionRegistry); + this.aggregateFunctionArguments = ImmutableList - .copyOf(requireNonNull(aggregateFunctionArguments, "aggregateFunctionArguments")); + .copyOf(rewrittenAggregateAnalysis.getAggregateFunctionArguments()); this.functionList = ImmutableList - .copyOf(requireNonNull(functionList, "functionList")); + .copyOf(rewrittenAggregateAnalysis.getAggregateFunctions()); this.requiredColumns = ImmutableList - .copyOf(requireNonNull(requiredColumns, "requiredColumns")); - this.havingExpressions = havingExpressions; + .copyOf(rewrittenAggregateAnalysis.getRequiredColumns()); + this.finalSelectExpressions = ImmutableList.copyOf(projectionExpressions.stream() + .map(se -> SelectExpression.of( + se.getAlias(), + ExpressionTreeRewriter + .rewriteWith(aggregateExpressionRewriter::process, se.getExpression()) + )) + .collect(Collectors.toList())); + this.havingExpressions = rewrittenAggregateAnalysis.getHavingExpression() + .map(exp -> ExpressionTreeRewriter.rewriteWith(aggregateExpressionRewriter::process, exp)); this.keyField = KeyField.of(requireNonNull(keyFieldName, "keyFieldName")) .validateKeyExistsIn(schema); } @@ -206,7 +214,7 @@ public SchemaKStream buildStream(final KsqlQueryBuilder builder) { aggregationContext ); - final Optional havingExpression = Optional.ofNullable(havingExpressions) + final Optional havingExpression = havingExpressions .map(internalSchema::resolveToInternal); if (havingExpression.isPresent()) { @@ -217,7 +225,7 @@ public SchemaKStream buildStream(final KsqlQueryBuilder builder) { } final List finalSelects = internalSchema - .updateFinalSelectExpressions(getSelectExpressions()); + .updateFinalSelectExpressions(finalSelectExpressions); return aggregated.select( finalSelects, @@ -230,28 +238,6 @@ protected int getPartitions(final KafkaTopicClient kafkaTopicClient) { return source.getPartitions(kafkaTopicClient); } - private static List buildSelectExpressions( - final LogicalSchema schema, - final List finalSelectExpressions - ) { - final List finalSelectExpressionList = new ArrayList<>(); - if (finalSelectExpressions.size() != schema.value().size()) { - throw new RuntimeException( - "Incompatible aggregate schema, field count must match, " - + "selected field count:" - + finalSelectExpressions.size() - + " schema field count:" - + schema.value().size()); - } - for (int i = 0; i < finalSelectExpressions.size(); i++) { - finalSelectExpressionList.add(SelectExpression.of( - schema.value().get(i).name(), - finalSelectExpressions.get(i) - )); - } - return finalSelectExpressionList; - } - private static class InternalSchema { private final Optional singleKeyColumn; diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/DataSourceNode.java b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/DataSourceNode.java index c14fc192d392..c387e14d563a 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/DataSourceNode.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/DataSourceNode.java @@ -22,16 +22,22 @@ import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; import io.confluent.ksql.execution.context.QueryContext.Stacker; -import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.metastore.model.DataSource; import io.confluent.ksql.metastore.model.DataSource.DataSourceType; import io.confluent.ksql.metastore.model.KeyField; +import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.name.SourceName; +import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.services.KafkaTopicClient; import io.confluent.ksql.structured.SchemaKSourceFactory; import io.confluent.ksql.structured.SchemaKStream; +import io.confluent.ksql.util.SchemaUtil; import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.Stream; @Immutable public class DataSourceNode extends PlanNode { @@ -39,29 +45,25 @@ public class DataSourceNode extends PlanNode { private static final String SOURCE_OP_NAME = "Source"; private final DataSource dataSource; - private final SourceName alias; private final KeyField keyField; private final SchemaKStreamFactory schemaKStreamFactory; public DataSourceNode( final PlanNodeId id, final DataSource dataSource, - final SourceName alias, - final List selectExpressions + final SourceName alias ) { - this(id, dataSource, alias, selectExpressions, SchemaKSourceFactory::buildSource); + this(id, dataSource, alias, SchemaKSourceFactory::buildSource); } DataSourceNode( final PlanNodeId id, final DataSource dataSource, final SourceName alias, - final List selectExpressions, final SchemaKStreamFactory schemaKStreamFactory ) { - super(id, dataSource.getDataSourceType(), buildSchema(dataSource), selectExpressions); + super(id, dataSource.getDataSourceType(), buildSchema(dataSource), Optional.of(alias)); this.dataSource = requireNonNull(dataSource, "dataSource"); - this.alias = requireNonNull(alias, "alias"); this.keyField = dataSource.getKeyField() .validateKeyExistsIn(getSchema()); @@ -79,7 +81,7 @@ public DataSource getDataSource() { } public SourceName getAlias() { - return alias; + return getSourceName().orElseThrow(IllegalStateException::new); } public DataSourceType getDataSourceType() { @@ -116,6 +118,20 @@ public SchemaKStream buildStream(final KsqlQueryBuilder builder) { ); } + @Override + public Stream resolveSelectStar( + final Optional sourceName, final boolean valueOnly + ) { + if (sourceName.isPresent() && !sourceName.equals(getSourceName())) { + throw new IllegalArgumentException("Expected alias of " + getAlias() + + ", but was " + sourceName.get()); + } + + return valueOnly + ? getSchema().withoutMetaAndKeyColsInValue().value().stream().map(Column::name) + : orderColumns(getSchema().value(), getSchema()); + } + private static LogicalSchema buildSchema(final DataSource dataSource) { // DataSourceNode copies implicit and key fields into the value schema // It users a KS valueMapper to add the key fields @@ -124,6 +140,21 @@ private static LogicalSchema buildSchema(final DataSource dataSource) { .withMetaAndKeyColsInValue(dataSource.getKsqlTopic().getKeyFormat().isWindowed()); } + private static Stream orderColumns( + final List columns, + final LogicalSchema schema + ) { + // When doing a `select *` system and key columns should be at the front of the column list + // but are added at the back during processing for performance reasons. + // Switch them around here: + final Map> partitioned = columns.stream().collect(Collectors + .groupingBy(c -> SchemaUtil.isSystemColumn(c.name()) || schema.isKeyColumn(c.name()))); + + final List all = partitioned.get(true); + all.addAll(partitioned.get(false)); + return all.stream().map(Column::name); + } + @Immutable interface SchemaKStreamFactory { diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/FilterNode.java b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/FilterNode.java index 7971c487eb80..a8ff1a96484d 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/FilterNode.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/FilterNode.java @@ -37,7 +37,7 @@ public FilterNode( final PlanNode source, final Expression predicate ) { - super(id, source.getNodeOutputType(), source.getSchema(), source.getSelectExpressions()); + super(id, source.getNodeOutputType(), source.getSchema(), source.getSourceName()); this.source = Objects.requireNonNull(source, "source"); this.predicate = Objects.requireNonNull(predicate, "predicate"); diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/FlatMapNode.java b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/FlatMapNode.java index 952557380266..44e72a7893e9 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/FlatMapNode.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/FlatMapNode.java @@ -16,6 +16,7 @@ package io.confluent.ksql.planner.plan; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.Immutable; import io.confluent.ksql.analyzer.ImmutableAnalysis; import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter; @@ -26,12 +27,13 @@ import io.confluent.ksql.execution.expression.tree.FunctionCall; import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor; -import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.execution.streams.StreamFlatMapBuilder; import io.confluent.ksql.function.FunctionRegistry; import io.confluent.ksql.metastore.model.KeyField; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.name.FunctionName; +import io.confluent.ksql.parser.tree.SelectItem; +import io.confluent.ksql.parser.tree.SingleColumn; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.services.KafkaTopicClient; import io.confluent.ksql.structured.SchemaKStream; @@ -49,6 +51,7 @@ public class FlatMapNode extends PlanNode { private final PlanNode source; private final ImmutableList tableFunctions; + private final ImmutableMap columnMappings; public FlatMapNode( final PlanNodeId id, @@ -60,10 +63,11 @@ public FlatMapNode( id, source.getNodeOutputType(), buildSchema(source, functionRegistry, analysis), - buildFinalSelectExpressions(functionRegistry, analysis) + Optional.empty() ); this.source = Objects.requireNonNull(source, "source"); this.tableFunctions = ImmutableList.copyOf(analysis.getTableFunctions()); + this.columnMappings = buildColumnMappings(functionRegistry, analysis); } @Override @@ -90,6 +94,10 @@ protected int getPartitions(final KafkaTopicClient kafkaTopicClient) { return source.getPartitions(kafkaTopicClient); } + public Expression resolveSelect(final int idx, final Expression expression) { + return columnMappings.getOrDefault(idx, expression); + } + @Override public SchemaKStream buildStream(final KsqlQueryBuilder builder) { @@ -101,24 +109,31 @@ public SchemaKStream buildStream(final KsqlQueryBuilder builder) { ); } - private static ImmutableList buildFinalSelectExpressions( + private static ImmutableMap buildColumnMappings( final FunctionRegistry functionRegistry, final ImmutableAnalysis analysis ) { final TableFunctionExpressionRewriter tableFunctionExpressionRewriter = new TableFunctionExpressionRewriter(functionRegistry); - final ImmutableList.Builder selectExpressions = ImmutableList.builder(); - for (final SelectExpression select : analysis.getSelectExpressions()) { - final Expression exp = select.getExpression(); - selectExpressions.add( - SelectExpression.of( - select.getAlias(), - ExpressionTreeRewriter.rewriteWith( - tableFunctionExpressionRewriter::process, exp) - )); + final ImmutableMap.Builder builder = ImmutableMap.builder(); + + for (int idx = 0; idx < analysis.getSelectItems().size(); idx++) { + final SelectItem selectItem = analysis.getSelectItems().get(idx); + if (!(selectItem instanceof SingleColumn)) { + continue; + } + + final SingleColumn singleColumn = (SingleColumn) selectItem; + final Expression rewritten = ExpressionTreeRewriter.rewriteWith( + tableFunctionExpressionRewriter::process, singleColumn.getExpression()); + + if (!rewritten.equals(singleColumn.getExpression())) { + builder.put(idx, rewritten); + } } - return selectExpressions.build(); + + return builder.build(); } private static class TableFunctionExpressionRewriter diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/JoinNode.java b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/JoinNode.java index b447efceb472..ae77a5848969 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/JoinNode.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/JoinNode.java @@ -18,10 +18,11 @@ import com.google.common.collect.ImmutableMap; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; -import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.execution.streams.JoinParamsFactory; import io.confluent.ksql.metastore.model.DataSource.DataSourceType; import io.confluent.ksql.metastore.model.KeyField; +import io.confluent.ksql.name.ColumnName; +import io.confluent.ksql.name.SourceName; import io.confluent.ksql.parser.tree.WithinExpression; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.serde.ValueFormat; @@ -37,6 +38,7 @@ import java.util.Objects; import java.util.Optional; import java.util.function.Supplier; +import java.util.stream.Stream; import org.apache.kafka.clients.consumer.ConsumerConfig; public class JoinNode extends PlanNode { @@ -53,13 +55,12 @@ public enum JoinType { public JoinNode( final PlanNodeId id, - final List selectExpressions, final JoinType joinType, final PlanNode left, final PlanNode right, final Optional withinExpression ) { - super(id, calculateSinkType(left, right), buildJoinSchema(left, right), selectExpressions); + super(id, calculateSinkType(left, right), buildJoinSchema(left, right), Optional.empty()); this.joinType = Objects.requireNonNull(joinType, "joinType"); this.left = Objects.requireNonNull(left, "left"); this.right = Objects.requireNonNull(right, "right"); @@ -111,6 +112,15 @@ protected int getPartitions(final KafkaTopicClient kafkaTopicClient) { return right.getPartitions(kafkaTopicClient); } + @Override + public Stream resolveSelectStar( + final Optional sourceName, final boolean valueOnly + ) { + return getSources().stream() + .filter(s -> !sourceName.isPresent() || sourceName.equals(s.getSourceName())) + .flatMap(s -> s.resolveSelectStar(sourceName, false)); + } + private void ensureMatchingPartitionCounts(final KafkaTopicClient kafkaTopicClient) { final int leftPartitions = left.getPartitions(kafkaTopicClient); final int rightPartitions = right.getPartitions(kafkaTopicClient); diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/KsqlStructuredDataOutputNode.java b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/KsqlStructuredDataOutputNode.java index 9e6751f5501b..212cb79abfcd 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/KsqlStructuredDataOutputNode.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/KsqlStructuredDataOutputNode.java @@ -22,20 +22,22 @@ import io.confluent.ksql.execution.builder.KsqlQueryBuilder; import io.confluent.ksql.execution.context.QueryContext; import io.confluent.ksql.execution.ddl.commands.KsqlTopic; -import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.execution.timestamp.TimestampColumn; import io.confluent.ksql.metastore.model.KeyField; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.name.SourceName; import io.confluent.ksql.query.QueryId; import io.confluent.ksql.query.id.QueryIdGenerator; +import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.serde.SerdeOption; import io.confluent.ksql.structured.SchemaKStream; import io.confluent.ksql.util.KsqlException; +import java.util.Map.Entry; import java.util.Optional; import java.util.OptionalInt; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; public class KsqlStructuredDataOutputNode extends OutputNode { @@ -78,7 +80,7 @@ public KsqlStructuredDataOutputNode( this.doCreateInto = doCreateInto; this.intoSourceName = requireNonNull(intoSourceName, "intoSourceName"); - validate(); + validate(source); } public boolean isDoCreateInto() { @@ -130,12 +132,16 @@ public SchemaKStream buildStream(final KsqlQueryBuilder builder) { ); } - private void validate() { - final LogicalSchema schema = getSchema(); + private static void validate(final PlanNode source) { + final LogicalSchema schema = source.getSchema(); - final String duplicates = getSelectExpressions().stream() - .map(SelectExpression::getAlias) - .filter(schema::isKeyColumn) + final String duplicates = schema.columns().stream() + .map(Column::name) + .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())) + .entrySet() + .stream() + .filter(e -> e.getValue() > 1) + .map(Entry::getKey) .map(ColumnName::toString) .collect(Collectors.joining(", ")); diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/OutputNode.java b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/OutputNode.java index 01b1a4d28080..04607e808c15 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/OutputNode.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/OutputNode.java @@ -44,7 +44,7 @@ protected OutputNode( final OptionalInt limit, final Optional timestampColumn ) { - super(id, source.getNodeOutputType(), schema, source.getSelectExpressions()); + super(id, source.getNodeOutputType(), schema, source.getSourceName()); this.source = requireNonNull(source, "source"); this.limit = requireNonNull(limit, "limit"); diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/PlanNode.java b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/PlanNode.java index eb934467d919..d4b32345e5b2 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/PlanNode.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/PlanNode.java @@ -17,16 +17,19 @@ import static java.util.Objects.requireNonNull; -import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.Immutable; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; -import io.confluent.ksql.execution.plan.SelectExpression; +import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.metastore.model.DataSource.DataSourceType; import io.confluent.ksql.metastore.model.KeyField; +import io.confluent.ksql.name.ColumnName; +import io.confluent.ksql.name.SourceName; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.services.KafkaTopicClient; import io.confluent.ksql.structured.SchemaKStream; import java.util.List; +import java.util.Optional; +import java.util.stream.Stream; @Immutable public abstract class PlanNode { @@ -34,19 +37,18 @@ public abstract class PlanNode { private final PlanNodeId id; private final DataSourceType nodeOutputType; private final LogicalSchema schema; - private final ImmutableList selectExpressions; + private final Optional sourceName; protected PlanNode( final PlanNodeId id, final DataSourceType nodeOutputType, final LogicalSchema schema, - final List selectExpressions + final Optional sourceName ) { this.id = requireNonNull(id, "id"); this.nodeOutputType = requireNonNull(nodeOutputType, "nodeOutputType"); this.schema = requireNonNull(schema, "schema"); - this.selectExpressions = ImmutableList - .copyOf(requireNonNull(selectExpressions, "projectExpressions")); + this.sourceName = requireNonNull(sourceName, "sourceName"); } public final PlanNodeId getId() { @@ -61,10 +63,6 @@ public final LogicalSchema getSchema() { return schema; } - public final List getSelectExpressions() { - return selectExpressions; - } - public abstract KeyField getKeyField(); public abstract List getSources(); @@ -76,13 +74,51 @@ public R accept(final PlanVisitor visitor, final C context) { public DataSourceNode getTheSourceNode() { if (this instanceof DataSourceNode) { return (DataSourceNode) this; - } else if (this.getSources() != null && !this.getSources().isEmpty()) { + } else if (!getSources().isEmpty()) { return this.getSources().get(0).getTheSourceNode(); } - return null; + throw new IllegalStateException("No source node in hierarchy"); } protected abstract int getPartitions(KafkaTopicClient kafkaTopicClient); public abstract SchemaKStream buildStream(KsqlQueryBuilder builder); + + Optional getSourceName() { + return sourceName; + } + + /** + * Call to resolve an {@link io.confluent.ksql.parser.tree.AllColumns} instance into a + * corresponding set of columns. + * + * @param sourceName the name of the source + * @param valueOnly {@code false} if key & system columns should be included. + * @return the list of columns. + */ + public Stream resolveSelectStar( + final Optional sourceName, + final boolean valueOnly + ) { + return getSources().stream() + .filter(s -> !sourceName.isPresent() || sourceName.equals(s.getSourceName())) + .flatMap(s -> s.resolveSelectStar(sourceName, valueOnly)); + } + + /** + * Called to resolve the supplied {@code expression} into an expression that matches the nodes + * schema. + * + *

{@link AggregateNode} and {@link FlatMapNode} replace UDAFs and UDTFs with synthetic column + * names. Where a select is a UDAF or UDTF this method will return the appropriate synthetic + * {@link io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp} + * + * + * @param idx the index of the select within the projection. + * @param expression the expression to resolve. + * @return the resolved expression. + */ + public Expression resolveSelect(final int idx, final Expression expression) { + return expression; + } } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/ProjectNode.java b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/ProjectNode.java index 32df132305fb..5d28a55f517a 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/ProjectNode.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/ProjectNode.java @@ -18,24 +18,33 @@ import static java.util.Objects.requireNonNull; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.Immutable; import io.confluent.ksql.execution.builder.KsqlQueryBuilder; +import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.metastore.model.KeyField; import io.confluent.ksql.name.ColumnName; +import io.confluent.ksql.name.SourceName; import io.confluent.ksql.schema.ksql.Column; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.services.KafkaTopicClient; import io.confluent.ksql.structured.SchemaKStream; import io.confluent.ksql.util.KsqlException; +import java.util.Collection; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.stream.Stream; @Immutable public class ProjectNode extends PlanNode { private final PlanNode source; + private final ImmutableList projectExpressions; private final KeyField keyField; + private final ImmutableMap> aliases; public ProjectNode( final PlanNodeId id, @@ -44,11 +53,15 @@ public ProjectNode( final LogicalSchema schema, final Optional keyFieldName ) { - super(id, source.getNodeOutputType(), schema, projectExpressions); + super(id, source.getNodeOutputType(), schema, source.getSourceName()); this.source = requireNonNull(source, "source"); + this.projectExpressions = ImmutableList.copyOf( + requireNonNull(projectExpressions, "projectExpressions") + ); this.keyField = KeyField.of(requireNonNull(keyFieldName, "keyFieldName")) .validateKeyExistsIn(schema); + this.aliases = buildAliasMapping(projectExpressions); validate(); } @@ -72,6 +85,10 @@ public KeyField getKeyField() { return keyField; } + public List getSelectExpressions() { + return projectExpressions; + } + @Override public R accept(final PlanVisitor visitor, final C context) { return visitor.visitProject(this, context); @@ -81,25 +98,54 @@ public R accept(final PlanVisitor visitor, final C context) { public SchemaKStream buildStream(final KsqlQueryBuilder builder) { return getSource().buildStream(builder) .select( - getSelectExpressions(), + projectExpressions, builder.buildNodeContext(getId().toString()), builder ); } + public Stream resolveSelectStar( + final Optional sourceName, + final boolean valueOnly + ) { + return source.resolveSelectStar(sourceName, valueOnly) + .map(name -> aliases.getOrDefault(name, ImmutableList.of())) + .flatMap(Collection::stream); + } + private void validate() { - if (getSchema().value().size() != getSelectExpressions().size()) { + if (getSchema().value().size() != projectExpressions.size()) { throw new KsqlException("Error in projection. Schema fields and expression list are not " + "compatible."); } - for (int i = 0; i < getSelectExpressions().size(); i++) { + for (int i = 0; i < projectExpressions.size(); i++) { final Column column = getSchema().value().get(i); - final SelectExpression selectExpression = getSelectExpressions().get(i); + final SelectExpression selectExpression = projectExpressions.get(i); if (!column.name().equals(selectExpression.getAlias())) { throw new IllegalArgumentException("Mismatch between schema and selects"); } } } + + private static ImmutableMap> buildAliasMapping( + final List projectExpressions + ) { + final Map> aliases = new HashMap<>(); + + projectExpressions.stream() + .filter(se -> se.getExpression() instanceof ColumnReferenceExp) + .forEach(se -> aliases.computeIfAbsent( + ((ColumnReferenceExp) se.getExpression()).getColumnName(), + k -> ImmutableList.builder()) + .add(se.getAlias())); + + final ImmutableMap.Builder> builder = + ImmutableMap.builder(); + + aliases.forEach((k, v) -> builder.put(k, v.build())); + + return builder.build(); + } } diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/RepartitionNode.java b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/RepartitionNode.java index 8a08784226c7..d6d168c47473 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/RepartitionNode.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/planner/plan/RepartitionNode.java @@ -42,7 +42,7 @@ public RepartitionNode( final Expression partitionBy, final KeyField keyField ) { - super(id, source.getNodeOutputType(), schema, source.getSelectExpressions()); + super(id, source.getNodeOutputType(), schema, source.getSourceName()); this.source = requireNonNull(source, "source"); this.partitionBy = requireNonNull(partitionBy, "partitionBy"); this.keyField = requireNonNull(keyField, "keyField"); diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/analyzer/AggregateAnalyzerTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/analyzer/AggregateAnalyzerTest.java index 5e8ce2301a1e..40e3fd1f61b8 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/analyzer/AggregateAnalyzerTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/analyzer/AggregateAnalyzerTest.java @@ -17,35 +17,46 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.hasItems; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertThrows; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; +import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.expression.tree.FunctionCall; import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp; import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp; +import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.function.InternalFunctionRegistry; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.name.SourceName; +import io.confluent.ksql.parser.tree.WindowExpression; import io.confluent.ksql.util.KsqlException; import io.confluent.ksql.util.SchemaUtil; import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +@RunWith(MockitoJUnitRunner.class) public class AggregateAnalyzerTest { - private static final SourceName ORDERS = SourceName.of("ORDERS"); - - private static final QualifiedColumnReferenceExp DEFAULT_ARGUMENT = - new QualifiedColumnReferenceExp(ORDERS, SchemaUtil.ROWTIME_NAME); + private static final UnqualifiedColumnReferenceExp DEFAULT_ARGUMENT = + new UnqualifiedColumnReferenceExp(SchemaUtil.ROWTIME_NAME); private static final UnqualifiedColumnReferenceExp COL0 = new UnqualifiedColumnReferenceExp(ColumnName.of("COL0")); @@ -62,152 +73,213 @@ public class AggregateAnalyzerTest { private static final FunctionCall AGG_FUNCTION_CALL = new FunctionCall(FunctionName.of("MAX"), ImmutableList.of(COL0, COL1)); + private static final FunctionCall REQUIRED_AGG_FUNC_CALL = new FunctionCall( + FunctionName.of("MAX"), + ImmutableList.of(new UnqualifiedColumnReferenceExp(ColumnName.of("AGG_COL"))) + ); + + @Mock + private ImmutableAnalysis analysis; + private final InternalFunctionRegistry functionRegistry = new InternalFunctionRegistry(); - private MutableAggregateAnalysis analysis; private AggregateAnalyzer analyzer; - @Rule - public final ExpectedException expectedException = ExpectedException.none(); + private List selects; @Before public void init() { - analysis = new MutableAggregateAnalysis(); - analyzer = new AggregateAnalyzer(analysis, DEFAULT_ARGUMENT, false, functionRegistry); + analyzer = new AggregateAnalyzer(functionRegistry); + + givenGroupByExpressions(COL0, COL1); + + selects = new ArrayList<>(); + // Aggregate requires at least one aggregation column: + selects.add(SelectExpression.of(ColumnName.of("AGG_COLUMN"), REQUIRED_AGG_FUNC_CALL)); + + when(analysis.getDefaultArgument()).thenReturn(DEFAULT_ARGUMENT); } @Test public void shouldCaptureSelectNonAggregateFunctionArguments() { + // Given: + givenSelectExpression(FUNCTION_CALL); + // When: - analyzer.processSelect(FUNCTION_CALL); + final MutableAggregateAnalysis result = (MutableAggregateAnalysis) analyzer + .analyze(analysis, selects); // Then: - assertThat(analysis.getNonAggregateSelectExpressions().get(FUNCTION_CALL), contains(COL0)); + assertThat(result.getNonAggregateSelectExpressions().get(FUNCTION_CALL), contains(COL0)); } @Test public void shouldCaptureSelectDereferencedExpression() { + // Given: + givenSelectExpression(COL0); + // When: - analyzer.processSelect(COL0); + final MutableAggregateAnalysis result = (MutableAggregateAnalysis) analyzer + .analyze(analysis, selects); // Then: - assertThat(analysis.getNonAggregateSelectExpressions().get(COL0), contains(COL0)); + assertThat(result.getNonAggregateSelectExpressions().get(COL0), contains(COL0)); } @Test public void shouldCaptureOtherSelectsWithEmptySet() { // Given: final Expression someExpression = mock(Expression.class); + givenSelectExpression(someExpression); // When: - analyzer.processSelect(someExpression); + final MutableAggregateAnalysis result = (MutableAggregateAnalysis) analyzer + .analyze(analysis, selects); // Then: - assertThat(analysis.getNonAggregateSelectExpressions().get(someExpression), is(empty())); + assertThat(result.getNonAggregateSelectExpressions().get(someExpression), is(empty())); } @Test public void shouldNotCaptureOtherNonAggregateFunctionArgumentsAsNonAggSelectColumns() { + // Given: + givenGroupByExpressions(COL0); + + givenHavingExpression(FUNCTION_CALL); + // When: - analyzer.processGroupBy(FUNCTION_CALL); - analyzer.processHaving(FUNCTION_CALL); + final MutableAggregateAnalysis result = (MutableAggregateAnalysis) analyzer + .analyze(analysis, selects); // Then: - assertThat(analysis.getNonAggregateSelectExpressions().keySet(), is(empty())); + assertThat(result.getNonAggregateSelectExpressions().keySet(), is(empty())); } @Test public void shouldNotCaptureAggregateFunctionArgumentsAsNonAggSelectColumns() { + // Given: + givenHavingExpression(AGG_FUNCTION_CALL); + // When: - analyzer.processSelect(AGG_FUNCTION_CALL); - analyzer.processHaving(AGG_FUNCTION_CALL); + final MutableAggregateAnalysis result = (MutableAggregateAnalysis) analyzer + .analyze(analysis, selects); // Then: - assertThat(analysis.getNonAggregateSelectExpressions().keySet(), is(empty())); + assertThat(result.getNonAggregateSelectExpressions().keySet(), is(empty())); } @Test public void shouldThrowOnGroupByAggregateFunction() { - // Then: - expectedException.expect(KsqlException.class); - expectedException.expectMessage( - "GROUP BY does not support aggregate functions: MAX is an aggregate function."); + // Given: + givenGroupByExpressions(AGG_FUNCTION_CALL); // When: - analyzer.processGroupBy(AGG_FUNCTION_CALL); - } - - @Test - public void shouldCaptureSelectNonAggregateFunctionArgumentsAsRequired() { - // When: - analyzer.processSelect(FUNCTION_CALL); + final KsqlException e = assertThrows(KsqlException.class, + () -> analyzer.analyze(analysis, selects)); // Then: - assertThat(analysis.getRequiredColumns(), contains(COL0)); + assertThat(e.getMessage(), containsString( + "GROUP BY does not support aggregate functions: MAX is an aggregate function.")); } @Test public void shouldCaptureHavingNonAggregateFunctionArgumentsAsRequired() { + // Given: + when(analysis.getHavingExpression()).thenReturn(Optional.of( + new FunctionCall(FunctionName.of("MAX"), + ImmutableList.of(COL2)) + )); + // When: - analyzer.processHaving(FUNCTION_CALL); + final AggregateAnalysisResult result = analyzer + .analyze(analysis, selects); // Then: - assertThat(analysis.getRequiredColumns(), contains(COL0)); + assertThat(result.getRequiredColumns(), hasItem(COL2)); } @Test public void shouldCaptureGroupByNonAggregateFunctionArgumentsAsRequired() { + // Given: + givenGroupByExpressions(COL0, COL1); + // When: - analyzer.processGroupBy(FUNCTION_CALL); + final AggregateAnalysisResult result = analyzer.analyze(analysis, selects); // Then: - assertThat(analysis.getRequiredColumns(), contains(COL0)); + assertThat(result.getRequiredColumns(), hasItems(COL0, COL1)); } @Test public void shouldCaptureSelectAggregateFunctionArgumentsAsRequired() { + // Given: + givenSelectExpression(AGG_FUNCTION_CALL); + // When: - analyzer.processSelect(AGG_FUNCTION_CALL); + final AggregateAnalysisResult result = analyzer.analyze(analysis, selects); // Then: - assertThat(analysis.getRequiredColumns(), contains(COL0, COL1)); + assertThat(result.getRequiredColumns(), hasItems(COL0, COL1)); } @Test public void shouldCaptureHavingAggregateFunctionArgumentsAsRequired() { + // Given: + givenHavingExpression(AGG_FUNCTION_CALL); + // When: - analyzer.processHaving(AGG_FUNCTION_CALL); + final AggregateAnalysisResult result = analyzer.analyze(analysis, selects); // Then: - assertThat(analysis.getRequiredColumns(), contains(COL0, COL1)); + assertThat(result.getRequiredColumns(), hasItems(COL0, COL1)); } @Test public void shouldNotCaptureNonAggregateFunction() { + // given: + givenSelectExpression(FUNCTION_CALL); + givenHavingExpression(FUNCTION_CALL); + // When: - analyzer.processSelect(FUNCTION_CALL); - analyzer.processHaving(FUNCTION_CALL); - analyzer.processGroupBy(FUNCTION_CALL); + final AggregateAnalysisResult result = analyzer.analyze(analysis, selects); // Then: - assertThat(analysis.getAggregateFunctions(), is(empty())); + assertThat(result.getAggregateFunctions(), contains(REQUIRED_AGG_FUNC_CALL)); + } + + @Test + public void shouldNotCaptureNonAggregateGroupByFunction() { + // given: + givenGroupByExpressions(FUNCTION_CALL); + + // When: + final AggregateAnalysisResult result = analyzer.analyze(analysis, selects); + + // Then: + assertThat(result.getAggregateFunctions(), contains(REQUIRED_AGG_FUNC_CALL)); } @Test public void shouldCaptureSelectAggregateFunction() { + // Given: + givenSelectExpression(AGG_FUNCTION_CALL); + // When: - analyzer.processSelect(AGG_FUNCTION_CALL); + final AggregateAnalysisResult result = analyzer.analyze(analysis, selects); // Then: - assertThat(analysis.getAggregateFunctions(), contains(AGG_FUNCTION_CALL)); + assertThat(result.getAggregateFunctions(), hasItem(AGG_FUNCTION_CALL)); } @Test public void shouldCaptureHavingAggregateFunction() { + // Given: + givenHavingExpression(AGG_FUNCTION_CALL); + // When: - analyzer.processHaving(AGG_FUNCTION_CALL); + final AggregateAnalysisResult result = analyzer.analyze(analysis, selects); // Then: - assertThat(analysis.getAggregateFunctions(), contains(AGG_FUNCTION_CALL)); + assertThat(result.getAggregateFunctions(), hasItem(AGG_FUNCTION_CALL)); } @Test @@ -216,12 +288,15 @@ public void shouldThrowOnNestedSelectAggFunctions() { final FunctionCall nestedCall = new FunctionCall(FunctionName.of("MIN"), ImmutableList.of(AGG_FUNCTION_CALL, COL2)); - // Then: - expectedException.expect(KsqlException.class); - expectedException.expectMessage("Aggregate functions can not be nested: MIN(MAX())"); + givenSelectExpression(nestedCall); // When: - analyzer.processSelect(nestedCall); + final KsqlException e = assertThrows(KsqlException.class, + () -> analyzer.analyze(analysis, selects)); + + // Then: + assertThat(e.getMessage(), + containsString("Aggregate functions can not be nested: MIN(MAX())")); } @Test @@ -230,25 +305,15 @@ public void shouldThrowOnNestedHavingAggFunctions() { final FunctionCall nestedCall = new FunctionCall(FunctionName.of("MIN"), ImmutableList.of(AGG_FUNCTION_CALL, COL2)); - // Then: - expectedException.expect(KsqlException.class); - expectedException.expectMessage("Aggregate functions can not be nested: MIN(MAX())"); - - // When: - analyzer.processHaving(nestedCall); - } - - @Test - public void shouldCaptureNonAggregateFunctionArgumentsWithNestedAggFunction() { - // Given: - final FunctionCall nonAggWithNestedAggFunc = new FunctionCall(FunctionName.of("SUBSTRING"), - ImmutableList.of(COL2, AGG_FUNCTION_CALL, COL1)); + givenHavingExpression(nestedCall); // When: - analyzer.processSelect(nonAggWithNestedAggFunc); + final KsqlException e = assertThrows(KsqlException.class, + () -> analyzer.analyze(analysis, selects)); // Then: - assertThat(analysis.getAggregateSelectFields(), containsInAnyOrder(COL1, COL2)); + assertThat(e.getMessage(), + containsString("Aggregate functions can not be nested: MIN(MAX())")); } @Test @@ -260,11 +325,14 @@ public void shouldNotCaptureNonAggregateFunctionArgumentsWhenNestedInsideAggFunc final FunctionCall aggFuncWithNestedNonAgg = new FunctionCall(FunctionName.of("MAX"), ImmutableList.of(COL1, nonAggFunc)); + givenSelectExpression(aggFuncWithNestedNonAgg); + // When: - analyzer.processSelect(aggFuncWithNestedNonAgg); + final MutableAggregateAnalysis result = (MutableAggregateAnalysis) analyzer + .analyze(analysis, selects); // Then: - assertThat(analysis.getNonAggregateSelectExpressions().keySet(), is(empty())); + assertThat(result.getNonAggregateSelectExpressions().keySet(), is(empty())); } @Test @@ -272,12 +340,14 @@ public void shouldCaptureDefaultFunctionArguments() { // Given: final FunctionCall emptyFunc = new FunctionCall(FunctionName.of("COUNT"), new ArrayList<>()); + givenSelectExpression(emptyFunc); + // When: - analyzer.processSelect(emptyFunc); + final AggregateAnalysisResult result = analyzer.analyze(analysis, selects); // Then: - assertThat(analysis.getRequiredColumns(), contains(DEFAULT_ARGUMENT)); - assertThat(analysis.getAggregateFunctionArguments(), contains(DEFAULT_ARGUMENT)); + assertThat(result.getRequiredColumns(), hasItem(DEFAULT_ARGUMENT)); + assertThat(result.getAggregateFunctionArguments(), hasItem(DEFAULT_ARGUMENT)); } @Test @@ -285,36 +355,81 @@ public void shouldAddDefaultArgToFunctionCallWithNoArgs() { // Given: final FunctionCall emptyFunc = new FunctionCall(FunctionName.of("COUNT"), new ArrayList<>()); + givenSelectExpression(emptyFunc); + // When: - analyzer.processSelect(emptyFunc); + final AggregateAnalysisResult result = analyzer.analyze(analysis, selects); // Then: - assertThat(analysis.getAggregateFunctions(), hasSize(1)); - assertThat(analysis.getAggregateFunctions().get(0).getName(), is(emptyFunc.getName())); - assertThat(analysis.getAggregateFunctions().get(0).getArguments(), contains(DEFAULT_ARGUMENT)); + assertThat(result.getAggregateFunctions(), hasSize(2)); + assertThat(result.getAggregateFunctions().get(1).getName(), is(emptyFunc.getName())); + assertThat(result.getAggregateFunctions().get(1).getArguments(), contains(DEFAULT_ARGUMENT)); } @Test public void shouldNotCaptureWindowStartAsRequiredColumn() { + // Given: + givenWindowExpression(); + givenSelectExpression(new UnqualifiedColumnReferenceExp(SchemaUtil.WINDOWSTART_NAME)); + // When: - analyzer.processSelect(new QualifiedColumnReferenceExp( - ORDERS, - SchemaUtil.WINDOWSTART_NAME - )); + final AggregateAnalysisResult result = analyzer.analyze(analysis, selects); // Then: - assertThat(analysis.getRequiredColumns(), is(empty())); + final List requiredColumnNames = result.getRequiredColumns().stream() + .map(ColumnReferenceExp::getColumnName) + .collect(Collectors.toList()); + + assertThat(requiredColumnNames, not(hasItem(SchemaUtil.WINDOWSTART_NAME))); } @Test public void shouldNotCaptureWindowEndAsRequiredColumn() { + // Given: + givenWindowExpression(); + givenSelectExpression(new UnqualifiedColumnReferenceExp(SchemaUtil.WINDOWEND_NAME)); + // When: - analyzer.processSelect(new QualifiedColumnReferenceExp( - ORDERS, + final AggregateAnalysisResult result = analyzer.analyze(analysis, selects); + + // Then: + final List requiredColumnNames = result.getRequiredColumns().stream() + .map(ColumnReferenceExp::getColumnName) + .collect(Collectors.toList()); + + assertThat(requiredColumnNames, not(hasItem(SchemaUtil.WINDOWEND_NAME))); + } + + @Test + public void shouldThrowOnQualifiedColumnReference() { + // Given: + givenSelectExpression(new QualifiedColumnReferenceExp( + SourceName.of("Fred"), SchemaUtil.WINDOWEND_NAME )); - // Then: - assertThat(analysis.getRequiredColumns(), is(empty())); + // When: + assertThrows(UnsupportedOperationException.class, + () -> analyzer.analyze(analysis, selects)); + } + + private void givenSelectExpression(final Expression expression) { + selects.add(SelectExpression.of(ColumnName.of("x"), expression)); + } + + private void givenGroupByExpressions(final Expression... expressions) { + when(analysis.getGroupByExpressions()) + .thenReturn(ImmutableList.copyOf(expressions)); + } + + private void givenHavingExpression(final Expression expression) { + when(analysis.getHavingExpression()) + .thenReturn(Optional.of(expression)); + } + + private void givenWindowExpression() { + final WindowExpression windowExpression = mock(WindowExpression.class); + when(analysis.getWindowExpression()) + .thenReturn(Optional.of(windowExpression)); } } diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/analyzer/AnalyzerFunctionalTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/analyzer/AnalyzerFunctionalTest.java index 043f36a325c6..51044397f533 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/analyzer/AnalyzerFunctionalTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/analyzer/AnalyzerFunctionalTest.java @@ -15,42 +15,21 @@ package io.confluent.ksql.analyzer; -import static io.confluent.ksql.testutils.AnalysisTestUtil.analyzeQuery; -import static io.confluent.ksql.util.SchemaUtil.ROWTIME_NAME; import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.hamcrest.Matchers.hasItem; -import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; -import static org.junit.Assert.assertEquals; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; -import io.confluent.ksql.analyzer.Analysis.Into; -import io.confluent.ksql.analyzer.Analysis.JoinInfo; -import io.confluent.ksql.analyzer.Analyzer.SerdeOptionsSupplier; import io.confluent.ksql.execution.ddl.commands.KsqlTopic; -import io.confluent.ksql.execution.expression.tree.BooleanLiteral; -import io.confluent.ksql.execution.expression.tree.FunctionCall; -import io.confluent.ksql.execution.expression.tree.IntegerLiteral; -import io.confluent.ksql.execution.expression.tree.Literal; -import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp; -import io.confluent.ksql.execution.expression.tree.StringLiteral; -import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.function.InternalFunctionRegistry; import io.confluent.ksql.metastore.MetaStore; import io.confluent.ksql.metastore.MutableMetaStore; import io.confluent.ksql.metastore.model.KeyField; import io.confluent.ksql.metastore.model.KsqlStream; import io.confluent.ksql.name.ColumnName; -import io.confluent.ksql.name.FunctionName; import io.confluent.ksql.name.SourceName; import io.confluent.ksql.parser.KsqlParser.PreparedStatement; import io.confluent.ksql.parser.properties.with.CreateSourceAsProperties; @@ -58,10 +37,8 @@ import io.confluent.ksql.parser.tree.Query; import io.confluent.ksql.parser.tree.Sink; import io.confluent.ksql.parser.tree.Statement; -import io.confluent.ksql.planner.plan.JoinNode.JoinType; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.schema.ksql.types.SqlTypes; -import io.confluent.ksql.serde.Format; import io.confluent.ksql.serde.FormatFactory; import io.confluent.ksql.serde.FormatInfo; import io.confluent.ksql.serde.KeyFormat; @@ -72,10 +49,7 @@ import io.confluent.ksql.util.KsqlParserTestUtil; import io.confluent.ksql.util.MetaStoreFixture; import io.confluent.ksql.util.SchemaUtil; -import java.util.HashMap; import java.util.List; -import java.util.Map; -import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -99,11 +73,9 @@ public class AnalyzerFunctionalTest { private static final Set DEFAULT_SERDE_OPTIONS = SerdeOption.none(); - private static final SourceName TEST1 = SourceName.of("TEST1"); private static final ColumnName COL0 = ColumnName.of("COL0"); private static final ColumnName COL1 = ColumnName.of("COL1"); private static final ColumnName COL2 = ColumnName.of("COL2"); - private static final ColumnName COL3 = ColumnName.of("COL3"); private MutableMetaStore jsonMetaStore; private MutableMetaStore avroMetaStore; @@ -111,15 +83,11 @@ public class AnalyzerFunctionalTest { @Rule public final ExpectedException expectedException = ExpectedException.none(); - @Mock - private SerdeOptionsSupplier serdeOptionsSupplier; @Mock private Sink sink; private Query query; private Analyzer analyzer; - private Optional sinkFormat = Optional.empty(); - private Optional sinkWrapSingleValues = Optional.empty(); @Before public void init() { @@ -132,8 +100,7 @@ public void init() { analyzer = new Analyzer( jsonMetaStore, "", - DEFAULT_SERDE_OPTIONS, - serdeOptionsSupplier + DEFAULT_SERDE_OPTIONS ); when(sink.getName()).thenReturn(SourceName.of("TEST0")); @@ -144,166 +111,6 @@ public void init() { registerKafkaSource(); } - @Test - public void testSimpleQueryAnalysis() { - final String simpleQuery = "SELECT col0, col2, col3 FROM test1 WHERE col0 > 100 EMIT CHANGES;"; - final Analysis analysis = analyzeQuery(simpleQuery, jsonMetaStore); - assertEquals("FROM was not analyzed correctly.", - analysis.getFromDataSources().get(0).getDataSource().getName(), - TEST1); - assertThat(analysis.getWhereExpression().get().toString(), is("(TEST1.COL0 > 100)")); - - final List selects = analysis.getSelectExpressions(); - assertThat(selects.get(0).getExpression().toString(), is("TEST1.COL0")); - assertThat(selects.get(1).getExpression().toString(), is("TEST1.COL2")); - assertThat(selects.get(2).getExpression().toString(), is("TEST1.COL3")); - - assertThat(selects.get(0).getAlias(), is(COL0)); - assertThat(selects.get(1).getAlias(), is(COL2)); - assertThat(selects.get(2).getAlias(), is(COL3)); - } - - @Test - public void testSimpleLeftJoinAnalysis() { - // When: - final Analysis analysis = analyzeQuery( - "SELECT t1.col1, t2.col1, t2.col4, col5, t2.col2 " - + "FROM test1 t1 LEFT JOIN test2 t2 " - + "ON t1.col1 = t2.col1 EMIT CHANGES;", jsonMetaStore); - - // Then: - assertThat(analysis.getFromDataSources(), hasSize(2)); - assertThat(analysis.getFromDataSources().get(0).getAlias(), is(SourceName.of("T1"))); - assertThat(analysis.getFromDataSources().get(1).getAlias(), is(SourceName.of("T2"))); - - assertThat(analysis.getJoin(), is(not(Optional.empty()))); - assertThat( - analysis.getJoin().get().getLeftJoinExpression(), - is(new QualifiedColumnReferenceExp(SourceName.of("T1"), ColumnName.of("COL1"))) - ); - assertThat( - analysis.getJoin().get().getRightJoinExpression(), - is(new QualifiedColumnReferenceExp(SourceName.of("T2"), ColumnName.of("COL1"))) - ); - - final List selects = analysis.getSelectExpressions().stream() - .map(SelectExpression::getExpression) - .map(Objects::toString) - .collect(Collectors.toList()); - - assertThat(selects, contains("T1.COL1", "T2.COL1", "T2.COL4", "T1.COL5", "T2.COL2")); - - final List aliases = analysis.getSelectExpressions().stream() - .map(SelectExpression::getAlias) - .collect(Collectors.toList()); - - assertThat(aliases.stream().map(ColumnName::text).collect(Collectors.toList()), - contains("T1_COL1", "T2_COL1", "T2_COL4", "COL5", "T2_COL2")); - } - - @Test - public void testExpressionLeftJoinAnalysis() { - // When: - final Analysis analysis = analyzeQuery( - "SELECT t1.col1, t2.col1, t2.col4, col5, t2.col2 " - + "FROM test1 t1 LEFT JOIN test2 t2 " - + "ON t1.col1 = SUBSTRING(t2.col1, 2) EMIT CHANGES;", jsonMetaStore); - - // Then: - assertThat(analysis.getFromDataSources(), hasSize(2)); - assertThat(analysis.getFromDataSources().get(0).getAlias(), is(SourceName.of("T1"))); - assertThat(analysis.getFromDataSources().get(1).getAlias(), is(SourceName.of("T2"))); - - assertThat(analysis.getJoin(), is(not(Optional.empty()))); - assertThat( - analysis.getJoin().get().getLeftJoinExpression(), - is(new QualifiedColumnReferenceExp(SourceName.of("T1"), ColumnName.of("COL1")))); - assertThat( - analysis.getJoin().get().getRightJoinExpression(), - is(new FunctionCall( - FunctionName.of("SUBSTRING"), - ImmutableList.of( - new QualifiedColumnReferenceExp(SourceName.of("T2"), ColumnName.of("COL1")), - new IntegerLiteral(2) - )))); - - final List selects = analysis.getSelectExpressions().stream() - .map(SelectExpression::getExpression) - .map(Objects::toString) - .collect(Collectors.toList()); - - assertThat(selects, contains("T1.COL1", "T2.COL1", "T2.COL4", "T1.COL5", "T2.COL2")); - - final List aliases = analysis.getSelectExpressions().stream() - .map(SelectExpression::getAlias) - .collect(Collectors.toList()); - - assertThat(aliases.stream().map(ColumnName::text).collect(Collectors.toList()), - contains("T1_COL1", "T2_COL1", "T2_COL4", "COL5", "T2_COL2")); - } - - @Test - public void shouldHandleJoinOnRowKey() { - // When: - final Optional join = analyzeQuery( - "SELECT * FROM test1 t1 LEFT JOIN test2 t2 ON t1.ROWKEY = t2.ROWKEY EMIT CHANGES;", - jsonMetaStore) - .getJoin(); - - // Then: - assertThat(join, is(not(Optional.empty()))); - assertThat(join.get().getType(), is(JoinType.LEFT)); - assertThat( - join.get().getLeftJoinExpression(), - is(new QualifiedColumnReferenceExp(SourceName.of("T1"), ColumnName.of("ROWKEY")))); - assertThat( - join.get().getRightJoinExpression(), - is(new QualifiedColumnReferenceExp(SourceName.of("T2"), ColumnName.of("ROWKEY")))); - } - - @Test - public void testBooleanExpressionAnalysis() { - final String queryStr = "SELECT col0 = 10, col2, col3 > col1 FROM test1 EMIT CHANGES;"; - final Analysis analysis = analyzeQuery(queryStr, jsonMetaStore); - - assertEquals("FROM was not analyzed correctly.", - analysis.getFromDataSources().get(0).getDataSource().getName(), TEST1); - - final List selects = analysis.getSelectExpressions(); - assertThat(selects.get(0).getExpression().toString(), is("(TEST1.COL0 = 10)")); - assertThat(selects.get(1).getExpression().toString(), is("TEST1.COL2")); - assertThat(selects.get(2).getExpression().toString(), is("(TEST1.COL3 > TEST1.COL1)")); - } - - @Test - public void testFilterAnalysis() { - final String queryStr = "SELECT col0 = 10, col2, col3 > col1 FROM test1 WHERE col0 > 20 EMIT CHANGES;"; - final Analysis analysis = analyzeQuery(queryStr, jsonMetaStore); - - assertThat(analysis.getFromDataSources().get(0).getDataSource().getName(), is(TEST1)); - - final List selects = analysis.getSelectExpressions(); - assertThat(selects.get(0).getExpression().toString(), is("(TEST1.COL0 = 10)")); - assertThat(selects.get(1).getExpression().toString(), is("TEST1.COL2")); - assertThat(selects.get(2).getExpression().toString(), is("(TEST1.COL3 > TEST1.COL1)")); - assertThat(analysis.getWhereExpression().get().toString(), is("(TEST1.COL0 > 20)")); - } - - @Test - public void shouldCreateCorrectSinkKsqlTopic() { - final String simpleQuery = "CREATE STREAM FOO WITH (KAFKA_TOPIC='TEST_TOPIC1') AS SELECT col0, col2, col3 FROM test1 WHERE col0 > 100;"; - final List statements = parse(simpleQuery, jsonMetaStore); - final CreateStreamAsSelect createStreamAsSelect = (CreateStreamAsSelect) statements.get(0); - final Query query = createStreamAsSelect.getQuery(); - - final Analyzer analyzer = new Analyzer(jsonMetaStore, "", DEFAULT_SERDE_OPTIONS); - final Analysis analysis = analyzer.analyze(query, Optional.of(createStreamAsSelect.getSink())); - - final Optional into = analysis.getInto(); - assertThat(into, is(not((Optional.empty())))); - final KsqlTopic createdKsqlTopic = into.get().getKsqlTopic(); - assertThat(createdKsqlTopic.getKafkaTopicName(), is("TEST_TOPIC1")); - } @Test public void shouldUseExplicitNamespaceForAvroSchema() { @@ -445,62 +252,6 @@ public void shouldFailIfExplicitNamespaceIsProvidedButEmpty() { analyzer.analyze(query, Optional.of(createStreamAsSelect.getSink())); } - @Test - public void shouldGetSerdeOptions() { - // Given: - final Set serdeOptions = ImmutableSet.of(SerdeOption.UNWRAP_SINGLE_VALUES); - when(serdeOptionsSupplier.build(any(), any(), any(), any())).thenReturn(serdeOptions); - - givenSinkValueFormat(FormatFactory.AVRO); - givenWrapSingleValues(true); - - // When: - final Analysis result = analyzer.analyze(query, Optional.of(sink)); - - // Then: - verify(serdeOptionsSupplier).build( - ImmutableList.of("COL0", "COL1").stream().map(ColumnName::of).collect(Collectors.toList()), - FormatFactory.AVRO, - Optional.of(true), - DEFAULT_SERDE_OPTIONS); - - assertThat(result.getSerdeOptions(), is(serdeOptions)); - } - - @Test - public void shouldThrowOnGroupByIfKafkaFormat() { - // Given: - query = parseSingle("Select COL0 from KAFKA_SOURCE GROUP BY COL0;"); - - givenSinkValueFormat(FormatFactory.KAFKA); - - // Then: - expectedException.expect(KsqlException.class); - expectedException.expectMessage("Source(s) KAFKA_SOURCE are using the 'KAFKA' value format." - + " This format does not yet support GROUP BY."); - - // When: - analyzer.analyze(query, Optional.of(sink)); - } - - @Test - public void shouldThrowOnJoinIfKafkaFormat() { - // Given: - query = parseSingle("Select TEST1.COL0 from TEST1 JOIN KAFKA_SOURCE " - + "WITHIN 1 SECOND ON " - + "TEST1.COL0 = KAFKA_SOURCE.COL0;"); - - givenSinkValueFormat(FormatFactory.KAFKA); - - // Then: - expectedException.expect(KsqlException.class); - expectedException.expectMessage("Source(s) KAFKA_SOURCE are using the 'KAFKA' value format." - + " This format does not yet support JOIN."); - - // When: - analyzer.analyze(query, Optional.of(sink)); - } - @Test public void shouldCaptureProjectionColumnRefs() { // Given: @@ -510,45 +261,13 @@ public void shouldCaptureProjectionColumnRefs() { final Analysis analysis = analyzer.analyze(query, Optional.empty()); // Then: - assertThat(analysis.getSelectColumnRefs(), containsInAnyOrder( + assertThat(analysis.getSelectColumnNames(), containsInAnyOrder( COL0, COL1, COL2 )); } - @Test - public void shouldIncludeMetaColumnsForSelectStarOnContinuousQueries() { - // Given: - query = parseSingle("Select * from TEST1 EMIT CHANGES;"); - - // When: - final Analysis analysis = analyzer.analyze(query, Optional.empty()); - - // Then: - assertThat(analysis.getSelectExpressions(), hasItem( - SelectExpression.of( - ROWTIME_NAME, - new QualifiedColumnReferenceExp(TEST1, ROWTIME_NAME) - ) - )); - } - - @Test - public void shouldNotIncludeMetaColumnsForSelectStartOnStaticQueries() { - // Given: - query = parseSingle("Select * from TEST1;"); - - // When: - final Analysis analysis = analyzer.analyze(query, Optional.empty()); - - // Then: - assertThat(analysis.getSelectExpressions(), not(hasItem( - SelectExpression.of( - ROWTIME_NAME, new QualifiedColumnReferenceExp(TEST1, ROWTIME_NAME)) - ))); - } - @Test public void shouldThrowOnSelfJoin() { // Given: @@ -639,27 +358,6 @@ private T parseSingle(final String simpleQuery) { return (T) Iterables.getOnlyElement(parse(simpleQuery, jsonMetaStore)); } - private void givenSinkValueFormat(final Format format) { - this.sinkFormat = Optional.of(format); - buildProps(); - } - - private void givenWrapSingleValues(final boolean wrap) { - this.sinkWrapSingleValues = Optional.of(wrap); - buildProps(); - } - - private void buildProps() { - final Map props = new HashMap<>(); - sinkFormat.ifPresent(f -> props.put("VALUE_FORMAT", new StringLiteral(f.name()))); - sinkWrapSingleValues.ifPresent(b -> props.put("WRAP_SINGLE_VALUE", new BooleanLiteral(Boolean.toString(b)))); - - final CreateSourceAsProperties properties = CreateSourceAsProperties.from(props); - - when(sink.getProperties()).thenReturn(properties); - - } - private void registerKafkaSource() { final LogicalSchema schema = LogicalSchema.builder() .withRowTime() diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/analyzer/ColumnReferenceValidatorTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/analyzer/ColumnReferenceValidatorTest.java index 8c89b2bf7472..c0a46367a0ef 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/analyzer/ColumnReferenceValidatorTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/analyzer/ColumnReferenceValidatorTest.java @@ -25,12 +25,10 @@ import com.google.common.collect.Iterables; import io.confluent.ksql.execution.expression.tree.Expression; import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp; -import io.confluent.ksql.execution.expression.tree.StringLiteral; import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp; import io.confluent.ksql.name.ColumnName; import io.confluent.ksql.name.SourceName; import io.confluent.ksql.util.KsqlException; -import io.confluent.ksql.util.SchemaUtil; import java.util.Arrays; import java.util.Optional; import java.util.Set; @@ -46,12 +44,6 @@ @RunWith(MockitoJUnitRunner.class) public class ColumnReferenceValidatorTest { - private static final Expression WINDOW_START_EXP = new UnqualifiedColumnReferenceExp( - SchemaUtil.WINDOWSTART_NAME - ); - - private static final Expression OTHER_EXP = new StringLiteral("foo"); - @Rule public final ExpectedException expectedException = ExpectedException.none(); diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/analyzer/QueryAnalyzerFunctionalTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/analyzer/QueryAnalyzerFunctionalTest.java index 8eee43f6938a..e39d168e262a 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/analyzer/QueryAnalyzerFunctionalTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/analyzer/QueryAnalyzerFunctionalTest.java @@ -1,5 +1,5 @@ /* - * Copyright 2018 Confluent Inc. + * Copyright 2020 Confluent Inc. * * Licensed under the Confluent Community License (the "License"); you may not use * this file except in compliance with the License. You may obtain a copy of the @@ -15,37 +15,16 @@ package io.confluent.ksql.analyzer; -import static io.confluent.ksql.util.ExpressionMatchers.qualifiedNameExpressions; import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.not; -import io.confluent.ksql.analyzer.Analysis.AliasedDataSource; -import io.confluent.ksql.analyzer.Analysis.Into; -import io.confluent.ksql.execution.expression.tree.ComparisonExpression; -import io.confluent.ksql.execution.expression.tree.Expression; -import io.confluent.ksql.execution.expression.tree.IntegerLiteral; -import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp; -import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp; -import io.confluent.ksql.execution.plan.SelectExpression; import io.confluent.ksql.function.InternalFunctionRegistry; import io.confluent.ksql.function.UserFunctionLoader; import io.confluent.ksql.metastore.MetaStore; -import io.confluent.ksql.metastore.model.DataSource; -import io.confluent.ksql.metastore.model.KsqlStream; -import io.confluent.ksql.metastore.model.KsqlTable; -import io.confluent.ksql.name.ColumnName; -import io.confluent.ksql.name.SourceName; import io.confluent.ksql.parser.KsqlParser.PreparedStatement; import io.confluent.ksql.parser.tree.CreateStreamAsSelect; -import io.confluent.ksql.parser.tree.CreateTableAsSelect; -import io.confluent.ksql.parser.tree.InsertInto; import io.confluent.ksql.parser.tree.Query; import io.confluent.ksql.parser.tree.Sink; import io.confluent.ksql.serde.FormatFactory; @@ -54,8 +33,6 @@ import io.confluent.ksql.util.KsqlParserTestUtil; import io.confluent.ksql.util.MetaStoreFixture; import java.io.File; -import java.util.Arrays; -import java.util.Collections; import java.util.Optional; import org.junit.Rule; import org.junit.Test; @@ -71,27 +48,6 @@ @SuppressWarnings("OptionalGetWithoutIsPresent") public class QueryAnalyzerFunctionalTest { - private static final SourceName ORDERS = SourceName.of("ORDERS"); - private static final SourceName TEST1 = SourceName.of("TEST1"); - - private static final QualifiedColumnReferenceExp ITEM_ID = - new QualifiedColumnReferenceExp(ORDERS, ColumnName.of("ITEMID")); - - private static final QualifiedColumnReferenceExp ORDER_ID = - new QualifiedColumnReferenceExp( - ORDERS, - ColumnName.of("ORDERID") - ); - - private static final QualifiedColumnReferenceExp ORDER_UNITS = - new QualifiedColumnReferenceExp( - ORDERS, - ColumnName.of("ORDERUNITS") - ); - - private static final QualifiedColumnReferenceExp TEST_COL1 = - new QualifiedColumnReferenceExp(TEST1, ColumnName.of("COL1")); - @Rule public final ExpectedException expectedException = ExpectedException.none(); @@ -100,110 +56,6 @@ public class QueryAnalyzerFunctionalTest { private final QueryAnalyzer queryAnalyzer = new QueryAnalyzer(metaStore, "prefix-~", SerdeOption.none()); - @Test - public void shouldCreateAnalysisForSimpleQuery() { - // Given: - final Query query = givenQuery("select orderid from orders EMIT CHANGES;"); - - // When: - final Analysis analysis = queryAnalyzer.analyze(query, Optional.empty()); - - // Then: - final AliasedDataSource fromDataSource = analysis.getFromDataSources().get(0); - assertThat( - analysis.getSelectExpressions(), - contains(SelectExpression.of(ColumnName.of("ORDERID"), ORDER_ID)) - ); - assertThat(analysis.getFromDataSources(), hasSize(1)); - assertThat(fromDataSource.getDataSource(), instanceOf(KsqlStream.class)); - assertThat(fromDataSource.getAlias(), equalTo(SourceName.of("ORDERS"))); - } - - @Test - public void shouldCreateAnalysisForCsas() { - // Given: - final PreparedStatement statement = KsqlParserTestUtil.buildSingleAst( - "create stream s as select col1 from test1 EMIT CHANGES;", metaStore); - final Query query = statement.getStatement().getQuery(); - final Optional sink = Optional.of(statement.getStatement().getSink()); - - // When: - final Analysis analysis = queryAnalyzer.analyze(query, sink); - - // Then: - assertThat( - analysis.getSelectExpressions(), - contains(SelectExpression.of(ColumnName.of("COL1"), TEST_COL1)) - ); - - assertThat(analysis.getFromDataSources(), hasSize(1)); - - final AliasedDataSource fromDataSource = analysis.getFromDataSources().get(0); - assertThat(fromDataSource.getDataSource(), instanceOf(KsqlStream.class)); - assertThat(fromDataSource.getAlias(), equalTo(SourceName.of("TEST1"))); - assertThat(analysis.getInto().get().getName(), is(SourceName.of("S"))); - } - - @Test - public void shouldCreateAnalysisForCtas() { - // Given: - final PreparedStatement statement = KsqlParserTestUtil.buildSingleAst( - "create table t as select col1 from test2 EMIT CHANGES;", metaStore); - final Query query = statement.getStatement().getQuery(); - final Optional sink = Optional.of(statement.getStatement().getSink()); - - // When: - final Analysis analysis = queryAnalyzer.analyze(query, sink); - - // Then: - assertThat( - analysis.getSelectExpressions(), - contains(SelectExpression.of( - ColumnName.of("COL1"), - new QualifiedColumnReferenceExp( - SourceName.of("TEST2"), - ColumnName.of("COL1") - ) - )) - ); - - assertThat(analysis.getFromDataSources(), hasSize(1)); - - final AliasedDataSource fromDataSource = analysis.getFromDataSources().get(0); - assertThat(fromDataSource.getDataSource(), instanceOf(KsqlTable.class)); - assertThat(fromDataSource.getAlias(), equalTo(SourceName.of("TEST2"))); - assertThat(analysis.getInto().get().getName(), is(SourceName.of("T"))); - } - - @Test - public void shouldCreateAnalysisForInsertInto() { - // Given: - final PreparedStatement statement = KsqlParserTestUtil.buildSingleAst( - "insert into test0 select col1 from test1 EMIT CHANGES;", metaStore); - final Query query = statement.getStatement().getQuery(); - final Optional sink = Optional.of(statement.getStatement().getSink()); - - // When: - final Analysis analysis = queryAnalyzer.analyze(query, sink); - - // Then: - assertThat( - analysis.getSelectExpressions(), - contains(SelectExpression.of(ColumnName.of("COL1"), TEST_COL1)) - ); - - assertThat(analysis.getFromDataSources(), hasSize(1)); - - final AliasedDataSource fromDataSource = analysis.getFromDataSources().get(0); - assertThat(fromDataSource.getDataSource(), instanceOf(KsqlStream.class)); - assertThat(fromDataSource.getAlias(), equalTo(SourceName.of("TEST1"))); - assertThat(analysis.getInto(), is(not(Optional.empty()))); - final Into into = analysis.getInto().get(); - final DataSource test0 = metaStore.getSource(SourceName.of("TEST0")); - assertThat(into.getName(), is(test0.getName())); - assertThat(into.getKsqlTopic(), is(test0.getKsqlTopic())); - } - @Test public void shouldAnalyseTableFunctions() { @@ -225,290 +77,18 @@ public void shouldAnalyseTableFunctions() { assertThat(analysis.getTableFunctions(), hasSize(1)); assertThat(analysis.getTableFunctions().get(0).getName().text(), equalTo("EXPLODE")); } - - @Test - public void shouldAnalyseWindowedAggregate() { - // Given: - final Query query = givenQuery( - "select itemid, sum(orderunits) from orders window TUMBLING ( size 30 second) " + - "where orderunits > 5 group by itemid EMIT CHANGES;"); - - // When: - final Analysis analysis = queryAnalyzer.analyze(query, Optional.empty()); - final AggregateAnalysis aggregateAnalysis = queryAnalyzer.analyzeAggregate(query, analysis); - - // Then: - assertThat(aggregateAnalysis.getNonAggregateSelectExpressions().get(ITEM_ID), contains(ITEM_ID)); - assertThat(aggregateAnalysis.getFinalSelectExpressions(), equalTo(Arrays.asList(ITEM_ID, new UnqualifiedColumnReferenceExp( - ColumnName.of("KSQL_AGG_VARIABLE_0"))))); - assertThat(aggregateAnalysis.getAggregateFunctionArguments(), equalTo(Collections.singletonList(ORDER_UNITS))); - assertThat(aggregateAnalysis.getRequiredColumns(), containsInAnyOrder(ITEM_ID, ORDER_UNITS)); - } @Test - public void shouldThrowIfAggregateAnalysisDoesNotHaveGroupBy() { + public void shouldThrowIfUdafsAndNoGroupBy() { // Given: final Query query = givenQuery("select itemid, sum(orderunits) from orders EMIT CHANGES;"); - final Analysis analysis = queryAnalyzer.analyze(query, Optional.empty()); - - expectedException.expect(KsqlException.class); - expectedException.expectMessage( - "Use of aggregate functions requires a GROUP BY clause. Aggregate function(s): SUM"); - - // When: - queryAnalyzer.analyzeAggregate(query, analysis); - } - - @Test - public void shouldThrowOnAdditionalNonAggregateSelects() { - // Given: - final Query query = givenQuery( - "select itemid, orderid, sum(orderunits) from orders group by itemid EMIT CHANGES;"); - - final Analysis analysis = queryAnalyzer.analyze(query, Optional.empty()); - - expectedException.expect(KsqlException.class); - expectedException.expectMessage( - "Non-aggregate SELECT expression(s) not part of GROUP BY: [ORDERS.ORDERID]"); - - // When: - queryAnalyzer.analyzeAggregate(query, analysis); - } - - @Test - public void shouldThrowOnAdditionalNonAggregateHavings() { - // Given: - final Query query = givenQuery( - "select sum(orderunits) from orders group by itemid having orderid = 1 EMIT CHANGES;"); - - final Analysis analysis = queryAnalyzer.analyze(query, Optional.empty()); - - expectedException.expect(KsqlException.class); - expectedException - .expectMessage("Non-aggregate HAVING expression not part of GROUP BY: [ORDERS.ORDERID]"); - - // When: - queryAnalyzer.analyzeAggregate(query, analysis); - } - - @Test - public void shouldProcessGroupByExpression() { - // Given: - final Query query = givenQuery( - "select sum(orderunits) from orders group by itemid EMIT CHANGES;"); - - final Analysis analysis = queryAnalyzer.analyze(query, Optional.empty()); - - // When: - final AggregateAnalysis aggregateAnalysis = queryAnalyzer.analyzeAggregate(query, analysis); - - // Then: - assertThat(aggregateAnalysis.getRequiredColumns(), hasItem(ITEM_ID)); - } - - @Test - public void shouldProcessGroupByArithmetic() { - // Given: - final Query query = givenQuery( - "select sum(orderunits) from orders group by itemid + 1 EMIT CHANGES;"); - - final Analysis analysis = queryAnalyzer.analyze(query, Optional.empty()); - - // When: - final AggregateAnalysis aggregateAnalysis = queryAnalyzer.analyzeAggregate(query, analysis); - - // Then: - assertThat(aggregateAnalysis.getRequiredColumns(), hasItem(ITEM_ID)); - } - - @Test - public void shouldProcessGroupByFunction() { - // Given: - final Query query = givenQuery( - "select sum(orderunits) from orders group by ucase(itemid) EMIT CHANGES;"); - - final Analysis analysis = queryAnalyzer.analyze(query, Optional.empty()); - - // When: - final AggregateAnalysis aggregateAnalysis = queryAnalyzer.analyzeAggregate(query, analysis); - - // Then: - assertThat(aggregateAnalysis.getRequiredColumns(), hasItem(ITEM_ID)); - } - - @Test - public void shouldProcessGroupByConstant() { - // Given: - final Query query = givenQuery( - "select sum(orderunits) from orders group by 1 EMIT CHANGES;"); - - final Analysis analysis = queryAnalyzer.analyze(query, Optional.empty()); - - // When: - queryAnalyzer.analyzeAggregate(query, analysis); - - // Then: did not throw. - } - - @Test - public void shouldThrowIfGroupByAggFunction() { - // Given: - final Query query = givenQuery( - "select sum(orderunits) from orders group by sum(orderid) EMIT CHANGES;"); - - final Analysis analysis = queryAnalyzer.analyze(query, Optional.empty()); - - // Then: - expectedException.expect(KsqlException.class); - expectedException.expectMessage( - "GROUP BY does not support aggregate functions: SUM is an aggregate function."); - - // When: - queryAnalyzer.analyzeAggregate(query, analysis); - } - - @Test - public void shouldProcessHavingExpression() { - // Given: - final Query query = givenQuery( - "select itemid, sum(orderunits) from orders window TUMBLING ( size 30 second) " + - "where orderunits > 5 group by itemid having count(itemid) > 10 EMIT CHANGES;"); - - final Analysis analysis = queryAnalyzer.analyze(query, Optional.empty()); - - // When: - final AggregateAnalysis aggregateAnalysis = queryAnalyzer.analyzeAggregate(query, analysis); - - // Then: - final Expression havingExpression = aggregateAnalysis.getHavingExpression().get(); - assertThat(havingExpression, equalTo(new ComparisonExpression( - ComparisonExpression.Type.GREATER_THAN, - new UnqualifiedColumnReferenceExp(ColumnName.of("KSQL_AGG_VARIABLE_1")), - new IntegerLiteral(10)))); - } - - @Test - public void shouldFailOnSelectStarWithGroupBy() { - // Given: - final Query query = givenQuery("select *, count() from orders group by itemid EMIT CHANGES;"); - final Analysis analysis = queryAnalyzer.analyze(query, Optional.empty()); - expectedException.expect(KsqlException.class); expectedException.expectMessage( - "Non-aggregate SELECT expression(s) not part of GROUP BY: " - + "[ORDERS.ADDRESS, ORDERS.ARRAYCOL, ORDERS.ITEMINFO, ORDERS.MAPCOL, ORDERS.ORDERID, " - + "ORDERS.ORDERTIME, ORDERS.ORDERUNITS, ORDERS.ROWKEY, ORDERS.ROWTIME]" - ); - - // When: - queryAnalyzer.analyzeAggregate(query, analysis); - } - - @Test - public void shouldHandleSelectStarWithCorrectGroupBy() { - // Given: - final Query query = givenQuery("select *, count() from orders group by " - + "ROWTIME, ROWKEY, ITEMID, ORDERTIME, ORDERUNITS, MAPCOL, ORDERID, ITEMINFO, ARRAYCOL, ADDRESS" - + " EMIT CHANGES;"); - - final Analysis analysis = queryAnalyzer.analyze(query, Optional.empty()); - - // When: - final AggregateAnalysis aggregateAnalysis = queryAnalyzer.analyzeAggregate(query, analysis); - - // Then: - assertThat(aggregateAnalysis.getNonAggregateSelectExpressions().keySet(), containsInAnyOrder( - qualifiedNameExpressions( - "ORDERS.ROWTIME", "ORDERS.ROWKEY", "ORDERS.ITEMID", "ORDERS.ORDERTIME", - "ORDERS.ORDERUNITS", "ORDERS.MAPCOL", "ORDERS.ORDERID", "ORDERS.ITEMINFO", - "ORDERS.ARRAYCOL", "ORDERS.ADDRESS") - )); - } - - @Test - public void shouldThrowIfSelectContainsUdfNotInGroupBy() { - // Given: - final Query query = givenQuery("select substring(orderid, 1, 2), count(*) " - + "from orders group by substring(orderid, 2, 5) EMIT CHANGES;"); - - final Analysis analysis = queryAnalyzer.analyze(query, Optional.empty()); - - expectedException.expect(KsqlException.class); - expectedException.expectMessage( - "Non-aggregate SELECT expression(s) not part of GROUP BY: [SUBSTRING(ORDERS.ORDERID, 1, 2)]" - ); - - // When: - queryAnalyzer.analyzeAggregate(query, analysis); - } - - @Test - public void shouldThrowIfSelectContainsReversedStringConcatExpression() { - // Given: - final Query query = givenQuery("select itemid + address->street, count(*) " - + "from orders group by address->street + itemid EMIT CHANGES;"); - - final Analysis analysis = queryAnalyzer.analyze(query, Optional.empty()); - - expectedException.expect(KsqlException.class); - expectedException.expectMessage( - "Non-aggregate SELECT expression(s) not part of GROUP BY: " - + "[(ORDERS.ITEMID + ORDERS.ADDRESS->STREET)]" - ); - - // When: - queryAnalyzer.analyzeAggregate(query, analysis); - } - - @Test - public void shouldThrowIfSelectContainsFieldsUsedInExpressionInGroupBy() { - // Given: - final Query query = givenQuery("select orderId, count(*) " - + "from orders group by orderid + orderunits EMIT CHANGES;"); - - final Analysis analysis = queryAnalyzer.analyze(query, Optional.empty()); - - expectedException.expect(KsqlException.class); - expectedException.expectMessage( - "Non-aggregate SELECT expression(s) not part of GROUP BY: [ORDERS.ORDERID]" - ); - - // When: - queryAnalyzer.analyzeAggregate(query, analysis); - } - - @Test - public void shouldThrowIfSelectContainsIncompatibleBinaryArithmetic() { - // Given: - final Query query = givenQuery("SELECT orderId - ordertime, COUNT(*) " - + "FROM ORDERS GROUP BY ordertime - orderId EMIT CHANGES;"); - - final Analysis analysis = queryAnalyzer.analyze(query, Optional.empty()); - - expectedException.expect(KsqlException.class); - expectedException.expectMessage( - "Non-aggregate SELECT expression(s) not part of GROUP BY: " - + "[(ORDERS.ORDERID - ORDERS.ORDERTIME)]" - ); - - // When: - queryAnalyzer.analyzeAggregate(query, analysis); - } - - @Test - public void shouldThrowIfGroupByMissingAggregateSelectExpressions() { - // Given: - final Query query = givenQuery("select orderid from orders group by orderid EMIT CHANGES;"); - final Analysis analysis = queryAnalyzer.analyze(query, Optional.empty()); - - expectedException.expect(KsqlException.class); - expectedException.expectMessage( - "GROUP BY requires columns using aggregate functions in SELECT clause." - ); + "Use of aggregate function SUM requires a GROUP BY clause."); // When: - queryAnalyzer.analyzeAggregate(query, analysis); + queryAnalyzer.analyze(query, Optional.empty()); } @Test diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/analyzer/QueryAnalyzerTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/analyzer/QueryAnalyzerTest.java index 6489f92d6bc3..771f74f22b5b 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/analyzer/QueryAnalyzerTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/analyzer/QueryAnalyzerTest.java @@ -20,7 +20,6 @@ import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; -import io.confluent.ksql.metastore.MetaStore; import io.confluent.ksql.parser.tree.Query; import io.confluent.ksql.parser.tree.Sink; import java.util.Optional; @@ -38,8 +37,6 @@ public class QueryAnalyzerTest { @Rule public final ExpectedException expectedException = ExpectedException.none(); - @Mock - private MetaStore metaStore; @Mock private Analyzer analyzer; @Mock @@ -57,7 +54,6 @@ public class QueryAnalyzerTest { @Before public void setUp() { queryAnalyzer = new QueryAnalyzer( - metaStore, analyzer, continuousValidator, staticValidator diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/codegen/CodeGenRunnerTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/codegen/CodeGenRunnerTest.java deleted file mode 100644 index e7b7759e7186..000000000000 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/codegen/CodeGenRunnerTest.java +++ /dev/null @@ -1,1086 +0,0 @@ -/* - * Copyright 2018 Confluent Inc. - * - * Licensed under the Confluent Community License (the "License"); you may not use - * this file except in compliance with the License. You may obtain a copy of the - * License at - * - * http://www.confluent.io/confluent-community-license - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package io.confluent.ksql.codegen; - -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.both; -import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.greaterThanOrEqualTo; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.lessThanOrEqualTo; -import static org.hamcrest.Matchers.nullValue; -import static org.junit.internal.matchers.ThrowableMessageMatcher.hasMessage; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import io.confluent.ksql.GenericRow; -import io.confluent.ksql.analyzer.Analysis; -import io.confluent.ksql.analyzer.ImmutableAnalysis; -import io.confluent.ksql.analyzer.RewrittenAnalysis; -import io.confluent.ksql.engine.rewrite.ExpressionTreeRewriter.Context; -import io.confluent.ksql.execution.codegen.CodeGenRunner; -import io.confluent.ksql.execution.codegen.ExpressionMetadata; -import io.confluent.ksql.execution.ddl.commands.KsqlTopic; -import io.confluent.ksql.execution.expression.tree.Expression; -import io.confluent.ksql.execution.expression.tree.QualifiedColumnReferenceExp; -import io.confluent.ksql.execution.expression.tree.UnqualifiedColumnReferenceExp; -import io.confluent.ksql.execution.expression.tree.VisitParentExpressionVisitor; -import io.confluent.ksql.function.InternalFunctionRegistry; -import io.confluent.ksql.function.KsqlScalarFunction; -import io.confluent.ksql.function.MutableFunctionRegistry; -import io.confluent.ksql.function.UdfLoaderUtil; -import io.confluent.ksql.function.types.ParamTypes; -import io.confluent.ksql.function.udf.Kudf; -import io.confluent.ksql.metastore.MetaStore; -import io.confluent.ksql.metastore.MutableMetaStore; -import io.confluent.ksql.metastore.model.KeyField; -import io.confluent.ksql.metastore.model.KsqlStream; -import io.confluent.ksql.name.ColumnName; -import io.confluent.ksql.name.FunctionName; -import io.confluent.ksql.name.SourceName; -import io.confluent.ksql.schema.ksql.LogicalSchema; -import io.confluent.ksql.schema.ksql.SchemaConverters; -import io.confluent.ksql.schema.ksql.types.SqlTypes; -import io.confluent.ksql.serde.FormatFactory; -import io.confluent.ksql.serde.FormatInfo; -import io.confluent.ksql.serde.KeyFormat; -import io.confluent.ksql.serde.SerdeOption; -import io.confluent.ksql.serde.ValueFormat; -import io.confluent.ksql.testutils.AnalysisTestUtil; -import io.confluent.ksql.util.KsqlConfig; -import io.confluent.ksql.util.KsqlException; -import io.confluent.ksql.util.MetaStoreFixture; -import java.math.BigDecimal; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; -import org.apache.kafka.connect.data.Schema; -import org.apache.kafka.connect.data.Struct; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; - - -@SuppressWarnings({"SameParameterValue", "OptionalGetWithoutIsPresent"}) -public class CodeGenRunnerTest { - - private static final String COL_INVALID_JAVA = "col!Invalid:("; - - private static final LogicalSchema META_STORE_SCHEMA = LogicalSchema.builder() - .keyColumn(ColumnName.of("K0"), SqlTypes.BIGINT) - .valueColumn(ColumnName.of("COL0"), SqlTypes.BIGINT) - .valueColumn(ColumnName.of("COL1"), SqlTypes.STRING) - .valueColumn(ColumnName.of("COL2"), SqlTypes.STRING) - .valueColumn(ColumnName.of("COL3"), SqlTypes.DOUBLE) - .valueColumn(ColumnName.of("COL4"), SqlTypes.DOUBLE) - .valueColumn(ColumnName.of("COL5"), SqlTypes.INTEGER) - .valueColumn(ColumnName.of("COL6"), SqlTypes.BOOLEAN) - .valueColumn(ColumnName.of("COL7"), SqlTypes.BOOLEAN) - .valueColumn(ColumnName.of("COL8"), SqlTypes.BIGINT) - .valueColumn(ColumnName.of("COL9"), SqlTypes.array(SqlTypes.INTEGER)) - .valueColumn(ColumnName.of("COL10"), SqlTypes.array(SqlTypes.INTEGER)) - .valueColumn(ColumnName.of("COL11"), SqlTypes.map(SqlTypes.STRING)) - .valueColumn(ColumnName.of("COL12"), SqlTypes.map(SqlTypes.INTEGER)) - .valueColumn(ColumnName.of("COL13"), SqlTypes.array(SqlTypes.STRING)) - .valueColumn(ColumnName.of("COL14"), SqlTypes.array(SqlTypes.array(SqlTypes.STRING))) - .valueColumn(ColumnName.of("COL15"), SqlTypes - .struct() - .field("A", SqlTypes.STRING) - .build()) - .valueColumn(ColumnName.of("COL16"), SqlTypes.decimal(10, 10)) - .valueColumn(ColumnName.of(COL_INVALID_JAVA), SqlTypes.BIGINT) - .build(); - - private static final int INT64_INDEX1 = 0; - private static final int STRING_INDEX1 = 1; - private static final int STRING_INDEX2 = 2; - private static final int FLOAT64_INDEX1 = 3; - private static final int FLOAT64_INDEX2 = 4; - private static final int INT32_INDEX1 = 5; - private static final int BOOLEAN_INDEX1 = 6; - private static final int BOOLEAN_INDEX2 = 7; - private static final int INT64_INDEX2 = 8; - private static final int ARRAY_INDEX1 = 9; - private static final int ARRAY_INDEX2 = 10; - private static final int MAP_INDEX1 = 11; - private static final int MAP_INDEX2 = 12; - private static final int STRUCT_INDEX = 15; - private static final int DECIMAL_INDEX = 16; - private static final int INVALID_JAVA_IDENTIFIER_INDEX = 17; - - private static final Schema STRUCT_SCHEMA = SchemaConverters.sqlToConnectConverter() - .toConnectSchema( - META_STORE_SCHEMA.findValueColumn(ColumnName.of("COL15")) - .get() - .type()); - - private static final List ONE_ROW = ImmutableList.of( - 0L, "S1", "S2", 3.1, 4.2, 5, true, false, 8L, - ImmutableList.of(1, 2), ImmutableList.of(2, 4), - ImmutableMap.of("key1", "value1", "address", "{\"city\":\"adelaide\",\"country\":\"oz\"}"), - ImmutableMap.of("k1", 4), - ImmutableList.of("one", "two"), - ImmutableList.of(ImmutableList.of("1", "2"), ImmutableList.of("3")), - new Struct(STRUCT_SCHEMA).put("A", "VALUE"), - new BigDecimal("12345.6789"), - (long) INVALID_JAVA_IDENTIFIER_INDEX); - - @Rule - public final ExpectedException expectedException = ExpectedException.none(); - - private MutableMetaStore metaStore; - private CodeGenRunner codeGenRunner; - private final MutableFunctionRegistry functionRegistry = new InternalFunctionRegistry(); - private final KsqlConfig ksqlConfig = new KsqlConfig(Collections.emptyMap()); - - @Before - public void init() { - final KsqlScalarFunction whenCondition = KsqlScalarFunction.createLegacyBuiltIn( - SqlTypes.BOOLEAN, - ImmutableList.of(ParamTypes.BOOLEAN, ParamTypes.BOOLEAN), - FunctionName.of("WHENCONDITION"), - WhenCondition.class - ); - final KsqlScalarFunction whenResult = KsqlScalarFunction.createLegacyBuiltIn( - SqlTypes.INTEGER, - ImmutableList.of(ParamTypes.INTEGER, ParamTypes.BOOLEAN), - FunctionName.of("WHENRESULT"), - WhenResult.class - ); - functionRegistry.ensureFunctionFactory( - UdfLoaderUtil.createTestUdfFactory(whenCondition)); - functionRegistry.addFunction(whenCondition); - functionRegistry.ensureFunctionFactory( - UdfLoaderUtil.createTestUdfFactory(whenResult)); - functionRegistry.addFunction(whenResult); - metaStore = MetaStoreFixture.getNewMetaStore(functionRegistry); - // load substring function - UdfLoaderUtil.load(functionRegistry); - - final KsqlTopic ksqlTopic = new KsqlTopic( - "codegen_test", - KeyFormat.nonWindowed(FormatInfo.of(FormatFactory.KAFKA.name())), - ValueFormat.of(FormatInfo.of(FormatFactory.JSON.name())) - ); - - final KsqlStream ksqlStream = new KsqlStream<>( - "sqlexpression", - SourceName.of("CODEGEN_TEST"), - META_STORE_SCHEMA, - SerdeOption.none(), - KeyField.of(ColumnName.of("COL0")), - Optional.empty(), - false, - ksqlTopic - ); - - metaStore.putSource(ksqlStream); - - codeGenRunner = new CodeGenRunner(META_STORE_SCHEMA, ksqlConfig, functionRegistry); - } - - @Test - public void testNullEquals() { - assertThat(evalBooleanExprEq(INT32_INDEX1, INT64_INDEX1, new Object[]{null, 12344L}),is(false)); - assertThat(evalBooleanExprEq(INT32_INDEX1, INT64_INDEX1, new Object[]{null, null}), is(false)); - } - - @Test - public void testIsDistinctFrom() { - assertThat(evalBooleanExprIsDistinctFrom(INT32_INDEX1, INT64_INDEX1, new Object[]{12344, 12344L}), is(false)); - assertThat(evalBooleanExprIsDistinctFrom(INT32_INDEX1, INT64_INDEX1, new Object[]{12345, 12344L}), is(true)); - assertThat(evalBooleanExprIsDistinctFrom(INT32_INDEX1, INT64_INDEX1, new Object[]{null, 12344L}), is(true)); - assertThat(evalBooleanExprIsDistinctFrom(INT32_INDEX1, INT64_INDEX1, new Object[]{null, null}), is(false)); - } - - @Test - public void testIsNull() { - final String simpleQuery = "SELECT col0 IS NULL FROM CODEGEN_TEST EMIT CHANGES;"; - final ImmutableAnalysis analysis = analyzeQuery(simpleQuery, metaStore); - - final ExpressionMetadata expressionEvaluatorMetadata0 = codeGenRunner.buildCodeGenFromParseTree - (analysis.getSelectExpressions().get(0).getExpression(), "Select"); - - assertThat(expressionEvaluatorMetadata0.arguments(), hasSize(1)); - - Object result0 = expressionEvaluatorMetadata0.evaluate(genericRow(null, 1)); - assertThat(result0, is(true)); - - result0 = expressionEvaluatorMetadata0.evaluate(genericRow(12345L)); - assertThat(result0, is(false)); - } - - @Test - public void shouldHandleMultiDimensionalArray() { - // Given: - final String simpleQuery = "SELECT col14[1][1] FROM CODEGEN_TEST EMIT CHANGES;"; - final ImmutableAnalysis analysis = analyzeQuery(simpleQuery, metaStore); - - // When: - final Object result = codeGenRunner.buildCodeGenFromParseTree - (analysis.getSelectExpressions().get(0).getExpression(), "Select") - .evaluate(genericRow(ONE_ROW)); - - // Then: - assertThat(result, is("1")); - } - - @Test - public void testIsNotNull() { - final String simpleQuery = "SELECT col0 IS NOT NULL FROM CODEGEN_TEST EMIT CHANGES;"; - final ImmutableAnalysis analysis = analyzeQuery(simpleQuery, metaStore); - - final ExpressionMetadata expressionEvaluatorMetadata0 = - codeGenRunner.buildCodeGenFromParseTree( - analysis.getSelectExpressions().get(0).getExpression(), "Filter"); - - assertThat(expressionEvaluatorMetadata0.arguments(), hasSize(1)); - - Object result0 = expressionEvaluatorMetadata0.evaluate(genericRow(null, "1")); - assertThat(result0, is(false)); - - result0 = expressionEvaluatorMetadata0.evaluate(genericRow(12345L)); - assertThat(result0, is(true)); - } - - @Test - public void testBooleanExprScalarEq() { - // int32 - assertThat(evalBooleanExprEq(INT32_INDEX1, INT64_INDEX1, new Object[]{12345, 12344L}), is(false)); - assertThat(evalBooleanExprEq(INT32_INDEX1, INT64_INDEX1, new Object[]{12345, 12345L}), is(true)); - // int64 - assertThat(evalBooleanExprEq(INT64_INDEX2, INT32_INDEX1, new Object[]{12345L, 12344}), is(false)); - assertThat(evalBooleanExprEq(INT64_INDEX2, INT32_INDEX1, new Object[]{12345L, 12345}), is(true)); - // double - assertThat(evalBooleanExprEq(FLOAT64_INDEX2, FLOAT64_INDEX1, new Object[]{12345.0, 12344.0}), is(false)); - assertThat(evalBooleanExprEq(FLOAT64_INDEX2, FLOAT64_INDEX1, new Object[]{12345.0, 12345.0}), is(true)); - } - - @Test - public void testBooleanExprBooleanEq() { - assertThat(evalBooleanExprEq(BOOLEAN_INDEX2, BOOLEAN_INDEX1, new Object[]{false, true}), is(false)); - assertThat(evalBooleanExprEq(BOOLEAN_INDEX2, BOOLEAN_INDEX1, new Object[]{true, true}), is(true)); - } - - @Test - public void testBooleanExprStringEq() { - assertThat(evalBooleanExprEq(STRING_INDEX1, STRING_INDEX2, new Object[]{"abc", "def"}), is(false)); - assertThat(evalBooleanExprEq(STRING_INDEX1, STRING_INDEX2, new Object[]{"abc", "abc"}), is(true)); - } - - @Test - public void testBooleanExprArrayComparisonFails() { - // Given: - expectedException.expect(KsqlException.class); - expectedException.expectMessage("Code generation failed for Filter: " - + "Cannot compare ARRAY values. " - + "expression:(COL9 = COL10)"); - expectedException.expectCause(hasMessage(equalTo("Cannot compare ARRAY values"))); - - // When: - evalBooleanExprEq(ARRAY_INDEX1, ARRAY_INDEX2, - new Object[]{new Integer[]{1}, new Integer[]{1}}); - } - - @Test - public void testBooleanExprMapComparisonFails() { - // Given: - expectedException.expect(KsqlException.class); - expectedException.expectMessage("Code generation failed for Filter: " - + "Cannot compare MAP values. " - + "expression:(COL11 = COL12)"); - expectedException.expectCause(hasMessage(equalTo("Cannot compare MAP values"))); - - // When: - evalBooleanExprEq(MAP_INDEX1, MAP_INDEX2, - new Object[]{ImmutableMap.of(1, 2), ImmutableMap.of(1, 2)}); - } - - @Test - public void testBooleanExprScalarNeq() { - // int32 - assertThat(evalBooleanExprNeq(INT32_INDEX1, INT64_INDEX1, new Object[]{12345, 12344L}), is(true)); - assertThat(evalBooleanExprNeq(INT32_INDEX1, INT64_INDEX1, new Object[]{12345, 12345L}), is(false)); - // int64 - assertThat(evalBooleanExprNeq(INT64_INDEX2, INT32_INDEX1, new Object[]{12345L, 12344}), is(true)); - assertThat(evalBooleanExprNeq(INT64_INDEX2, INT32_INDEX1, new Object[]{12345L, 12345}), is(false)); - // double - assertThat(evalBooleanExprNeq(FLOAT64_INDEX2, FLOAT64_INDEX1, new Object[]{12345.0, 12344.0}), is(true)); - assertThat(evalBooleanExprNeq(FLOAT64_INDEX2, FLOAT64_INDEX1, new Object[]{12345.0, 12345.0}), is(false)); - } - - @Test - public void testBooleanExprBooleanNeq() { - assertThat(evalBooleanExprNeq(BOOLEAN_INDEX2, BOOLEAN_INDEX1, new Object[]{false, true}), is(true)); - assertThat(evalBooleanExprNeq(BOOLEAN_INDEX2, BOOLEAN_INDEX1, new Object[]{true, true}), is(false)); - } - - @Test - public void testBooleanExprStringNeq() { - assertThat(evalBooleanExprNeq(STRING_INDEX1, STRING_INDEX2, new Object[]{"abc", "def"}), is(true)); - assertThat(evalBooleanExprNeq(STRING_INDEX1, STRING_INDEX2, new Object[]{"abc", "abc"}), is(false)); - } - - @Test - public void testBooleanExprScalarLessThan() { - // int32 - assertThat(evalBooleanExprLessThan(INT32_INDEX1, INT64_INDEX1, new Object[]{12344, 12345L}), is(true)); - assertThat(evalBooleanExprLessThan(INT32_INDEX1, INT64_INDEX1, new Object[]{12346, 12345L}), is(false)); - // int64 - assertThat(evalBooleanExprLessThan(INT64_INDEX2, INT32_INDEX1, new Object[]{12344L, 12345}), is(true)); - assertThat(evalBooleanExprLessThan(INT64_INDEX2, INT32_INDEX1, new Object[]{12346L, 12345}), is(false)); - // double - assertThat(evalBooleanExprLessThan(FLOAT64_INDEX2, FLOAT64_INDEX1, new Object[]{12344.0, 12345.0}), is(true)); - assertThat(evalBooleanExprLessThan(FLOAT64_INDEX2, FLOAT64_INDEX1, new Object[]{12346.0, 12345.0}), is(false)); - } - - @Test - public void testBooleanExprStringLessThan() { - assertThat(evalBooleanExprLessThan(STRING_INDEX1, STRING_INDEX2, new Object[]{"abc", "def"}), is(true)); - assertThat(evalBooleanExprLessThan(STRING_INDEX1, STRING_INDEX2, new Object[]{"abc", "abc"}), is(false)); - } - - @Test - public void testBooleanExprScalarLessThanEq() { - // int32 - assertThat(evalBooleanExprLessThanEq(INT32_INDEX1, INT64_INDEX1, new Object[]{12345, 12345L}), is(true)); - assertThat(evalBooleanExprLessThanEq(INT32_INDEX1, INT64_INDEX1, new Object[]{12346, 12345L}), is(false)); - // int64 - assertThat(evalBooleanExprLessThanEq(INT64_INDEX2, INT32_INDEX1, new Object[]{12345L, 12345}), is(true)); - assertThat(evalBooleanExprLessThanEq(INT64_INDEX2, INT32_INDEX1, new Object[]{12346L, 12345}), is(false)); - // double - assertThat(evalBooleanExprLessThanEq(FLOAT64_INDEX2, FLOAT64_INDEX1, new Object[]{12344.0, 12345.0}), is(true)); - assertThat(evalBooleanExprLessThanEq(FLOAT64_INDEX2, FLOAT64_INDEX1, new Object[]{12346.0, 12345.0}), is(false)); - } - - @Test - public void testBooleanExprStringLessThanEq() { - assertThat(evalBooleanExprLessThanEq(STRING_INDEX1, STRING_INDEX2, new Object[]{"abc", "abc"}), is(true)); - assertThat(evalBooleanExprLessThanEq(STRING_INDEX1, STRING_INDEX2, new Object[]{"abc", "abb"}), is(false)); - } - - @Test - public void testBooleanExprScalarGreaterThan() { - // int32 - assertThat(evalBooleanExprGreaterThan(INT32_INDEX1, INT64_INDEX1, new Object[]{12346, 12345L}), is(true)); - assertThat(evalBooleanExprGreaterThan(INT32_INDEX1, INT64_INDEX1, new Object[]{12345, 12345L}), is(false)); - // int64 - assertThat(evalBooleanExprGreaterThan(INT64_INDEX2, INT32_INDEX1, new Object[]{12346L, 12345}), is(true)); - assertThat(evalBooleanExprGreaterThan(INT64_INDEX2, INT32_INDEX1, new Object[]{12345L, 12345}), is(false)); - // double - assertThat(evalBooleanExprGreaterThan(FLOAT64_INDEX2, FLOAT64_INDEX1, new Object[]{12346.0, 12345.0}), is(true)); - assertThat(evalBooleanExprGreaterThan(FLOAT64_INDEX2, FLOAT64_INDEX1, new Object[]{12344.0, 12345.0}), is(false)); - } - - @Test - public void testBooleanExprStringGreaterThan() { - assertThat(evalBooleanExprGreaterThan(STRING_INDEX1, STRING_INDEX2, new Object[]{"def", "abc"}), is(true)); - assertThat(evalBooleanExprGreaterThan(STRING_INDEX1, STRING_INDEX2, new Object[]{"abc", "abc"}), is(false)); - } - - @Test - public void testBooleanExprScalarGreaterThanEq() { - // int32 - assertThat(evalBooleanExprGreaterThanEq(INT32_INDEX1, INT64_INDEX1, new Object[]{12345, 12345L}), is(true)); - assertThat(evalBooleanExprGreaterThanEq(INT32_INDEX1, INT64_INDEX1, new Object[]{12344, 12345L}), is(false)); - // int64 - assertThat(evalBooleanExprGreaterThanEq(INT64_INDEX2, INT32_INDEX1, new Object[]{12345L, 12345}), is(true)); - assertThat(evalBooleanExprGreaterThanEq(INT64_INDEX2, INT32_INDEX1, new Object[]{12344L, 12345}), is(false)); - // double - assertThat(evalBooleanExprGreaterThanEq(FLOAT64_INDEX2, FLOAT64_INDEX1, new Object[]{12346.0, 12345.0}), is(true)); - assertThat(evalBooleanExprGreaterThanEq(FLOAT64_INDEX2, FLOAT64_INDEX1, new Object[]{12344.0, 12345.0}), is(false)); - } - - @Test - public void testBooleanExprStringGreaterThanEq() { - assertThat(evalBooleanExprGreaterThanEq(STRING_INDEX1, STRING_INDEX2, new Object[]{"def", "abc"}), is(true)); - assertThat(evalBooleanExprGreaterThanEq(STRING_INDEX1, STRING_INDEX2, new Object[]{"abc", "def"}), is(false)); - } - - @Test - public void testBetweenExprScalar() { - // int - assertThat(evalBetweenClauseScalar(INT32_INDEX1, 1, 0, 2), is(true)); - assertThat(evalBetweenClauseScalar(INT32_INDEX1, 0, 0, 2), is(true)); - assertThat(evalBetweenClauseScalar(INT32_INDEX1, 3, 0, 2), is(false)); - assertThat(evalBetweenClauseScalar(INT32_INDEX1, null, 0, 2), is(false)); - - // long - assertThat(evalBetweenClauseScalar(INT64_INDEX1, 12345L, 12344L, 12346L), is(true)); - assertThat(evalBetweenClauseScalar(INT64_INDEX1, 12344L, 12344L, 12346L), is(true)); - assertThat(evalBetweenClauseScalar(INT64_INDEX1, 12345L, 0, 2L), is(false)); - assertThat(evalBetweenClauseScalar(INT64_INDEX1, null, 0, 2L), is(false)); - - // double - assertThat(evalBetweenClauseScalar(FLOAT64_INDEX1, 1.0d, 0.1d, 1.9d), is(true)); - assertThat(evalBetweenClauseScalar(FLOAT64_INDEX1, 0.1d, 0.1d, 1.9d), is(true)); - assertThat(evalBetweenClauseScalar(FLOAT64_INDEX1, 2.0d, 0.1d, 1.9d), is(false)); - assertThat(evalBetweenClauseScalar(FLOAT64_INDEX1, null, 0.1d, 1.9d), is(false)); - - // decimal - assertThat(evalBetweenClauseScalar(DECIMAL_INDEX, new BigDecimal("1.0"), new BigDecimal("0.1"), new BigDecimal("1.9")), is(true)); - assertThat(evalBetweenClauseScalar(DECIMAL_INDEX, new BigDecimal("0.1"), new BigDecimal("0.1"), new BigDecimal("1.9")), is(true)); - assertThat(evalBetweenClauseScalar(DECIMAL_INDEX, new BigDecimal("2.0"), new BigDecimal("0.1"), new BigDecimal("1.9")), is(false)); - assertThat(evalBetweenClauseScalar(DECIMAL_INDEX, null, new BigDecimal("0.1"), new BigDecimal("1.9")), is(false)); - } - - @Test - public void testNotBetweenScalar() { - // int - assertThat(evalNotBetweenClauseScalar(INT32_INDEX1, 1, 0, 2), is(false)); - assertThat(evalNotBetweenClauseScalar(INT32_INDEX1, 0, 0, 2), is(false)); - assertThat(evalNotBetweenClauseScalar(INT32_INDEX1, 3, 0, 2), is(true)); - assertThat(evalNotBetweenClauseScalar(INT32_INDEX1, null, 0, 2), is(true)); - - // long - assertThat(evalNotBetweenClauseScalar(INT64_INDEX1, 12345L, 12344L, 12346L), is(false)); - assertThat(evalNotBetweenClauseScalar(INT64_INDEX1, 12344L, 12344L, 12346L), is(false)); - assertThat(evalNotBetweenClauseScalar(INT64_INDEX1, 12345L, 0, 2L), is(true)); - assertThat(evalNotBetweenClauseScalar(INT64_INDEX1, null, 0, 2L), is(true)); - - // double - assertThat(evalNotBetweenClauseScalar(FLOAT64_INDEX1, 1.0d, 0.1d, 1.9d), is(false)); - assertThat(evalNotBetweenClauseScalar(FLOAT64_INDEX1, 0.1d, 0.1d, 1.9d), is(false)); - assertThat(evalNotBetweenClauseScalar(FLOAT64_INDEX1, 2.0d, 0.1d, 1.9d), is(true)); - assertThat(evalNotBetweenClauseScalar(FLOAT64_INDEX1, null, 0.1d, 1.9d), is(true)); - - // decimal - assertThat(evalNotBetweenClauseScalar(DECIMAL_INDEX, new BigDecimal("1.0"), new BigDecimal("0.1"), new BigDecimal("1.9")), is(false)); - assertThat(evalNotBetweenClauseScalar(DECIMAL_INDEX, new BigDecimal("0.1"), new BigDecimal("0.1"), new BigDecimal("1.9")), is(false)); - assertThat(evalNotBetweenClauseScalar(DECIMAL_INDEX, new BigDecimal("2.0"), new BigDecimal("0.1"), new BigDecimal("1.9")), is(true)); - assertThat(evalNotBetweenClauseScalar(DECIMAL_INDEX, null, new BigDecimal("0.1"), new BigDecimal("1.9")), is(true)); - } - - @Test - public void testBetweenExprString() { - // constants - assertThat(evalBetweenClauseString(STRING_INDEX1, "b", "'a'", "'c'"), is(true)); - assertThat(evalBetweenClauseString(STRING_INDEX1, "a", "'a'", "'c'"), is(true)); - assertThat(evalBetweenClauseString(STRING_INDEX1, "d", "'a'", "'c'"), is(false)); - assertThat(evalBetweenClauseString(STRING_INDEX1, null, "'a'", "'c'"), is(false)); - - // columns - assertThat(evalBetweenClauseString(STRING_INDEX1, "S2", "col" + STRING_INDEX2, "'S3'"), is(true)); - assertThat(evalBetweenClauseString(STRING_INDEX1, "S3", "col" + STRING_INDEX2, "'S3'"), is(true)); - assertThat(evalBetweenClauseString(STRING_INDEX1, "S4", "col" + STRING_INDEX2, "'S3'"), is(false)); - assertThat(evalBetweenClauseString(STRING_INDEX1, null, "col" + STRING_INDEX2, "'S3'"), is(false)); - } - - @Test - public void testNotBetweenExprString() { - // constants - assertThat(evalNotBetweenClauseString(STRING_INDEX1, "b", "'a'", "'c'"), is(false)); - assertThat(evalNotBetweenClauseString(STRING_INDEX1, "a", "'a'", "'c'"), is(false)); - assertThat(evalNotBetweenClauseString(STRING_INDEX1, "d", "'a'", "'c'"), is(true)); - assertThat(evalNotBetweenClauseString(STRING_INDEX1, null, "'a'", "'c'"), is(true)); - - // columns - assertThat(evalNotBetweenClauseString(STRING_INDEX1, "S2", "col" + STRING_INDEX2, "'S3'"), is(false)); - assertThat(evalNotBetweenClauseString(STRING_INDEX1, "S3", "col" + STRING_INDEX2, "'S3'"), is(false)); - assertThat(evalNotBetweenClauseString(STRING_INDEX1, "S4", "col" + STRING_INDEX2, "'S3'"), is(true)); - assertThat(evalNotBetweenClauseString(STRING_INDEX1, null, "col" + STRING_INDEX2, "'S3'"), is(true)); - } - - @Test - public void testInvalidBetweenArrayValue() { - // Given: - expectedException.expect(KsqlException.class); - expectedException.expectMessage("Code generation failed for Filter:"); - expectedException.expectMessage("expression:(NOT (COL9 BETWEEN 'a' AND 'c'))"); - expectedException.expectCause(hasMessage( - equalTo("Cannot compare ARRAY values"))); - - // When: - evalNotBetweenClauseObject(ARRAY_INDEX1, new Object[]{1, 2}, "'a'", "'c'"); - } - - @Test - public void testInvalidBetweenMapValue() { - // Given: - expectedException.expect(KsqlException.class); - expectedException.expectMessage("Code generation failed for Filter: "); - expectedException.expectMessage("expression:(NOT (COL11 BETWEEN 'a' AND 'c'))"); - expectedException.expectCause(hasMessage( - equalTo("Cannot compare MAP values"))); - - // When: - evalNotBetweenClauseObject(MAP_INDEX1, ImmutableMap.of(1, 2), "'a'", "'c'"); - } - - @Test - public void testInvalidBetweenBooleanValue() { - // Given: - expectedException.expect(KsqlException.class); - expectedException.expectMessage("Code generation failed for Filter: "); - expectedException.expectMessage("expression:(NOT (COL6 BETWEEN 'a' AND 'c'))"); - expectedException.expectCause(hasMessage( - equalTo("Unexpected boolean comparison: >="))); - - // When: - evalNotBetweenClauseObject(BOOLEAN_INDEX1, true, "'a'", "'c'"); - } - - @Test - public void shouldHandleArithmeticExpr() { - // Given: - final String query = - "SELECT col0+col3, col3+10, col0*25, 12*4+2 FROM codegen_test WHERE col0 > 100 EMIT CHANGES;"; - - final Map inputValues = ImmutableMap.of(0, 5L, 3, 15.0); - - // When: - final List columns = executeExpression(query, inputValues); - - // Then: - assertThat(columns, contains(20.0, 25.0, 125L, 50)); - } - - @Test - public void testCastNumericArithmeticExpressions() { - final Map inputValues = - ImmutableMap.of(0, 1L, 3, 3.0D, 4, 4.0D, 5, 5); - - // INT - BIGINT - assertThat(executeExpression( - "SELECT " - + "CAST((col5 - col0) AS INTEGER)," - + "CAST((col5 - col0) AS BIGINT)," - + "CAST((col5 - col0) AS DOUBLE)," - + "CAST((col5 - col0) AS STRING)" - + "FROM codegen_test EMIT CHANGES;", - inputValues), contains(4, 4L, 4.0, "4")); - - // DOUBLE - DOUBLE - assertThat(executeExpression( - "SELECT " - + "CAST((col4 - col3) AS INTEGER)," - + "CAST((col4 - col3) AS BIGINT)," - + "CAST((col4 - col3) AS DOUBLE)," - + "CAST((col4 - col3) AS STRING)" - + "FROM codegen_test EMIT CHANGES;", - inputValues), contains(1, 1L, 1.0, "1.0")); - - // DOUBLE - INT - assertThat(executeExpression( - "SELECT " - + "CAST((col4 - col0) AS INTEGER)," - + "CAST((col4 - col0) AS BIGINT)," - + "CAST((col4 - col0) AS DOUBLE)," - + "CAST((col4 - col0) AS STRING)" - + "FROM codegen_test EMIT CHANGES;", - inputValues), contains(3, 3L, 3.0, "3.0")); - } - - @Test - public void shouldHandleStringLiteralWithCharactersThatMustBeEscaped() { - // Given: - final String query = "SELECT CONCAT(CONCAT('\\\"', 'foo'), '\\\"') FROM CODEGEN_TEST EMIT CHANGES;"; - - // When: - final List columns = executeExpression(query, Collections.emptyMap()); - - // Then: - assertThat(columns, contains("\\\"foo\\\"")); - } - - @Test - public void shouldHandleMathUdfs() { - // Given: - final String query = - "SELECT FLOOR(col3), CEIL(col3*3), ABS(col0+1.34E0), ROUND(col3*2)+12 FROM codegen_test EMIT CHANGES;"; - - final Map inputValues = ImmutableMap.of(0, 15L, 3, 1.5); - - // When: - final List columns = executeExpression(query, inputValues); - - // Then: - assertThat(columns, contains(1.0, 5.0, 16.34, 15L)); - } - - @Test - public void shouldHandleRandomUdf() { - // Given: - final String query = "SELECT RANDOM()+10, RANDOM()+col0 FROM codegen_test EMIT CHANGES;"; - final Map inputValues = ImmutableMap.of(0, 15L); - - // When: - final List columns = executeExpression(query, inputValues); - - // Then: - assertThat(columns.get(0), is(instanceOf(Double.class))); - assertThat((Double)columns.get(0), - is(both(greaterThanOrEqualTo(10.0)).and(lessThanOrEqualTo(11.0)))); - - assertThat(columns.get(1), is(instanceOf(Double.class))); - assertThat((Double)columns.get(1), - is(both(greaterThanOrEqualTo(15.0)).and(lessThanOrEqualTo(16.0)))); - } - - @Test - public void shouldHandleStringUdfs() { - // Given: - final String query = - "SELECT LCASE(col1), UCASE(col1), TRIM(col1), CONCAT(col1,'_test'), SUBSTRING(col1, 2, 4)" - + " FROM codegen_test EMIT CHANGES;"; - - final Map inputValues = ImmutableMap.of(1, " Hello "); - - // When: - final List columns = executeExpression(query, inputValues); - - // Then: - assertThat(columns, contains(" hello ", " HELLO ", "Hello", " Hello _test", "Hell")); - } - - @Test - public void shouldHandleNestedUdfs() { - final String query = - "SELECT " - + "CONCAT(EXTRACTJSONFIELD(col1,'$.name'),CONCAT('-',EXTRACTJSONFIELD(col1,'$.value')))" - + " FROM codegen_test EMIT CHANGES;"; - - final Map inputValues = ImmutableMap.of(1, "{\"name\":\"fred\",\"value\":1}"); - - // When: - executeExpression(query, inputValues); - } - - @Test - public void shouldHandleMaps() { - // Given: - final Expression expression = analyzeQuery( - "SELECT col11['key1'] as Address FROM codegen_test EMIT CHANGES;", metaStore) - .getSelectExpressions() - .get(0) - .getExpression(); - - // When: - final Object result = codeGenRunner - .buildCodeGenFromParseTree(expression, "Group By") - .evaluate(genericRow(ONE_ROW)); - - // Then: - assertThat(result, is("value1")); - } - - @Test - public void shouldHandleCreateArray() { - // Given: - final Expression expression = analyzeQuery( - "SELECT ARRAY['foo', COL" + STRING_INDEX1 + "] FROM codegen_test EMIT CHANGES;", metaStore) - .getSelectExpressions() - .get(0) - .getExpression(); - - // When: - final Object result = codeGenRunner - .buildCodeGenFromParseTree(expression, "Array") - .evaluate(genericRow(ONE_ROW)); - - // Then: - assertThat(result, is(ImmutableList.of("foo", "S1"))); - } - - @Test - public void shouldHandleCreateMap() { - // Given: - final Expression expression = analyzeQuery( - "SELECT MAP('foo' := 'foo', 'bar' := COL" + STRING_INDEX1 + ") FROM codegen_test EMIT CHANGES;", metaStore) - .getSelectExpressions() - .get(0) - .getExpression(); - - // When: - final Object result = codeGenRunner - .buildCodeGenFromParseTree(expression, "Map") - .evaluate(genericRow(ONE_ROW)); - - // Then: - assertThat(result, is(ImmutableMap.of("foo", "foo", "bar", "S1"))); - } - - @Test - public void shouldHandleInvalidJavaIdentifiers() { - // Given: - final Expression expression = analyzeQuery( - "SELECT `" + COL_INVALID_JAVA + "` FROM codegen_test EMIT CHANGES;", - metaStore) - .getSelectExpressions() - .get(0) - .getExpression(); - - // When: - final Object result = codeGenRunner - .buildCodeGenFromParseTree(expression, "math") - .evaluate(genericRow(ONE_ROW)); - - // Then: - assertThat(result, is((long) INVALID_JAVA_IDENTIFIER_INDEX)); - } - - @Test - public void shouldHandleCaseStatement() { - // Given: - final Expression expression = analyzeQuery( - "SELECT CASE " - + " WHEN col0 < 10 THEN 'small' " - + " WHEN col0 < 100 THEN 'medium' " - + " ELSE 'large' " - + "END " - + "FROM codegen_test EMIT CHANGES;", metaStore) - .getSelectExpressions() - .get(0) - .getExpression(); - - // When: - final Object result = codeGenRunner - .buildCodeGenFromParseTree(expression, "Case") - .evaluate(genericRow(ONE_ROW)); - - // Then: - assertThat(result, is("small")); - } - - @Test - public void shouldHandleCaseStatementLazily() { - // Given: - final Expression expression = analyzeQuery( - "SELECT CASE " - + " WHEN WHENCONDITION(true, true) THEN WHENRESULT(100, true) " - + " WHEN WHENCONDITION(true, false) THEN WHENRESULT(200, false) " - + " ELSE WHENRESULT(300, false) " - + "END " - + "FROM codegen_test EMIT CHANGES;", metaStore) - .getSelectExpressions() - .get(0) - .getExpression(); - - // When: - final Object result = codeGenRunner - .buildCodeGenFromParseTree(expression, "Case") - .evaluate(genericRow(ONE_ROW)); - - // Then: - assertThat(result, is(100)); - } - - @Test - public void shouldOnlyRunElseIfNoMatchInWhen() { - // Given: - final Expression expression = analyzeQuery( - "SELECT CASE " - + " WHEN WHENCONDITION(false, true) THEN WHENRESULT(100, false) " - + " WHEN WHENCONDITION(false, true) THEN WHENRESULT(200, false) " - + " ELSE WHENRESULT(300, true) " - + "END " - + "FROM codegen_test EMIT CHANGES;", metaStore) - .getSelectExpressions() - .get(0) - .getExpression(); - - // When: - final Object result = codeGenRunner - .buildCodeGenFromParseTree(expression, "Case") - .evaluate(genericRow(ONE_ROW)); - - // Then: - assertThat(result, is(300)); - } - - @Test - public void shouldReturnDefaultForCaseCorrectly() { - // Given: - final Expression expression = analyzeQuery( - "SELECT CASE " - + " WHEN col0 > 10 THEN 'small' " - + " ELSE 'large' " - + "END " - + "FROM codegen_test EMIT CHANGES;", metaStore) - .getSelectExpressions() - .get(0) - .getExpression(); - - // When: - final Object result = codeGenRunner - .buildCodeGenFromParseTree(expression, "Case") - .evaluate(genericRow(ONE_ROW)); - - // Then: - assertThat(result, is("large")); - } - - @Test - public void shouldReturnNullForCaseIfNoDefault() { - // Given: - final Expression expression = analyzeQuery( - "SELECT CASE " - + " WHEN col0 > 10 THEN 'small' " - + "END " - + "FROM codegen_test EMIT CHANGES;", metaStore) - .getSelectExpressions() - .get(0) - .getExpression(); - - // When: - final Object result = codeGenRunner - .buildCodeGenFromParseTree(expression, "Case") - .evaluate(genericRow(ONE_ROW)); - - // Then: - assertThat(result, is(nullValue())); - } - - - @Test - public void shouldHandleUdfsExtractingFromMaps() { - // Given: - final Expression expression = analyzeQuery( - "SELECT EXTRACTJSONFIELD(col11['address'], '$.city') FROM codegen_test EMIT CHANGES;", - metaStore) - .getSelectExpressions() - .get(0) - .getExpression(); - - // When: - final Object result = codeGenRunner - .buildCodeGenFromParseTree(expression, "Select") - .evaluate(genericRow(ONE_ROW)); - - // Then: - assertThat(result, is("adelaide")); - } - - @Test - public void shouldHandleFunctionWithNullArgument() { - final String query = - "SELECT test_udf(col0, NULL) FROM codegen_test EMIT CHANGES;"; - - final Map inputValues = ImmutableMap.of(0, 0L); - final List columns = executeExpression(query, inputValues); - // test - assertThat(columns, equalTo(Collections.singletonList("doStuffLongString"))); - } - - @Test - public void shouldHandleFunctionWithVarargs() { - final String query = - "SELECT test_udf(col0, col0, col0, col0, col0) FROM codegen_test EMIT CHANGES;"; - - final Map inputValues = ImmutableMap.of(0, 0L); - final List columns = executeExpression(query, inputValues); - // test - assertThat(columns, equalTo(Collections.singletonList("doStuffLongVarargs"))); - } - - @Test - public void shouldHandleFunctionWithStruct() { - // Given: - final String query = - "SELECT test_udf(col" + STRUCT_INDEX + ") FROM codegen_test EMIT CHANGES;"; - - // When: - final List columns = executeExpression(query, ImmutableMap.of()); - - // Then: - assertThat(columns, equalTo(Collections.singletonList("VALUE"))); - } - - @Test - public void shouldChoseFunctionWithCorrectNumberOfArgsWhenNullArgument() { - final String query = - "SELECT test_udf(col0, col0, NULL) FROM codegen_test EMIT CHANGES;"; - - final Map inputValues = ImmutableMap.of(0, 0L); - final List columns = executeExpression(query, inputValues); - // test - assertThat(columns, equalTo(Collections.singletonList("doStuffLongLongString"))); - } - - private List executeExpression(final String query, - final Map inputValues) { - final ImmutableAnalysis analysis = analyzeQuery(query, metaStore); - - final GenericRow input = buildRow(inputValues); - - return analysis.getSelectExpressions().stream() - .map(exp -> codeGenRunner.buildCodeGenFromParseTree(exp.getExpression(), "Select")) - .map(md -> md.evaluate(input)) - .collect(Collectors.toList()); - } - - private boolean evalBooleanExprEq(final int cola, final int colb, final Object[] values) { - return evalBooleanExpr("SELECT col%d = col%d FROM CODEGEN_TEST EMIT CHANGES;", cola, colb, values); - } - - private boolean evalBooleanExprNeq(final int cola, final int colb, final Object[] values) { - return evalBooleanExpr("SELECT col%d != col%d FROM CODEGEN_TEST EMIT CHANGES;", cola, colb, values); - } - - private boolean evalBooleanExprIsDistinctFrom(final int cola, final int colb, - final Object[] values) { - return evalBooleanExpr("SELECT col%d IS DISTINCT FROM col%d FROM CODEGEN_TEST EMIT CHANGES;", cola, colb, values); - } - - private boolean evalBooleanExprLessThan(final int cola, final int colb, final Object[] values) { - return evalBooleanExpr("SELECT col%d < col%d FROM CODEGEN_TEST EMIT CHANGES;", cola, colb, values); - } - - private boolean evalBooleanExprLessThanEq(final int cola, final int colb, - final Object[] values) { - return evalBooleanExpr("SELECT col%d <= col%d FROM CODEGEN_TEST EMIT CHANGES;", cola, colb, values); - } - - private boolean evalBooleanExprGreaterThan(final int cola, final int colb, - final Object[] values) { - return evalBooleanExpr("SELECT col%d > col%d FROM CODEGEN_TEST EMIT CHANGES;", cola, colb, values); - } - - private boolean evalBooleanExprGreaterThanEq(final int cola, final int colb, - final Object[] values) { - return evalBooleanExpr("SELECT col%d >= col%d FROM CODEGEN_TEST EMIT CHANGES;", cola, colb, values); - } - - private boolean evalBooleanExpr( - final String queryFormat, final int cola, final int colb, final Object[] values) { - final String simpleQuery = String.format(queryFormat, cola, colb); - final ImmutableAnalysis analysis = analyzeQuery(simpleQuery, metaStore); - - final ExpressionMetadata expressionEvaluatorMetadata0 = codeGenRunner.buildCodeGenFromParseTree - (analysis.getSelectExpressions().get(0).getExpression(), "Filter"); - - assertThat(expressionEvaluatorMetadata0.arguments(), hasSize(2)); - - final List columns = new ArrayList<>(ONE_ROW); - columns.set(cola, values[0]); - columns.set(colb, values[1]); - - final Object result0 = expressionEvaluatorMetadata0.evaluate(genericRow(columns)); - assertThat(result0, instanceOf(Boolean.class)); - return (Boolean)result0; - } - - private boolean evalBetweenClauseScalar(final int col, final Number val, final Number min, final Number max) { - final String simpleQuery; - if (val instanceof Double) { - simpleQuery = String.format("SELECT * FROM CODEGEN_TEST WHERE col%d BETWEEN %e AND %E EMIT CHANGES;", col, min, max); - } else { - simpleQuery = String.format("SELECT * FROM CODEGEN_TEST WHERE col%d BETWEEN %s AND %s EMIT CHANGES;", col, min.toString(), max.toString()); - } - - return evalBetweenClause(simpleQuery, col, val); - } - - private boolean evalNotBetweenClauseScalar(final int col, final Number val, final Number min, final Number max) { - final String simpleQuery; - if (val instanceof Double) { - simpleQuery = String.format("SELECT * FROM CODEGEN_TEST WHERE col%d NOT BETWEEN %e AND %E EMIT CHANGES;", col, min, max); - } else { - simpleQuery = String.format("SELECT * FROM CODEGEN_TEST WHERE col%d NOT BETWEEN %s AND %s EMIT CHANGES;", col, min.toString(), max.toString()); - } - - return evalBetweenClause(simpleQuery, col, val); - } - - private boolean evalBetweenClauseString(final int col, final String val, final String min, final String max) { - final String simpleQuery = String.format("SELECT * FROM CODEGEN_TEST WHERE col%d BETWEEN %s AND %s EMIT CHANGES;", col, min, max); - return evalBetweenClause(simpleQuery, col, val); - } - - private boolean evalNotBetweenClauseString(final int col, final String val, final String min, final String max) { - final String simpleQuery = String.format("SELECT * FROM CODEGEN_TEST WHERE col%d NOT BETWEEN %s AND %s EMIT CHANGES;", col, min, max); - return evalBetweenClause(simpleQuery, col, val); - } - - private void evalNotBetweenClauseObject(final int col, final Object val, final String min, final String max) { - final String simpleQuery = String.format("SELECT * FROM CODEGEN_TEST WHERE col%d NOT BETWEEN %s AND %s EMIT CHANGES;", col, min, max); - evalBetweenClause(simpleQuery, col, val); - } - - private boolean evalBetweenClause(final String simpleQuery, final int col, final Object val) { - final ImmutableAnalysis analysis = analyzeQuery(simpleQuery, metaStore); - - final ExpressionMetadata expressionEvaluatorMetadata0 = codeGenRunner - .buildCodeGenFromParseTree(analysis.getWhereExpression().get(), "Filter"); - - final List columns = new ArrayList<>(ONE_ROW); - columns.set(col, val); - - final Object result0 = expressionEvaluatorMetadata0.evaluate(genericRow(columns)); - assertThat(result0, instanceOf(Boolean.class)); - return (Boolean)result0; - } - - private static GenericRow buildRow(final Map overrides) { - final List columns = new ArrayList<>(ONE_ROW); - overrides.forEach(columns::set); - return genericRow(columns); - } - - private static GenericRow genericRow(final Object... columns) { - return GenericRow.genericRow(columns); - } - - private static GenericRow genericRow(final List columns) { - return new GenericRow().appendAll(columns); - } - - public static final class WhenCondition implements Kudf { - - @Override - public Object evaluate(final Object... args) { - final boolean shouldBeEvaluated = (boolean) args[1]; - if (!shouldBeEvaluated) { - throw new KsqlException("When condition in case is not running lazily!"); - } - return args[0]; - } - } - - public static final class WhenResult implements Kudf { - @Override - public Object evaluate(final Object... args) { - final boolean shouldBeEvaluated = (boolean) args[1]; - if (!shouldBeEvaluated) { - throw new KsqlException("Then expression in case is not running lazily!"); - } - return args[0]; - } - } - - private ImmutableAnalysis analyzeQuery(final String query, final MetaStore metaStore) { - final Analysis analysis = AnalysisTestUtil.analyzeQuery(query, metaStore); - return new RewrittenAnalysis( - analysis, - new VisitParentExpressionVisitor, Context>(Optional.empty()) { - @Override - public Optional visitQualifiedColumnReference( - final QualifiedColumnReferenceExp node, - final Context ctx - ) { - return Optional.of(new UnqualifiedColumnReferenceExp(node.getColumnName())); - } - }::process - ); - } -} diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/function/udf/WhenCondition.java b/ksqldb-engine/src/test/java/io/confluent/ksql/function/udf/WhenCondition.java new file mode 100644 index 000000000000..8cc462d41535 --- /dev/null +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/function/udf/WhenCondition.java @@ -0,0 +1,31 @@ +/* + * Copyright 2020 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.function.udf; + +import io.confluent.ksql.util.KsqlException; + +@SuppressWarnings({"unused", "MethodMayBeStatic"}) // Invoked via reflection +@UdfDescription(name = "WhenCondition", description = "UDF used in case-expression.json") +public class WhenCondition { + + @Udf + public boolean evaluate(final boolean retValue, final boolean shouldBeEvaluated) { + if (!shouldBeEvaluated) { + throw new KsqlException("When condition in case is not running lazily!"); + } + return retValue; + } +} \ No newline at end of file diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/function/udf/WhenResult.java b/ksqldb-engine/src/test/java/io/confluent/ksql/function/udf/WhenResult.java new file mode 100644 index 000000000000..23297625bcf1 --- /dev/null +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/function/udf/WhenResult.java @@ -0,0 +1,31 @@ +/* + * Copyright 2020 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.ksql.function.udf; + +import io.confluent.ksql.util.KsqlException; + +@SuppressWarnings({"unused", "MethodMayBeStatic"}) // Invoked via reflection +@UdfDescription(name="WhenResult", description = "UDF used in case-expression.json") +public class WhenResult { + + @Udf + public int evaluate(final int retVal, final boolean shouldBeEvaluated) { + if (!shouldBeEvaluated) { + throw new KsqlException("Then expression in case is not running lazily!"); + } + return retVal; + } +} \ No newline at end of file diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/DataSourceNodeTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/DataSourceNodeTest.java index 719f744d376a..76ffb03e4e92 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/DataSourceNodeTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/DataSourceNodeTest.java @@ -189,8 +189,7 @@ public void before() { node = new DataSourceNode( PLAN_NODE_ID, SOME_SOURCE, - SOME_SOURCE.getName(), - Collections.emptyList() + SOME_SOURCE.getName() ); } @@ -256,8 +255,7 @@ public void shouldBuildSchemaKTableWhenKTableSource() { node = new DataSourceNode( PLAN_NODE_ID, table, - table.getName(), - Collections.emptyList() + table.getName() ); // When: @@ -378,7 +376,6 @@ private void givenNodeWithMockSource() { PLAN_NODE_ID, dataSource, SOURCE_NAME, - Collections.emptyList(), schemaKStreamFactory ); } diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/JoinNodeTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/JoinNodeTest.java index 566dfc9d4dfd..faba1cc81c25 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/JoinNodeTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/JoinNodeTest.java @@ -202,7 +202,6 @@ public void shouldReturnLeftJoinKeyAsKeyField() { // When: final JoinNode joinNode = new JoinNode( nodeId, - Collections.emptyList(), JoinType.LEFT, left, right, @@ -264,7 +263,6 @@ public void shouldPerformStreamToStreamLeftJoin() { final JoinNode joinNode = new JoinNode( nodeId, - Collections.emptyList(), JoinNode.JoinType.LEFT, left, right, @@ -293,7 +291,6 @@ public void shouldPerformStreamToStreamInnerJoin() { final JoinNode joinNode = new JoinNode( nodeId, - Collections.emptyList(), JoinNode.JoinType.INNER, left, right, @@ -322,7 +319,6 @@ public void shouldPerformStreamToStreamOuterJoin() { final JoinNode joinNode = new JoinNode( nodeId, - Collections.emptyList(), JoinNode.JoinType.OUTER, left, right, @@ -351,7 +347,6 @@ public void shouldNotPerformStreamStreamJoinWithoutJoinWindow() { final JoinNode joinNode = new JoinNode( nodeId, - Collections.emptyList(), JoinNode.JoinType.INNER, left, right, @@ -379,7 +374,6 @@ public void shouldNotPerformJoinIfInputPartitionsMisMatch() { final JoinNode joinNode = new JoinNode( nodeId, - Collections.emptyList(), JoinNode.JoinType.OUTER, left, right, @@ -404,7 +398,6 @@ public void shouldHandleJoinIfTableHasNoKeyAndJoinFieldIsRowKey() { final JoinNode joinNode = new JoinNode( nodeId, - Collections.emptyList(), JoinNode.JoinType.LEFT, left, right, @@ -431,7 +424,6 @@ public void shouldPerformStreamToTableLeftJoin() { final JoinNode joinNode = new JoinNode( nodeId, - Collections.emptyList(), JoinNode.JoinType.LEFT, left, right, @@ -458,7 +450,6 @@ public void shouldPerformStreamToTableInnerJoin() { final JoinNode joinNode = new JoinNode( nodeId, - Collections.emptyList(), JoinNode.JoinType.INNER, left, right, @@ -485,7 +476,6 @@ public void shouldNotAllowStreamToTableOuterJoin() { final JoinNode joinNode = new JoinNode( nodeId, - Collections.emptyList(), JoinNode.JoinType.OUTER, left, right, @@ -512,7 +502,6 @@ public void shouldNotPerformStreamToTableJoinIfJoinWindowIsSpecified() { final JoinNode joinNode = new JoinNode( nodeId, - Collections.emptyList(), JoinNode.JoinType.OUTER, left, right, @@ -537,7 +526,6 @@ public void shouldPerformTableToTableInnerJoin() { final JoinNode joinNode = new JoinNode( nodeId, - Collections.emptyList(), JoinNode.JoinType.INNER, left, right, @@ -562,7 +550,6 @@ public void shouldPerformTableToTableLeftJoin() { final JoinNode joinNode = new JoinNode( nodeId, - Collections.emptyList(), JoinNode.JoinType.LEFT, left, right, @@ -587,7 +574,6 @@ public void shouldPerformTableToTableOuterJoin() { final JoinNode joinNode = new JoinNode( nodeId, - Collections.emptyList(), JoinNode.JoinType.OUTER, left, right, @@ -614,7 +600,6 @@ public void shouldNotPerformTableToTableJoinIfJoinWindowIsSpecified() { final JoinNode joinNode = new JoinNode( nodeId, - Collections.emptyList(), JoinNode.JoinType.OUTER, left, right, @@ -636,7 +621,6 @@ public void shouldHaveFullyQualifiedJoinSchema() { // When: final JoinNode joinNode = new JoinNode( nodeId, - Collections.emptyList(), JoinNode.JoinType.OUTER, left, right, @@ -667,7 +651,6 @@ public void shouldNotUseSourceSerdeOptionsForInternalTopics() { final JoinNode joinNode = new JoinNode( nodeId, - Collections.emptyList(), JoinNode.JoinType.LEFT, left, right, @@ -687,7 +670,6 @@ public void shouldReturnCorrectSchema() { // When: final JoinNode joinNode = new JoinNode( nodeId, - Collections.emptyList(), JoinNode.JoinType.LEFT, left, right, diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/KsqlStructuredDataOutputNodeTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/KsqlStructuredDataOutputNodeTest.java index 4739c33b0eaa..f8e9a0f9bcd0 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/KsqlStructuredDataOutputNodeTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/planner/plan/KsqlStructuredDataOutputNodeTest.java @@ -109,6 +109,7 @@ public void before() { when(queryIdGenerator.getNext()).thenReturn(QUERY_ID_VALUE); + when(sourceNode.getSchema()).thenReturn(LogicalSchema.builder().build()); when(sourceNode.getNodeOutputType()).thenReturn(DataSourceType.KSTREAM); when(sourceNode.buildStream(ksqlStreamBuilder)).thenReturn((SchemaKStream) sourceStream); @@ -127,10 +128,13 @@ public void before() { @Test public void shouldThrowIfSelectExpressionsHaveSameNameAsAnyKeyColumn() { // Given: - givenSourceSelectExpressions( - selectExpression("field1"), - selectExpression("k0"), - selectExpression("field2") + givenSourceSchema( + LogicalSchema.builder() + .keyColumn(ColumnName.of("k0"), SqlTypes.INTEGER) + .valueColumn(ColumnName.of("field1"), SqlTypes.STRING) + .valueColumn(ColumnName.of("k0"), SqlTypes.STRING) + .valueColumn(ColumnName.of("field2"), SqlTypes.STRING) + .build() ); // Expect: @@ -244,9 +248,9 @@ private void buildNode() { SourceName.of(PLAN_NODE_ID.toString())); } - private void givenSourceSelectExpressions(final SelectExpression... selectExpressions) { - when(sourceNode.getSelectExpressions()) - .thenReturn(ImmutableList.copyOf(selectExpressions)); + private void givenSourceSchema(final LogicalSchema schema) { + when(sourceNode.getSchema()) + .thenReturn(schema); } private static SelectExpression selectExpression(final String alias) { diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/testutils/AnalysisTestUtil.java b/ksqldb-engine/src/test/java/io/confluent/ksql/testutils/AnalysisTestUtil.java index 7094b93d1053..67b1e4b7dd96 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/testutils/AnalysisTestUtil.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/testutils/AnalysisTestUtil.java @@ -18,7 +18,6 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.hasSize; -import io.confluent.ksql.analyzer.AggregateAnalysisResult; import io.confluent.ksql.analyzer.Analysis; import io.confluent.ksql.analyzer.QueryAnalyzer; import io.confluent.ksql.metastore.MetaStore; @@ -40,10 +39,6 @@ public final class AnalysisTestUtil { private AnalysisTestUtil() { } - public static Analysis analyzeQuery(final String queryStr, final MetaStore metaStore) { - return new Analyzer(queryStr, metaStore).analysis; - } - public static OutputNode buildLogicalPlan( final KsqlConfig ksqlConfig, final String queryStr, @@ -51,26 +46,22 @@ public static OutputNode buildLogicalPlan( ) { final Analyzer analyzer = new Analyzer(queryStr, metaStore); - final LogicalPlanner logicalPlanner = new LogicalPlanner( - ksqlConfig, - analyzer.analysis, - analyzer.aggregateAnalysis(), - metaStore); + final LogicalPlanner logicalPlanner = + new LogicalPlanner(ksqlConfig, analyzer.analysis, metaStore); return logicalPlanner.buildPlan(); } private static class Analyzer { - private final Query query; + private final Analysis analysis; - private final QueryAnalyzer queryAnalyzer; private Analyzer(final String queryStr, final MetaStore metaStore) { - this.queryAnalyzer = new QueryAnalyzer(metaStore, "", SerdeOption.none()); + final QueryAnalyzer queryAnalyzer = new QueryAnalyzer(metaStore, "", SerdeOption.none()); final Statement statement = parseStatement(queryStr, metaStore); - this.query = statement instanceof QueryContainer - ? ((QueryContainer)statement).getQuery() - : (Query) statement; + final Query query = statement instanceof QueryContainer + ? ((QueryContainer) statement).getQuery() + : (Query) statement; final Optional sink = statement instanceof QueryContainer ? Optional.of(((QueryContainer)statement).getSink()) @@ -88,9 +79,5 @@ private static Statement parseStatement( assertThat(statements, hasSize(1)); return statements.get(0).getStatement(); } - - AggregateAnalysisResult aggregateAnalysis() { - return queryAnalyzer.analyzeAggregate(query, analysis); - } } } diff --git a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ComparisonUtil.java b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ComparisonUtil.java index a1ecc4e3fe2b..1d66d6ce87ed 100644 --- a/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ComparisonUtil.java +++ b/ksqldb-execution/src/main/java/io/confluent/ksql/execution/util/ComparisonUtil.java @@ -15,38 +15,104 @@ package io.confluent.ksql.execution.util; +import static java.util.Objects.requireNonNull; + +import com.google.common.collect.ImmutableList; import io.confluent.ksql.execution.expression.tree.ComparisonExpression; +import io.confluent.ksql.execution.expression.tree.ComparisonExpression.Type; import io.confluent.ksql.schema.ksql.SqlBaseType; import io.confluent.ksql.schema.ksql.types.SqlType; import io.confluent.ksql.util.KsqlException; +import java.util.List; +import java.util.function.BiPredicate; +import java.util.function.Predicate; final class ComparisonUtil { - private ComparisonUtil() { + private static final List HANDLERS = ImmutableList.builder() + .add(handler(SqlBaseType::isNumber, ComparisonUtil::handleNumber)) + .add(handler(SqlBaseType.STRING, ComparisonUtil::handleString)) + .add(handler(SqlBaseType.BOOLEAN, ComparisonUtil::handleBoolean)) + .build(); + private ComparisonUtil() { } - static boolean isValidComparison( + static void isValidComparison( final SqlType left, final ComparisonExpression.Type operator, final SqlType right ) { - if (left.baseType().isNumber() && right.baseType().isNumber()) { - return true; + if (left == null || right == null) { + throw nullSchemaException(left, operator, right); } - if (left.baseType() == SqlBaseType.STRING && right.baseType() == SqlBaseType.STRING) { - return true; - } + final boolean valid = HANDLERS.stream() + .filter(h -> h.handles.test(left.baseType())) + .findFirst() + .map(h -> h.validator.test(operator, right)) + .orElse(false); - if (left.baseType() == SqlBaseType.BOOLEAN && right.baseType() == SqlBaseType.BOOLEAN) { - if (operator == ComparisonExpression.Type.EQUAL - || operator == ComparisonExpression.Type.NOT_EQUAL) { - return true; - } + if (!valid) { + throw new KsqlException( + "Operator " + operator + " cannot be used to compare " + + left.baseType() + " and " + right.baseType() + ); } + } + + private static KsqlException nullSchemaException( + final SqlType left, + final Type operator, + final SqlType right + ) { + final String leftType = left == null ? "NULL" : left.baseType().name(); + final String rightType = right == null ? "NULL" : right.baseType().name(); - throw new KsqlException( - "Operator " + operator + " cannot be used to compare " - + left.baseType() + " and " + right.baseType() + return new KsqlException( + "Comparison with NULL not supported: " + + leftType + " " + operator.getValue() + " " + rightType + + System.lineSeparator() + + "Use 'IS NULL' or 'IS NOT NULL' instead." ); } + + private static boolean handleNumber(final Type operator, final SqlType right) { + return right.baseType().isNumber(); + } + + private static boolean handleString(final Type operator, final SqlType right) { + return right.baseType() == SqlBaseType.STRING; + } + + private static boolean handleBoolean(final Type operator, final SqlType right) { + return right.baseType() == SqlBaseType.BOOLEAN + && (operator == Type.EQUAL || operator == Type.NOT_EQUAL); + } + + private static Handler handler( + final SqlBaseType baseType, + final BiPredicate validator + ) { + return handler(t -> t == baseType, validator); + } + + private static Handler handler( + final Predicate handles, + final BiPredicate validator + ) { + return new Handler(handles, validator); + } + + private static final class Handler { + + final Predicate handles; + final BiPredicate validator; + + private Handler( + final Predicate handles, + final BiPredicate validator + ) { + this.handles = requireNonNull(handles, "handles"); + this.validator = requireNonNull(validator, "validator"); + } + } } diff --git a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/util/ComparisonUtilTest.java b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/util/ComparisonUtilTest.java index 6b774fa101a1..eb1af7f41e72 100644 --- a/ksqldb-execution/src/test/java/io/confluent/ksql/execution/util/ComparisonUtilTest.java +++ b/ksqldb-execution/src/test/java/io/confluent/ksql/execution/util/ComparisonUtilTest.java @@ -15,7 +15,6 @@ package io.confluent.ksql.execution.util; -import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; @@ -73,9 +72,7 @@ public void shouldAssertTrueForValidComparisons() { for (final SqlType leftType: typesTable) { for (final SqlType rightType: typesTable) { if (expectedResults.get(i).get(j)) { - assertThat( - ComparisonUtil.isValidComparison(leftType, ComparisonExpression.Type.EQUAL, rightType) - , equalTo(true)); + ComparisonUtil.isValidComparison(leftType, ComparisonExpression.Type.EQUAL, rightType); } j++; @@ -111,4 +108,26 @@ public void shouldThrowForInvalidComparisons() { j = 0; } } + + @SuppressWarnings("ConstantConditions") + @Test + public void shouldNotCompareLeftNullSchema() { + // Expect: + expectedException.expect(KsqlException.class); + expectedException.expectMessage("Comparison with NULL not supported: NULL = STRING"); + + // When: + ComparisonUtil.isValidComparison(null, ComparisonExpression.Type.EQUAL, SqlTypes.STRING); + } + + @SuppressWarnings("ConstantConditions") + @Test + public void shouldNotCompareLeftRightSchema() { + // Expect: + expectedException.expect(KsqlException.class); + expectedException.expectMessage("Comparison with NULL not supported: STRING = NULL"); + + // When: + ComparisonUtil.isValidComparison(SqlTypes.STRING, ComparisonExpression.Type.EQUAL, null); + } } \ No newline at end of file diff --git a/ksqldb-functional-tests/src/test/resources/query-validation-tests/array.json b/ksqldb-functional-tests/src/test/resources/query-validation-tests/array.json index 8131916ecf3f..4cd7ab91f8f0 100644 --- a/ksqldb-functional-tests/src/test/resources/query-validation-tests/array.json +++ b/ksqldb-functional-tests/src/test/resources/query-validation-tests/array.json @@ -106,6 +106,19 @@ {"topic": "OUTPUT", "value": {"ARRAY_LEN": 1, "MAP_LEN": 2, "STRUCT_LEN": 3}}, {"topic": "OUTPUT", "value": {"ARRAY_LEN": 0, "MAP_LEN": 0, "STRUCT_LEN": 0}} ] + }, + { + "name": "multi-dimensional", + "statements": [ + "CREATE STREAM INPUT (col0 ARRAY>) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT col0[1][2] FROM INPUT;" + ], + "inputs": [ + {"topic": "test_topic", "value": {"col0": [[0, 1],[2]]}} + ], + "outputs": [ + {"topic": "OUTPUT", "value": {"KSQL_COL_0": 1}} + ] } ] } \ No newline at end of file diff --git a/ksqldb-functional-tests/src/test/resources/query-validation-tests/binary-arithmetic.json b/ksqldb-functional-tests/src/test/resources/query-validation-tests/binary-arithmetic.json new file mode 100644 index 000000000000..96f0ed1ba42f --- /dev/null +++ b/ksqldb-functional-tests/src/test/resources/query-validation-tests/binary-arithmetic.json @@ -0,0 +1,23 @@ +{ + "comments": [ + "Test cases covering Binary Arithmetic" + ], + "tests": [ + { + "name": "in projection", + "statements": [ + "CREATE STREAM INPUT (col0 INT KEY, col1 BIGINT, col2 DOUBLE) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT col0+col1, col2+10, col1*25, 12*4+2 FROM INPUT;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "inputs": [ + {"topic": "test_topic", "key": 6, "value": {"col1": 25, "col2": 3.21}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": 6, "value": {"KSQL_COL_0": 31, "KSQL_COL_1": 13.21, "KSQL_COL_2": 625, "KSQL_COL_3": 50}} + ] + } + ] +} \ No newline at end of file diff --git a/ksqldb-functional-tests/src/test/resources/query-validation-tests/binary-comparison.json b/ksqldb-functional-tests/src/test/resources/query-validation-tests/binary-comparison.json new file mode 100644 index 000000000000..ed08a48ea07c --- /dev/null +++ b/ksqldb-functional-tests/src/test/resources/query-validation-tests/binary-comparison.json @@ -0,0 +1,257 @@ +{ + "comments": [ + "Tests covering SQL binary comparisons" + ], + "tests": [ + { + "name": "equals", + "statements": [ + "CREATE STREAM INPUT (A INT KEY, B BOOLEAN, C BIGINT, D DOUBLE, E DECIMAL(4,3), F STRING) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT A = 1, B = true, C = 11, D = 1.1, E = 1.20, F = 'foo' FROM INPUT;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "inputs": [ + {"topic": "test_topic", "key": 1, "value": {"B": true, "C": 11, "D": 1.1, "E": 1.20, "F": "foo"}}, + {"topic": "test_topic", "key": 2, "value": {"B": false, "C": 10, "D": 1.0, "E": 1.21, "F": "Foo"}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": 1, "value": {"KSQL_COL_0": true, "KSQL_COL_1": true, "KSQL_COL_2": true, "KSQL_COL_3": true, "KSQL_COL_4": true, "KSQL_COL_5": true}}, + {"topic": "OUTPUT", "key": 2, "value": {"KSQL_COL_0": false, "KSQL_COL_1": false, "KSQL_COL_2": false, "KSQL_COL_3": false, "KSQL_COL_4": false, "KSQL_COL_5": false}} + ] + }, + { + "name": "not equals", + "statements": [ + "CREATE STREAM INPUT (A INT KEY, B BOOLEAN, C BIGINT, D DOUBLE, E DECIMAL(4,3), F STRING) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT A <> 1, B <> true, C <> 11, D <> 1.1, E <> 1.20, F <> 'foo' FROM INPUT;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "inputs": [ + {"topic": "test_topic", "key": 1, "value": {"B": true, "C": 11, "D": 1.1, "E": 1.20, "F": "foo"}}, + {"topic": "test_topic", "key": 2, "value": {"B": false, "C": 10, "D": 1.0, "E": 1.21, "F": "Foo"}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": 1, "value": {"KSQL_COL_0": false, "KSQL_COL_1": false, "KSQL_COL_2": false, "KSQL_COL_3": false, "KSQL_COL_4": false, "KSQL_COL_5": false}}, + {"topic": "OUTPUT", "key": 2, "value": {"KSQL_COL_0": true, "KSQL_COL_1": true, "KSQL_COL_2": true, "KSQL_COL_3": true, "KSQL_COL_4": true, "KSQL_COL_5": true}} + ] + }, + { + "name": "less than", + "statements": [ + "CREATE STREAM INPUT (A INT KEY, B BIGINT, C DOUBLE, D DECIMAL(4,3), E STRING) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT A < 1, B < 11, C < 1.1, D < 1.20, E < 'foo' FROM INPUT;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "inputs": [ + {"topic": "test_topic", "key": 0, "value": {"B": 10, "C": 1.0, "D": 1.19, "E": "Foo"}}, + {"topic": "test_topic", "key": 1, "value": {"B": 11, "C": 1.1, "D": 1.20, "E": "foo"}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": 0, "value": {"KSQL_COL_0": true, "KSQL_COL_1": true, "KSQL_COL_2": true, "KSQL_COL_3": true, "KSQL_COL_4": true}}, + {"topic": "OUTPUT", "key": 1, "value": {"KSQL_COL_0": false, "KSQL_COL_1": false, "KSQL_COL_2": false, "KSQL_COL_3": false, "KSQL_COL_4": false}} + ] + }, + { + "name": "less than or equal", + "statements": [ + "CREATE STREAM INPUT (A INT KEY, B BIGINT, C DOUBLE, D DECIMAL(4,3), E STRING) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT A <= 1, B <= 11, C <= 1.1, D <= 1.20, E <= 'foo' FROM INPUT;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "inputs": [ + {"topic": "test_topic", "key": 0, "value": {"B": 10, "C": 1.0, "D": 1.19, "E": "Foo"}}, + {"topic": "test_topic", "key": 1, "value": {"B": 11, "C": 1.1, "D": 1.20, "E": "foo"}}, + {"topic": "test_topic", "key": 2, "value": {"B": 12, "C": 1.11, "D": 1.21, "E": "goo"}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": 0, "value": {"KSQL_COL_0": true, "KSQL_COL_1": true, "KSQL_COL_2": true, "KSQL_COL_3": true, "KSQL_COL_4": true}}, + {"topic": "OUTPUT", "key": 1, "value": {"KSQL_COL_0": true, "KSQL_COL_1": true, "KSQL_COL_2": true, "KSQL_COL_3": true, "KSQL_COL_4": true}}, + {"topic": "OUTPUT", "key": 2, "value": {"KSQL_COL_0": false, "KSQL_COL_1": false, "KSQL_COL_2": false, "KSQL_COL_3": false, "KSQL_COL_4": false}} + ] + }, + { + "name": "greater than", + "statements": [ + "CREATE STREAM INPUT (A INT KEY, B BIGINT, C DOUBLE, D DECIMAL(4,3), E STRING) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT A > 1, B > 11, C > 1.1, D > 1.20, E > 'foo' FROM INPUT;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "inputs": [ + {"topic": "test_topic", "key": 1, "value": {"B": 11, "C": 1.1, "D": 1.20, "E": "foo"}}, + {"topic": "test_topic", "key": 2, "value": {"B": 12, "C": 1.11, "D": 1.21, "E": "goo"}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": 1, "value": {"KSQL_COL_0": false, "KSQL_COL_1": false, "KSQL_COL_2": false, "KSQL_COL_3": false, "KSQL_COL_4": false}}, + {"topic": "OUTPUT", "key": 2, "value": {"KSQL_COL_0": true, "KSQL_COL_1": true, "KSQL_COL_2": true, "KSQL_COL_3": true, "KSQL_COL_4": true}} + ] + }, + { + "name": "greater than or equal", + "statements": [ + "CREATE STREAM INPUT (A INT KEY, B BIGINT, C DOUBLE, D DECIMAL(4,3), E STRING) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT A >= 1, B >= 11, C >= 1.1, D >= 1.20, E >= 'foo' FROM INPUT;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "inputs": [ + {"topic": "test_topic", "key": 0, "value": {"B": 10, "C": 1.0, "D": 1.19, "E": "Foo"}}, + {"topic": "test_topic", "key": 1, "value": {"B": 11, "C": 1.1, "D": 1.20, "E": "foo"}}, + {"topic": "test_topic", "key": 2, "value": {"B": 12, "C": 1.11, "D": 1.21, "E": "goo"}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": 0, "value": {"KSQL_COL_0": false, "KSQL_COL_1": false, "KSQL_COL_2": false, "KSQL_COL_3": false, "KSQL_COL_4": false}}, + {"topic": "OUTPUT", "key": 1, "value": {"KSQL_COL_0": true, "KSQL_COL_1": true, "KSQL_COL_2": true, "KSQL_COL_3": true, "KSQL_COL_4": true}}, + {"topic": "OUTPUT", "key": 2, "value": {"KSQL_COL_0": true, "KSQL_COL_1": true, "KSQL_COL_2": true, "KSQL_COL_3": true, "KSQL_COL_4": true}} + ] + }, + { + "name": "between", + "statements": [ + "CREATE STREAM INPUT (A INT KEY, B BIGINT, C DOUBLE, D DECIMAL(4,3), E STRING) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT A BETWEEN 0 AND 2, B BETWEEN 10 AND 12, C BETWEEN 1.0 AND 1.11, D BETWEEN 1.19 AND 1.21, E BETWEEN 'eoo' AND 'goo' FROM INPUT;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "inputs": [ + {"topic": "test_topic", "key": -1, "value": {"B": 9, "C": 0.99, "D": 1.18, "E": "doo"}}, + {"topic": "test_topic", "key": 0, "value": {"B": 10, "C": 1.0, "D": 1.19, "E": "eoo"}}, + {"topic": "test_topic", "key": 1, "value": {"B": 11, "C": 1.1, "D": 1.20, "E": "foo"}}, + {"topic": "test_topic", "key": 2, "value": {"B": 12, "C": 1.11, "D": 1.21, "E": "goo"}}, + {"topic": "test_topic", "key": 3, "value": {"B": 13, "C": 1.12, "D": 1.22, "E": "hoo"}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": -1, "value": {"KSQL_COL_0": false, "KSQL_COL_1": false, "KSQL_COL_2": false, "KSQL_COL_3": false, "KSQL_COL_4": false}}, + {"topic": "OUTPUT", "key": 0, "value": {"KSQL_COL_0": true, "KSQL_COL_1": true, "KSQL_COL_2": true, "KSQL_COL_3": true, "KSQL_COL_4": true}}, + {"topic": "OUTPUT", "key": 1, "value": {"KSQL_COL_0": true, "KSQL_COL_1": true, "KSQL_COL_2": true, "KSQL_COL_3": true, "KSQL_COL_4": true}}, + {"topic": "OUTPUT", "key": 2, "value": {"KSQL_COL_0": true, "KSQL_COL_1": true, "KSQL_COL_2": true, "KSQL_COL_3": true, "KSQL_COL_4": true}}, + {"topic": "OUTPUT", "key": 3, "value": {"KSQL_COL_0": false, "KSQL_COL_1": false, "KSQL_COL_2": false, "KSQL_COL_3": false, "KSQL_COL_4": false}} + ] + }, + { + "name": "not between", + "statements": [ + "CREATE STREAM INPUT (A INT KEY, B BIGINT, C DOUBLE, D DECIMAL(4,3), E STRING) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT A NOT BETWEEN 0 AND 2, B NOT BETWEEN 10 AND 12, C NOT BETWEEN 1.0 AND 1.11, D NOT BETWEEN 1.19 AND 1.21, E NOT BETWEEN 'eoo' AND 'goo' FROM INPUT;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "inputs": [ + {"topic": "test_topic", "key": -1, "value": {"B": 9, "C": 0.99, "D": 1.18, "E": "doo"}}, + {"topic": "test_topic", "key": 0, "value": {"B": 10, "C": 1.0, "D": 1.19, "E": "eoo"}}, + {"topic": "test_topic", "key": 1, "value": {"B": 11, "C": 1.1, "D": 1.20, "E": "foo"}}, + {"topic": "test_topic", "key": 2, "value": {"B": 12, "C": 1.11, "D": 1.21, "E": "goo"}}, + {"topic": "test_topic", "key": 3, "value": {"B": 13, "C": 1.12, "D": 1.22, "E": "hoo"}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": -1, "value": {"KSQL_COL_0": true, "KSQL_COL_1": true, "KSQL_COL_2": true, "KSQL_COL_3": true, "KSQL_COL_4": true}}, + {"topic": "OUTPUT", "key": 0, "value": {"KSQL_COL_0": false, "KSQL_COL_1": false, "KSQL_COL_2": false, "KSQL_COL_3": false, "KSQL_COL_4": false}}, + {"topic": "OUTPUT", "key": 1, "value": {"KSQL_COL_0": false, "KSQL_COL_1": false, "KSQL_COL_2": false, "KSQL_COL_3": false, "KSQL_COL_4": false}}, + {"topic": "OUTPUT", "key": 2, "value": {"KSQL_COL_0": false, "KSQL_COL_1": false, "KSQL_COL_2": false, "KSQL_COL_3": false, "KSQL_COL_4": false}}, + {"topic": "OUTPUT", "key": 3, "value": {"KSQL_COL_0": true, "KSQL_COL_1": true, "KSQL_COL_2": true, "KSQL_COL_3": true, "KSQL_COL_4": true}} + ] + }, + { + "name": "is distinct from", + "statements": [ + "CREATE STREAM INPUT (ID INT KEY, ID2 INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT ID IS DISTINCT FROM ID2 FROM INPUT;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "inputs": [ + {"topic": "test_topic", "key": 1, "value": {"ID2": 1}}, + {"topic": "test_topic", "key": 2, "value": {"ID2": 1}}, + {"topic": "test_topic", "key": 3, "value": {"ID2": null}}, + {"topic": "test_topic", "key": null, "value": {"ID2": 1}}, + {"topic": "test_topic", "key": null, "value": {"ID2": null}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": 1, "value": {"KSQL_COL_0": false}}, + {"topic": "OUTPUT", "key": 2, "value": {"KSQL_COL_0": true}}, + {"topic": "OUTPUT", "key": 3, "value": {"KSQL_COL_0": true}}, + {"topic": "OUTPUT", "key": null, "value": {"KSQL_COL_0": true}}, + {"topic": "OUTPUT", "key": null, "value": {"KSQL_COL_0": false}} + ] + }, + { + "name": "is not distinct from", + "statements": [ + "CREATE STREAM INPUT (ID INT KEY, ID2 INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT ID IS NOT DISTINCT FROM ID2 FROM INPUT;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "inputs": [ + {"topic": "test_topic", "key": 1, "value": {"ID2": 1}}, + {"topic": "test_topic", "key": 2, "value": {"ID2": 1}}, + {"topic": "test_topic", "key": 3, "value": {"ID2": null}}, + {"topic": "test_topic", "key": null, "value": {"ID2": 1}}, + {"topic": "test_topic", "key": null, "value": {"ID2": null}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": 1, "value": {"KSQL_COL_0": true}}, + {"topic": "OUTPUT", "key": 2, "value": {"KSQL_COL_0": false}}, + {"topic": "OUTPUT", "key": 3, "value": {"KSQL_COL_0": false}}, + {"topic": "OUTPUT", "key": null, "value": {"KSQL_COL_0": false}}, + {"topic": "OUTPUT", "key": null, "value": {"KSQL_COL_0": true}} + ] + }, + { + "name": "comparison array fails", + "statements": [ + "CREATE STREAM INPUT (A INT KEY, B ARRAY, C ARRAY) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT B = C FROM INPUT;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "expectedException": { + "type": "io.confluent.ksql.util.KsqlStatementException", + "message": "Operator EQUAL cannot be used to compare ARRAY and ARRAY" + } + }, + { + "name": "comparison map fails", + "statements": [ + "CREATE STREAM INPUT (A INT KEY, B MAP, C MAP) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT B = C FROM INPUT;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "expectedException": { + "type": "io.confluent.ksql.util.KsqlStatementException", + "message": "Operator EQUAL cannot be used to compare MAP and MAP" + } + }, + { + "name": "comparison struct fails", + "statements": [ + "CREATE STREAM INPUT (A INT KEY, B STRUCT, C STRUCT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT B = C FROM INPUT;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "expectedException": { + "type": "io.confluent.ksql.util.KsqlStatementException", + "message": "Operator EQUAL cannot be used to compare STRUCT and STRUCT" + } + } + ] +} \ No newline at end of file diff --git a/ksqldb-functional-tests/src/test/resources/query-validation-tests/case-expression.json b/ksqldb-functional-tests/src/test/resources/query-validation-tests/case-expression.json index 56edfbb0880b..ceb7e3805b38 100644 --- a/ksqldb-functional-tests/src/test/resources/query-validation-tests/case-expression.json +++ b/ksqldb-functional-tests/src/test/resources/query-validation-tests/case-expression.json @@ -1,11 +1,6 @@ { "comments": [ - "You can specify multiple statements per test case, i.e., to set up the various streams needed", - "for joins etc, but currently only the final topology will be verified. This should be enough", - "for most tests as we can simulate the outputs from previous stages into the final stage. If we", - "take a modular approach to testing we can still verify that it all works correctly, i.e, if we", - "verify the output of a select or aggregate is correct, we can use simulated output to feed into", - "a join or another aggregate." + "Test cases covering SQL CASE statements" ], "tests": [ { @@ -33,11 +28,13 @@ ], "inputs": [ {"topic": "test_topic", "value": {"ORDERID": 4, "ORDERUNITS": 1.9}}, - {"topic": "test_topic", "value": {"ORDERID": 5, "ORDERUNITS": 1.0}} + {"topic": "test_topic", "value": {"ORDERID": 5, "ORDERUNITS": 1.0}}, + {"topic": "test_topic", "value": {"ORDERID": 5, "ORDERUNITS": 2.0}} ], "outputs": [ {"topic": "S1", "value": {"CASE_RESAULT": 6}}, - {"topic": "S1", "value": {"CASE_RESAULT": 7}} + {"topic": "S1", "value": {"CASE_RESAULT": 7}}, + {"topic": "S1", "value": {"CASE_RESAULT": null}} ] }, { @@ -138,6 +135,40 @@ {"topic": "S1", "value": {"CASE_RESAULT": "default"}}, {"topic": "S1", "value": {"CASE_RESAULT": "CITY_6"}} ] + }, + { + "name": "should execute branches lazily", + "comment": [ + "The test UDF 'WHENCONDITION' and 'WHENRESULT' return their first arg and fails if their second arg is false.", + "Hence this test case would throw if execution was not lazy" + ], + "statements": [ + "CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT CASE WHEN WHENCONDITION(true, true) THEN WHENRESULT(100, true) WHEN WHENCONDITION(true, false) THEN WHENRESULT(200, false) ELSE WHENRESULT(300, false)END FROM input;" + ], + "inputs": [ + {"topic": "test_topic", "value": {}} + ], + "outputs": [ + {"topic": "OUTPUT", "value": {"KSQL_COL_0": 100}} + ] + }, + { + "name": "should only execute ELSE if not matching WHENs", + "comment": [ + "The test UDF 'WHENCONDITION' and 'WHENRESULT' return their first arg and fails if their second arg is false.", + "Hence this test case would throw if execution was not lazy" + ], + "statements": [ + "CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT CASE WHEN WHENCONDITION(false, true) THEN WHENRESULT(100, false) WHEN WHENCONDITION(false, true) THEN WHENRESULT(200, false) ELSE WHENRESULT(300, true) END FROM input;" + ], + "inputs": [ + {"topic": "test_topic", "value": {}} + ], + "outputs": [ + {"topic": "OUTPUT", "value": {"KSQL_COL_0": 300}} + ] } ] } \ No newline at end of file diff --git a/ksqldb-functional-tests/src/test/resources/query-validation-tests/cast.json b/ksqldb-functional-tests/src/test/resources/query-validation-tests/cast.json index 89291d61de62..ba29c8e451a9 100644 --- a/ksqldb-functional-tests/src/test/resources/query-validation-tests/cast.json +++ b/ksqldb-functional-tests/src/test/resources/query-validation-tests/cast.json @@ -162,6 +162,45 @@ {"topic": "OUT", "value": {"I": 10, "L": 10, "D": 10.00, "S": "10.00"}}, {"topic": "OUT", "value": {"I": 10, "L": 10, "D": 10.01, "S": "10.01"}} ] + }, + { + "name": "integer to bigint", + "statements": [ + "CREATE STREAM INPUT (col0 INT, col1 INT) WITH (kafka_topic='test_topic', value_format='AVRO');", + "CREATE STREAM OUT AS SELECT cast((col0 - col1) AS BIGINT) as VAL FROM INPUT;" + ], + "inputs": [ + {"topic": "test_topic", "value": {"col0": 1, "col1": 2}} + ], + "outputs": [ + {"topic": "OUT", "value": {"VAL": -1}} + ] + }, + { + "name": "integer to string", + "statements": [ + "CREATE STREAM INPUT (col0 INT, col1 INT) WITH (kafka_topic='test_topic', value_format='AVRO');", + "CREATE STREAM OUT AS SELECT cast((col0 - col1) AS STRING) as VAL FROM INPUT;" + ], + "inputs": [ + {"topic": "test_topic", "value": {"col0": 1, "col1": 2}} + ], + "outputs": [ + {"topic": "OUT", "value": {"VAL": "-1"}} + ] + }, + { + "name": "double to int", + "statements": [ + "CREATE STREAM INPUT (col0 DOUBLE, col1 DOUBLE) WITH (kafka_topic='test_topic', value_format='AVRO');", + "CREATE STREAM OUT AS SELECT cast((col0 - col1) AS INT) as VAL FROM INPUT;" + ], + "inputs": [ + {"topic": "test_topic", "value": {"col0": 3.3, "col1": 2.1}} + ], + "outputs": [ + {"topic": "OUT", "value": {"VAL": 1}} + ] } ] } \ No newline at end of file diff --git a/ksqldb-functional-tests/src/test/resources/query-validation-tests/concat.json b/ksqldb-functional-tests/src/test/resources/query-validation-tests/concat.json index cfcdafe491e4..2e54b0824f68 100644 --- a/ksqldb-functional-tests/src/test/resources/query-validation-tests/concat.json +++ b/ksqldb-functional-tests/src/test/resources/query-validation-tests/concat.json @@ -5,7 +5,6 @@ "tests": [ { "name": "concat fields using CONCAT", - "format": ["JSON", "PROTOBUF"], "statements": [ "CREATE STREAM TEST (source VARCHAR) WITH (kafka_topic='test_topic', value_format='JSON');", "CREATE STREAM OUTPUT AS SELECT CONCAT('prefix-', CONCAT(source, '-postfix')) AS THING FROM TEST;" @@ -21,7 +20,6 @@ }, { "name": "concat fields using '+' operator", - "format": ["JSON", "PROTOBUF"], "statements": [ "CREATE STREAM TEST (source VARCHAR) WITH (kafka_topic='test_topic', value_format='JSON');", "CREATE STREAM OUTPUT AS SELECT 'prefix-' + source + '-postfix' AS THING FROM TEST;" @@ -34,6 +32,34 @@ {"topic": "OUTPUT", "value": {"THING":"prefix-s1-postfix"}}, {"topic": "OUTPUT", "value": {"THING":"prefix-s2-postfix"}} ] + }, + { + "name": "should handle characters the must be escaped in java", + "statements": [ + "CREATE STREAM INPUT (source VARCHAR) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT CONCAT('\"', CONCAT(source, '\\')) AS THING FROM INPUT;" + ], + "inputs": [ + {"topic": "test_topic", "value": {"source": "foo"}}, + {"topic": "test_topic", "value": {"source": "\\foo\""}} + ], + "outputs": [ + {"topic": "OUTPUT", "value": {"THING":"\"foo\\"}}, + {"topic": "OUTPUT", "value": {"THING":"\"\\foo\"\\"}} + ] + }, + { + "name": "should handle characters the must be escaped in sql", + "statements": [ + "CREATE STREAM INPUT (source VARCHAR) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT CONCAT('''', CONCAT(source, '''')) AS THING FROM INPUT;" + ], + "inputs": [ + {"topic": "test_topic", "value": {"source": "foo"}} + ], + "outputs": [ + {"topic": "OUTPUT", "value": {"THING":"'foo'"}} + ] } ] } \ No newline at end of file diff --git a/ksqldb-functional-tests/src/test/resources/query-validation-tests/group-by.json b/ksqldb-functional-tests/src/test/resources/query-validation-tests/group-by.json index e4fe5ea5e304..6e786506c144 100644 --- a/ksqldb-functional-tests/src/test/resources/query-validation-tests/group-by.json +++ b/ksqldb-functional-tests/src/test/resources/query-validation-tests/group-by.json @@ -523,7 +523,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Field used in aggregate SELECT expression(s) outside of aggregate functions not part of GROUP BY: [TEST.F1]" + "message": "Column used in aggregate SELECT expression(s) outside of aggregate functions not part of GROUP BY: [F1]" } }, { @@ -616,7 +616,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [SUBSTRING(TEST.SOURCE, 0, 1)]" + "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [SUBSTRING(SOURCE, 0, 1)]" } }, { @@ -627,7 +627,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [SUBSTRING(TEST.REGION, 7, 1)]" + "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [SUBSTRING(REGION, 7, 1)]" } }, { @@ -638,7 +638,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [SUBSTRING(TEST.SOURCE, 0, 3)]" + "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [SUBSTRING(SOURCE, 0, 3)]" } }, { @@ -649,7 +649,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [SUBSTRING(TEST.REGION, 7, 3)]" + "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [SUBSTRING(REGION, 7, 3)]" } }, { @@ -899,7 +899,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [(TEST.F1 + TEST.F2)]" + "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [(F1 + F2)]" } }, { @@ -910,7 +910,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [(TEST.SUBREGION + TEST.REGION)]" + "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [(SUBREGION + REGION)]" } }, { @@ -921,7 +921,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [TEST.F1, TEST.F2]" + "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [F1, F2]" } }, { @@ -932,7 +932,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [TEST.REGION, TEST.SUBREGION]" + "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [REGION, SUBREGION]" } }, { @@ -984,7 +984,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [(TEST.F1 - TEST.F2)]" + "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [(F1 - F2)]" } }, { @@ -995,7 +995,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [(TEST.F1 - TEST.F0)]" + "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [(F1 - F0)]" } }, { @@ -1192,6 +1192,26 @@ {"topic": "OUTPUT", "key": "d1", "value": {"DATA": "d1", "KSQL_COL_1": 2, "KSQL_COL_2": 2, "COPY": "d1"}, "timestamp": 3} ] }, + { + "name": "duplicate udafs (stream->table)", + "statements": [ + "CREATE STREAM TEST (data VARCHAR) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE TABLE OUTPUT AS SELECT COUNT(1), COUNT(1) FROM TEST GROUP BY data;" + ], + "inputs": [ + {"topic": "test_topic", "value": {"DATA": "d1"}, "timestamp": 1}, + {"topic": "test_topic", "value": {"DATA": "d2"}, "timestamp": 2}, + {"topic": "test_topic", "value": {"DATA": "d1"}, "timestamp": 3} + ], + "outputs": [ + {"topic": "_confluent-ksql-some.ksql.service.idquery_CTAS_OUTPUT_0-Aggregate-Aggregate-Materialize-changelog", "key": "d1", "value": {"KSQL_INTERNAL_COL_0": "d1", "KSQL_AGG_VARIABLE_0": 1, "KSQL_AGG_VARIABLE_1": 1}, "timestamp": 1}, + {"topic": "_confluent-ksql-some.ksql.service.idquery_CTAS_OUTPUT_0-Aggregate-Aggregate-Materialize-changelog", "key": "d2", "value": {"KSQL_INTERNAL_COL_0": "d2", "KSQL_AGG_VARIABLE_0": 1, "KSQL_AGG_VARIABLE_1": 1}, "timestamp": 2}, + {"topic": "_confluent-ksql-some.ksql.service.idquery_CTAS_OUTPUT_0-Aggregate-Aggregate-Materialize-changelog", "key": "d1", "value": {"KSQL_INTERNAL_COL_0": "d1", "KSQL_AGG_VARIABLE_0": 2, "KSQL_AGG_VARIABLE_1": 2}, "timestamp": 3}, + {"topic": "OUTPUT", "key": "d1", "value": {"KSQL_COL_0": 1, "KSQL_COL_1": 1}, "timestamp": 1}, + {"topic": "OUTPUT", "key": "d2", "value": {"KSQL_COL_0": 1, "KSQL_COL_1": 1}, "timestamp": 2}, + {"topic": "OUTPUT", "key": "d1", "value": {"KSQL_COL_0": 2, "KSQL_COL_1": 2}, "timestamp": 3} + ] + }, { "name": "with non-aggregate projection field not in group by (stream->table)", "statements": [ @@ -1200,7 +1220,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [TEST.D1]" + "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [D1]" } }, { @@ -1211,7 +1231,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [TEST.D1]" + "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [D1]" } }, { @@ -1251,7 +1271,7 @@ "name": "with projection without aggregate functions (stream->table)", "statements": [ "CREATE STREAM TEST (d1 VARCHAR, d2 INT) WITH (kafka_topic='test_topic', value_format='DELIMITED');", - "CREATE TABLE OUTPUT AS SELECT SUBSTRING(d2, 1, 2) FROM TEST GROUP BY d2;" + "CREATE TABLE OUTPUT AS SELECT SUBSTRING(d1, 1, 2) FROM TEST GROUP BY d2;" ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", @@ -1266,7 +1286,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Use of aggregate functions requires a GROUP BY clause. Aggregate function(s): COUNT" + "message": "Use of aggregate function COUNT requires a GROUP BY clause" } }, { @@ -1502,6 +1522,58 @@ } ] } + }, + { + "name": "with select * where all columns in group by", + "statements": [ + "CREATE STREAM TEST (id INT, id2 INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE TABLE OUTPUT AS SELECT *, COUNT() FROM TEST GROUP BY id, id2;" + ], + "inputs": [ + {"topic": "test_topic", "value": {"ID": 1, "ID2": 2}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": "1|+|2", "value": {"ID": 1, "ID2": 2, "KSQL_COL_0": 1}} + ] + }, + { + "name": "with select * where not all columns in group by", + "statements": [ + "CREATE STREAM TEST (id INT, id2 INT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE TABLE OUTPUT AS SELECT *, COUNT() FROM TEST GROUP BY id;" + ], + "expectedException": { + "type": "io.confluent.ksql.util.KsqlStatementException", + "message": "Non-aggregate SELECT expression(s) not part of GROUP BY: [ID2]" + } + }, + { + "name": "on join", + "statements": [ + "CREATE TABLE t1 (ROWKEY BIGINT KEY, TOTAL integer) WITH (kafka_topic='T1', value_format='AVRO');", + "CREATE TABLE t2 (ROWKEY BIGINT KEY, TOTAL integer) WITH (kafka_topic='T2', value_format='AVRO');", + "CREATE TABLE OUTPUT AS SELECT SUM(t1.total + CASE WHEN t2.total IS NULL THEN 0 ELSE t2.total END) as SUM FROM T1 LEFT JOIN T2 ON (t1.rowkey = t2.rowkey) GROUP BY t1.rowkey HAVING COUNT(1) > 0;" + ], + "inputs": [ + {"topic": "T1", "key": 0, "value": {"total": 100}}, + {"topic": "T1", "key": 1, "value": {"total": 101}}, + {"topic": "T2", "key": 0, "value": {"total": 5}}, + {"topic": "T2", "key": 1, "value": {"total": 10}}, + {"topic": "T2", "key": 0, "value": {"total": 20}}, + {"topic": "T2", "key": 0, "value": null} + ], + "outputs": [ + {"topic": "OUTPUT", "key": 0,"value": {"SUM": 100}}, + {"topic": "OUTPUT", "key": 1,"value": {"SUM": 101}}, + {"topic": "OUTPUT", "key": 0,"value": null}, + {"topic": "OUTPUT", "key": 0,"value": {"SUM": 105}}, + {"topic": "OUTPUT", "key": 1,"value": null}, + {"topic": "OUTPUT", "key": 1,"value": {"SUM": 111}}, + {"topic": "OUTPUT", "key": 0,"value": null}, + {"topic": "OUTPUT", "key": 0,"value": {"SUM": 120}}, + {"topic": "OUTPUT", "key": 0,"value": null}, + {"topic": "OUTPUT", "key": 0,"value": {"SUM": 100}} + ] } ] } \ No newline at end of file diff --git a/ksqldb-functional-tests/src/test/resources/query-validation-tests/map.json b/ksqldb-functional-tests/src/test/resources/query-validation-tests/map.json new file mode 100644 index 000000000000..6eee293cc363 --- /dev/null +++ b/ksqldb-functional-tests/src/test/resources/query-validation-tests/map.json @@ -0,0 +1,39 @@ +{ + "comments": [ + "Tests covering SQL MAP" + ], + "tests": [ + { + "name": "string map", + "statements": [ + "CREATE STREAM INPUT (ID STRING KEY, A_MAP MAP) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT A_MAP['expected'], A_MAP['missing'] FROM INPUT;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "inputs": [ + {"topic": "test_topic", "value": {"A_MAP": {"expected": 10}}} + ], + "outputs": [ + {"topic": "OUTPUT", "value": {"KSQL_COL_0": 10, "KSQL_COL_1": null}} + ] + }, + { + "name": "map value as UDF param", + "statements": [ + "CREATE STREAM INPUT (ID STRING KEY, col11 MAP) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT EXTRACTJSONFIELD(col11['address'], '$.city') FROM INPUT;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "inputs": [ + {"topic": "test_topic", "value": {"col11": {"address": "{\"city\": \"London\"}"}}} + ], + "outputs": [ + {"topic": "OUTPUT", "value": {"KSQL_COL_0": "London"}} + ] + } + ] +} \ No newline at end of file diff --git a/ksqldb-functional-tests/src/test/resources/query-validation-tests/null.json b/ksqldb-functional-tests/src/test/resources/query-validation-tests/null.json new file mode 100644 index 000000000000..7b498fedceba --- /dev/null +++ b/ksqldb-functional-tests/src/test/resources/query-validation-tests/null.json @@ -0,0 +1,81 @@ +{ + "comments": [ + "Tests covering SQL NULL" + ], + "tests": [ + { + "name": "is null", + "statements": [ + "CREATE STREAM INPUT (ID INT KEY, NAME STRING) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT ID IS NULL AS ID_NULL, NAME IS NULL AS NAME_NULL FROM INPUT;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "inputs": [ + {"topic": "test_topic", "key": 1, "value": {"NAME": "not null"}}, + {"topic": "test_topic", "key": null, "value": {"NAME": null}}, + {"topic": "test_topic", "key": 0, "value": {}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": 1, "value": {"ID_NULL": false, "NAME_NULL": false}}, + {"topic": "OUTPUT", "key": null, "value": {"ID_NULL": true, "NAME_NULL": true}}, + {"topic": "OUTPUT", "key": 0, "value": {"ID_NULL": false, "NAME_NULL": true}} + ] + }, + { + "name": "is not null", + "statements": [ + "CREATE STREAM INPUT (ID INT KEY, NAME STRING) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT ID IS NOT NULL AS ID_NULL, NAME IS NOT NULL AS NAME_NULL FROM INPUT;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "inputs": [ + {"topic": "test_topic", "key": 1, "value": {"NAME": "not null"}}, + {"topic": "test_topic", "key": null, "value": {"NAME": null}}, + {"topic": "test_topic", "key": 0, "value": {}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": 1, "value": {"ID_NULL": true, "NAME_NULL": true}}, + {"topic": "OUTPUT", "key": null, "value": {"ID_NULL": false, "NAME_NULL": false}}, + {"topic": "OUTPUT", "key": 0, "value": {"ID_NULL": true, "NAME_NULL": false}} + ] + }, + { + "name": "null equals", + "statements": [ + "CREATE STREAM INPUT (ID INT KEY, COL0 BIGINT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT ID = COL0, NULL IS NULL FROM INPUT;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "inputs": [ + {"topic": "test_topic", "key": null, "value": {"COL0": 12344}}, + {"topic": "test_topic", "key": null, "value": {"COL0": null}}, + {"topic": "test_topic", "key": 0, "value": {}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": null, "value": {"KSQL_COL_0": false, "KSQL_COL_1": true}}, + {"topic": "OUTPUT", "key": null, "value": {"KSQL_COL_0": false, "KSQL_COL_1": true}}, + {"topic": "OUTPUT", "key": 0, "value": {"KSQL_COL_0": false, "KSQL_COL_1": true}} + ] + }, + { + "name": "comparison with null", + "statements": [ + "CREATE STREAM INPUT (ID INT KEY, COL0 BIGINT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT NULL <> NULL FROM INPUT;" + ], + "properties": { + "ksql.any.key.name.enabled": true + }, + "expectedException": { + "type": "io.confluent.ksql.util.KsqlStatementException", + "message": "Comparison with NULL not supported: NULL <> NULL" + } + } + ] +} \ No newline at end of file diff --git a/ksqldb-functional-tests/src/test/resources/query-validation-tests/string.json b/ksqldb-functional-tests/src/test/resources/query-validation-tests/string.json index 34a26bc09b6b..dc535a730633 100644 --- a/ksqldb-functional-tests/src/test/resources/query-validation-tests/string.json +++ b/ksqldb-functional-tests/src/test/resources/query-validation-tests/string.json @@ -33,6 +33,38 @@ } ] } + }, + { + "name": "LCASE, UCASE, TRIM SUBSTRING", + "statements": [ + "CREATE STREAM INPUT (text STRING) WITH (kafka_topic='test_topic', value_format='DELIMITED');", + "CREATE STREAM OUTPUT AS select LCASE(text), UCASE(text), TRIM(text), SUBSTRING(text, 2, 5) from INPUT;" + ], + "inputs": [ + {"topic": "test_topic", "value": "lower"}, + {"topic": "test_topic", "value": "UPPER"}, + {"topic": "test_topic", "value": "MiXeD"}, + {"topic": "test_topic", "value": " \t with white space \t"}, + {"topic": "test_topic", "value": "s"}, + {"topic": "test_topic", "value": "long enough"} + ], + "outputs": [ + {"topic": "OUTPUT", "value": "lower,LOWER,lower,ower"}, + {"topic": "OUTPUT", "value": "upper,UPPER,UPPER,PPER"}, + {"topic": "OUTPUT", "value": "mixed,MIXED,MiXeD,iXeD"}, + {"topic": "OUTPUT", "value": "\" \t with white space \t\",\" \t WITH WHITE SPACE \t\",with white space,\"\t wit\""}, + {"topic": "OUTPUT", "value": "s,S,s,"}, + {"topic": "OUTPUT", "value": "long enough,LONG ENOUGH,long enough,ong e"} + ], + "post": { + "sources": [ + { + "name": "OUTPUT", + "type": "stream", + "schema": "ROWKEY STRING KEY, KSQL_COL_0 STRING, KSQL_COL_1 STRING, KSQL_COL_2 STRING, KSQL_COL_3 STRING" + } + ] + } } ] } \ No newline at end of file diff --git a/ksqldb-functional-tests/src/test/resources/query-validation-tests/table-functions.json b/ksqldb-functional-tests/src/test/resources/query-validation-tests/table-functions.json index 4f17a0e21700..73fb3e8f699f 100644 --- a/ksqldb-functional-tests/src/test/resources/query-validation-tests/table-functions.json +++ b/ksqldb-functional-tests/src/test/resources/query-validation-tests/table-functions.json @@ -269,6 +269,57 @@ "type": "io.confluent.ksql.util.KsqlException", "message": "Table source is not supported with table functions" } + }, + { + "name": "with select *", + "statements": [ + "CREATE STREAM TEST (ROWKEY STRING KEY, ID BIGINT, MY_ARR ARRAY) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT *, EXPLODE(MY_ARR) VAL FROM TEST;" + ], + "inputs": [ + {"topic": "test_topic", "key": "0", "value": {"ID": 0, "MY_ARR": [1, 2]}}, + {"topic": "test_topic", "key": "1", "value": {"ID": 1, "MY_ARR": [3, 4]}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": "0", "value": {"ID": 0, "MY_ARR": [1, 2], "VAL": 1}}, + {"topic": "OUTPUT", "key": "0", "value": {"ID": 0, "MY_ARR": [1, 2], "VAL": 2}}, + {"topic": "OUTPUT", "key": "1", "value": {"ID": 1, "MY_ARR": [3, 4], "VAL": 3}}, + {"topic": "OUTPUT", "key": "1", "value": {"ID": 1, "MY_ARR": [3, 4], "VAL": 4}} + ] + }, + { + "name": "with duplicate columns", + "statements": [ + "CREATE STREAM TEST (ID BIGINT, MY_ARR ARRAY) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT ID, EXPLODE(MY_ARR) VAL, ID AS ID2 FROM TEST;" + ], + "inputs": [ + {"topic": "test_topic", "key": "0", "value": {"ID": 0, "MY_ARR": [1, 2]}}, + {"topic": "test_topic", "key": "1", "value": {"ID": 1, "MY_ARR": [3, 4]}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": "0", "value": {"ID": 0, "VAL": 1, "ID2": 0}}, + {"topic": "OUTPUT", "key": "0", "value": {"ID": 0, "VAL": 2, "ID2": 0}}, + {"topic": "OUTPUT", "key": "1", "value": {"ID": 1, "VAL": 3, "ID2": 1}}, + {"topic": "OUTPUT", "key": "1", "value": {"ID": 1, "VAL": 4, "ID2": 1}} + ] + }, + { + "name": "with duplicate udtfs", + "statements": [ + "CREATE STREAM TEST (ID BIGINT, MY_ARR ARRAY) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT EXPLODE(MY_ARR), EXPLODE(MY_ARR) FROM TEST;" + ], + "inputs": [ + {"topic": "test_topic", "key": "0", "value": {"ID": 0, "MY_ARR": [1, 2]}}, + {"topic": "test_topic", "key": "1", "value": {"ID": 1, "MY_ARR": [3, 4]}} + ], + "outputs": [ + {"topic": "OUTPUT", "key": "0", "value": {"KSQL_COL_0": 1, "KSQL_COL_1": 1}}, + {"topic": "OUTPUT", "key": "0", "value": {"KSQL_COL_0": 2, "KSQL_COL_1": 2}}, + {"topic": "OUTPUT", "key": "1", "value": {"KSQL_COL_0": 3, "KSQL_COL_1": 3}}, + {"topic": "OUTPUT", "key": "1", "value": {"KSQL_COL_0": 4, "KSQL_COL_1": 4}} + ] } ] } \ No newline at end of file diff --git a/ksqldb-functional-tests/src/test/resources/query-validation-tests/udf.json b/ksqldb-functional-tests/src/test/resources/query-validation-tests/udf.json new file mode 100644 index 000000000000..9708ff5e3f99 --- /dev/null +++ b/ksqldb-functional-tests/src/test/resources/query-validation-tests/udf.json @@ -0,0 +1,92 @@ +{ + "tests": [ + { + "name": "nested", + "statements": [ + "CREATE STREAM INPUT (text STRING) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT CONCAT(EXTRACTJSONFIELD(text,'$.name'),CONCAT('-',EXTRACTJSONFIELD(text,'$.value'))) from INPUT;" + ], + "inputs": [ + {"topic": "test_topic", "value": {"text": "{\"name\":\"fred\",\"value\":1}"}} + ], + "outputs": [ + {"topic": "OUTPUT", "value": {"KSQL_COL_0": "fred-1"}} + ], + "post": { + "sources": [ + { + "name": "OUTPUT", + "type": "stream", + "schema": "ROWKEY STRING KEY, KSQL_COL_0 STRING" + } + ] + } + }, + { + "name": "null args", + "statements": [ + "CREATE STREAM INPUT (ID BIGINT) WITH (kafka_topic='test_topic', value_format='DELIMITED');", + "CREATE STREAM OUTPUT AS SELECT test_udf(ID, NULL), test_udf(ID, ID, NULL) from INPUT;" + ], + "inputs": [ + {"topic": "test_topic", "value": "0"} + ], + "outputs": [ + {"topic": "OUTPUT", "value": "doStuffLongString,doStuffLongLongString"} + ], + "post": { + "sources": [ + { + "name": "OUTPUT", + "type": "stream", + "schema": "ROWKEY STRING KEY, KSQL_COL_0 STRING, KSQL_COL_1 STRING" + } + ] + } + }, + { + "name": "var args", + "statements": [ + "CREATE STREAM INPUT (col0 BIGINT) WITH (kafka_topic='test_topic', value_format='DELIMITED');", + "CREATE STREAM OUTPUT AS SELECT test_udf(col0, col0, col0, col0, col0) from INPUT;" + ], + "inputs": [ + {"topic": "test_topic", "value": "0"} + ], + "outputs": [ + {"topic": "OUTPUT", "value": "doStuffLongVarargs"} + ], + "post": { + "sources": [ + { + "name": "OUTPUT", + "type": "stream", + "schema": "`ROWKEY` STRING KEY, `KSQL_COL_0` STRING" + } + ] + } + }, + { + "name": "struct args", + "statements": [ + "CREATE STREAM INPUT (col0 STRUCT) WITH (kafka_topic='test_topic', value_format='JSON');", + "CREATE STREAM OUTPUT AS SELECT test_udf(col0) from INPUT;" + ], + "inputs": [ + {"topic": "test_topic", "value": {"col0": {"A": "expect-result"}}} + ], + "outputs": [ + {"topic": "OUTPUT", "value": {"KSQL_COL_0": "expect-result"}} + ], + "post": { + "sources": [ + { + "name": "OUTPUT", + "type": "stream", + "schema": "ROWKEY STRING KEY, KSQL_COL_0 STRING" + } + ] + } + } + ] +} \ No newline at end of file diff --git a/ksqldb-functional-tests/src/test/resources/query-validation-tests/window-bounds.json b/ksqldb-functional-tests/src/test/resources/query-validation-tests/window-bounds.json index 9b9102f02039..4a009d3825f5 100644 --- a/ksqldb-functional-tests/src/test/resources/query-validation-tests/window-bounds.json +++ b/ksqldb-functional-tests/src/test/resources/query-validation-tests/window-bounds.json @@ -126,7 +126,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Window bounds column TEST.WINDOWSTART can only be used in the SELECT clause of windowed aggregations" + "message": "Window bounds column WINDOWSTART can only be used in the SELECT clause of windowed aggregations" } }, { @@ -164,7 +164,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Window bounds column TEST.WINDOWEND can only be used in the SELECT clause of windowed aggregations" + "message": "Window bounds column WINDOWEND can only be used in the SELECT clause of windowed aggregations" } }, { @@ -176,7 +176,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Window bounds column TEST.WINDOWSTART can only be used in the SELECT clause of windowed aggregations" + "message": "Window bounds column WINDOWSTART can only be used in the SELECT clause of windowed aggregations" } }, { @@ -188,7 +188,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Window bounds column TEST.WINDOWSTART can only be used in the SELECT clause of windowed aggregations" + "message": "Window bounds column WINDOWSTART can only be used in the SELECT clause of windowed aggregations" } }, { @@ -199,7 +199,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Window bounds column TEST.WINDOWSTART can only be used in the SELECT clause of windowed aggregations" + "message": "Window bounds column WINDOWSTART can only be used in the SELECT clause of windowed aggregations" } }, { @@ -210,7 +210,7 @@ ], "expectedException": { "type": "io.confluent.ksql.util.KsqlStatementException", - "message": "Window bounds column TEST.WINDOWSTART can only be used in the SELECT clause of windowed aggregations" + "message": "Window bounds column WINDOWSTART can only be used in the SELECT clause of windowed aggregations" } } ] diff --git a/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/pull-queries-against-materialized-aggregates.json b/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/pull-queries-against-materialized-aggregates.json index d6af70e88c5f..1c18a728e703 100644 --- a/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/pull-queries-against-materialized-aggregates.json +++ b/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/pull-queries-against-materialized-aggregates.json @@ -610,26 +610,6 @@ ]} ] }, - { - "name": "projection with ROWTIME and star", - "statements": [ - "CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');", - "CREATE TABLE AGGREGATE AS SELECT COUNT(1) AS COUNT FROM INPUT GROUP BY ROWKEY;", - "SELECT *, ROWTIME FROM AGGREGATE WHERE ROWKEY='10';" - ], - "inputs": [ - {"topic": "test_topic", "timestamp": 99, "key": "11", "value": {"val": 1}}, - {"topic": "test_topic", "timestamp": 100, "key": "10", "value": {"val": 2}} - ], - "responses": [ - {"admin": {"@type": "currentStatus"}}, - {"admin": {"@type": "currentStatus"}}, - {"query": [ - {"header":{"schema":"`ROWKEY` STRING KEY, `COUNT` BIGINT, `ROWTIME` BIGINT"}}, - {"row":{"columns":["10", 1, 100]}} - ]} - ] - }, { "name": "non-windowed projection with ROWMEY and more columns in aggregate", "statements": [ diff --git a/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/push-queries.json b/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/push-queries.json index 78197583384a..b0b54490933d 100644 --- a/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/push-queries.json +++ b/ksqldb-functional-tests/src/test/resources/rest-query-validation-tests/push-queries.json @@ -348,7 +348,7 @@ ], "expectedError": { "type": "io.confluent.ksql.rest.entity.KsqlErrorMessage", - "message": "Window bounds column TEST.WINDOWSTART can only be used in the SELECT clause of windowed aggregations", + "message": "Window bounds column WINDOWSTART can only be used in the SELECT clause of windowed aggregations", "status": 400 } }, @@ -360,7 +360,7 @@ ], "expectedError": { "type": "io.confluent.ksql.rest.entity.KsqlErrorMessage", - "message": "Window bounds column TEST.WINDOWSTART can only be used in the SELECT clause of windowed aggregations", + "message": "Window bounds column WINDOWSTART can only be used in the SELECT clause of windowed aggregations", "status": 400 } } diff --git a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/PullQueryExecutor.java b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/PullQueryExecutor.java index 88ff4a9f05bc..aedf7fa9a06d 100644 --- a/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/PullQueryExecutor.java +++ b/ksqldb-rest-app/src/main/java/io/confluent/ksql/rest/server/execution/PullQueryExecutor.java @@ -15,7 +15,6 @@ package io.confluent.ksql.rest.server.execution; -import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.BoundType; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -67,7 +66,7 @@ import io.confluent.ksql.parser.tree.AllColumns; import io.confluent.ksql.parser.tree.Query; import io.confluent.ksql.parser.tree.Select; -import io.confluent.ksql.parser.tree.SelectItem; +import io.confluent.ksql.parser.tree.SingleColumn; import io.confluent.ksql.query.QueryId; import io.confluent.ksql.rest.Errors; import io.confluent.ksql.rest.SessionProperties; @@ -246,7 +245,7 @@ private TableRowsEntity handlePullQuery( "Unable to execute pull query: %s", statement.getStatementText())); } - private TableRowsEntity routeQuery( + private static TableRowsEntity routeQuery( final KsqlNode node, final ConfiguredStatement statement, final KsqlExecutionContext executionContext, @@ -267,8 +266,7 @@ private TableRowsEntity routeQuery( } } - @VisibleForTesting - static TableRowsEntity queryRowsLocally( + private static TableRowsEntity queryRowsLocally( final ConfiguredStatement statement, final KsqlExecutionContext executionContext, final PullQueryContext pullQueryContext @@ -297,14 +295,22 @@ static TableRowsEntity queryRowsLocally( result.schema, pullQueryContext.mat.windowType().isPresent()); rows = TableRowsEntityFactory.createRows(result.rows); } else { + final List projection = pullQueryContext.analysis.getSelectItems().stream() + .map(SingleColumn.class::cast) + .map(si -> SelectExpression + .of(si.getAlias().orElseThrow(IllegalStateException::new), si.getExpression())) + .collect(Collectors.toList()); + outputSchema = selectOutputSchema( - result, executionContext, pullQueryContext.analysis, pullQueryContext.mat.windowType()); + result, executionContext, projection, pullQueryContext.mat.windowType()); + rows = handleSelects( result, statement, executionContext, pullQueryContext.analysis, outputSchema, + projection, pullQueryContext.mat.windowType(), pullQueryContext.queryId, pullQueryContext.contextStacker @@ -757,8 +763,15 @@ private static ComparisonTarget extractWhereClauseTarget( } private static boolean isSelectStar(final Select select) { - final List selects = select.getSelectItems(); - return selects.size() == 1 && selects.get(0) instanceof AllColumns; + final boolean someStars = select.getSelectItems().stream() + .anyMatch(s -> s instanceof AllColumns); + + if (someStars && select.getSelectItems().size() != 1) { + throw new KsqlException("Pull queries only support wildcards in the projects " + + "if they are the only expression"); + } + + return someStars; } private static List> handleSelects( @@ -767,14 +780,15 @@ private static List> handleSelects( final KsqlExecutionContext executionContext, final ImmutableAnalysis analysis, final LogicalSchema outputSchema, + final List projection, final Optional windowType, final QueryId queryId, final Stacker contextStacker ) { - final boolean noSystemColumns = analysis.getSelectColumnRefs().stream() + final boolean noSystemColumns = analysis.getSelectColumnNames().stream() .noneMatch(SchemaUtil::isSystemColumn); - final boolean noKeyColumns = analysis.getSelectColumnRefs().stream() + final boolean noKeyColumns = analysis.getSelectColumnNames().stream() .noneMatch(input.schema::isKeyColumn); final LogicalSchema intermediateSchema; @@ -819,7 +833,7 @@ private static List> handleSelects( .cloneWithPropertyOverwrite(statement.getConfigOverrides()); final SelectValueMapper select = SelectValueMapperFactory.create( - analysis.getSelectExpressions(), + projection, intermediateSchema, ksqlConfig, executionContext.getMetaStore() @@ -869,7 +883,7 @@ private static void validateProjection( private static LogicalSchema selectOutputSchema( final Result input, final KsqlExecutionContext executionContext, - final ImmutableAnalysis analysis, + final List selectExpressions, final Optional windowType ) { final Builder schemaBuilder = LogicalSchema.builder(); @@ -881,8 +895,7 @@ private static LogicalSchema selectOutputSchema( final ExpressionTypeManager expressionTypeManager = new ExpressionTypeManager(schema, executionContext.getMetaStore()); - for (int idx = 0; idx < analysis.getSelectExpressions().size(); idx++) { - final SelectExpression select = analysis.getSelectExpressions().get(idx); + for (final SelectExpression select : selectExpressions) { final SqlType type = expressionTypeManager.getExpressionSqlType(select.getExpression()); if (input.schema.isKeyColumn(select.getAlias())