Skip to content

Commit

Permalink
allow schema to be merged on non struct root type
Browse files Browse the repository at this point in the history
  • Loading branch information
jackierwzhang committed Oct 2, 2024
1 parent 19c054b commit 2204f22
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ object SchemaMergingUtils {
}

/**
* A variant of [[mergeDataTypes]] with common default values and enforce struct type
* as inputs for Delta table operation.
*
* Check whether we can write to the Delta table, which has `tableSchema`, using a query that has
* `dataSchema`. Our rules are that:
* - `dataSchema` may be missing columns or have additional columns
Expand All @@ -144,9 +147,29 @@ object SchemaMergingUtils {
*
* Schema merging occurs in a case insensitive manner. Hence, column names that only differ
* by case are not accepted in the `dataSchema`.
*
* @param tableSchema The current schema of the table.
* @param dataSchema The schema of the new data being written.
*/
def mergeSchemas(
tableSchema: StructType,
dataSchema: StructType,
allowImplicitConversions: Boolean = false,
keepExistingType: Boolean = false,
allowTypeWidening: Boolean = false,
caseSensitive: Boolean = false): StructType = {
checkColumnNameDuplication(dataSchema, "in the data to save", caseSensitive)
mergeDataTypes(
tableSchema,
dataSchema,
allowImplicitConversions,
keepExistingType,
allowTypeWidening,
caseSensitive,
allowOverride = false
).asInstanceOf[StructType]
}

/**
* @param current The current data type.
* @param update The data type of the new data being written.
* @param allowImplicitConversions Whether to allow Spark SQL implicit conversions. By default,
* we merge according to Parquet write compatibility - for
* example, an integer type data field will throw when merged to a
Expand All @@ -157,15 +180,16 @@ object SchemaMergingUtils {
* @param keepExistingType Whether to keep existing types instead of trying to merge types.
* @param caseSensitive Whether we should keep field mapping case-sensitively.
* This should default to false for Delta, which is case insensitive.
* @param allowOverride Whether to let incoming type override the existing type if unmatched.
*/
def mergeSchemas(
tableSchema: StructType,
dataSchema: StructType,
allowImplicitConversions: Boolean = false,
keepExistingType: Boolean = false,
allowTypeWidening: Boolean = false,
caseSensitive: Boolean = false): StructType = {
checkColumnNameDuplication(dataSchema, "in the data to save", caseSensitive)
def mergeDataTypes(
current: DataType,
update: DataType,
allowImplicitConversions: Boolean,
keepExistingType: Boolean,
allowTypeWidening: Boolean,
caseSensitive: Boolean,
allowOverride: Boolean): DataType = {
def merge(current: DataType, update: DataType): DataType = {
(current, update) match {
case (StructType(currentFields), StructType(updateFields)) =>
Expand Down Expand Up @@ -201,31 +225,32 @@ object SchemaMergingUtils {
// Create the merged struct, the new fields are appended at the end of the struct.
StructType(updatedCurrentFields ++ newFields)
case (ArrayType(currentElementType, currentContainsNull),
ArrayType(updateElementType, _)) =>
ArrayType(updateElementType, _)) =>
ArrayType(
merge(currentElementType, updateElementType),
currentContainsNull)
case (MapType(currentKeyType, currentElementType, currentContainsNull),
MapType(updateKeyType, updateElementType, _)) =>
MapType(updateKeyType, updateElementType, _)) =>
MapType(
merge(currentKeyType, updateKeyType),
merge(currentElementType, updateElementType),
currentContainsNull)

// Simply keeps the existing type for primitive types
case (current, update) if keepExistingType => current
case (current, _) if keepExistingType => current
case (_, update) if allowOverride => update

case (current: AtomicType, update: AtomicType) if allowTypeWidening &&
TypeWidening.isTypeChangeSupportedForSchemaEvolution(current, update) => update

// If implicit conversions are allowed, that means we can use any valid implicit cast to
// perform the merge.
case (current, update)
if allowImplicitConversions && typeForImplicitCast(update, current).isDefined =>
if allowImplicitConversions && typeForImplicitCast(update, current).isDefined =>
typeForImplicitCast(update, current).get

case (DecimalType.Fixed(leftPrecision, leftScale),
DecimalType.Fixed(rightPrecision, rightScale)) =>
DecimalType.Fixed(rightPrecision, rightScale)) =>
if ((leftPrecision == rightPrecision) && (leftScale == rightScale)) {
current
} else if ((leftPrecision != rightPrecision) && (leftScale != rightScale)) {
Expand Down Expand Up @@ -267,7 +292,7 @@ object SchemaMergingUtils {
messageParameters = Array(current.toString, update.toString))
}
}
merge(tableSchema, dataSchema).asInstanceOf[StructType]
merge(current, update)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2428,6 +2428,51 @@ class SchemaUtilsSuite extends QueryTest
testParquetUpcast()

}

test("schema merging non struct root type") {
// Array root type
val base1 = ArrayType(new StructType().add("a", IntegerType))
val update1 = ArrayType(new StructType().add("b", IntegerType))

assert(mergeDataTypes(
base1, update1, false, false, false, false, allowOverride = false) ===
ArrayType(new StructType().add("a", IntegerType).add("b", IntegerType)))

// Map root type
val base2 = MapType(
new StructType().add("a", IntegerType),
new StructType().add("b", IntegerType)
)
val update2 = MapType(
new StructType().add("b", IntegerType),
new StructType().add("c", IntegerType)
)

assert(mergeDataTypes(
base2, update2, false, false, false, false, allowOverride = false) ===
MapType(
new StructType().add("a", IntegerType).add("b", IntegerType),
new StructType().add("b", IntegerType).add("c", IntegerType)
))
}

test("schema merging allow override") {
// override root type
val base1 = new StructType().add("a", IntegerType)
val update1 = ArrayType(LongType)

assert(mergeDataTypes(
base1, update1, false, false, false, false, allowOverride = true) === ArrayType(LongType))

// override nested type
val base2 = ArrayType(new StructType().add("a", IntegerType).add("b", StringType))
val update2 = ArrayType(new StructType().add("a", MapType(StringType, StringType)))

assert(mergeDataTypes(
base2, update2, false, false, false, false, allowOverride = true) ===
ArrayType(new StructType().add("a", MapType(StringType, StringType)).add("b", StringType)))
}

////////////////////////////
// transformColumns
////////////////////////////
Expand Down

0 comments on commit 2204f22

Please sign in to comment.