Skip to content

Commit

Permalink
[SPARK-47009][SQL] Enable create table support for collation
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Adding support for create table with collated columns using parquet.

We will map collated strings types to a regular parquet string type. This means that won't support cross-engine compatibility for now.

I will add a PR soon to fix parquet filter pushdown. At first we will disable it completely for collated strings but we should look into using sort keys instead as min/max values to support pushdown later on.

### Why are the changes needed?

In order to support basic DDL operations for collations

### Does this PR introduce _any_ user-facing change?

Yes, users are now able to create tables with collated columns

### How was this patch tested?

With UTs

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#45105 from stefankandic/SPARK-47009-createTableCollation.

Authored-by: Stefan Kandic <stefan.kandic@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
stefankandic authored and ericm-db committed Mar 5, 2024
1 parent d0af9bc commit d9f61cc
Show file tree
Hide file tree
Showing 13 changed files with 161 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -989,7 +989,7 @@ primaryExpression
| CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase
| CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase
| name=(CAST | TRY_CAST) LEFT_PAREN expression AS dataType RIGHT_PAREN #cast
| primaryExpression COLLATE stringLit #collate
| primaryExpression collateClause #collate
| primaryExpression DOUBLE_COLON dataType #castByColon
| STRUCT LEFT_PAREN (argument+=namedExpression (COMMA argument+=namedExpression)*)? RIGHT_PAREN #struct
| FIRST LEFT_PAREN expression (IGNORE NULLS)? RIGHT_PAREN #first
Expand Down Expand Up @@ -1095,6 +1095,10 @@ colPosition
: position=FIRST | position=AFTER afterCol=errorCapturingIdentifier
;

collateClause
: COLLATE collationName=stringLit
;

type
: BOOLEAN
| TINYINT | BYTE
Expand All @@ -1105,7 +1109,7 @@ type
| DOUBLE
| DATE
| TIMESTAMP | TIMESTAMP_NTZ | TIMESTAMP_LTZ
| STRING
| STRING collateClause?
| CHARACTER | CHAR
| VARCHAR
| BINARY
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.antlr.v4.runtime.Token
import org.antlr.v4.runtime.tree.ParseTree

