Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-51290][SQL] Enable filling default values in DSv2 writes #50044

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ case class StructField(
}
}

private[sql] def hasExistenceDefaultValue: Boolean = {
metadata.contains(EXISTS_DEFAULT_COLUMN_METADATA_KEY)
}

private def getDDLDefault = getCurrentDefaultValue()
.map(" DEFAULT " + _)
.getOrElse("")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3534,7 +3534,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
TableOutputResolver.suitableForByNameCheck(v2Write.isByName,
expected = v2Write.table.output, queryOutput = v2Write.query.output)
val projection = TableOutputResolver.resolveOutputColumns(
v2Write.table.name, v2Write.table.output, v2Write.query, v2Write.isByName, conf)
v2Write.table.name, v2Write.table.output, v2Write.query, v2Write.isByName, conf,
supportColDefaultValue = true)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is value in validating if the catalog defines SUPPORT_COLUMN_DEFAULT_VALUE in capabilities during writes. If a connector includes default value metadata in its schema, it should be enough to fill default values. The flag exists for ALTER and CREATE/REPLACE statements.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea true, Spark fills the default values during table writing and it works for all catalogs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean supportColDefaultValue is true or false doesn't matter for v2 here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. You mean to check for the flag SUPPORT_COLUMN_DEFAULT_VALUE here for the catalog.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, I don't see value in checking SUPPORT_COLUMN_DEFAULT_VALUE here.

