Skip to content

Commit

Permalink
fix: ensure only deserializable cmds are written to command topic (#5645
Browse files Browse the repository at this point in the history
)

* fix: ensure only deserializable cmds are written to command topic

fixes: #5643

Ensures all commands written to the command topic can be deserialized before writing them.

A non-deserializable command causes the command runner thread to die.
Even restarting the server won't help as the server will stop when it hits the non-deserializable command again.

This adds some level of protection.


Co-authored-by: Andy Coates <big-andy-coates@users.noreply.github.com>
  • Loading branch information
2 people authored and agavra committed Jun 19, 2020
1 parent a8e6630 commit 4ad2bde
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,20 @@

package io.confluent.ksql.rest.server.computation;

import com.fasterxml.jackson.core.JsonProcessingException;
import io.confluent.ksql.KsqlExecutionContext;
import io.confluent.ksql.engine.KsqlPlan;
import io.confluent.ksql.execution.json.PlanJsonMapper;
import io.confluent.ksql.parser.tree.Statement;
import io.confluent.ksql.parser.tree.TerminateQuery;
import io.confluent.ksql.planner.plan.ConfiguredKsqlPlan;
import io.confluent.ksql.query.QueryId;
import io.confluent.ksql.rest.util.TerminateCluster;
import io.confluent.ksql.services.ServiceContext;
import io.confluent.ksql.statement.ConfiguredStatement;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.KsqlServerException;
import io.confluent.ksql.util.KsqlStatementException;
import io.confluent.ksql.util.PersistentQueryMetadata;
import java.util.Objects;
import java.util.Optional;

