Skip to content

Commit

Permalink
[FLINK-22666][table] Make structured type's fields more lenient durin…
Browse files Browse the repository at this point in the history
…g casting

Compare children individually for anonymous structured types. This
fixes issues with primitive fields and Scala case classes.

This closes #15935.
  • Loading branch information
twalthr committed May 17, 2021
1 parent 5eebab4 commit 099a18e
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

import static org.apache.flink.table.types.logical.LogicalTypeFamily.BINARY_STRING;
import static org.apache.flink.table.types.logical.LogicalTypeFamily.CHARACTER_STRING;
Expand Down Expand Up @@ -315,8 +317,8 @@ private static boolean supportsCasting(
} else if (hasFamily(sourceType, CONSTRUCTED) || hasFamily(targetType, CONSTRUCTED)) {
return supportsConstructedCasting(sourceType, targetType, allowExplicit);
} else if (sourceRoot == STRUCTURED_TYPE || targetRoot == STRUCTURED_TYPE) {
// inheritance is not supported yet, so structured type must be fully equal
return false;
return supportsStructuredCasting(
sourceType, targetType, (s, t) -> supportsCasting(s, t, allowExplicit));
} else if (sourceRoot == RAW || targetRoot == RAW) {
// the two raw types are not equal (from initial invariant), casting is not possible
return false;
Expand All @@ -334,6 +336,51 @@ private static boolean supportsCasting(
return false;
}

private static boolean supportsStructuredCasting(
LogicalType sourceType,
LogicalType targetType,
BiFunction<LogicalType, LogicalType, Boolean> childPredicate) {
final LogicalTypeRoot sourceRoot = sourceType.getTypeRoot();
final LogicalTypeRoot targetRoot = targetType.getTypeRoot();
if (sourceRoot != STRUCTURED_TYPE || targetRoot != STRUCTURED_TYPE) {
return false;
}
final StructuredType sourceStructuredType = (StructuredType) sourceType;
final StructuredType targetStructuredType = (StructuredType) targetType;
// non-anonymous structured types must be fully equal
if (sourceStructuredType.getObjectIdentifier().isPresent()
|| targetStructuredType.getObjectIdentifier().isPresent()) {
return false;
}
// for anonymous structured types we are a bit more lenient, if they provide similar fields
// e.g. this is necessary when structured types derived from type information and
// structured types derived within Table API are slightly different
final Class<?> sourceClass = sourceStructuredType.getImplementationClass().orElse(null);
final Class<?> targetClass = targetStructuredType.getImplementationClass().orElse(null);
if (sourceClass != targetClass) {
return false;
}
final List<String> sourceNames =
sourceStructuredType.getAttributes().stream()
.map(StructuredType.StructuredAttribute::getName)
.collect(Collectors.toList());
final List<String> targetNames =
sourceStructuredType.getAttributes().stream()
.map(StructuredType.StructuredAttribute::getName)
.collect(Collectors.toList());
if (!sourceNames.equals(targetNames)) {
return false;
}
final List<LogicalType> sourceChildren = sourceType.getChildren();
final List<LogicalType> targetChildren = targetType.getChildren();
for (int i = 0; i < sourceChildren.size(); i++) {
if (!childPredicate.apply(sourceChildren.get(i), targetChildren.get(i))) {
return false;
}
}
return true;
}

private static boolean supportsConstructedCasting(
LogicalType sourceType, LogicalType targetType, boolean allowExplicit) {
final LogicalTypeRoot sourceRoot = sourceType.getTypeRoot();
Expand Down Expand Up @@ -493,8 +540,11 @@ public Boolean visit(StructuredType targetType) {
final List<LogicalType> targetChildren = targetType.getChildren();
return supportsAvoidingCast(sourceChildren, targetChildren);
}
// structured types should be equal (modulo nullability)
return sourceType.equals(targetType) || sourceType.copy(true).equals(targetType);
if (sourceType.equals(targetType) || sourceType.copy(true).equals(targetType)) {
return true;
}
return supportsStructuredCasting(
sourceType, targetType, LogicalTypeCasts::supportsAvoidingCast);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ public static List<Object[]> testData() {
true
},

// row and structure type
// row and structured type
{
RowType.of(new IntType(), new VarCharType()),
createUserType("User2", new IntType(), new VarCharType()),
Expand All @@ -251,6 +251,46 @@ public static List<Object[]> testData() {
RowType.of(new BigIntType(), new VarCharType()),
false
},

// test slightly different children of anonymous structured types
{
StructuredType.newBuilder(Void.class)
.attributes(
Arrays.asList(
new StructuredType.StructuredAttribute(
"f1", new TimestampType()),
new StructuredType.StructuredAttribute(
"diff", new TinyIntType(false))))
.build(),
StructuredType.newBuilder(Void.class)
.attributes(
Arrays.asList(
new StructuredType.StructuredAttribute(
"f1", new TimestampType()),
new StructuredType.StructuredAttribute(
"diff", new TinyIntType(true))))
.build(),
true
},
{
StructuredType.newBuilder(Void.class)
.attributes(
Arrays.asList(
new StructuredType.StructuredAttribute(
"f1", new TimestampType()),
new StructuredType.StructuredAttribute(
"diff", new TinyIntType(true))))
.build(),
StructuredType.newBuilder(Void.class)
.attributes(
Arrays.asList(
new StructuredType.StructuredAttribute(
"f1", new TimestampType()),
new StructuredType.StructuredAttribute(
"diff", new TinyIntType(false))))
.build(),
false
}
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,42 @@ public static List<Object[]> testData() {
false,
true
},

// test slightly different children of anonymous structured types
{
StructuredType.newBuilder(Void.class)
.attributes(
Arrays.asList(
new StructuredAttribute("f1", new TimestampType()),
new StructuredAttribute(
"diff", new TinyIntType(false))))
.build(),
StructuredType.newBuilder(Void.class)
.attributes(
Arrays.asList(
new StructuredAttribute("f1", new TimestampType()),
new StructuredAttribute(
"diff", new TinyIntType(true))))
.build(),
true,
true
},
{
StructuredType.newBuilder(Void.class)
.attributes(
Arrays.asList(
new StructuredAttribute("f1", new TimestampType()),
new StructuredAttribute("diff", new IntType())))
.build(),
StructuredType.newBuilder(Void.class)
.attributes(
Arrays.asList(
new StructuredAttribute("f1", new TimestampType()),
new StructuredAttribute("diff", new TinyIntType())))
.build(),
false,
true
}
});
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.runtime.stream.sql

import org.apache.flink.streaming.api.scala.{CloseableIterator, DataStream, StreamExecutionEnvironment}
import org.apache.flink.table.api.bridge.scala.StreamTableEnvironment
import org.apache.flink.test.util.AbstractTestBase
import org.apache.flink.api.scala._
import org.apache.flink.table.api.{DataTypes, Table, TableResult}
import org.apache.flink.table.catalog.{Column, ResolvedSchema}
import org.apache.flink.table.planner.runtime.stream.sql.DataStreamScalaITCase.{ComplexCaseClass, ImmutableCaseClass}
import org.apache.flink.types.Row
import org.apache.flink.util.CollectionUtil

import org.hamcrest.Matchers.containsInAnyOrder
import org.junit.Assert.{assertEquals, assertThat}
import org.junit.{Before, Test}

import java.util
import scala.collection.JavaConverters._

/** Tests for connecting to the Scala [[DataStream]] API. */
class DataStreamScalaITCase extends AbstractTestBase {

private var env: StreamExecutionEnvironment = _

private var tableEnv: StreamTableEnvironment = _

@Before
def before(): Unit = {
env = StreamExecutionEnvironment.getExecutionEnvironment
env.setParallelism(4)
tableEnv = StreamTableEnvironment.create(env)
}

@Test
def testFromAndToDataStreamWithCaseClass(): Unit = {
val caseClasses = Array(
ComplexCaseClass(42, "hello", ImmutableCaseClass(42.0, b = true)),
ComplexCaseClass(42, null, ImmutableCaseClass(42.0, b = false)))

val dataStream = env.fromElements(caseClasses: _*)

val table = tableEnv.fromDataStream(dataStream)

testSchema(
table,
Column.physical("c", DataTypes.INT().notNull().bridgedTo(classOf[Int])),
Column.physical("a", DataTypes.STRING()),
Column.physical(
"p",
DataTypes.STRUCTURED(
classOf[ImmutableCaseClass],
DataTypes.FIELD(
"d",
DataTypes.DOUBLE().notNull()), // serializer doesn't support null
DataTypes.FIELD(
"b",
DataTypes.BOOLEAN().notNull().bridgedTo(classOf[Boolean]))).notNull()))

testResult(
table.execute(),
Row.of(Int.box(42), "hello", ImmutableCaseClass(42.0, b = true)),
Row.of(Int.box(42), null, ImmutableCaseClass(42.0, b = false)))

val resultStream = tableEnv.toDataStream(table, classOf[ComplexCaseClass])

testResult(resultStream, caseClasses: _*)
}

// --------------------------------------------------------------------------------------------
// Helper methods
// --------------------------------------------------------------------------------------------

private def testSchema(table: Table, expectedColumns: Column*): Unit = {
assertEquals(ResolvedSchema.of(expectedColumns: _*), table.getResolvedSchema)
}

private def testResult(result: TableResult, expectedRows: Row*): Unit = {
val actualRows: util.List[Row] = CollectionUtil.iteratorToList(result.collect)
assertThat(actualRows, containsInAnyOrder(expectedRows: _*))
}

private def testResult[T](dataStream: DataStream[T], expectedResult: T*): Unit = {
var iterator: CloseableIterator[T] = null
try {
iterator = dataStream.executeAndCollect()
val list: util.List[T] = iterator.toList.asJava
assertThat(list, containsInAnyOrder(expectedResult: _*))
} finally {
if (iterator != null) {
iterator.close()
}
}
}
}

object DataStreamScalaITCase {

case class ComplexCaseClass(var c: Int, var a: String, var p: ImmutableCaseClass)

case class ImmutableCaseClass(d: java.lang.Double, b: Boolean)
}

0 comments on commit 099a18e

Please sign in to comment.