if (projection != v2Write.query) {
val cleanedTable = v2Write.table match {
case r: DataSourceV2Relation =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ object TableOutputResolver extends SQLConfHelper with Logging {
query: LogicalPlan,
byName: Boolean,
conf: SQLConf,
// TODO: Only DS v1 writing will set it to true. We should enable in for DS v2 as well.
supportColDefaultValue: Boolean = false): LogicalPlan = {

if (expected.size < query.output.size) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,15 +459,17 @@ object ResolveDefaultColumns extends QueryErrorsBase
* Any type suitable for assigning into a row using the InternalRow.update method.
*/
def getExistenceDefaultValues(schema: StructType): Array[Any] = {
schema.fields.map { field: StructField =>
val defaultValue: Option[String] = field.getExistenceDefaultValue()
defaultValue.map { _: String =>
val expr = analyzeExistenceDefaultValue(field)

// The expression should be a literal value by this point, possibly wrapped in a cast
// function. This is enforced by the execution of commands that assign default values.
expr.eval()
}.orNull
schema.fields.map(getExistenceDefaultValue)
}

def getExistenceDefaultValue(field: StructField): Any = {
if (field.hasExistenceDefaultValue) {
val expr = analyzeExistenceDefaultValue(field)
// The expression should be a literal value by this point, possibly wrapped in a cast
// function. This is enforced by the execution of commands that assign default values.
expr.eval()
} else {
null
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,8 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest {
assertNotResolved(parsedPlan)
assertAnalysisErrorCondition(
parsedPlan,
expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`x`")
expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_COLUMNS",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because of spark.sql.defaultColumn.useNullsForMissingDefaultValues and is aligned with V1 writes.

expectedMessageParameters = Map("tableName" -> "`table-name`", "extraColumns" -> "`a`, `b`")
)
}

Expand All @@ -438,8 +438,8 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest {
assertNotResolved(parsedPlan)
assertAnalysisErrorCondition(
parsedPlan,
expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`x`")
expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_COLUMNS",
expectedMessageParameters = Map("tableName" -> "`table-name`", "extraColumns" -> "`X`")
)
}

Expand Down Expand Up @@ -513,12 +513,14 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest {

val parsedPlan = byName(table, query)

assertNotResolved(parsedPlan)
assertAnalysisErrorCondition(
parsedPlan,
expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`x`")
)
withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "false") {
assertNotResolved(parsedPlan)
assertAnalysisErrorCondition(
parsedPlan,
expectedErrorCondition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
expectedMessageParameters = Map("tableName" -> "`table-name`", "colName" -> "`x`")
)
}
}

test("byName: insert safe cast") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import com.google.common.base.Objects

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow, MetadataStructFieldWithLogicalName}
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, ResolveDefaultColumns}
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
import org.apache.spark.sql.connector.expressions._
import org.apache.spark.sql.connector.metric.{CustomMetric, CustomSumMetric, CustomTaskMetric}
Expand Down Expand Up @@ -141,7 +141,8 @@ abstract class InMemoryBaseTable(
schema: StructType,
row: InternalRow): (Any, DataType) = {
val index = schema.fieldIndex(fieldNames(0))
val value = row.toSeq(schema).apply(index)
val field = schema(index)
val value = row.get(index, field.dataType)
if (fieldNames.length > 1) {
(value, schema(index).dataType) match {
case (row: InternalRow, nestedSchema: StructType) =>
Expand Down Expand Up @@ -400,18 +401,23 @@ abstract class InMemoryBaseTable(
val sizeInBytes = numRows * rowSizeInBytes

val numOfCols = tableSchema.fields.length
val dataTypes = tableSchema.fields.map(_.dataType)
val colValueSets = new Array[util.HashSet[Object]](numOfCols)
val colValueSets = new Array[util.HashSet[Any]](numOfCols)
val numOfNulls = new Array[Long](numOfCols)
for (i <- 0 until numOfCols) {
colValueSets(i) = new util.HashSet[Object]
colValueSets(i) = new util.HashSet[Any]
}

inputPartitions.foreach(inputPartition =>
inputPartition.rows.foreach(row =>
for (i <- 0 until numOfCols) {
colValueSets(i).add(row.get(i, dataTypes(i)))
if (row.isNullAt(i)) {
val field = tableSchema(i)
val colValue = if (i < row.numFields) {
row.get(i, field.dataType)
} else {
ResolveDefaultColumns.getExistenceDefaultValue(field)
}
colValueSets(i).add(colValue)
if (colValue == null) {
numOfNulls(i) += 1
}
}
Expand Down Expand Up @@ -718,6 +724,11 @@ private class BufferedRowsReader(
schema: StructType,
row: InternalRow): Any = {
val index = schema.fieldIndex(field.name)

if (index >= row.numFields) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed for support for adding columns with default values to the end.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is method extractFieldValue. Looks like it is only used by get. Why this is for adding columns?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is needed to read data inserted prior to adding columns to the schema. If that happens, there would be extra columns in the schema and we have to default new columns using the existence default value.

return ResolveDefaultColumns.getExistenceDefaultValue(field)
}

field.dataType match {
case StructType(fields) =>
if (row.isNullAt(index)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class BasicInMemoryTableCatalog extends TableCatalog {
override def alterTable(ident: Identifier, changes: TableChange*): Table = {
val table = loadTable(ident).asInstanceOf[InMemoryTable]
val properties = CatalogV2Util.applyPropertiesChanges(table.properties, changes)
val schema = CatalogV2Util.applySchemaChanges(table.schema, changes, None, "ALTER TABLE")
val schema = CatalogV2Util.applySchemaChanges(table.schema, changes, Some(name), "ALTER TABLE")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add memory table provider to DEFAULT_COLUMN_ALLOWED_PROVIDERS?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, nvm, the name is given when initializing the catalog.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct.

val finalPartitioning = CatalogV2Util.applyClusterByChanges(table.partitioning, schema, changes)

// fail if the last column in the schema was dropped
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
exception = intercept[AnalysisException] {
spark.table("source").withColumnRenamed("data", "d").writeTo("testcat.table_name").append()
},
condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
parameters = Map("tableName" -> "`testcat`.`table_name`", "colName" -> "`data`")
condition = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_COLUMNS",
parameters = Map("tableName" -> "`testcat`.`table_name`", "extraColumns" -> "`d`")
)

checkAnswer(
Expand Down Expand Up @@ -251,8 +251,8 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
spark.table("source").withColumnRenamed("data", "d")
.writeTo("testcat.table_name").overwrite(lit(true))
},
condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
parameters = Map("tableName" -> "`testcat`.`table_name`", "colName" -> "`data`")
condition = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_COLUMNS",
parameters = Map("tableName" -> "`testcat`.`table_name`", "extraColumns" -> "`d`")
)

checkAnswer(
Expand Down Expand Up @@ -356,8 +356,8 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo
spark.table("source").withColumnRenamed("data", "d")
.writeTo("testcat.table_name").overwritePartitions()
},
condition = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
parameters = Map("tableName" -> "`testcat`.`table_name`", "colName" -> "`data`")
condition = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_COLUMNS",
parameters = Map("tableName" -> "`testcat`.`table_name`", "extraColumns" -> "`d`")
)

checkAnswer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,10 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils with AdaptiveSparkP
processInsert("t1", df, overwrite = false, byName = true)
},
v1ErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_COLUMNS",
v2ErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA",
v2ErrorClass = "INCOMPATIBLE_DATA_FOR_TABLE.EXTRA_COLUMNS",
v1Parameters = Map("tableName" -> "`spark_catalog`.`default`.`t1`",
"extraColumns" -> "`x1`"),
v2Parameters = Map("tableName" -> "`testcat`.`t1`", "colName" -> "`c1`")
v2Parameters = Map("tableName" -> "`testcat`.`t1`", "extraColumns" -> "`x1`")
)
val df2 = Seq((3, 2, 1, 0)).toDF(Seq("c3", "c2", "c1", "c0"): _*)
checkV1AndV2Error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase {

protected val catalogAndNamespace: String

protected def catalog: String = {
if (catalogAndNamespace.nonEmpty) {
catalogAndNamespace.split('.').headOption.getOrElse("spark_catalog")
} else {
"spark_catalog"
}
}

protected val v2Format: String

private def fullTableName(tableName: String): String = {
Expand Down Expand Up @@ -328,7 +336,7 @@ trait AlterTableTests extends SharedSparkSession with QueryErrorsBase {
}

test("SPARK-39383 DEFAULT columns on V2 data sources with ALTER TABLE ADD/ALTER COLUMN") {
withSQLConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS.key -> s"$v2Format, ") {
withSQLConf(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS.key -> s"$v2Format,$catalog") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this conf affect the testing v2 in-memory catalog? I thought it's only for v1 file source.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya is correct. We previously passed "" as provider in the in-memory connector, which required workarounds like this. No longer needed as we pass the catalog name as provider. Simplifies testing.

val t = fullTableName("table_name")
withTable("t") {
sql(s"create table $t (a string) using $v2Format")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class DataSourceV2DataFrameSuite
before {
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
spark.conf.set("spark.sql.catalog.testcat2", classOf[InMemoryTableCatalog].getName)
spark.conf.set(SQLConf.DEFAULT_COLUMN_ALLOWED_PROVIDERS.key, "testcat")
}

after {
Expand Down Expand Up @@ -263,4 +264,38 @@ class DataSourceV2DataFrameSuite
spark.listenerManager.unregister(listener)
}
}

test("columns with default values") {
val tableName = "testcat.ns1.ns2.tbl"
withTable(tableName) {
sql(s"CREATE TABLE $tableName (id INT, dep STRING) USING foo")

val df1 = Seq((1, "hr")).toDF("id", "dep")
df1.writeTo(tableName).append()

sql(s"ALTER TABLE $tableName ADD COLUMN txt STRING DEFAULT 'initial-text'")

val df2 = Seq((2, "hr"), (3, "software")).toDF("id", "dep")
df2.writeTo(tableName).append()

sql(s"ALTER TABLE $tableName ALTER COLUMN txt SET DEFAULT 'new-text'")

val df3 = Seq((4, "hr"), (5, "hr")).toDF("id", "dep")
df3.writeTo(tableName).append()

val df4 = Seq((6, "hr", null), (7, "hr", "explicit-text")).toDF("id", "dep", "txt")
df4.writeTo(tableName).append()

checkAnswer(
sql(s"SELECT * FROM $tableName"),
Seq(
Row(1, "hr", "initial-text"),
Row(2, "hr", "initial-text"),
Row(3, "software", "initial-text"),
Row(4, "hr", "new-text"),
Row(5, "hr", "new-text"),
Row(6, "hr", null),
Row(7, "hr", "explicit-text")))
}
}
}