Skip to content

Commit

Permalink
feat: move filters to plan builder (#3346)
Browse files Browse the repository at this point in the history
This patch moves SqlPredicate to ksql-execution and moves the
code that filters kstreams into execution step builders. There
is also a change to pull out rewriting rowtime filters from
SqlPredicate into schemakstream, since the plan should just
have the rewritten filters.
  • Loading branch information
rodesai authored Sep 17, 2019
1 parent 06aa252 commit d4d52f3
Show file tree
Hide file tree
Showing 19 changed files with 605 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
import io.confluent.ksql.execution.context.QueryLoggerUtil;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.sqlpredicate.SqlPredicate;
import io.confluent.ksql.execution.streams.SelectValueMapperFactory;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.logging.processing.ProcessingLogContext;
import io.confluent.ksql.logging.processing.ProcessingLogger;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.structured.SqlPredicate;
import io.confluent.ksql.util.KsqlConfig;
import java.util.List;
import java.util.function.Function;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ public SchemaKStream<?> buildStream(final KsqlQueryBuilder builder) {
aggregated = aggregated.filter(
havingExpression.get(),
contextStacker.push(FILTER_OP_NAME),
builder.getProcessingLogContext());
builder
);
}

final List<SelectExpression> finalSelects = internalSchema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public SchemaKStream<?> buildStream(final KsqlQueryBuilder builder) {
.filter(
getPredicate(),
builder.buildNodeContext(getId().toString()),
builder.getProcessingLogContext()
builder
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,21 @@
import io.confluent.ksql.execution.plan.JoinType;
import io.confluent.ksql.execution.plan.LogicalSchemaWithMetaAndKeyFields;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.plan.StreamFilter;
import io.confluent.ksql.execution.plan.StreamMapValues;
import io.confluent.ksql.execution.plan.StreamSource;
import io.confluent.ksql.execution.plan.StreamToTable;
import io.confluent.ksql.execution.streams.ExecutionStepFactory;
import io.confluent.ksql.execution.streams.StreamFilterBuilder;
import io.confluent.ksql.execution.streams.StreamMapValuesBuilder;
import io.confluent.ksql.execution.streams.StreamSourceBuilder;
import io.confluent.ksql.execution.streams.StreamToTableBuilder;
import io.confluent.ksql.execution.streams.StreamsUtil;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.logging.processing.ProcessingLogContext;
import io.confluent.ksql.metastore.model.DataSource;
import io.confluent.ksql.metastore.model.KeyField;
import io.confluent.ksql.metastore.model.KeyField.LegacyField;
import io.confluent.ksql.parser.rewrite.StatementRewriteForRowtime;
import io.confluent.ksql.schema.ksql.Column;
import io.confluent.ksql.schema.ksql.FormatOptions;
import io.confluent.ksql.schema.ksql.LogicalSchema;
Expand Down Expand Up @@ -312,27 +314,15 @@ public SchemaKStream<K> into(
public SchemaKStream<K> filter(
final Expression filterExpression,
final QueryContext.Stacker contextStacker,
final ProcessingLogContext processingLogContext
final KsqlQueryBuilder queryBuilder
) {
final SqlPredicate predicate = new SqlPredicate(
filterExpression,
getSchema(),
ksqlConfig,
functionRegistry,
processingLogContext.getLoggerFactory().getLogger(
QueryLoggerUtil.queryLoggerName(
contextStacker.push(Type.FILTER.name()).getQueryContext())
)
);

final KStream<K, GenericRow> filteredKStream = kstream.filter(predicate.getPredicate());
final ExecutionStep<KStream<K, GenericRow>> step = ExecutionStepFactory.streamFilter(
final StreamFilter<KStream<K, GenericRow>> step = ExecutionStepFactory.streamFilter(
contextStacker,
sourceStep,
filterExpression
rewriteTimeComparisonForFilter(filterExpression)
);
return new SchemaKStream<>(
filteredKStream,
StreamFilterBuilder.build(kstream, step, queryBuilder),
step,
keyFormat,
keySerde,
Expand All @@ -344,6 +334,13 @@ public SchemaKStream<K> filter(
);
}

Expression rewriteTimeComparisonForFilter(final Expression expression) {
if (StatementRewriteForRowtime.requiresRewrite(expression)) {
return new StatementRewriteForRowtime(expression).rewriteForRowtime();
}
return expression;
}

public SchemaKStream<K> select(
final List<SelectExpression> selectExpressions,
final QueryContext.Stacker contextStacker,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@
import io.confluent.ksql.GenericRow;
import io.confluent.ksql.execution.builder.KsqlQueryBuilder;
import io.confluent.ksql.execution.context.QueryContext;
import io.confluent.ksql.execution.context.QueryLoggerUtil;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.plan.ExecutionStep;
import io.confluent.ksql.execution.plan.Formats;
import io.confluent.ksql.execution.plan.JoinType;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.plan.TableFilter;
import io.confluent.ksql.execution.plan.TableMapValues;
import io.confluent.ksql.execution.streams.ExecutionStepFactory;
import io.confluent.ksql.execution.streams.StreamsUtil;
import io.confluent.ksql.execution.streams.TableFilterBuilder;
import io.confluent.ksql.execution.streams.TableMapValuesBuilder;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.logging.processing.ProcessingLogContext;
import io.confluent.ksql.metastore.model.KeyField;
import io.confluent.ksql.metastore.model.KeyField.LegacyField;
import io.confluent.ksql.schema.ksql.Column;
Expand Down Expand Up @@ -168,26 +168,15 @@ public SchemaKTable<K> into(
public SchemaKTable<K> filter(
final Expression filterExpression,
final QueryContext.Stacker contextStacker,
final ProcessingLogContext processingLogContext
final KsqlQueryBuilder queryBuilder
) {
final SqlPredicate predicate = new SqlPredicate(
filterExpression,
getSchema(),
ksqlConfig,
functionRegistry,
processingLogContext.getLoggerFactory().getLogger(
QueryLoggerUtil.queryLoggerName(
contextStacker.push(Type.FILTER.name()).getQueryContext()))
);

final KTable filteredKTable = ktable.filter(predicate.getPredicate());
final ExecutionStep<KTable<K, GenericRow>> step = ExecutionStepFactory.tableFilter(
final TableFilter<KTable<K, GenericRow>> step = ExecutionStepFactory.tableFilter(
contextStacker,
sourceTableStep,
filterExpression
rewriteTimeComparisonForFilter(filterExpression)
);
return new SchemaKTable<>(
filteredKTable,
TableFilterBuilder.build(ktable, step, queryBuilder),
step,
keyFormat,
keySerde,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.confluent.ksql.execution.context.QueryContext.Stacker;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.sqlpredicate.SqlPredicate;
import io.confluent.ksql.function.FunctionRegistry;
import io.confluent.ksql.logging.processing.ProcessingLogContext;
import io.confluent.ksql.logging.processing.ProcessingLogger;
Expand All @@ -41,7 +42,6 @@
import io.confluent.ksql.query.QueryId;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import io.confluent.ksql.structured.SqlPredicate;
import io.confluent.ksql.util.KsqlConfig;
import java.util.List;
import java.util.Optional;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ public class FilterNodeTest {
@Mock
private SchemaKStream schemaKStream;
@Mock
private ProcessingLogContext processingLogContext;
@Mock
private KsqlQueryBuilder ksqlStreamBuilder;
@Mock
private Stacker stacker;
Expand All @@ -62,10 +60,8 @@ public void setup() {
when(schemaKStream.filter(any(), any(), any()))
.thenReturn(schemaKStream);

when(ksqlStreamBuilder.getProcessingLogContext()).thenReturn(processingLogContext);
when(ksqlStreamBuilder.buildNodeContext(nodeId.toString())).thenReturn(stacker);


node = new FilterNode(nodeId, sourceNode, predicate);
}

Expand All @@ -76,6 +72,6 @@ public void shouldApplyFilterCorrectly() {

// Then:
verify(sourceNode).buildStream(ksqlStreamBuilder);
verify(schemaKStream).filter(predicate, stacker, processingLogContext);
verify(schemaKStream).filter(predicate, stacker, ksqlStreamBuilder);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ public class ProjectNodeTest {
private KsqlQueryBuilder ksqlStreamBuilder;
@Mock
private Stacker stacker;
@Mock
private ProcessingLogContext processingLogContext;

private ProjectNode projectNode;

Expand All @@ -79,7 +77,6 @@ public void init() {
when(source.getKeyField()).thenReturn(SOURCE_KEY_FIELD);
when(source.buildStream(any())).thenReturn((SchemaKStream) stream);
when(source.getNodeOutputType()).thenReturn(DataSourceType.KSTREAM);
when(ksqlStreamBuilder.getProcessingLogContext()).thenReturn(processingLogContext);
when(ksqlStreamBuilder.buildNodeContext(NODE_ID.toString())).thenReturn(stacker);
when(stream.select(anyList(), any(), any())).thenReturn((SchemaKStream) stream);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@
import io.confluent.ksql.execution.builder.KsqlQueryBuilder;
import io.confluent.ksql.execution.context.QueryContext;
import io.confluent.ksql.execution.ddl.commands.KsqlTopic;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.FunctionCall;
import io.confluent.ksql.execution.expression.tree.LongLiteral;
import io.confluent.ksql.execution.expression.tree.QualifiedName;
import io.confluent.ksql.execution.expression.tree.QualifiedNameReference;
import io.confluent.ksql.execution.plan.DefaultExecutionStepProperties;
Expand All @@ -51,6 +53,7 @@
import io.confluent.ksql.execution.plan.Formats;
import io.confluent.ksql.execution.plan.JoinType;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.plan.StreamFilter;
import io.confluent.ksql.execution.streams.ExecutionStepFactory;
import io.confluent.ksql.function.InternalFunctionRegistry;
import io.confluent.ksql.logging.processing.ProcessingLogContext;
Expand Down Expand Up @@ -457,7 +460,7 @@ public void testSelectWithExpression() {
}

@Test
public void testFilter() {
public void shouldReturnSchemaKStreamWithCorrectSchemaForFilter() {
// Given:
final PlanNode logicalPlan = givenInitialKStreamOf(
"SELECT col0, col2, col3 FROM test1 WHERE col0 > 100 EMIT CHANGES;");
Expand All @@ -467,7 +470,7 @@ public void testFilter() {
final SchemaKStream filteredSchemaKStream = initialSchemaKStream.filter(
filterNode.getPredicate(),
childContextStacker,
processingLogContext);
queryBuilder);

// Then:
assertThat(filteredSchemaKStream.getSchema().value(), contains(
Expand All @@ -484,6 +487,35 @@ public void testFilter() {
assertThat(filteredSchemaKStream.getSourceSchemaKStreams().get(0), is(initialSchemaKStream));
}

@Test
public void shouldRewriteTimeComparisonInFilter() {
// Given:
final PlanNode logicalPlan = givenInitialKStreamOf(
"SELECT col0, col2, col3 FROM test1 "
+ "WHERE ROWTIME = '1984-01-01T00:00:00+00:00' EMIT CHANGES;");
final FilterNode filterNode = (FilterNode) logicalPlan.getSources().get(0).getSources().get(0);

// When:
final SchemaKStream filteredSchemaKStream = initialSchemaKStream.filter(
filterNode.getPredicate(),
childContextStacker,
queryBuilder);

// Then:
final StreamFilter step = (StreamFilter) filteredSchemaKStream.getSourceStep();
assertThat(
step.getFilterExpression(),
equalTo(
new ComparisonExpression(
ComparisonExpression.Type.EQUAL,
new DereferenceExpression(
new QualifiedNameReference(QualifiedName.of("TEST1")), "ROWTIME"),
new LongLiteral(441763200000L)
)
)
);
}

@Test
public void shouldBuildStepForFilter() {
// Given:
Expand All @@ -495,7 +527,7 @@ public void shouldBuildStepForFilter() {
final SchemaKStream filteredSchemaKStream = initialSchemaKStream.filter(
filterNode.getPredicate(),
childContextStacker,
processingLogContext);
queryBuilder);

// Then:
assertThat(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,18 @@
import io.confluent.ksql.execution.builder.KsqlQueryBuilder;
import io.confluent.ksql.execution.context.QueryContext;
import io.confluent.ksql.execution.ddl.commands.KsqlTopic;
import io.confluent.ksql.execution.expression.tree.ComparisonExpression;
import io.confluent.ksql.execution.expression.tree.DereferenceExpression;
import io.confluent.ksql.execution.expression.tree.Expression;
import io.confluent.ksql.execution.expression.tree.LongLiteral;
import io.confluent.ksql.execution.expression.tree.QualifiedName;
import io.confluent.ksql.execution.expression.tree.QualifiedNameReference;
import io.confluent.ksql.execution.plan.DefaultExecutionStepProperties;
import io.confluent.ksql.execution.plan.ExecutionStep;
import io.confluent.ksql.execution.plan.Formats;
import io.confluent.ksql.execution.plan.JoinType;
import io.confluent.ksql.execution.plan.SelectExpression;
import io.confluent.ksql.execution.plan.TableFilter;
import io.confluent.ksql.execution.streams.ExecutionStepFactory;
import io.confluent.ksql.function.InternalFunctionRegistry;
import io.confluent.ksql.logging.processing.ProcessingLogContext;
Expand Down Expand Up @@ -341,7 +344,7 @@ public void testSelectWithExpression() {
}

@Test
public void testFilter() {
public void shouldBuildSchemaKTableWithCorrectSchemaForFilter() {
// Given:
final String selectQuery = "SELECT col0, col2, col3 FROM test2 WHERE col0 > 100 EMIT CHANGES;";
final PlanNode logicalPlan = buildLogicalPlan(selectQuery);
Expand All @@ -352,7 +355,7 @@ public void testFilter() {
final SchemaKTable filteredSchemaKStream = initialSchemaKTable.filter(
filterNode.getPredicate(),
childContextStacker,
processingLogContext
queryBuilder
);

// Then:
Expand All @@ -369,6 +372,37 @@ public void testFilter() {
assertThat(filteredSchemaKStream.getSourceSchemaKStreams().get(0), is(initialSchemaKTable));
}

@Test
public void shouldRewriteTimeComparisonInFilter() {
// Given:
final String selectQuery = "SELECT col0, col2, col3 FROM test2 "
+ "WHERE ROWTIME = '1984-01-01T00:00:00+00:00' EMIT CHANGES;";
final PlanNode logicalPlan = buildLogicalPlan(selectQuery);
final FilterNode filterNode = (FilterNode) logicalPlan.getSources().get(0).getSources().get(0);
initialSchemaKTable = buildSchemaKTableFromPlan(logicalPlan);

// When:
final SchemaKTable filteredSchemaKTable = initialSchemaKTable.filter(
filterNode.getPredicate(),
childContextStacker,
queryBuilder
);

// Then:
final TableFilter step = (TableFilter) filteredSchemaKTable.getSourceTableStep();
assertThat(
step.getFilterExpression(),
Matchers.equalTo(
new ComparisonExpression(
ComparisonExpression.Type.EQUAL,
new DereferenceExpression(
new QualifiedNameReference(QualifiedName.of("TEST2")), "ROWTIME"),
new LongLiteral(441763200000L)
)
)
);
}

@Test
public void shouldBuildStepForFilter() {
// Given:
Expand All @@ -381,7 +415,7 @@ public void shouldBuildStepForFilter() {
final SchemaKTable filteredSchemaKStream = initialSchemaKTable.filter(
filterNode.getPredicate(),
childContextStacker,
processingLogContext
queryBuilder
);

// Then:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ public List<ExecutionStep<?>> getSources() {
return Collections.singletonList(source);
}

public Expression getFilterExpression() {
return filterExpression;
}

public ExecutionStep<S> getSource() {
return source;
}

@Override
public S build(final KsqlQueryBuilder streamsBuilder) {
throw new UnsupportedOperationException();
Expand Down
Loading

0 comments on commit d4d52f3

Please sign in to comment.