diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaMergingUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaMergingUtils.scala index fd7172603c0..6cda18b7e78 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaMergingUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/schema/SchemaMergingUtils.scala @@ -127,6 +127,9 @@ object SchemaMergingUtils { } /** + * A variant of [[mergeDataTypes]] with common default values and enforce struct type + * as inputs for Delta table operation. + * * Check whether we can write to the Delta table, which has `tableSchema`, using a query that has * `dataSchema`. Our rules are that: * - `dataSchema` may be missing columns or have additional columns @@ -144,9 +147,29 @@ object SchemaMergingUtils { * * Schema merging occurs in a case insensitive manner. Hence, column names that only differ * by case are not accepted in the `dataSchema`. - * - * @param tableSchema The current schema of the table. - * @param dataSchema The schema of the new data being written. + */ + def mergeSchemas( + tableSchema: StructType, + dataSchema: StructType, + allowImplicitConversions: Boolean = false, + keepExistingType: Boolean = false, + allowTypeWidening: Boolean = false, + caseSensitive: Boolean = false): StructType = { + checkColumnNameDuplication(dataSchema, "in the data to save", caseSensitive) + mergeDataTypes( + tableSchema, + dataSchema, + allowImplicitConversions, + keepExistingType, + allowTypeWidening, + caseSensitive, + allowOverride = false + ).asInstanceOf[StructType] + } + + /** + * @param current The current data type. + * @param update The data type of the new data being written. * @param allowImplicitConversions Whether to allow Spark SQL implicit conversions. By default, * we merge according to Parquet write compatibility - for * example, an integer type data field will throw when merged to a @@ -157,15 +180,16 @@ object SchemaMergingUtils { * @param keepExistingType Whether to keep existing types instead of trying to merge types. * @param caseSensitive Whether we should keep field mapping case-sensitively. * This should default to false for Delta, which is case insensitive. + * @param allowOverride Whether to let incoming type override the existing type if unmatched. */ - def mergeSchemas( - tableSchema: StructType, - dataSchema: StructType, - allowImplicitConversions: Boolean = false, - keepExistingType: Boolean = false, - allowTypeWidening: Boolean = false, - caseSensitive: Boolean = false): StructType = { - checkColumnNameDuplication(dataSchema, "in the data to save", caseSensitive) + def mergeDataTypes( + current: DataType, + update: DataType, + allowImplicitConversions: Boolean, + keepExistingType: Boolean, + allowTypeWidening: Boolean, + caseSensitive: Boolean, + allowOverride: Boolean): DataType = { def merge(current: DataType, update: DataType): DataType = { (current, update) match { case (StructType(currentFields), StructType(updateFields)) => @@ -201,19 +225,20 @@ object SchemaMergingUtils { // Create the merged struct, the new fields are appended at the end of the struct. StructType(updatedCurrentFields ++ newFields) case (ArrayType(currentElementType, currentContainsNull), - ArrayType(updateElementType, _)) => + ArrayType(updateElementType, _)) => ArrayType( merge(currentElementType, updateElementType), currentContainsNull) case (MapType(currentKeyType, currentElementType, currentContainsNull), - MapType(updateKeyType, updateElementType, _)) => + MapType(updateKeyType, updateElementType, _)) => MapType( merge(currentKeyType, updateKeyType), merge(currentElementType, updateElementType), currentContainsNull) // Simply keeps the existing type for primitive types - case (current, update) if keepExistingType => current + case (current, _) if keepExistingType => current + case (_, update) if allowOverride => update case (current: AtomicType, update: AtomicType) if allowTypeWidening && TypeWidening.isTypeChangeSupportedForSchemaEvolution(current, update) => update @@ -221,11 +246,11 @@ object SchemaMergingUtils { // If implicit conversions are allowed, that means we can use any valid implicit cast to // perform the merge. case (current, update) - if allowImplicitConversions && typeForImplicitCast(update, current).isDefined => + if allowImplicitConversions && typeForImplicitCast(update, current).isDefined => typeForImplicitCast(update, current).get case (DecimalType.Fixed(leftPrecision, leftScale), - DecimalType.Fixed(rightPrecision, rightScale)) => + DecimalType.Fixed(rightPrecision, rightScale)) => if ((leftPrecision == rightPrecision) && (leftScale == rightScale)) { current } else if ((leftPrecision != rightPrecision) && (leftScale != rightScale)) { @@ -267,7 +292,7 @@ object SchemaMergingUtils { messageParameters = Array(current.toString, update.toString)) } } - merge(tableSchema, dataSchema).asInstanceOf[StructType] + merge(current, update) } /** diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala index a8fbfd51ff8..8a58fdd160d 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/schema/SchemaUtilsSuite.scala @@ -2428,6 +2428,51 @@ class SchemaUtilsSuite extends QueryTest testParquetUpcast() } + + test("schema merging non struct root type") { + // Array root type + val base1 = ArrayType(new StructType().add("a", IntegerType)) + val update1 = ArrayType(new StructType().add("b", IntegerType)) + + assert(mergeDataTypes( + base1, update1, false, false, false, false, allowOverride = false) === + ArrayType(new StructType().add("a", IntegerType).add("b", IntegerType))) + + // Map root type + val base2 = MapType( + new StructType().add("a", IntegerType), + new StructType().add("b", IntegerType) + ) + val update2 = MapType( + new StructType().add("b", IntegerType), + new StructType().add("c", IntegerType) + ) + + assert(mergeDataTypes( + base2, update2, false, false, false, false, allowOverride = false) === + MapType( + new StructType().add("a", IntegerType).add("b", IntegerType), + new StructType().add("b", IntegerType).add("c", IntegerType) + )) + } + + test("schema merging allow override") { + // override root type + val base1 = new StructType().add("a", IntegerType) + val update1 = ArrayType(LongType) + + assert(mergeDataTypes( + base1, update1, false, false, false, false, allowOverride = true) === ArrayType(LongType)) + + // override nested type + val base2 = ArrayType(new StructType().add("a", IntegerType).add("b", StringType)) + val update2 = ArrayType(new StructType().add("a", MapType(StringType, StringType))) + + assert(mergeDataTypes( + base2, update2, false, false, false, false, allowOverride = true) === + ArrayType(new StructType().add("a", MapType(StringType, StringType)).add("b", StringType))) + } + //////////////////////////// // transformColumns ////////////////////////////