Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(static): fail on ROWTIME in projection #3430

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
import io.confluent.ksql.serde.SerdeOption;
import io.confluent.ksql.util.SchemaUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -56,6 +58,7 @@ public class Analysis {
private Optional<JoinInfo> joinInfo = Optional.empty();
private Optional<Expression> whereExpression = Optional.empty();
private final List<SelectExpression> selectExpressions = new ArrayList<>();
private final Set<ColumnRef> selectColumnRefs = new HashSet<>();
private final List<Expression> groupByExpressions = new ArrayList<>();
private Optional<WindowExpression> windowExpression = Optional.empty();
private Optional<ColumnName> partitionBy = Optional.empty();
Expand All @@ -76,6 +79,10 @@ void addSelectItem(final Expression expression, final ColumnName alias) {
selectExpressions.add(SelectExpression.of(alias, expression));
}

void addSelectColumnRefs(final Collection<ColumnRef> columnRefs) {
selectColumnRefs.addAll(columnRefs);
}

public Optional<Into> getInto() {
return into;
}
Expand All @@ -96,6 +103,10 @@ public List<SelectExpression> getSelectExpressions() {
return Collections.unmodifiableList(selectExpressions);
}

Set<ColumnRef> getSelectColumnRefs() {
return Collections.unmodifiableSet(selectColumnRefs);
}

public List<Expression> getGroupByExpressions() {
return ImmutableList.copyOf(groupByExpressions);
}
Expand Down Expand Up @@ -156,7 +167,7 @@ public List<AliasedDataSource> getFromDataSources() {
return ImmutableList.copyOf(fromDataSources);
}

public SourceSchemas getFromSourceSchemas() {
SourceSchemas getFromSourceSchemas() {
final Map<SourceName, LogicalSchema> schemaBySource = fromDataSources.stream()
.collect(Collectors.toMap(
AliasedDataSource::getAlias,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.TraversalExpressionVisitor;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.windows.KsqlWindowExpression;
import io.confluent.ksql.metastore.MetaStore;
Expand Down Expand Up @@ -62,6 +63,7 @@
import io.confluent.ksql.serde.ValueFormat;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
Expand Down Expand Up @@ -511,7 +513,7 @@ protected AstNode visitSelect(final Select node, final Void context) {
visitSelectStar((AllColumns) selectItem);
} else if (selectItem instanceof SingleColumn) {
final SingleColumn column = (SingleColumn) selectItem;
analysis.addSelectItem(column.getExpression(), column.getAlias());
addSelectItem(column.getExpression(), column.getAlias());
} else {
throw new IllegalArgumentException(
"Unsupported SelectItem type: " + selectItem.getClass().getName());
Expand Down Expand Up @@ -562,14 +564,19 @@ private void visitSelectStar(final AllColumns allColumns) {
? source.getAlias().name() + "_"
: "";

for (final Column column : source.getDataSource().getSchema().columns()) {
final LogicalSchema schema = source.getDataSource().getSchema();
for (final Column column : schema.columns()) {

if (staticQuery && schema.isMetaColumn(column.name())) {
continue;
}

final ColumnReferenceExp selectItem = new ColumnReferenceExp(location,
ColumnRef.of(source.getAlias(), column.name()));

final String alias = aliasPrefix + column.name().name();

analysis.addSelectItem(selectItem, ColumnName.of(alias));
addSelectItem(selectItem, ColumnName.of(alias));
}
}
}
Expand Down Expand Up @@ -598,6 +605,25 @@ public void validate() {
+ System.lineSeparator() + KAFKA_VALUE_FORMAT_LIMITATION_DETAILS);
}
}

private void addSelectItem(final Expression exp, final ColumnName columnName) {
final Set<ColumnRef> columnRefs = new HashSet<>();
final TraversalExpressionVisitor<Void> visitor = new TraversalExpressionVisitor<Void>() {
@Override
public Void visitColumnReference(
final ColumnReferenceExp node,
final Void context
) {
columnRefs.add(node.getReference());
return null;
}
};

visitor.process(exp, null);

analysis.addSelectItem(exp, columnName);
analysis.addSelectColumnRefs(columnRefs);
}
}

@FunctionalInterface
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

import com.google.common.collect.ImmutableList;
import io.confluent.ksql.parser.tree.ResultMaterialization;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.List;
import java.util.Objects;
import java.util.function.Predicate;
Expand Down Expand Up @@ -89,6 +91,12 @@ public class StaticQueryValidator implements QueryValidator {
Rule.of(
analysis -> !analysis.getLimitClause().isPresent(),
"Static queries don't support LIMIT clauses."
),
Rule.of(
analysis -> analysis.getSelectColumnRefs().stream()
.map(ColumnRef::name)
.noneMatch(n -> n.equals(SchemaUtil.ROWTIME_NAME)),
"Static queries don't support ROWTIME in the projection."
)
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
package io.confluent.ksql.analyzer;

import static io.confluent.ksql.testutils.AnalysisTestUtil.analyzeQuery;
import static io.confluent.ksql.util.SchemaUtil.ROWTIME_NAME;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
Expand All @@ -35,6 +38,7 @@
import io.confluent.ksql.analyzer.Analyzer.SerdeOptionsSupplier;
import io.confluent.ksql.execution.ddl.commands.KsqlTopic;
import io.confluent.ksql.execution.expression.tree.BooleanLiteral;
import io.confluent.ksql.execution.expression.tree.ColumnReferenceExp;
import io.confluent.ksql.execution.expression.tree.Literal;
import io.confluent.ksql.execution.expression.tree.StringLiteral;
import io.confluent.ksql.execution.plan.SelectExpression;
Expand All @@ -53,6 +57,7 @@
import io.confluent.ksql.parser.tree.Sink;
import io.confluent.ksql.parser.tree.Statement;
import io.confluent.ksql.planner.plan.JoinNode.JoinType;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.serde.Format;
Expand Down Expand Up @@ -90,6 +95,11 @@
public class AnalyzerFunctionalTest {

private static final Set<SerdeOption> DEFAULT_SERDE_OPTIONS = SerdeOption.none();
private static final SourceName TEST1 = SourceName.of("TEST1");
private static final ColumnName COL0 = ColumnName.of("COL0");
private static final ColumnName COL1 = ColumnName.of("COL1");
private static final ColumnName COL2 = ColumnName.of("COL2");
private static final ColumnName COL3 = ColumnName.of("COL3");

private MutableMetaStore jsonMetaStore;
private MutableMetaStore avroMetaStore;
Expand Down Expand Up @@ -136,17 +146,17 @@ public void testSimpleQueryAnalysis() {
final Analysis analysis = analyzeQuery(simpleQuery, jsonMetaStore);
assertEquals("FROM was not analyzed correctly.",
analysis.getFromDataSources().get(0).getDataSource().getName(),
SourceName.of("TEST1"));
TEST1);
assertThat(analysis.getWhereExpression().get().toString(), is("(TEST1.COL0 > 100)"));

final List<SelectExpression> selects = analysis.getSelectExpressions();
assertThat(selects.get(0).getExpression().toString(), is("TEST1.COL0"));
assertThat(selects.get(1).getExpression().toString(), is("TEST1.COL2"));
assertThat(selects.get(2).getExpression().toString(), is("TEST1.COL3"));

assertThat(selects.get(0).getName(), is(ColumnName.of("COL0")));
assertThat(selects.get(1).getName(), is(ColumnName.of("COL2")));
assertThat(selects.get(2).getName(), is(ColumnName.of("COL3")));
assertThat(selects.get(0).getName(), is(COL0));
assertThat(selects.get(1).getName(), is(COL2));
assertThat(selects.get(2).getName(), is(COL3));
}

@Test
Expand Down Expand Up @@ -202,7 +212,7 @@ public void testBooleanExpressionAnalysis() {
final Analysis analysis = analyzeQuery(queryStr, jsonMetaStore);

assertEquals("FROM was not analyzed correctly.",
analysis.getFromDataSources().get(0).getDataSource().getName(), SourceName.of("TEST1"));
analysis.getFromDataSources().get(0).getDataSource().getName(), TEST1);

final List<SelectExpression> selects = analysis.getSelectExpressions();
assertThat(selects.get(0).getExpression().toString(), is("(TEST1.COL0 = 10)"));
Expand All @@ -215,7 +225,7 @@ public void testFilterAnalysis() {
final String queryStr = "SELECT col0 = 10, col2, col3 > col1 FROM test1 WHERE col0 > 20 EMIT CHANGES;";
final Analysis analysis = analyzeQuery(queryStr, jsonMetaStore);

assertThat(analysis.getFromDataSources().get(0).getDataSource().getName(), is(SourceName.of("TEST1")));
assertThat(analysis.getFromDataSources().get(0).getDataSource().getName(), is(TEST1));

final List<SelectExpression> selects = analysis.getSelectExpressions();
assertThat(selects.get(0).getExpression().toString(), is("(TEST1.COL0 = 10)"));
Expand Down Expand Up @@ -450,6 +460,50 @@ public void shouldThrowOnJoinIfKafkaFormat() {
analyzer.analyze(query, Optional.of(sink));
}

@Test
public void shouldCaptureProjectionColumnRefs() {
// Given:
query = parseSingle("Select COL0, COL0 + COL1, SUBSTRING(COL2, 1) from TEST1;");

// When:
final Analysis analysis = analyzer.analyze(query, Optional.empty());

// Then:
assertThat(analysis.getSelectColumnRefs(), containsInAnyOrder(
ColumnRef.of(TEST1, COL0),
ColumnRef.of(TEST1, COL1),
ColumnRef.of(TEST1, COL2)
));
}

@Test
public void shouldIncludeMetaColumnsForSelectStarOnContinuousQueries() {
// Given:
query = parseSingle("Select * from TEST1 EMIT CHANGES;");

// When:
final Analysis analysis = analyzer.analyze(query, Optional.empty());

// Then:
assertThat(analysis.getSelectExpressions(), hasItem(
SelectExpression.of(ROWTIME_NAME, new ColumnReferenceExp(ColumnRef.of(TEST1, ROWTIME_NAME)))
));
}

@Test
public void shouldNotIncludeMetaColumnsForSelectStartOnStaticQueries() {
// Given:
query = parseSingle("Select * from TEST1;");

// When:
final Analysis analysis = analyzer.analyze(query, Optional.empty());

// Then:
assertThat(analysis.getSelectExpressions(), not(hasItem(
SelectExpression.of(ROWTIME_NAME, new ColumnReferenceExp(ColumnRef.of(TEST1, ROWTIME_NAME)))
)));
}

@SuppressWarnings("unchecked")
private <T extends Statement> T parseSingle(final String simpleQuery) {
return (T) Iterables.getOnlyElement(parse(simpleQuery, jsonMetaStore));
Expand Down Expand Up @@ -478,7 +532,7 @@ private void buildProps() {

private void registerKafkaSource() {
final LogicalSchema schema = LogicalSchema.builder()
.valueColumn(ColumnName.of("COL0"), SqlTypes.BIGINT)
.valueColumn(COL0, SqlTypes.BIGINT)
.build();

final KsqlTopic topic = new KsqlTopic(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.confluent.ksql.analyzer.Analysis.Into;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.parser.tree.ResultMaterialization;
import io.confluent.ksql.parser.tree.WindowExpression;
import io.confluent.ksql.schema.ksql.ColumnRef;
import io.confluent.ksql.util.KsqlException;
import io.confluent.ksql.util.SchemaUtil;
import java.util.Optional;
import java.util.OptionalInt;
import org.junit.Before;
Expand Down Expand Up @@ -109,7 +112,7 @@ public void shouldThrowOnStaticQueryThatIsWindowed() {
}

@Test
public void shouldThrowOnStaticQueryThatHasGroupBy() {
public void shouldThrowOnGroupBy() {
// Given:
when(analysis.getGroupByExpressions()).thenReturn(ImmutableList.of(AN_EXPRESSION));

Expand All @@ -122,7 +125,7 @@ public void shouldThrowOnStaticQueryThatHasGroupBy() {
}

@Test
public void shouldThrowOnStaticQueryThatHasPartitionBy() {
public void shouldThrowOnPartitionBy() {
// Given:
when(analysis.getPartitionBy()).thenReturn(Optional.of(ColumnName.of("Something")));

Expand All @@ -135,7 +138,7 @@ public void shouldThrowOnStaticQueryThatHasPartitionBy() {
}

@Test
public void shouldThrowOnStaticQueryThatHasHavingClause() {
public void shouldThrowOnHavingClause() {
// Given:
when(analysis.getHavingExpression()).thenReturn(Optional.of(AN_EXPRESSION));

Expand All @@ -148,7 +151,7 @@ public void shouldThrowOnStaticQueryThatHasHavingClause() {
}

@Test
public void shouldThrowOnStaticQueryThatHasLimitClause() {
public void shouldThrowOnLimitClause() {
// Given:
when(analysis.getLimitClause()).thenReturn(OptionalInt.of(1));

Expand All @@ -159,4 +162,18 @@ public void shouldThrowOnStaticQueryThatHasLimitClause() {
// When:
validator.validate(analysis);
}

@Test
public void shouldThrowOnRowTimeInProjection() {
// Given:
when(analysis.getSelectColumnRefs())
.thenReturn(ImmutableSet.of(ColumnRef.of(SchemaUtil.ROWTIME_NAME)));

// Then:
expectedException.expect(KsqlException.class);
expectedException.expectMessage("Static queries don't support ROWTIME in the projection.");

// When:
validator.validate(analysis);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,32 @@
}
]
},
{
"name": "non-windowed projection WITH ROWTIME",
"statements": [
"CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');",
"CREATE TABLE AGGREGATE AS SELECT COUNT(1) AS COUNT FROM INPUT GROUP BY ROWKEY;",
"SELECT ROWTIME + 10, COUNT FROM AGGREGATE WHERE ROWKEY='10';"
],
"expectedError": {
"type": "io.confluent.ksql.rest.entity.KsqlStatementErrorMessage",
"message": "Static queries don't support ROWTIME in the projection.",
"status": 400
}
},
{
"name": "windowed with projection with ROWTIME",
"statements": [
"CREATE STREAM INPUT (IGNORED INT) WITH (kafka_topic='test_topic', value_format='JSON');",
"CREATE TABLE AGGREGATE AS SELECT COUNT(1) AS COUNT FROM INPUT WINDOW TUMBLING(SIZE 1 SECOND) GROUP BY ROWKEY;",
"SELECT COUNT, ROWTIME + 10 FROM AGGREGATE WHERE ROWKEY='10' AND WindowStart=12000;"
],
"expectedError": {
"type": "io.confluent.ksql.rest.entity.KsqlStatementErrorMessage",
"message": "Static queries don't support ROWTIME in the projection.",
"status": 400
}
},
{
"name": "text datetime window bounds",
"enabled": false,
Expand Down