import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.catalyst.util.SparkParserUtils.{string, withOrigin}
import org.apache.spark.sql.errors.QueryParsingErrors
import org.apache.spark.sql.internal.SqlApiConf
Expand Down Expand Up @@ -58,8 +59,8 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
* Resolve/create a primitive type.
*/
override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) {
val typeName = ctx.`type`.start.getType
(typeName, ctx.INTEGER_VALUE().asScala.toList) match {
val typeCtx = ctx.`type`
(typeCtx.start.getType, ctx.INTEGER_VALUE().asScala.toList) match {
case (BOOLEAN, Nil) => BooleanType
case (TINYINT | BYTE, Nil) => ByteType
case (SMALLINT | SHORT, Nil) => ShortType
Expand All @@ -71,7 +72,14 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
case (TIMESTAMP, Nil) => SqlApiConf.get.timestampType
case (TIMESTAMP_NTZ, Nil) => TimestampNTZType
case (TIMESTAMP_LTZ, Nil) => TimestampType
case (STRING, Nil) => StringType
case (STRING, Nil) =>
typeCtx.children.asScala.toSeq match {
case Seq(_) => StringType
case Seq(_, ctx: CollateClauseContext) =>
val collationName = visitCollateClause(ctx)
val collationId = CollationFactory.collationNameToId(collationName)
StringType(collationId)
}
case (CHARACTER | CHAR, length :: Nil) => CharType(length.getText.toInt)
case (VARCHAR, length :: Nil) => VarcharType(length.getText.toInt)
case (BINARY, Nil) => BinaryType
Expand Down Expand Up @@ -205,4 +213,11 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
override def visitCommentSpec(ctx: CommentSpecContext): String = withOrigin(ctx) {
string(visitStringLit(ctx.stringLit))
}

/**
* Returns a collation name.
*/
override def visitCollateClause(ctx: CollateClauseContext): String = withOrigin(ctx) {
string(visitStringLit(ctx.stringLit))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ import org.apache.spark.{SparkIllegalArgumentException, SparkThrowable}
import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.analysis.SqlApiAnalysis
import org.apache.spark.sql.catalyst.parser.DataTypeParser
import org.apache.spark.sql.catalyst.util.{CollationFactory, StringConcat}
import org.apache.spark.sql.catalyst.util.DataTypeJsonUtils.{DataTypeJsonDeserializer, DataTypeJsonSerializer}
import org.apache.spark.sql.catalyst.util.StringConcat
import org.apache.spark.sql.errors.DataTypeErrors
import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types.DayTimeIntervalType._
Expand Down Expand Up @@ -117,6 +117,7 @@ object DataType {
private val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r
private val CHAR_TYPE = """char\(\s*(\d+)\s*\)""".r
private val VARCHAR_TYPE = """varchar\(\s*(\d+)\s*\)""".r
private val COLLATED_STRING_TYPE = """string\s+COLLATE\s+([\w_]+)""".r

def fromDDL(ddl: String): DataType = {
parseTypeWithFallback(
Expand Down Expand Up @@ -181,6 +182,9 @@ object DataType {
/** Given the string representation of a type, return its DataType */
private def nameToType(name: String): DataType = {
name match {
case COLLATED_STRING_TYPE(collation) =>
val collationId = CollationFactory.collationNameToId(collation)
StringType(collationId)
case "decimal" => DecimalType.USER_DEFAULT
case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt)
case CHAR_TYPE(length) => CharType(length.toInt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa
*/
override def typeName: String =
if (isDefaultCollation) "string"
else s"string(${CollationFactory.fetchCollation(collationId).collationName})"
else s"string COLLATE ${CollationFactory.fetchCollation(collationId).collationName}"

override def equals(obj: Any): Boolean =
obj.isInstanceOf[StringType] && obj.asInstanceOf[StringType].collationId == collationId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ object Cast extends QueryErrorsBase {

case (NullType, _) => true

case (_, StringType) => true
case (_, _: StringType) => true

case (StringType, _: BinaryType) => true

Expand Down Expand Up @@ -301,8 +301,8 @@ object Cast extends QueryErrorsBase {
case _ if from == to => true
case (NullType, _) => true
case (_: NumericType, _: NumericType) => true
case (_: AtomicType, StringType) => true
case (_: CalendarIntervalType, StringType) => true
case (_: AtomicType, _: StringType) => true
case (_: CalendarIntervalType, _: StringType) => true
case (_: DatetimeType, _: DatetimeType) => true

case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
Expand Down Expand Up @@ -574,7 +574,7 @@ case class Cast(

// BinaryConverter
private[this] def castToBinary(from: DataType): Any => Any = from match {
case StringType => buildCast[UTF8String](_, _.getBytes)
case _: StringType => buildCast[UTF8String](_, _.getBytes)
case ByteType => buildCast[Byte](_, NumberConverter.toBinary)
case ShortType => buildCast[Short](_, NumberConverter.toBinary)
case IntegerType => buildCast[Int](_, NumberConverter.toBinary)
Expand Down Expand Up @@ -1109,7 +1109,7 @@ case class Cast(
} else {
to match {
case dt if dt == from => identity[Any]
case StringType => castToString(from)
case _: StringType => castToString(from)
case BinaryType => castToBinary(from)
case DateType => castToDate(from)
case decimal: DecimalType => castToDecimal(from, decimal)
Expand Down Expand Up @@ -1198,7 +1198,7 @@ case class Cast(

case _ if from == NullType => (c, evPrim, evNull) => code"$evNull = true;"
case _ if to == from => (c, evPrim, evNull) => code"$evPrim = $c;"
case StringType => (c, evPrim, _) => castToStringCode(from, ctx).apply(c, evPrim)
case _: StringType => (c, evPrim, _) => castToStringCode(from, ctx).apply(c, evPrim)
case BinaryType => castToBinaryCode(from)
case DateType => castToDateCode(from, ctx)
case decimal: DecimalType => castToDecimalCode(from, decimal, ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ abstract class HashExpression[E] extends Expression {
case _: DayTimeIntervalType => genHashLong(input, result)
case _: YearMonthIntervalType => genHashInt(input, result)
case BinaryType => genHashBytes(input, result)
case StringType => genHashString(input, result)
case _: StringType => genHashString(input, result)
case ArrayType(et, containsNull) => genHashForArray(ctx, input, result, et, containsNull)
case MapType(kt, vt, valueContainsNull) =>
genHashForMap(ctx, input, result, kt, vt, valueContainsNull)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2186,8 +2186,8 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
* Create a [[Collate]] expression.
*/
override def visitCollate(ctx: CollateContext): Expression = withOrigin(ctx) {
val collation = string(visitStringLit(ctx.stringLit))
Collate(expression(ctx.primaryExpression), collation)
val collationName = visitCollateClause(ctx.collateClause())
Collate(expression(ctx.primaryExpression), collationName)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa
}
}
case BINARY -> {
if (sparkType == DataTypes.StringType || sparkType == DataTypes.BinaryType ||
if (sparkType instanceof StringType || sparkType == DataTypes.BinaryType ||
canReadAsBinaryDecimal(descriptor, sparkType)) {
return new BinaryUpdater();
} else if (canReadAsDecimal(descriptor, sparkType)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ abstract class HashMapGenerator(
${hashBytes(bytes)}
"""
}
case StringType => hashBytes(s"$input.getBytes()")
case _: StringType => hashBytes(s"$input.getBytes()")
case CalendarIntervalType => hashInt(s"$input.hashCode()")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ private[parquet] class ParquetRowConverter(
throw QueryExecutionErrors.cannotCreateParquetConverterForDecimalTypeError(
t, parquetType.toString)

case StringType =>
case _: StringType =>
new ParquetStringConverter(updater)

// As long as the parquet type is INT64 timestamp, whether logical annotation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ class SparkToParquetSchemaConverter(
case DoubleType =>
Types.primitive(DOUBLE, repetition).named(field.name)

case StringType =>
case _: StringType =>
Types.primitive(BINARY, repetition)
.as(LogicalTypeAnnotation.stringType()).named(field.name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging {
(row: SpecializedGetters, ordinal: Int) =>
recordConsumer.addDouble(row.getDouble(ordinal))

case StringType =>
case _: StringType =>
(row: SpecializedGetters, ordinal: Int) =>
recordConsumer.addBinary(
Binary.fromReusedByteArray(row.getUTF8String(ordinal).getBytes))
Expand Down
119 changes: 117 additions & 2 deletions sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,21 @@

package org.apache.spark.sql

import scala.collection.immutable.Seq
import scala.jdk.CollectionConverters.MapHasAsJava

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCustomSchema}
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper
import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
import org.apache.spark.sql.types.StringType

class CollationSuite extends QueryTest with SharedSparkSession {
class CollationSuite extends DatasourceV2SQLBase {
protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName

test("collate returns proper type") {
Seq("ucs_basic", "ucs_basic_lcase", "unicode", "unicode_ci").foreach { collationName =>
checkAnswer(sql(s"select 'aaa' collate '$collationName'"), Row("aaa"))
Expand Down Expand Up @@ -174,4 +182,111 @@ class CollationSuite extends QueryTest with SharedSparkSession {
Row(expected))
}
}

test("create table with collation") {
val tableName = "parquet_dummy_tbl"
val collationName = "UCS_BASIC_LCASE"
val collationId = CollationFactory.collationNameToId(collationName)

withTable(tableName) {
sql(
s"""
|CREATE TABLE $tableName (c1 STRING COLLATE '$collationName')
|USING PARQUET
|""".stripMargin)

sql(s"INSERT INTO $tableName VALUES ('aaa')")
sql(s"INSERT INTO $tableName VALUES ('AAA')")

checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"), Seq(Row(collationName)))
assert(sql(s"select c1 FROM $tableName").schema.head.dataType == StringType(collationId))
}
}

test("create table with collations inside a struct") {
val tableName = "struct_collation_tbl"
val collationName = "UCS_BASIC_LCASE"
val collationId = CollationFactory.collationNameToId(collationName)

withTable(tableName) {
sql(
s"""
|CREATE TABLE $tableName
|(c1 STRUCT<name: STRING COLLATE '$collationName', age: INT>)
|USING PARQUET
|""".stripMargin)

sql(s"INSERT INTO $tableName VALUES (named_struct('name', 'aaa', 'id', 1))")
sql(s"INSERT INTO $tableName VALUES (named_struct('name', 'AAA', 'id', 2))")

checkAnswer(sql(s"SELECT DISTINCT collation(c1.name) FROM $tableName"),
Seq(Row(collationName)))
assert(sql(s"SELECT c1.name FROM $tableName").schema.head.dataType == StringType(collationId))
}
}

test("add collated column with alter table") {
val tableName = "alter_column_tbl"
val defaultCollation = "UCS_BASIC"
val collationName = "UCS_BASIC_LCASE"
val collationId = CollationFactory.collationNameToId(collationName)

withTable(tableName) {
sql(
s"""
|CREATE TABLE $tableName (c1 STRING)
|USING PARQUET
|""".stripMargin)

sql(s"INSERT INTO $tableName VALUES ('aaa')")
sql(s"INSERT INTO $tableName VALUES ('AAA')")

checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"),
Seq(Row(defaultCollation)))

sql(
s"""
|ALTER TABLE $tableName
|ADD COLUMN c2 STRING COLLATE '$collationName'
|""".stripMargin)

sql(s"INSERT INTO $tableName VALUES ('aaa', 'aaa')")
sql(s"INSERT INTO $tableName VALUES ('AAA', 'AAA')")

checkAnswer(sql(s"SELECT DISTINCT COLLATION(c2) FROM $tableName"),
Seq(Row(collationName)))
assert(sql(s"select c2 FROM $tableName").schema.head.dataType == StringType(collationId))
}
}

test("create v2 table with collation column") {
val tableName = "testcat.table_name"
val collationName = "UCS_BASIC_LCASE"
val collationId = CollationFactory.collationNameToId(collationName)

withTable(tableName) {
sql(
s"""
|CREATE TABLE $tableName (c1 string COLLATE '$collationName')
|USING $v2Source
|""".stripMargin)

val testCatalog = catalog("testcat").asTableCatalog
val table = testCatalog.loadTable(Identifier.of(Array(), "table_name"))

assert(table.name == tableName)
assert(table.partitioning.isEmpty)
assert(table.properties == withDefaultOwnership(Map("provider" -> v2Source)).asJava)
assert(table.columns().head.dataType() == StringType(collationId))

val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Seq.empty)

sql(s"INSERT INTO $tableName VALUES ('a'), ('A')")

checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"),
Seq(Row(collationName)))
assert(sql(s"select c1 FROM $tableName").schema.head.dataType == StringType(collationId))
}
}
}

0 comments on commit d9f61cc

Please sign in to comment.