From 099a18e1b36f12c3a61f1267feebd82a72a65b85 Mon Sep 17 00:00:00 2001 From: Timo Walther Date: Mon, 17 May 2021 09:15:49 +0200 Subject: [PATCH] [FLINK-22666][table] Make structured type's fields more lenient during casting Compare children individually for anonymous structured types. This fixes issues with primitive fields and Scala case classes. This closes #15935. --- .../types/logical/utils/LogicalTypeCasts.java | 58 ++++++++- .../types/LogicalTypeCastAvoidanceTest.java | 42 ++++++- .../table/types/LogicalTypeCastsTest.java | 36 ++++++ .../stream/sql/DataStreamScalaITCase.scala | 119 ++++++++++++++++++ 4 files changed, 250 insertions(+), 5 deletions(-) create mode 100644 flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DataStreamScalaITCase.scala diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeCasts.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeCasts.java index 8c37f56a745f1..d9591c442b24a 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeCasts.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeCasts.java @@ -35,6 +35,8 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.BiFunction; +import java.util.stream.Collectors; import static org.apache.flink.table.types.logical.LogicalTypeFamily.BINARY_STRING; import static org.apache.flink.table.types.logical.LogicalTypeFamily.CHARACTER_STRING; @@ -315,8 +317,8 @@ private static boolean supportsCasting( } else if (hasFamily(sourceType, CONSTRUCTED) || hasFamily(targetType, CONSTRUCTED)) { return supportsConstructedCasting(sourceType, targetType, allowExplicit); } else if (sourceRoot == STRUCTURED_TYPE || targetRoot == STRUCTURED_TYPE) { - // inheritance is not supported yet, so structured type must be fully equal - return false; + return supportsStructuredCasting( + sourceType, targetType, (s, t) -> supportsCasting(s, t, allowExplicit)); } else if (sourceRoot == RAW || targetRoot == RAW) { // the two raw types are not equal (from initial invariant), casting is not possible return false; @@ -334,6 +336,51 @@ private static boolean supportsCasting( return false; } + private static boolean supportsStructuredCasting( + LogicalType sourceType, + LogicalType targetType, + BiFunction childPredicate) { + final LogicalTypeRoot sourceRoot = sourceType.getTypeRoot(); + final LogicalTypeRoot targetRoot = targetType.getTypeRoot(); + if (sourceRoot != STRUCTURED_TYPE || targetRoot != STRUCTURED_TYPE) { + return false; + } + final StructuredType sourceStructuredType = (StructuredType) sourceType; + final StructuredType targetStructuredType = (StructuredType) targetType; + // non-anonymous structured types must be fully equal + if (sourceStructuredType.getObjectIdentifier().isPresent() + || targetStructuredType.getObjectIdentifier().isPresent()) { + return false; + } + // for anonymous structured types we are a bit more lenient, if they provide similar fields + // e.g. this is necessary when structured types derived from type information and + // structured types derived within Table API are slightly different + final Class sourceClass = sourceStructuredType.getImplementationClass().orElse(null); + final Class targetClass = targetStructuredType.getImplementationClass().orElse(null); + if (sourceClass != targetClass) { + return false; + } + final List sourceNames = + sourceStructuredType.getAttributes().stream() + .map(StructuredType.StructuredAttribute::getName) + .collect(Collectors.toList()); + final List targetNames = + sourceStructuredType.getAttributes().stream() + .map(StructuredType.StructuredAttribute::getName) + .collect(Collectors.toList()); + if (!sourceNames.equals(targetNames)) { + return false; + } + final List sourceChildren = sourceType.getChildren(); + final List targetChildren = targetType.getChildren(); + for (int i = 0; i < sourceChildren.size(); i++) { + if (!childPredicate.apply(sourceChildren.get(i), targetChildren.get(i))) { + return false; + } + } + return true; + } + private static boolean supportsConstructedCasting( LogicalType sourceType, LogicalType targetType, boolean allowExplicit) { final LogicalTypeRoot sourceRoot = sourceType.getTypeRoot(); @@ -493,8 +540,11 @@ public Boolean visit(StructuredType targetType) { final List targetChildren = targetType.getChildren(); return supportsAvoidingCast(sourceChildren, targetChildren); } - // structured types should be equal (modulo nullability) - return sourceType.equals(targetType) || sourceType.copy(true).equals(targetType); + if (sourceType.equals(targetType) || sourceType.copy(true).equals(targetType)) { + return true; + } + return supportsStructuredCasting( + sourceType, targetType, LogicalTypeCasts::supportsAvoidingCast); } @Override diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastAvoidanceTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastAvoidanceTest.java index a994f22af056c..470b7208ce7d0 100644 --- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastAvoidanceTest.java +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastAvoidanceTest.java @@ -230,7 +230,7 @@ public static List testData() { true }, - // row and structure type + // row and structured type { RowType.of(new IntType(), new VarCharType()), createUserType("User2", new IntType(), new VarCharType()), @@ -251,6 +251,46 @@ public static List testData() { RowType.of(new BigIntType(), new VarCharType()), false }, + + // test slightly different children of anonymous structured types + { + StructuredType.newBuilder(Void.class) + .attributes( + Arrays.asList( + new StructuredType.StructuredAttribute( + "f1", new TimestampType()), + new StructuredType.StructuredAttribute( + "diff", new TinyIntType(false)))) + .build(), + StructuredType.newBuilder(Void.class) + .attributes( + Arrays.asList( + new StructuredType.StructuredAttribute( + "f1", new TimestampType()), + new StructuredType.StructuredAttribute( + "diff", new TinyIntType(true)))) + .build(), + true + }, + { + StructuredType.newBuilder(Void.class) + .attributes( + Arrays.asList( + new StructuredType.StructuredAttribute( + "f1", new TimestampType()), + new StructuredType.StructuredAttribute( + "diff", new TinyIntType(true)))) + .build(), + StructuredType.newBuilder(Void.class) + .attributes( + Arrays.asList( + new StructuredType.StructuredAttribute( + "f1", new TimestampType()), + new StructuredType.StructuredAttribute( + "diff", new TinyIntType(false)))) + .build(), + false + } }); } diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastsTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastsTest.java index e908e6c729bf0..64eade90d028b 100644 --- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastsTest.java +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/LogicalTypeCastsTest.java @@ -218,6 +218,42 @@ public static List testData() { false, true }, + + // test slightly different children of anonymous structured types + { + StructuredType.newBuilder(Void.class) + .attributes( + Arrays.asList( + new StructuredAttribute("f1", new TimestampType()), + new StructuredAttribute( + "diff", new TinyIntType(false)))) + .build(), + StructuredType.newBuilder(Void.class) + .attributes( + Arrays.asList( + new StructuredAttribute("f1", new TimestampType()), + new StructuredAttribute( + "diff", new TinyIntType(true)))) + .build(), + true, + true + }, + { + StructuredType.newBuilder(Void.class) + .attributes( + Arrays.asList( + new StructuredAttribute("f1", new TimestampType()), + new StructuredAttribute("diff", new IntType()))) + .build(), + StructuredType.newBuilder(Void.class) + .attributes( + Arrays.asList( + new StructuredAttribute("f1", new TimestampType()), + new StructuredAttribute("diff", new TinyIntType()))) + .build(), + false, + true + } }); } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DataStreamScalaITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DataStreamScalaITCase.scala new file mode 100644 index 0000000000000..9cc1ce8e66fd6 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/DataStreamScalaITCase.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.runtime.stream.sql + +import org.apache.flink.streaming.api.scala.{CloseableIterator, DataStream, StreamExecutionEnvironment} +import org.apache.flink.table.api.bridge.scala.StreamTableEnvironment +import org.apache.flink.test.util.AbstractTestBase +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.{DataTypes, Table, TableResult} +import org.apache.flink.table.catalog.{Column, ResolvedSchema} +import org.apache.flink.table.planner.runtime.stream.sql.DataStreamScalaITCase.{ComplexCaseClass, ImmutableCaseClass} +import org.apache.flink.types.Row +import org.apache.flink.util.CollectionUtil + +import org.hamcrest.Matchers.containsInAnyOrder +import org.junit.Assert.{assertEquals, assertThat} +import org.junit.{Before, Test} + +import java.util +import scala.collection.JavaConverters._ + +/** Tests for connecting to the Scala [[DataStream]] API. */ +class DataStreamScalaITCase extends AbstractTestBase { + + private var env: StreamExecutionEnvironment = _ + + private var tableEnv: StreamTableEnvironment = _ + + @Before + def before(): Unit = { + env = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(4) + tableEnv = StreamTableEnvironment.create(env) + } + + @Test + def testFromAndToDataStreamWithCaseClass(): Unit = { + val caseClasses = Array( + ComplexCaseClass(42, "hello", ImmutableCaseClass(42.0, b = true)), + ComplexCaseClass(42, null, ImmutableCaseClass(42.0, b = false))) + + val dataStream = env.fromElements(caseClasses: _*) + + val table = tableEnv.fromDataStream(dataStream) + + testSchema( + table, + Column.physical("c", DataTypes.INT().notNull().bridgedTo(classOf[Int])), + Column.physical("a", DataTypes.STRING()), + Column.physical( + "p", + DataTypes.STRUCTURED( + classOf[ImmutableCaseClass], + DataTypes.FIELD( + "d", + DataTypes.DOUBLE().notNull()), // serializer doesn't support null + DataTypes.FIELD( + "b", + DataTypes.BOOLEAN().notNull().bridgedTo(classOf[Boolean]))).notNull())) + + testResult( + table.execute(), + Row.of(Int.box(42), "hello", ImmutableCaseClass(42.0, b = true)), + Row.of(Int.box(42), null, ImmutableCaseClass(42.0, b = false))) + + val resultStream = tableEnv.toDataStream(table, classOf[ComplexCaseClass]) + + testResult(resultStream, caseClasses: _*) + } + + // -------------------------------------------------------------------------------------------- + // Helper methods + // -------------------------------------------------------------------------------------------- + + private def testSchema(table: Table, expectedColumns: Column*): Unit = { + assertEquals(ResolvedSchema.of(expectedColumns: _*), table.getResolvedSchema) + } + + private def testResult(result: TableResult, expectedRows: Row*): Unit = { + val actualRows: util.List[Row] = CollectionUtil.iteratorToList(result.collect) + assertThat(actualRows, containsInAnyOrder(expectedRows: _*)) + } + + private def testResult[T](dataStream: DataStream[T], expectedResult: T*): Unit = { + var iterator: CloseableIterator[T] = null + try { + iterator = dataStream.executeAndCollect() + val list: util.List[T] = iterator.toList.asJava + assertThat(list, containsInAnyOrder(expectedResult: _*)) + } finally { + if (iterator != null) { + iterator.close() + } + } + } +} + +object DataStreamScalaITCase { + + case class ComplexCaseClass(var c: Int, var a: String, var p: ImmutableCaseClass) + + case class ImmutableCaseClass(d: java.lang.Double, b: Boolean) +}