diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala index 6feb662632763..c3643f4bd15be 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala @@ -21,43 +21,79 @@ import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} /** - * AbstractStringType is an abstract class for StringType with collation support. + * AbstractStringType is an abstract class for StringType with collation support. As every type of + * collation can support trim specifier this class is parametrized with it. */ -abstract class AbstractStringType extends AbstractDataType { +abstract class AbstractStringType(private[sql] val supportsTrimCollation: Boolean = false) + extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType override private[sql] def simpleString: String = "string" + private[sql] def canUseTrimCollation(other: DataType): Boolean = + supportsTrimCollation || !other.asInstanceOf[StringType].usesTrimCollation } /** * Use StringTypeBinary for expressions supporting only binary collation. */ -case object StringTypeBinary extends AbstractStringType { +case class StringTypeBinary(override val supportsTrimCollation: Boolean = false) + extends AbstractStringType(supportsTrimCollation) { override private[sql] def acceptsType(other: DataType): Boolean = - other.isInstanceOf[StringType] && other.asInstanceOf[StringType].supportsBinaryEquality + other.isInstanceOf[StringType] && other.asInstanceOf[StringType].supportsBinaryEquality && + canUseTrimCollation(other) +} + +object StringTypeBinary extends StringTypeBinary(false) { + def apply(supportsTrimCollation: Boolean): StringTypeBinary = { + new StringTypeBinary(supportsTrimCollation) + } } /** * Use StringTypeBinaryLcase for expressions supporting only binary and lowercase collation. */ -case object StringTypeBinaryLcase extends AbstractStringType { +case class StringTypeBinaryLcase(override val supportsTrimCollation: Boolean = false) + extends AbstractStringType(supportsTrimCollation) { override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] && (other.asInstanceOf[StringType].supportsBinaryEquality || - other.asInstanceOf[StringType].isUTF8LcaseCollation) + other.asInstanceOf[StringType].isUTF8LcaseCollation) && canUseTrimCollation(other) +} + +object StringTypeBinaryLcase extends StringTypeBinaryLcase(false) { + def apply(supportsTrimCollation: Boolean): StringTypeBinaryLcase = { + new StringTypeBinaryLcase(supportsTrimCollation) + } } /** * Use StringTypeWithCaseAccentSensitivity for expressions supporting all collation types (binary * and ICU) but limited to using case and accent sensitivity specifiers. */ -case object StringTypeWithCaseAccentSensitivity extends AbstractStringType { - override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] +case class StringTypeWithCaseAccentSensitivity( + override val supportsTrimCollation: Boolean = false) + extends AbstractStringType(supportsTrimCollation) { + override private[sql] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[StringType] && canUseTrimCollation(other) +} + +object StringTypeWithCaseAccentSensitivity extends StringTypeWithCaseAccentSensitivity(false) { + def apply(supportsTrimCollation: Boolean): StringTypeWithCaseAccentSensitivity = { + new StringTypeWithCaseAccentSensitivity(supportsTrimCollation) + } } /** * Use StringTypeNonCSAICollation for expressions supporting all possible collation types except * CS_AI collation types. */ -case object StringTypeNonCSAICollation extends AbstractStringType { +case class StringTypeNonCSAICollation(override val supportsTrimCollation: Boolean = false) + extends AbstractStringType(supportsTrimCollation) { override private[sql] def acceptsType(other: DataType): Boolean = - other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAI + other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAI && + canUseTrimCollation(other) +} + +object StringTypeNonCSAICollation extends StringTypeNonCSAICollation(false) { + def apply(supportsTrimCollation: Boolean): StringTypeNonCSAICollation = { + new StringTypeNonCSAICollation(supportsTrimCollation) + } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index c2dd6cec7ba74..29d48e3d1f47f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -47,6 +47,9 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ private[sql] def isNonCSAI: Boolean = !CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId) + private[sql] def usesTrimCollation: Boolean = + CollationFactory.usesTrimCollation(collationId) + private[sql] def isUTF8BinaryCollation: Boolean = collationId == CollationFactory.UTF8_BINARY_COLLATION_ID diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala index 28ec8482e5cdd..81bafda54135f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true)) override def dataType: DataType = BinaryType final lazy val collationId: Int = expr.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala index 78bd02d5703cd..a6448051a3996 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala @@ -106,7 +106,11 @@ case class HllSketchAgg( override def inputTypes: Seq[AbstractDataType] = Seq( - TypeCollection(IntegerType, LongType, StringTypeWithCaseAccentSensitivity, BinaryType), + TypeCollection( + IntegerType, + LongType, + StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true), + BinaryType), IntegerType) override def dataType: DataType = BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index b67e66323bbbd..effcdc4b038e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -77,7 +77,8 @@ case class Collate(child: Expression, collationName: String) extends UnaryExpression with ExpectsInputTypes { private val collationId = CollationFactory.collationNameToId(collationName) override def dataType: DataType = StringType(collationId) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true)) override protected def withNewChildInternal( newChild: Expression): Expression = copy(newChild) @@ -115,5 +116,6 @@ case class Collation(child: Expression) val collationName = CollationFactory.fetchCollation(collationId).collationName Literal.create(collationName, SQLConf.get.defaultStringType) } - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCaseAccentSensitivity(/* supportsTrimCollation = */ true)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 851160d2fbb94..4c3cd93873bd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -982,7 +982,11 @@ class CollationSQLExpressionsSuite StringToMapTestCase("1/AX2/BX3/C", "x", "/", "UNICODE_CI", Map("1" -> "A", "2" -> "B", "3" -> "C")) ) - val unsupportedTestCase = StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_AI", null) + val unsupportedTestCases = Seq( + StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_AI", null), + StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_RTRIM", null), + StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UTF8_BINARY_RTRIM", null), + StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UTF8_LCASE_RTRIM", null)) testCases.foreach(t => { // Unit test. val text = Literal.create(t.text, StringType(t.collation)) @@ -998,28 +1002,30 @@ class CollationSQLExpressionsSuite } }) // Test unsupported collation. - withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { - val query = - s"select str_to_map('${unsupportedTestCase.text}', '${unsupportedTestCase.pairDelim}', " + - s"'${unsupportedTestCase.keyValueDelim}')" - checkError( - exception = intercept[AnalysisException] { - sql(query).collect() - }, - condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", - sqlState = Some("42K09"), - parameters = Map( - "sqlExpr" -> ("\"str_to_map('a:1,b:2,c:3' collate UNICODE_AI, " + - "'?' collate UNICODE_AI, '?' collate UNICODE_AI)\""), - "paramIndex" -> "first", - "inputSql" -> "\"'a:1,b:2,c:3' collate UNICODE_AI\"", - "inputType" -> "\"STRING COLLATE UNICODE_AI\"", - "requiredType" -> "\"STRING\""), - context = ExpectedContext( - fragment = "str_to_map('a:1,b:2,c:3', '?', '?')", - start = 7, - stop = 41)) - } + unsupportedTestCases.foreach(t => { + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> t.collation) { + val query = + s"select str_to_map('${t.text}', '${t.pairDelim}', " + + s"'${t.keyValueDelim}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"str_to_map('a:1,b:2,c:3' collate " + s"${t.collation}, " + + "'?' collate " + s"${t.collation}, '?' collate ${t.collation})" + "\""), + "paramIndex" -> "first", + "inputSql" -> ("\"'a:1,b:2,c:3' collate " + s"${t.collation}" + "\""), + "inputType" -> ("\"STRING COLLATE " + s"${t.collation}" + "\""), + "requiredType" -> "\"STRING\""), + context = ExpectedContext( + fragment = "str_to_map('a:1,b:2,c:3', '?', '?')", + start = 7, + stop = 41)) + } + }) } test("Support RaiseError misc expression with collation") {