Skip to content

Commit

Permalink
fix: support NULL return values from CASE statements (#3531)
Browse files Browse the repository at this point in the history
* fix(3405): support NULL return values from CASE statements

This commit enhances the processing of `CASE` statements so that both `THEN` and the default return types can be `NULL`.

At least one branch must be non-null so that KSQL can determine the result type of the statement.

All non-null branches must have the same result schema, i.e. we don't (yet) do any implicit casting of numeric types.

The commit also improves some error messages by using KSQL types rather than connect schema types so that error messages use, for example, `BIGINT` rather than `Schema{INT64}`.
  • Loading branch information
big-andy-coates authored Oct 15, 2019
1 parent 8ef82eb commit eb9e41b
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ public String getSymbol() {
*/
public SqlType resultType(final SqlType left, final SqlType right) {
if (left.baseType().isNumber() && right.baseType().isNumber()) {
if (left.baseType().canUpCast(right.baseType())) {
if (left.baseType().canImplicitlyCast(right.baseType())) {
if (right.baseType() != SqlBaseType.DECIMAL) {
return right;
}

return binaryResolver.apply(toDecimal(left), (SqlDecimal) right);
}

if (right.baseType().canUpCast(left.baseType())) {
if (right.baseType().canImplicitlyCast(left.baseType())) {
if (left.baseType() != SqlBaseType.DECIMAL) {
return left;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,15 @@ public boolean isNumber() {
}

/**
* Test to see if this type can be up-cast to another.
* Test to see if this type can be <i>implicitly</i> cast to another.
*
* <p>This defines if KSQL supports <i>implicitly</i> converting one numeric type to another.
*
* <p>Types can always be upcast to themselves. Only numeric types can be upcast to different
* numeric types. Note: STRING to DECIMAL handling is not seen as up-casting, it's parsing.
* <p>Types can always be cast to themselves. Only numeric types can be implicitly cast to other
* numeric types. Note: STRING to DECIMAL handling is not seen as casting: it's parsing.
*
* @param to the target type.
* @return true if this type can be upcast to the supplied type.
* @return true if this type can be implicitly cast to the supplied type.
*/
public boolean canUpCast(final SqlBaseType to) {
public boolean canImplicitlyCast(final SqlBaseType to) {
return this.equals(to)
|| (isNumber() && to.isNumber() && this.ordinal() <= to.ordinal());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public void shouldBeNumber() {
public void shouldNotUpCastIfNotNumber() {
nonNumberTypes().forEach(sqlType -> assertThat(
sqlType + " should not upcast",
sqlType.canUpCast(SqlBaseType.DOUBLE),
sqlType.canImplicitlyCast(SqlBaseType.DOUBLE),
is(false))
);
}
Expand All @@ -65,51 +65,51 @@ public void shouldNotUpCastIfNotNumber() {
public void shouldUpCastIfNumber() {
numberTypes().forEach(sqlType -> assertThat(
sqlType + " should upcast",
sqlType.canUpCast(SqlBaseType.DOUBLE),
sqlType.canImplicitlyCast(SqlBaseType.DOUBLE),
is(true))
);
}

@Test
public void shouldUpCastToSelf() {
allTypes().forEach(sqlType ->
assertThat(sqlType + " should upcast to self", sqlType.canUpCast(sqlType), is(true)));
assertThat(sqlType + " should upcast to self", sqlType.canImplicitlyCast(sqlType), is(true)));
}

@Test
public void shouldUpCastInt() {
assertThat(SqlBaseType.INTEGER.canUpCast(SqlBaseType.BIGINT), is(true));
assertThat(SqlBaseType.INTEGER.canUpCast(SqlBaseType.DECIMAL), is(true));
assertThat(SqlBaseType.INTEGER.canUpCast(SqlBaseType.DOUBLE), is(true));
assertThat(SqlBaseType.INTEGER.canImplicitlyCast(SqlBaseType.BIGINT), is(true));
assertThat(SqlBaseType.INTEGER.canImplicitlyCast(SqlBaseType.DECIMAL), is(true));
assertThat(SqlBaseType.INTEGER.canImplicitlyCast(SqlBaseType.DOUBLE), is(true));
}

@Test
public void shouldUpCastBigInt() {
assertThat(SqlBaseType.BIGINT.canUpCast(SqlBaseType.DECIMAL), is(true));
assertThat(SqlBaseType.BIGINT.canUpCast(SqlBaseType.DOUBLE), is(true));
assertThat(SqlBaseType.BIGINT.canImplicitlyCast(SqlBaseType.DECIMAL), is(true));
assertThat(SqlBaseType.BIGINT.canImplicitlyCast(SqlBaseType.DOUBLE), is(true));
}

@Test
public void shouldUpCastDecimal() {
assertThat(SqlBaseType.DECIMAL.canUpCast(SqlBaseType.DOUBLE), is(true));
assertThat(SqlBaseType.DECIMAL.canImplicitlyCast(SqlBaseType.DOUBLE), is(true));
}

@Test
public void shouldNotDownCastBigInt() {
assertThat(SqlBaseType.BIGINT.canUpCast(SqlBaseType.INTEGER), is(false));
assertThat(SqlBaseType.BIGINT.canImplicitlyCast(SqlBaseType.INTEGER), is(false));
}

@Test
public void shouldNotDownCastDecimal() {
assertThat(SqlBaseType.DECIMAL.canUpCast(SqlBaseType.INTEGER), is(false));
assertThat(SqlBaseType.DECIMAL.canUpCast(SqlBaseType.BIGINT), is(false));
assertThat(SqlBaseType.DECIMAL.canImplicitlyCast(SqlBaseType.INTEGER), is(false));
assertThat(SqlBaseType.DECIMAL.canImplicitlyCast(SqlBaseType.BIGINT), is(false));
}

@Test
public void shouldNotDownCastDouble() {
assertThat(SqlBaseType.DOUBLE.canUpCast(SqlBaseType.INTEGER), is(false));
assertThat(SqlBaseType.DOUBLE.canUpCast(SqlBaseType.BIGINT), is(false));
assertThat(SqlBaseType.DOUBLE.canUpCast(SqlBaseType.DECIMAL), is(false));
assertThat(SqlBaseType.DOUBLE.canImplicitlyCast(SqlBaseType.INTEGER), is(false));
assertThat(SqlBaseType.DOUBLE.canImplicitlyCast(SqlBaseType.BIGINT), is(false));
assertThat(SqlBaseType.DOUBLE.canImplicitlyCast(SqlBaseType.DECIMAL), is(false));
}

private static Stream<SqlBaseType> numberTypes() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import io.confluent.ksql.schema.ksql.Column;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.SchemaConverters;
import io.confluent.ksql.schema.ksql.SqlBaseType;
import io.confluent.ksql.schema.ksql.types.SqlArray;
import io.confluent.ksql.schema.ksql.types.SqlMap;
import io.confluent.ksql.schema.ksql.types.SqlType;
Expand All @@ -64,6 +65,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.kafka.connect.data.Schema;

Expand Down Expand Up @@ -302,10 +304,32 @@ public Void visitIsNullPredicate(
@Override
public Void visitSearchedCaseExpression(
final SearchedCaseExpression node,
final ExpressionTypeContext expressionTypeContext
final ExpressionTypeContext context
) {
validateSearchedCaseExpression(node);
process(node.getWhenClauses().get(0).getResult(), expressionTypeContext);
final Optional<SqlType> whenType = validateWhenClauses(node.getWhenClauses(), context);

final Optional<SqlType> defaultType = node.getDefaultValue()
.map(ExpressionTypeManager.this::getExpressionSqlType);

if (whenType.isPresent() && defaultType.isPresent()) {
if (!whenType.get().equals(defaultType.get())) {
throw new KsqlException("Invalid Case expression. "
+ "Type for the default clause should be the same as for 'THEN' clauses."
+ System.lineSeparator()
+ "THEN type: " + whenType.get() + "."
+ System.lineSeparator()
+ "DEFAULT type: " + defaultType.get() + "."
);
}

context.setSqlType(whenType.get());
} else if (whenType.isPresent()) {
context.setSqlType(whenType.get());
} else if (defaultType.isPresent()) {
context.setSqlType(defaultType.get());
} else {
throw new KsqlException("Invalid Case expression. All case branches have NULL type");
}
return null;
}

Expand Down Expand Up @@ -449,38 +473,46 @@ public Void visitWhenClause(
throw VisitorUtil.illegalState(this, whenClause);
}

private void validateSearchedCaseExpression(
final SearchedCaseExpression searchedCaseExpression) {
final Schema firstResultSchema = getExpressionSchema(
searchedCaseExpression.getWhenClauses().get(0).getResult());
searchedCaseExpression.getWhenClauses()
.forEach(whenClause -> validateWhenClause(whenClause, firstResultSchema));
searchedCaseExpression.getDefaultValue()
.map(ExpressionTypeManager.this::getExpressionSchema)
.filter(defaultSchema -> !firstResultSchema.equals(defaultSchema))
.ifPresent(badSchema -> {
throw new KsqlException("Invalid Case expression."
+ " Schema for the default clause should be the same as schema for THEN clauses."
+ " Result scheme: " + firstResultSchema + "."
+ " Schema for default expression is " + badSchema);
});
}

private void validateWhenClause(final WhenClause whenClause,
final Schema expectedResultSchema) {
final Schema operandSchema = getExpressionSchema(whenClause.getOperand());
if (!operandSchema.equals(Schema.OPTIONAL_BOOLEAN_SCHEMA)) {
throw new KsqlException("When operand schema should be boolean. Schema for ("
+ whenClause.getOperand() + ") is " + operandSchema);
}
final Schema resultSchema = getExpressionSchema(whenClause.getResult());
if (!expectedResultSchema.equals(resultSchema)) {
throw new KsqlException("Invalid Case expression."
+ " Schemas for 'THEN' clauses should be the same."
+ " Result schema: " + expectedResultSchema + "."
+ " Schema for THEN expression '" + whenClause + "'"
+ " is " + resultSchema);
private Optional<SqlType> validateWhenClauses(
final List<WhenClause> whenClauses,
final ExpressionTypeContext context
) {
Optional<SqlType> previousResult = Optional.empty();
for (final WhenClause whenClause : whenClauses) {
process(whenClause.getOperand(), context);

final SqlType operandType = context.getSqlType();

if (operandType.baseType() != SqlBaseType.BOOLEAN) {
throw new KsqlException("WHEN operand type should be boolean."
+ System.lineSeparator()
+ "Type for '" + whenClause.getOperand() + "' is " + operandType
);
}

process(whenClause.getResult(), context);

final SqlType resultType = context.getSqlType();
if (resultType == null) {
continue; // `null` type
}

if (!previousResult.isPresent()) {
previousResult = Optional.of(resultType);
continue;
}

if (!previousResult.get().equals(resultType)) {
throw new KsqlException("Invalid Case expression. "
+ "Type for all 'THEN' clauses should be the same."
+ System.lineSeparator()
+ "THEN expression '" + whenClause + "' has type: " + resultType + "."
+ System.lineSeparator()
+ "Previous THEN expression(s) type: " + previousResult.get() + ".");
}
}

return previousResult;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,11 @@ public void shouldFailIfWhenIsNotBoolean() {
Optional.empty()
);
expectedException.expect(KsqlException.class);
expectedException.expectMessage("When operand schema should be boolean. Schema for ((TEST1.COL0 + 10)) is Schema{INT64}");
expectedException.expectMessage(
"WHEN operand type should be boolean."
+ System.lineSeparator()
+ "Type for '(TEST1.COL0 + 10)' is BIGINT"
);

// When:
expressionTypeManager.getExpressionSqlType(expression);
Expand All @@ -444,7 +448,13 @@ public void shouldFailOnInconsistentWhenResultType() {
Optional.empty()
);
expectedException.expect(KsqlException.class);
expectedException.expectMessage("Invalid Case expression. Schemas for 'THEN' clauses should be the same. Result schema: Schema{STRING}. Schema for THEN expression 'WHEN (TEST1.COL0 = 10) THEN 10' is Schema{INT32}");
expectedException.expectMessage(
"Invalid Case expression. Type for all 'THEN' clauses should be the same."
+ System.lineSeparator()
+ "THEN expression 'WHEN (TEST1.COL0 = 10) THEN 10' has type: INTEGER."
+ System.lineSeparator()
+ "Previous THEN expression(s) type: STRING."
);

// When:
expressionTypeManager.getExpressionSqlType(expression);
Expand All @@ -463,7 +473,13 @@ public void shouldFailIfDefaultHasDifferentTypeToWhen() {
Optional.of(new BooleanLiteral("true"))
);
expectedException.expect(KsqlException.class);
expectedException.expectMessage("Invalid Case expression. Schema for the default clause should be the same as schema for THEN clauses. Result scheme: Schema{STRING}. Schema for default expression is Schema{BOOLEAN}");
expectedException.expectMessage(
"Invalid Case expression. Type for the default clause should be the same as for 'THEN' clauses."
+ System.lineSeparator()
+ "THEN type: STRING."
+ System.lineSeparator()
+ "DEFAULT type: BOOLEAN."
);

// When:
expressionTypeManager.getExpressionSqlType(expression);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
{
"name": "searched case with arithmetic expression in result",
"statements": [
"CREATE STREAM orders (orderid bigint, ORDERUNITS double) WITH (kafka_topic='test_topic', key='orderid', value_format='JSON');",
"CREATE STREAM orders (orderid bigint, ORDERUNITS double) WITH (kafka_topic='test_topic', value_format='JSON');",
"CREATE STREAM S1 AS SELECT CASE WHEN orderunits < 2.0 THEN orderid + 2 END AS case_resault FROM orders;"
],
"inputs": [
Expand All @@ -43,7 +43,7 @@
{
"name": "searched case with null in when",
"statements": [
"CREATE STREAM orders (orderid bigint, ORDERUNITS double) WITH (kafka_topic='test_topic', key='orderid', value_format='JSON');",
"CREATE STREAM orders (orderid bigint, ORDERUNITS double) WITH (kafka_topic='test_topic', value_format='JSON');",
"CREATE STREAM S1 AS SELECT CASE WHEN orderunits > 2.0 THEN 'foo' ELSE 'default' END AS case_resault FROM orders;"
],
"inputs": [
Expand All @@ -58,6 +58,66 @@
}
]
},
{
"name": "searched case returning null in first branch",
"statements": [
"CREATE STREAM orders (ORDERUNITS double) WITH (kafka_topic='test_topic', value_format='JSON');",
"CREATE STREAM S1 AS SELECT CASE WHEN orderunits < 2.0 THEN null WHEN orderunits < 4.0 THEN 'medium' ELSE 'large' END AS case_result FROM orders;"
],
"inputs": [
{"topic": "test_topic", "value": {"ORDERUNITS": 4.2}},
{"topic": "test_topic", "value": {"ORDERUNITS": 3.99}},
{"topic": "test_topic", "value": {"ORDERUNITS": 1.1}}
],
"outputs": [
{"topic": "S1", "value": {"CASE_RESULT": "large"}},
{"topic": "S1", "value": {"CASE_RESULT": "medium"}},
{"topic": "S1", "value": {"CASE_RESULT": null}}
]
},
{
"name": "searched case returning null in later branch",
"statements": [
"CREATE STREAM orders (ORDERUNITS double) WITH (kafka_topic='test_topic', value_format='JSON');",
"CREATE STREAM S1 AS SELECT CASE WHEN orderunits < 2.0 THEN 'small' WHEN orderunits < 4.0 THEN null ELSE 'large' END AS case_result FROM orders;"
],
"inputs": [
{"topic": "test_topic", "value": {"ORDERUNITS": 4.2}},
{"topic": "test_topic", "value": {"ORDERUNITS": 3.99}},
{"topic": "test_topic", "value": {"ORDERUNITS": 1.1}}
],
"outputs": [
{"topic": "S1", "value": {"CASE_RESULT": "large"}},
{"topic": "S1", "value": {"CASE_RESULT": null}},
{"topic": "S1", "value": {"CASE_RESULT": "small"}}
]
},
{
"name": "searched case returning null in default branch",
"statements": [
"CREATE STREAM orders (ORDERUNITS double) WITH (kafka_topic='test_topic', value_format='JSON');",
"CREATE STREAM S1 AS SELECT CASE WHEN orderunits < 2.0 THEN 'small' ELSE null END AS case_result FROM orders;"
],
"inputs": [
{"topic": "test_topic", "value": {"ORDERUNITS": 4.2}},
{"topic": "test_topic", "value": {"ORDERUNITS": 1.1}}
],
"outputs": [
{"topic": "S1", "value": {"CASE_RESULT": null}},
{"topic": "S1", "value": {"CASE_RESULT": "small"}}
]
},
{
"name": "searched case returning null in all branch",
"statements": [
"CREATE STREAM orders (ORDERUNITS double) WITH (kafka_topic='test_topic', value_format='JSON');",
"CREATE STREAM S1 AS SELECT CASE WHEN orderunits < 2.0 THEN null ELSE null END AS case_result FROM orders;"
],
"expectedException": {
"type": "io.confluent.ksql.util.KsqlStatementException",
"message": "Invalid Case expression. All case branches have NULL type"
}
},
{
"name": "searched case expression with structs, multiple expression and the same type",
"statements": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public <T> Optional<T> coerce(final Object value, final SqlType targetType) {
return coerceDecimal(value, (SqlDecimal) targetType);
}

if (!(value instanceof Number) || !valueSqlType.canUpCast(targetType.baseType())) {
if (!(value instanceof Number) || !valueSqlType.canImplicitlyCast(targetType.baseType())) {
return Optional.empty();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ private static boolean coercionShouldBeSupported(
// Handled by parsing the string to a decimal:
return true;
}
return fromBaseType.canUpCast(toBaseType);
return fromBaseType.canImplicitlyCast(toBaseType);
}

private static List<SqlBaseType> supportedTypes() {
Expand Down

0 comments on commit eb9e41b

Please sign in to comment.