/**
Expand All @@ -36,11 +37,6 @@
* command queue.
*/
public final class ValidatedCommandFactory {
private final KsqlConfig config;

public ValidatedCommandFactory(final KsqlConfig config) {
this.config = Objects.requireNonNull(config, "config");
}

/**
* Create a validated command.
Expand All @@ -61,19 +57,56 @@ public Command create(
* @param context The KSQL engine snapshot to validate the command against.
* @return A validated command, which is safe to enqueue onto the command topic.
*/
@SuppressWarnings("MethodMayBeStatic") // Not static to allow dependency injection
public Command create(
final ConfiguredStatement<? extends Statement> statement,
final ServiceContext serviceContext,
final KsqlExecutionContext context) {
final KsqlExecutionContext context
) {
return ensureDeserializable(createCommand(statement, serviceContext, context));
}

/**
* Ensure any command written to the command topic can be deserialized.
*
* <p>Any command that can't be deserialized is a bug. However, given a non-deserializable
* command will kill the command runner thread, this is a safety net to ensure commands written to
* the command topic can be deserialzied.
*
* @param command the command to test.
* @return the passed in command.
*/
private static Command ensureDeserializable(final Command command) {
try {
final String json = PlanJsonMapper.INSTANCE.get().writeValueAsString(command);
PlanJsonMapper.INSTANCE.get().readValue(json, Command.class);
return command;
} catch (final JsonProcessingException e) {
throw new KsqlServerException("Did not write the command to the command topic "
+ "as it could not be deserialized. This is a bug! Please raise a Github issue "
+ "containing the series of commands you ran to get to this point."
+ System.lineSeparator()
+ e.getMessage());
}
}

private static Command createCommand(
final ConfiguredStatement<? extends Statement> statement,
final ServiceContext serviceContext,
final KsqlExecutionContext context
) {
if (statement.getStatementText().equals(TerminateCluster.TERMINATE_CLUSTER_STATEMENT_TEXT)) {
return Command.of(statement);
} else if (statement.getStatement() instanceof TerminateQuery) {
}

if (statement.getStatement() instanceof TerminateQuery) {
return createForTerminateQuery(statement, context);
}

return createForPlannedQuery(statement, serviceContext, context);
}

private Command createForTerminateQuery(
private static Command createForTerminateQuery(
final ConfiguredStatement<? extends Statement> statement,
final KsqlExecutionContext context
) {
Expand All @@ -93,7 +126,7 @@ private Command createForTerminateQuery(
return Command.of(statement);
}

private Command createForPlannedQuery(
private static Command createForPlannedQuery(
final ConfiguredStatement<? extends Statement> statement,
final ServiceContext serviceContext,
final KsqlExecutionContext context
Expand All @@ -107,6 +140,7 @@ private Command createForPlannedQuery(
statement.getConfig()
)
);

return Command.of(
ConfiguredKsqlPlan.of(plan, statement.getConfigOverrides(), statement.getConfig()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ public void configure(final KsqlConfig config) {
injectorFactory,
ksqlEngine::createSandbox,
config,
new ValidatedCommandFactory(config)
new ValidatedCommandFactory()
);

this.handler = new RequestHandler(
Expand All @@ -176,7 +176,7 @@ public void configure(final KsqlConfig config) {
distributedCmdResponseTimeout,
injectorFactory,
authorizationValidator,
new ValidatedCommandFactory(config),
new ValidatedCommandFactory(),
errorHandler
),
ksqlEngine,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
import com.google.common.collect.ImmutableList;
import io.confluent.ksql.KsqlExecutionContext;
import io.confluent.ksql.engine.KsqlPlan;
import io.confluent.ksql.execution.ddl.commands.DdlCommand;
import io.confluent.ksql.execution.ddl.commands.DdlCommandResult;
import io.confluent.ksql.execution.ddl.commands.DropSourceCommand;
import io.confluent.ksql.execution.ddl.commands.Executor;
import io.confluent.ksql.name.SourceName;
import io.confluent.ksql.parser.KsqlParser.PreparedStatement;
import io.confluent.ksql.parser.tree.CreateStream;
import io.confluent.ksql.parser.tree.Statement;
Expand All @@ -35,6 +40,7 @@
import io.confluent.ksql.services.ServiceContext;
import io.confluent.ksql.statement.ConfiguredStatement;
import io.confluent.ksql.util.KsqlConfig;
import io.confluent.ksql.util.KsqlServerException;
import io.confluent.ksql.util.KsqlStatementException;
import io.confluent.ksql.util.PersistentQueryMetadata;
import java.util.Map;
Expand All @@ -47,7 +53,12 @@

@RunWith(MockitoJUnitRunner.class)
public class ValidatedCommandFactoryTest {

private static final QueryId QUERY_ID = new QueryId("FOO");
private static final KsqlPlan A_PLAN = KsqlPlan.ddlPlanCurrent(
"DROP TABLE Bob",
new DropSourceCommand(SourceName.of("BOB"))
);

@Mock
private KsqlExecutionContext executionContext;
Expand All @@ -62,8 +73,6 @@ public class ValidatedCommandFactoryTest {
@Mock
private Map<String, Object> overrides;
@Mock
private KsqlPlan plan;
@Mock
private PersistentQueryMetadata query1;
@Mock
private PersistentQueryMetadata query2;
Expand All @@ -73,7 +82,7 @@ public class ValidatedCommandFactoryTest {

@Before
public void setup() {
commandFactory = new ValidatedCommandFactory(config);
commandFactory = new ValidatedCommandFactory();
}

@Test
Expand Down Expand Up @@ -171,7 +180,7 @@ public void shouldValidatePlannedQuery() {
verify(executionContext).plan(serviceContext, configuredStatement);
verify(executionContext).execute(
serviceContext,
ConfiguredKsqlPlan.of(plan, overrides, config)
ConfiguredKsqlPlan.of(A_PLAN, overrides, config)
);
}

Expand All @@ -184,7 +193,23 @@ public void shouldCreateCommandForPlannedQuery() {
final Command command = commandFactory.create(configuredStatement, executionContext);

// Then:
assertThat(command, is(Command.of(ConfiguredKsqlPlan.of(plan, overrides, config))));
assertThat(command, is(Command.of(ConfiguredKsqlPlan.of(A_PLAN, overrides, config))));
}

@Test
public void shouldThrowIfCommandCanNotBeDeserialized() {
// Given:
givenNonDeserializableCommand();

// When:
final Exception e = assertThrows(
KsqlServerException.class,
() -> commandFactory.create(configuredStatement, executionContext)
);

// Then:
assertThat(e.getMessage(), containsString("Did not write the command to the command topic "
+ "as it could not be deserialized."));
}

private void givenTerminate() {
Expand All @@ -201,8 +226,15 @@ private void givenTerminateAll() {

private void givenPlannedQuery() {
configuredStatement = configuredStatement("CREATE STREAM", plannedQuery);
when(plan.getStatementText()).thenReturn("CREATE STREAM ");
when(executionContext.plan(any(), any())).thenReturn(plan);
when(executionContext.plan(any(), any())).thenReturn(A_PLAN);
when(executionContext.getServiceContext()).thenReturn(serviceContext);
}

private void givenNonDeserializableCommand() {
configuredStatement = configuredStatement("CREATE STREAM", plannedQuery);
final KsqlPlan planThatFailsToDeserialize = KsqlPlan
.ddlPlanCurrent("some sql", new UnDeserializableCommand());
when(executionContext.plan(any(), any())).thenReturn(planThatFailsToDeserialize);
when(executionContext.getServiceContext()).thenReturn(serviceContext);
}

Expand All @@ -216,4 +248,13 @@ private <T extends Statement> ConfiguredStatement<T> configuredStatement(
config
);
}

// Not a known subtype so will fail to deserialize:
private static class UnDeserializableCommand implements DdlCommand {

@Override
public DdlCommandResult execute(final Executor executor) {
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
import io.confluent.ksql.engine.KsqlEngineTestUtil;
import io.confluent.ksql.engine.KsqlPlan;
import io.confluent.ksql.exception.KsqlTopicAuthorizationException;
import io.confluent.ksql.execution.ddl.commands.DdlCommand;
import io.confluent.ksql.execution.ddl.commands.DropSourceCommand;
import io.confluent.ksql.execution.ddl.commands.KsqlTopic;
import io.confluent.ksql.execution.expression.tree.StringLiteral;
import io.confluent.ksql.function.InternalFunctionRegistry;
Expand Down Expand Up @@ -1977,7 +1977,7 @@ private void givenMockEngine() {
when(sandbox.plan(any(), any())).thenAnswer(
i -> KsqlPlan.ddlPlanCurrent(
((ConfiguredStatement<?>) i.getArgument(1)).getStatementText(),
mock(DdlCommand.class)
new DropSourceCommand(SourceName.of("bob"))
)
);
when(ksqlEngine.createSandbox(any())).thenReturn(sandbox);
Expand Down

0 comments on commit 4ad2bde

Please sign in to comment.