From 6eae15753d7441f318c92b4bf97ff1e1a3edadf6 Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Mon, 30 Sep 2024 17:41:42 +0800 Subject: [PATCH] [SPARK-49811][SQL] Rename StringTypeAnyCollation ### What changes were proposed in this pull request? Rename StringTypeAnyCollation to StringTypeWithCaseAccentSensitivity. Name StringTypeAnyCollation is unfortunate, with adding new type of collations it requires ren ### Why are the changes needed? Name StringTypeAnyCollation is unfortunate, with adding new specifier (for example trim specifier) it requires always renaming it to (something like AllCollationExeptTrimCollation) until new collation is implemented in all functions. It gets even more confusing if multiple collations are not supported for some functions. Instead of this naming convention should be only specifiers that are supported and avoid using all. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Just renaming all tests passing. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48265 from jovanpavl-db/rename-string-type-collations. Authored-by: Jovan Pavlovic Signed-off-by: Wenchen Fan --- .../internal/types/AbstractStringType.scala | 7 +- .../sql/catalyst/analysis/TypeCoercion.scala | 5 +- .../expressions/CallMethodViaReflection.scala | 9 +- .../catalyst/expressions/CollationKey.scala | 4 +- .../sql/catalyst/expressions/ExprUtils.scala | 5 +- .../aggregate/datasketchesAggregates.scala | 6 +- .../expressions/collationExpressions.scala | 6 +- .../expressions/collectionOperations.scala | 13 ++- .../catalyst/expressions/csvExpressions.scala | 4 +- .../expressions/datetimeExpressions.scala | 41 ++++--- .../expressions/jsonExpressions.scala | 14 ++- .../expressions/maskExpressions.scala | 10 +- .../expressions/mathExpressions.scala | 8 +- .../spark/sql/catalyst/expressions/misc.scala | 13 ++- .../expressions/numberFormatExpressions.scala | 7 +- .../expressions/regexpExpressions.scala | 18 +-- .../expressions/stringExpressions.scala | 103 ++++++++++-------- .../catalyst/expressions/urlExpressions.scala | 13 ++- .../variant/variantExpressions.scala | 7 +- .../sql/catalyst/expressions/xml/xpath.scala | 6 +- .../catalyst/expressions/xmlExpressions.scala | 4 +- .../analysis/AnsiTypeCoercionSuite.scala | 20 ++-- .../expressions/StringExpressionsSuite.scala | 4 +- .../sql/CollationExpressionWalkerSuite.scala | 51 +++++---- 24 files changed, 218 insertions(+), 160 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala index dc4ee013fd189..6feb662632763 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} /** - * StringTypeCollated is an abstract class for StringType with collation support. + * AbstractStringType is an abstract class for StringType with collation support. */ abstract class AbstractStringType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType @@ -46,9 +46,10 @@ case object StringTypeBinaryLcase extends AbstractStringType { } /** - * Use StringTypeAnyCollation for expressions supporting all possible collation types. + * Use StringTypeWithCaseAccentSensitivity for expressions supporting all collation types (binary + * and ICU) but limited to using case and accent sensitivity specifiers. */ -case object StringTypeAnyCollation extends AbstractStringType { +case object StringTypeWithCaseAccentSensitivity extends AbstractStringType { override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 5983346ff1e27..e0298b19931c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -32,7 +32,8 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType, StringTypeAnyCollation} +import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType, + StringTypeWithCaseAccentSensitivity} import org.apache.spark.sql.types._ import org.apache.spark.sql.types.UpCastRule.numericPrecedence @@ -438,7 +439,7 @@ abstract class TypeCoercionBase { } case aj @ ArrayJoin(arr, d, nr) - if !AbstractArrayType(StringTypeAnyCollation).acceptsType(arr.dataType) && + if !AbstractArrayType(StringTypeWithCaseAccentSensitivity).acceptsType(arr.dataType) && ArrayType.acceptsType(arr.dataType) => val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull implicitCast(arr, ArrayType(StringType, containsNull)) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index 13ea8c77c41b4..6aa11b6fd16df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -84,7 +84,7 @@ case class CallMethodViaReflection( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("class"), - "inputType" -> toSQLType(StringTypeAnyCollation), + "inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity), "inputExpr" -> toSQLExpr(children.head) ) ) @@ -97,7 +97,7 @@ case class CallMethodViaReflection( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("method"), - "inputType" -> toSQLType(StringTypeAnyCollation), + "inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity), "inputExpr" -> toSQLExpr(children(1)) ) ) @@ -114,7 +114,8 @@ case class CallMethodViaReflection( "paramIndex" -> ordinalNumber(idx), "requiredType" -> toSQLType( TypeCollection(BooleanType, ByteType, ShortType, - IntegerType, LongType, FloatType, DoubleType, StringTypeAnyCollation)), + IntegerType, LongType, FloatType, DoubleType, + StringTypeWithCaseAccentSensitivity)), "inputSql" -> toSQLExpr(e), "inputType" -> toSQLType(e.dataType)) ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala index 6e400d026e0ee..28ec8482e5cdd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.CollationFactory -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def dataType: DataType = BinaryType final lazy val collationId: Int = expr.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 749152f135e92..08cb03edb78b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors} -import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeAnyCollation} +import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCaseAccentSensitivity} import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType, VariantType} import org.apache.spark.unsafe.types.UTF8String @@ -61,7 +61,8 @@ object ExprUtils extends QueryErrorsBase { def convertToMapData(exp: Expression): Map[String, String] = exp match { case m: CreateMap - if AbstractMapType(StringTypeAnyCollation, StringTypeAnyCollation).acceptsType(m.dataType) => + if AbstractMapType(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) + .acceptsType(m.dataType) => val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData] ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) => key.toString -> value.toString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala index 2102428131f64..78bd02d5703cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, DataType, IntegerType, LongType, StringType, TypeCollection} import org.apache.spark.unsafe.types.UTF8String @@ -105,7 +105,9 @@ case class HllSketchAgg( override def prettyName: String = "hll_sketch_agg" override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(IntegerType, LongType, StringTypeAnyCollation, BinaryType), IntegerType) + Seq( + TypeCollection(IntegerType, LongType, StringTypeWithCaseAccentSensitivity, BinaryType), + IntegerType) override def dataType: DataType = BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index d45ca533f9392..0cff70436db7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ // scalastyle:off line.contains.tab @@ -73,7 +73,7 @@ case class Collate(child: Expression, collationName: String) extends UnaryExpression with ExpectsInputTypes { private val collationId = CollationFactory.collationNameToId(collationName) override def dataType: DataType = StringType(collationId) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override protected def withNewChildInternal( newChild: Expression): Expression = copy(newChild) @@ -111,5 +111,5 @@ case class Collation(child: Expression) val collationName = CollationFactory.fetchCollation(collationId).collationName Literal.create(collationName, SQLConf.get.defaultStringType) } - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5cdd3c7eb62d1..c091d51fc177f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation} +import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeWithCaseAccentSensitivity} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SQLOpenHashSet import org.apache.spark.unsafe.UTF8StringBuilder @@ -1348,7 +1348,7 @@ case class Reverse(child: Expression) // Input types are utilized by type coercion in ImplicitTypeCasts. override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeAnyCollation, ArrayType)) + Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, ArrayType)) override def dataType: DataType = child.dataType @@ -2134,9 +2134,12 @@ case class ArrayJoin( this(array, delimiter, Some(nullReplacement)) override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) { - Seq(AbstractArrayType(StringTypeAnyCollation), StringTypeAnyCollation, StringTypeAnyCollation) + Seq(AbstractArrayType(StringTypeWithCaseAccentSensitivity), + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity) } else { - Seq(AbstractArrayType(StringTypeAnyCollation), StringTypeAnyCollation) + Seq(AbstractArrayType(StringTypeWithCaseAccentSensitivity), + StringTypeWithCaseAccentSensitivity) } override def children: Seq[Expression] = if (nullReplacement.isDefined) { @@ -2857,7 +2860,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio with QueryErrorsBase { private def allowedTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, BinaryType, ArrayType) + Seq(StringTypeWithCaseAccentSensitivity, BinaryType, ArrayType) final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index cb10440c48328..2f4462c0664f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.TypeUtils._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -147,7 +147,7 @@ case class CsvToStructs( converter(parser.parse(csv)) } - override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil override def prettyName: String = "from_csv" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 36bd53001594e..b166d235557fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.catalyst.util.LegacyDateFormats.SIMPLE_DATE_FORMAT import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DayTimeIntervalType.DAY import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -961,7 +961,8 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti override def dataType: DataType = SQLConf.get.defaultStringType - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(TimestampType, StringTypeWithCaseAccentSensitivity) override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) @@ -1269,8 +1270,10 @@ abstract class ToTimestamp override def forTimestampNTZ: Boolean = left.dataType == TimestampNTZType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType), - StringTypeAnyCollation) + Seq(TypeCollection( + StringTypeWithCaseAccentSensitivity, DateType, TimestampType, TimestampNTZType + ), + StringTypeWithCaseAccentSensitivity) override def dataType: DataType = LongType override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true @@ -1441,7 +1444,8 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(LongType, StringTypeWithCaseAccentSensitivity) override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) @@ -1549,7 +1553,8 @@ case class NextDay( def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(DateType, StringTypeWithCaseAccentSensitivity) override def dataType: DataType = DateType override def nullable: Boolean = true @@ -1760,7 +1765,8 @@ sealed trait UTCTimestamp extends BinaryExpression with ImplicitCastInputTypes w val func: (Long, String) => Long val funcName: String - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(TimestampType, StringTypeWithCaseAccentSensitivity) override def dataType: DataType = TimestampType override def nullSafeEval(time: Any, timezone: Any): Any = { @@ -2100,8 +2106,9 @@ case class ParseToDate( override def inputTypes: Seq[AbstractDataType] = { // Note: ideally this function should only take string input, but we allow more types here to // be backward compatible. - TypeCollection(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType) +: - format.map(_ => StringTypeAnyCollation).toSeq + TypeCollection( + StringTypeWithCaseAccentSensitivity, DateType, TimestampType, TimestampNTZType) +: + format.map(_ => StringTypeWithCaseAccentSensitivity).toSeq } override protected def withNewChildrenInternal( @@ -2172,10 +2179,10 @@ case class ParseToTimestamp( override def inputTypes: Seq[AbstractDataType] = { // Note: ideally this function should only take string input, but we allow more types here to // be backward compatible. - val types = Seq(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType) + val types = Seq(StringTypeWithCaseAccentSensitivity, DateType, TimestampType, TimestampNTZType) TypeCollection( (if (dataType.isInstanceOf[TimestampType]) types :+ NumericType else types): _* - ) +: format.map(_ => StringTypeAnyCollation).toSeq + ) +: format.map(_ => StringTypeWithCaseAccentSensitivity).toSeq } override protected def withNewChildrenInternal( @@ -2305,7 +2312,8 @@ case class TruncDate(date: Expression, format: Expression) override def left: Expression = date override def right: Expression = format - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(DateType, StringTypeWithCaseAccentSensitivity) override def dataType: DataType = DateType override def prettyName: String = "trunc" override val instant = date @@ -2374,7 +2382,8 @@ case class TruncTimestamp( override def left: Expression = format override def right: Expression = timestamp - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, TimestampType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCaseAccentSensitivity, TimestampType) override def dataType: TimestampType = TimestampType override def prettyName: String = "date_trunc" override val instant = timestamp @@ -2675,7 +2684,7 @@ case class MakeTimestamp( // casted into decimal safely, we use DecimalType(16, 6) which is wider than DecimalType(10, 0). override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, IntegerType, IntegerType, IntegerType, IntegerType, DecimalType(16, 6)) ++ - timezone.map(_ => StringTypeAnyCollation) + timezone.map(_ => StringTypeWithCaseAccentSensitivity) override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = @@ -3122,8 +3131,8 @@ case class ConvertTimezone( override def second: Expression = targetTz override def third: Expression = sourceTs - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, - StringTypeAnyCollation, TimestampNTZType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity, TimestampNTZType) override def dataType: DataType = TimestampNTZType override def nullSafeEval(srcTz: Any, tgtTz: Any, micros: Any): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 2037eb22fede6..bdcf3f0c1eeab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, TreePatt import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{UTF8String, VariantVal} import org.apache.spark.util.Utils @@ -134,7 +134,7 @@ case class GetJsonObject(json: Expression, path: Expression) override def left: Expression = json override def right: Expression = path override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true override def prettyName: String = "get_json_object" @@ -489,7 +489,9 @@ case class JsonTuple(children: Seq[Expression]) throw QueryCompilationErrors.wrongNumArgsError( toSQLId(prettyName), Seq("> 1"), children.length ) - } else if (children.forall(child => StringTypeAnyCollation.acceptsType(child.dataType))) { + } else if ( + children.forall( + child => StringTypeWithCaseAccentSensitivity.acceptsType(child.dataType))) { TypeCheckResult.TypeCheckSuccess } else { DataTypeMismatch( @@ -726,7 +728,7 @@ case class JsonToStructs( converter(parser.parse(json.asInstanceOf[UTF8String])) } - override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil override def sql: String = schema match { case _: MapType => "entries" @@ -968,7 +970,7 @@ case class SchemaOfJson( case class LengthOfJsonArray(child: Expression) extends UnaryExpression with CodegenFallback with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def dataType: DataType = IntegerType override def nullable: Boolean = true override def prettyName: String = "json_array_length" @@ -1041,7 +1043,7 @@ case class LengthOfJsonArray(child: Expression) extends UnaryExpression case class JsonObjectKeys(child: Expression) extends UnaryExpression with CodegenFallback with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType) override def nullable: Boolean = true override def prettyName: String = "json_object_keys" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala index c11357352c79a..cb62fa2cc3bd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter} import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types.{AbstractDataType, DataType} import org.apache.spark.unsafe.types.UTF8String @@ -192,8 +192,12 @@ case class Mask( * NumericType, IntegralType, FractionalType. */ override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation, - StringTypeAnyCollation, StringTypeAnyCollation) + Seq( + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity) override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index ddba820414ae4..e46acf467db22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -453,7 +453,7 @@ case class Conv( override def second: Expression = fromBaseExpr override def third: Expression = toBaseExpr override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, IntegerType, IntegerType) + Seq(StringTypeWithCaseAccentSensitivity, IntegerType, IntegerType) override def dataType: DataType = first.dataType override def nullable: Boolean = true @@ -1114,7 +1114,7 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, BinaryType, StringTypeAnyCollation)) + Seq(TypeCollection(LongType, BinaryType, StringTypeWithCaseAccentSensitivity)) override def dataType: DataType = child.dataType match { case st: StringType => st @@ -1158,7 +1158,7 @@ case class Unhex(child: Expression, failOnError: Boolean = false) def this(expr: Expression) = this(expr, false) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def nullable: Boolean = true override def dataType: DataType = BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 6629f724c4dda..cb846f606632b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, RandomUUIDGenerator} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -85,7 +85,7 @@ case class RaiseError(errorClass: Expression, errorParms: Expression, dataType: override def foldable: Boolean = false override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, MapType(StringType, StringType)) + Seq(StringTypeWithCaseAccentSensitivity, MapType(StringType, StringType)) override def left: Expression = errorClass override def right: Expression = errorParms @@ -415,7 +415,9 @@ case class AesEncrypt( override def prettyName: String = "aes_encrypt" override def inputTypes: Seq[AbstractDataType] = - Seq(BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, + Seq(BinaryType, BinaryType, + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, BinaryType, BinaryType) override def children: Seq[Expression] = Seq(input, key, mode, padding, iv, aad) @@ -489,7 +491,10 @@ case class AesDecrypt( this(input, key, Literal("GCM")) override def inputTypes: Seq[AbstractDataType] = { - Seq(BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType) + Seq(BinaryType, + BinaryType, + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, BinaryType) } override def prettyName: String = "aes_decrypt" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala index e914190c06456..5bd2ab6035e10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.util.ToNumberParser import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, DatetimeType, Decimal, DecimalType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -50,7 +50,7 @@ abstract class ToNumberBase(left: Expression, right: Expression, errorOnFail: Bo } override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() @@ -284,7 +284,8 @@ case class ToCharacter(left: Expression, right: Expression) } override def dataType: DataType = SQLConf.get.defaultStringType - override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(DecimalType, StringTypeWithCaseAccentSensitivity) override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() if (inputTypeCheck.isSuccess) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 970397c76a1cd..fdc3c27890469 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -35,7 +35,8 @@ import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern} import org.apache.spark.sql.catalyst.util.{CollationSupport, GenericArrayData, StringUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.internal.types.{StringTypeAnyCollation, StringTypeBinaryLcase} +import org.apache.spark.sql.internal.types.{ + StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -46,7 +47,7 @@ abstract class StringRegexExpression extends BinaryExpression def matches(regex: Pattern, str: String): Boolean override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation) + Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity) final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId final lazy val collationRegexFlags: Int = CollationSupport.collationAwareRegexFlags(collationId) @@ -278,7 +279,7 @@ case class ILike( this(left, right, '\\') override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation) + Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): Expression = { @@ -567,7 +568,7 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) override def dataType: DataType = ArrayType(str.dataType, containsNull = false) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation, IntegerType) + Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity, IntegerType) override def first: Expression = str override def second: Expression = regex override def third: Expression = limit @@ -711,7 +712,8 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def dataType: DataType = subject.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation, StringTypeBinaryLcase, IntegerType) + Seq(StringTypeBinaryLcase, + StringTypeWithCaseAccentSensitivity, StringTypeBinaryLcase, IntegerType) final lazy val collationId: Int = subject.dataType.asInstanceOf[StringType].collationId override def prettyName: String = "regexp_replace" @@ -799,7 +801,7 @@ abstract class RegExpExtractBase final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_EXTRACT_FAMILY) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation, IntegerType) + Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity, IntegerType) override def first: Expression = subject override def second: Expression = regexp override def third: Expression = idx @@ -1052,7 +1054,7 @@ case class RegExpCount(left: Expression, right: Expression) override def children: Seq[Expression] = Seq(left, right) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation) + Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): RegExpCount = @@ -1092,7 +1094,7 @@ case class RegExpSubStr(left: Expression, right: Expression) override def children: Seq[Expression] = Seq(left, right) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation) + Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): RegExpSubStr = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 786c3968be0fe..c91c57ee1eb3e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -38,7 +38,8 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LO import org.apache.spark.sql.catalyst.util.{ArrayData, CharsetProvider, CollationFactory, CollationSupport, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeNonCSAICollation} +import org.apache.spark.sql.internal.types.{AbstractArrayType, + StringTypeNonCSAICollation, StringTypeWithCaseAccentSensitivity} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.array.ByteArrayMethods @@ -81,8 +82,10 @@ case class ConcatWs(children: Seq[Expression]) /** The 1st child (separator) is str, and rest are either str or array of str. */ override def inputTypes: Seq[AbstractDataType] = { val arrayOrStr = - TypeCollection(AbstractArrayType(StringTypeAnyCollation), StringTypeAnyCollation) - StringTypeAnyCollation +: Seq.fill(children.size - 1)(arrayOrStr) + TypeCollection(AbstractArrayType(StringTypeWithCaseAccentSensitivity), + StringTypeWithCaseAccentSensitivity + ) + StringTypeWithCaseAccentSensitivity +: Seq.fill(children.size - 1)(arrayOrStr) } override def dataType: DataType = children.head.dataType @@ -433,7 +436,7 @@ trait String2StringExpression extends ImplicitCastInputTypes { def convert(v: UTF8String): UTF8String override def dataType: DataType = child.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) protected override def nullSafeEval(input: Any): Any = convert(input.asInstanceOf[UTF8String]) @@ -515,7 +518,7 @@ abstract class StringPredicate extends BinaryExpression def compare(l: UTF8String, r: UTF8String): Boolean override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) protected override def nullSafeEval(input1: Any, input2: Any): Any = compare(input1.asInstanceOf[UTF8String], input2.asInstanceOf[UTF8String]) @@ -732,7 +735,7 @@ case class IsValidUTF8(input: Expression) extends RuntimeReplaceable with Implic override lazy val replacement: Expression = Invoke(input, "isValid", BooleanType) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def nodeName: String = "is_valid_utf8" @@ -779,7 +782,7 @@ case class MakeValidUTF8(input: Expression) extends RuntimeReplaceable with Impl override lazy val replacement: Expression = Invoke(input, "makeValid", input.dataType) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def nodeName: String = "make_valid_utf8" @@ -824,7 +827,7 @@ case class ValidateUTF8(input: Expression) extends RuntimeReplaceable with Impli Seq(input), inputTypes) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def nodeName: String = "validate_utf8" @@ -873,7 +876,7 @@ case class TryValidateUTF8(input: Expression) extends RuntimeReplaceable with Im Seq(input), inputTypes) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def nodeName: String = "try_validate_utf8" @@ -1008,8 +1011,8 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len: override def dataType: DataType = input.dataType override def inputTypes: Seq[AbstractDataType] = Seq( - TypeCollection(StringTypeAnyCollation, BinaryType), - TypeCollection(StringTypeAnyCollation, BinaryType), IntegerType, IntegerType) + TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType), + TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType), IntegerType, IntegerType) override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() @@ -1213,7 +1216,7 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) override protected def nullSafeEval(word: Any, set: Any): Any = { CollationSupport.FindInSet. @@ -1241,7 +1244,8 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { override def children: Seq[Expression] = srcStr +: trimStr.toSeq override def dataType: DataType = srcStr.dataType - override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq.fill(children.size)(StringTypeWithCaseAccentSensitivity) final lazy val collationId: Int = srcStr.dataType.asInstanceOf[StringType].collationId @@ -1846,7 +1850,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) override def dataType: DataType = str.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, IntegerType, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, IntegerType, StringTypeWithCaseAccentSensitivity) override def nullSafeEval(string: Any, len: Any, pad: Any): Any = { string.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) @@ -1926,7 +1930,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression = Litera override def dataType: DataType = str.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, IntegerType, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, IntegerType, StringTypeWithCaseAccentSensitivity) override def nullSafeEval(string: Any, len: Any, pad: Any): Any = { string.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) @@ -1971,7 +1975,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC override def dataType: DataType = children(0).dataType override def inputTypes: Seq[AbstractDataType] = - StringTypeAnyCollation :: List.fill(children.size - 1)(AnyDataType) + StringTypeWithCaseAccentSensitivity :: List.fill(children.size - 1)(AnyDataType) override def checkInputDataTypes(): TypeCheckResult = { if (children.isEmpty) { @@ -2082,7 +2086,7 @@ case class InitCap(child: Expression) // Flag to indicate whether to use ICU instead of JVM case mappings for UTF8_BINARY collation. private final lazy val useICU = SQLConf.get.getConf(SQLConf.ICU_CASE_MAPPINGS_ENABLED) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def dataType: DataType = child.dataType override def nullSafeEval(string: Any): Any = { @@ -2114,7 +2118,8 @@ case class StringRepeat(str: Expression, times: Expression) override def left: Expression = str override def right: Expression = times override def dataType: DataType = str.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, IntegerType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCaseAccentSensitivity, IntegerType) override def nullSafeEval(string: Any, n: Any): Any = { string.asInstanceOf[UTF8String].repeat(n.asInstanceOf[Integer]) @@ -2207,7 +2212,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) override def dataType: DataType = str.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeAnyCollation, BinaryType), IntegerType, IntegerType) + Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType), IntegerType, IntegerType) override def first: Expression = str override def second: Expression = pos @@ -2265,7 +2270,8 @@ case class Right(str: Expression, len: Expression) extends RuntimeReplaceable ) ) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, IntegerType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCaseAccentSensitivity, IntegerType) override def left: Expression = str override def right: Expression = len override protected def withNewChildrenInternal( @@ -2296,7 +2302,7 @@ case class Left(str: Expression, len: Expression) extends RuntimeReplaceable override lazy val replacement: Expression = Substring(str, Literal(1), len) override def inputTypes: Seq[AbstractDataType] = { - Seq(TypeCollection(StringTypeAnyCollation, BinaryType), IntegerType) + Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType), IntegerType) } override def left: Expression = str @@ -2332,7 +2338,7 @@ case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeAnyCollation, BinaryType)) + Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType)) protected override def nullSafeEval(value: Any): Any = child.dataType match { case _: StringType => value.asInstanceOf[UTF8String].numChars @@ -2367,7 +2373,7 @@ case class BitLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeAnyCollation, BinaryType)) + Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType)) protected override def nullSafeEval(value: Any): Any = child.dataType match { case _: StringType => value.asInstanceOf[UTF8String].numBytes * 8 @@ -2406,7 +2412,7 @@ case class OctetLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeAnyCollation, BinaryType)) + Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType)) protected override def nullSafeEval(value: Any): Any = child.dataType match { case _: StringType => value.asInstanceOf[UTF8String].numBytes @@ -2466,8 +2472,9 @@ case class Levenshtein( } override def inputTypes: Seq[AbstractDataType] = threshold match { - case Some(_) => Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType) - case _ => Seq(StringTypeAnyCollation, StringTypeAnyCollation) + case Some(_) => + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity, IntegerType) + case _ => Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) } override def children: Seq[Expression] = threshold match { @@ -2592,7 +2599,7 @@ case class SoundEx(child: Expression) override def dataType: DataType = SQLConf.get.defaultStringType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def nullSafeEval(input: Any): Any = input.asInstanceOf[UTF8String].soundex() @@ -2622,7 +2629,7 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) protected override def nullSafeEval(string: Any): Any = { // only pick the first character to reduce the `toString` cost @@ -2767,7 +2774,7 @@ case class UnBase64(child: Expression, failOnError: Boolean = false) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = BinaryType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) def this(expr: Expression) = this(expr, false) @@ -2946,7 +2953,8 @@ case class StringDecode( this(bin, charset, SQLConf.get.legacyJavaCharsets, SQLConf.get.legacyCodingErrorAction) override val dataType: DataType = SQLConf.get.defaultStringType - override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(BinaryType, StringTypeWithCaseAccentSensitivity) override def prettyName: String = "decode" override def toString: String = s"$prettyName($bin, $charset)" @@ -2955,7 +2963,7 @@ case class StringDecode( SQLConf.get.defaultStringType, "decode", Seq(bin, charset, Literal(legacyCharsets), Literal(legacyErrorAction)), - Seq(BinaryType, StringTypeAnyCollation, BooleanType, BooleanType)) + Seq(BinaryType, StringTypeWithCaseAccentSensitivity, BooleanType, BooleanType)) override def children: Seq[Expression] = Seq(bin, charset) override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = @@ -3012,15 +3020,20 @@ case class Encode( override def dataType: DataType = BinaryType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) override lazy val replacement: Expression = StaticInvoke( classOf[Encode], BinaryType, "encode", Seq( - str, charset, Literal(legacyCharsets, BooleanType), Literal(legacyErrorAction, BooleanType)), - Seq(StringTypeAnyCollation, StringTypeAnyCollation, BooleanType, BooleanType)) + str, charset, Literal(legacyCharsets, BooleanType), Literal(legacyErrorAction, BooleanType) + ), + Seq( + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, + BooleanType, + BooleanType)) override def toString: String = s"$prettyName($str, $charset)" @@ -3104,7 +3117,8 @@ case class ToBinary( override def children: Seq[Expression] = expr +: format.toSeq - override def inputTypes: Seq[AbstractDataType] = children.map(_ => StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + children.map(_ => StringTypeWithCaseAccentSensitivity) override def checkInputDataTypes(): TypeCheckResult = { def isValidFormat: Boolean = { @@ -3120,7 +3134,8 @@ case class ToBinary( errorSubClass = "INVALID_ARG_VALUE", messageParameters = Map( "inputName" -> "fmt", - "requireType" -> s"case-insensitive ${toSQLType(StringTypeAnyCollation)}", + "requireType" -> + s"case-insensitive ${toSQLType(StringTypeWithCaseAccentSensitivity)}", "validValues" -> "'hex', 'utf-8', 'utf8', or 'base64'", "inputValue" -> toSQLValue(fmt, f.dataType) ) @@ -3131,7 +3146,7 @@ case class ToBinary( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("fmt"), - "inputType" -> toSQLType(StringTypeAnyCollation), + "inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity), "inputExpr" -> toSQLExpr(f) ) ) @@ -3140,7 +3155,8 @@ case class ToBinary( errorSubClass = "INVALID_ARG_VALUE", messageParameters = Map( "inputName" -> "fmt", - "requireType" -> s"case-insensitive ${toSQLType(StringTypeAnyCollation)}", + "requireType" -> + s"case-insensitive ${toSQLType(StringTypeWithCaseAccentSensitivity)}", "validValues" -> "'hex', 'utf-8', 'utf8', or 'base64'", "inputValue" -> toSQLValue(f.eval(), f.dataType) ) @@ -3189,7 +3205,7 @@ case class FormatNumber(x: Expression, d: Expression) override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(NumericType, TypeCollection(IntegerType, StringTypeAnyCollation)) + Seq(NumericType, TypeCollection(IntegerType, StringTypeWithCaseAccentSensitivity)) private val defaultFormat = "#,###,###,###,###,###,##0" @@ -3394,7 +3410,9 @@ case class Sentences( override def dataType: DataType = ArrayType(ArrayType(str.dataType, containsNull = false), containsNull = false) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation) + Seq( + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) override def first: Expression = str override def second: Expression = language override def third: Expression = country @@ -3540,10 +3558,9 @@ case class Luhncheck(input: Expression) extends RuntimeReplaceable with Implicit classOf[ExpressionImplUtils], BooleanType, "isLuhnNumber", - Seq(input), - inputTypes) + Seq(input), inputTypes) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def prettyName: String = "luhn_check" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala index 3e4e4f992002a..09e91da65484f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType} import org.apache.spark.unsafe.types.UTF8String @@ -59,13 +59,13 @@ case class UrlEncode(child: Expression) SQLConf.get.defaultStringType, "encode", Seq(child), - Seq(StringTypeAnyCollation)) + Seq(StringTypeWithCaseAccentSensitivity)) override protected def withNewChildInternal(newChild: Expression): Expression = { copy(child = newChild) } - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def prettyName: String = "url_encode" } @@ -98,13 +98,13 @@ case class UrlDecode(child: Expression, failOnError: Boolean = true) SQLConf.get.defaultStringType, "decode", Seq(child, Literal(failOnError)), - Seq(StringTypeAnyCollation, BooleanType)) + Seq(StringTypeWithCaseAccentSensitivity, BooleanType)) override protected def withNewChildInternal(newChild: Expression): Expression = { copy(child = newChild) } - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def prettyName: String = "url_decode" } @@ -190,7 +190,8 @@ case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.ge def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled) override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq.fill(children.size)(StringTypeWithCaseAccentSensitivity) override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "parse_url" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index 2c8ca1e8bb2bb..323f6e42f3e50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.types.variant._ import org.apache.spark.types.variant.VariantUtil.{IntervalFields, Type} @@ -66,7 +66,7 @@ case class ParseJson(child: Expression, failOnError: Boolean = true) inputTypes :+ BooleanType :+ BooleanType, returnNullable = !failOnError) - override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil override def dataType: DataType = VariantType @@ -271,7 +271,8 @@ case class VariantGet( final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET) - override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(VariantType, StringTypeWithCaseAccentSensitivity) override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index 31e65cf0abc95..6c38bd88144b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -42,7 +42,7 @@ abstract class XPathExtract override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) override def checkInputDataTypes(): TypeCheckResult = { if (!path.foldable) { @@ -50,7 +50,7 @@ abstract class XPathExtract errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("path"), - "inputType" -> toSQLType(StringTypeAnyCollation), + "inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity), "inputExpr" -> toSQLExpr(path) ) ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala index 48a87db291a8d..6f1430b04ed67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils._ import org.apache.spark.sql.catalyst.xml.{StaxXmlGenerator, StaxXmlParser, ValidatorUtil, XmlInferSchema, XmlOptions} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -124,7 +124,7 @@ case class XmlToStructs( defineCodeGen(ctx, ev, input => s"(InternalRow) $expr.nullSafeEval($input)") } - override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil override def prettyName: String = "from_xml" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala index de600d881b624..342dcbd8e6b6d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation} +import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeWithCaseAccentSensitivity} import org.apache.spark.sql.types._ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { @@ -1057,11 +1057,11 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { ArrayType(IntegerType)) shouldCast( ArrayType(StringType), - AbstractArrayType(StringTypeAnyCollation), + AbstractArrayType(StringTypeWithCaseAccentSensitivity), ArrayType(StringType)) shouldCast( ArrayType(IntegerType), - AbstractArrayType(StringTypeAnyCollation), + AbstractArrayType(StringTypeWithCaseAccentSensitivity), ArrayType(StringType)) shouldCast( ArrayType(StringType), @@ -1075,11 +1075,11 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { ArrayType(ArrayType(IntegerType))) shouldCast( ArrayType(ArrayType(StringType)), - AbstractArrayType(AbstractArrayType(StringTypeAnyCollation)), + AbstractArrayType(AbstractArrayType(StringTypeWithCaseAccentSensitivity)), ArrayType(ArrayType(StringType))) shouldCast( ArrayType(ArrayType(IntegerType)), - AbstractArrayType(AbstractArrayType(StringTypeAnyCollation)), + AbstractArrayType(AbstractArrayType(StringTypeWithCaseAccentSensitivity)), ArrayType(ArrayType(StringType))) shouldCast( ArrayType(ArrayType(StringType)), @@ -1088,14 +1088,16 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { // Invalid casts involving casting arrays into non-complex types. shouldNotCast(ArrayType(IntegerType), IntegerType) - shouldNotCast(ArrayType(StringType), StringTypeAnyCollation) + shouldNotCast(ArrayType(StringType), StringTypeWithCaseAccentSensitivity) shouldNotCast(ArrayType(StringType), IntegerType) - shouldNotCast(ArrayType(IntegerType), StringTypeAnyCollation) + shouldNotCast(ArrayType(IntegerType), StringTypeWithCaseAccentSensitivity) // Invalid casts involving casting arrays of arrays into arrays of non-complex types. shouldNotCast(ArrayType(ArrayType(IntegerType)), AbstractArrayType(IntegerType)) - shouldNotCast(ArrayType(ArrayType(StringType)), AbstractArrayType(StringTypeAnyCollation)) + shouldNotCast(ArrayType(ArrayType(StringType)), + AbstractArrayType(StringTypeWithCaseAccentSensitivity)) shouldNotCast(ArrayType(ArrayType(StringType)), AbstractArrayType(IntegerType)) - shouldNotCast(ArrayType(ArrayType(IntegerType)), AbstractArrayType(StringTypeAnyCollation)) + shouldNotCast(ArrayType(ArrayType(IntegerType)), + AbstractArrayType(StringTypeWithCaseAccentSensitivity)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 9b454ba764f92..1aae2f10b7326 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjectio import org.apache.spark.sql.catalyst.util.CharsetProvider import org.apache.spark.sql.errors.QueryExecutionErrors.toSQLId import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1466,7 +1466,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("fmt"), - "inputType" -> toSQLType(StringTypeAnyCollation), + "inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity), "inputExpr" -> toSQLExpr(wrongFmt) ) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 1d23774a51692..879c0c480943d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -66,10 +66,10 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi collationType: CollationType): Any = inputEntry match { case e: Class[_] if e.isAssignableFrom(classOf[Expression]) => - generateLiterals(StringTypeAnyCollation, collationType) + generateLiterals(StringTypeWithCaseAccentSensitivity, collationType) case se: Class[_] if se.isAssignableFrom(classOf[Seq[Expression]]) => - CreateArray(Seq(generateLiterals(StringTypeAnyCollation, collationType), - generateLiterals(StringTypeAnyCollation, collationType))) + CreateArray(Seq(generateLiterals(StringTypeWithCaseAccentSensitivity, collationType), + generateLiterals(StringTypeWithCaseAccentSensitivity, collationType))) case oe: Class[_] if oe.isAssignableFrom(classOf[Option[Any]]) => None case b: Class[_] if b.isAssignableFrom(classOf[Boolean]) => false case dt: Class[_] if dt.isAssignableFrom(classOf[DataType]) => StringType @@ -142,12 +142,12 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) ).head case ArrayType => - generateLiterals(StringTypeAnyCollation, collationType).map( + generateLiterals(StringTypeWithCaseAccentSensitivity, collationType).map( lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) ).head case MapType => - val key = generateLiterals(StringTypeAnyCollation, collationType) - val value = generateLiterals(StringTypeAnyCollation, collationType) + val key = generateLiterals(StringTypeWithCaseAccentSensitivity, collationType) + val value = generateLiterals(StringTypeWithCaseAccentSensitivity, collationType) CreateMap(Seq(key, value)) case MapType(keyType, valueType, _) => val key = generateLiterals(keyType, collationType) @@ -159,8 +159,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi CreateMap(Seq(key, value)) case StructType => CreateNamedStruct( - Seq(Literal("start"), generateLiterals(StringTypeAnyCollation, collationType), - Literal("end"), generateLiterals(StringTypeAnyCollation, collationType))) + Seq(Literal("start"), + generateLiterals(StringTypeWithCaseAccentSensitivity, collationType), + Literal("end"), generateLiterals(StringTypeWithCaseAccentSensitivity, collationType))) } /** @@ -209,10 +210,10 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case ArrayType(elementType, _) => "array(" + generateInputAsString(elementType, collationType) + ")" case ArrayType => - "array(" + generateInputAsString(StringTypeAnyCollation, collationType) + ")" + "array(" + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ")" case MapType => - "map(" + generateInputAsString(StringTypeAnyCollation, collationType) + ", " + - generateInputAsString(StringTypeAnyCollation, collationType) + ")" + "map(" + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ", " + + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ")" case MapType(keyType, valueType, _) => "map(" + generateInputAsString(keyType, collationType) + ", " + generateInputAsString(valueType, collationType) + ")" @@ -220,8 +221,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi "map(" + generateInputAsString(keyType, collationType) + ", " + generateInputAsString(valueType, collationType) + ")" case StructType => - "named_struct( 'start', " + generateInputAsString(StringTypeAnyCollation, collationType) + - ", 'end', " + generateInputAsString(StringTypeAnyCollation, collationType) + ")" + "named_struct( 'start', " + + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ", 'end', " + + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ")" case StructType(fields) => "named_struct(" + fields.map(f => "'" + f.name + "', " + generateInputAsString(f.dataType, collationType)).mkString(", ") + ")" @@ -267,10 +269,12 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case ArrayType(elementType, _) => "array<" + generateInputTypeAsStrings(elementType, collationType) + ">" case ArrayType => - "array<" + generateInputTypeAsStrings(StringTypeAnyCollation, collationType) + ">" + "array<" + generateInputTypeAsStrings(StringTypeWithCaseAccentSensitivity, collationType) + + ">" case MapType => - "map<" + generateInputTypeAsStrings(StringTypeAnyCollation, collationType) + ", " + - generateInputTypeAsStrings(StringTypeAnyCollation, collationType) + ">" + "map<" + generateInputTypeAsStrings(StringTypeWithCaseAccentSensitivity, collationType) + + ", " + + generateInputTypeAsStrings(StringTypeWithCaseAccentSensitivity, collationType) + ">" case MapType(keyType, valueType, _) => "map<" + generateInputTypeAsStrings(keyType, collationType) + ", " + generateInputTypeAsStrings(valueType, collationType) + ">" @@ -278,9 +282,10 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi "map<" + generateInputTypeAsStrings(keyType, collationType) + ", " + generateInputTypeAsStrings(valueType, collationType) + ">" case StructType => - "struct" + generateInputTypeAsStrings(StringTypeWithCaseAccentSensitivity, collationType) + ">" case StructType(fields) => "named_struct<" + fields.map(f => "'" + f.name + "', " + generateInputTypeAsStrings(f.dataType, collationType)).mkString(", ") + ">" @@ -293,8 +298,8 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi */ def hasStringType(inputType: AbstractDataType): Boolean = { inputType match { - case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase | AnyDataType => - true + case _: StringType | StringTypeWithCaseAccentSensitivity | StringTypeBinaryLcase | AnyDataType + => true case ArrayType => true case MapType => true case MapType(keyType, valueType, _) => hasStringType(keyType) || hasStringType(valueType) @@ -408,7 +413,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi var input: Seq[Expression] = Seq.empty var i = 0 for (_ <- 1 to 10) { - input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary) + input = input :+ generateLiterals(StringTypeWithCaseAccentSensitivity, Utf8Binary) try { method.invoke(null, funInfo.getClassName, input).asInstanceOf[ExpectsInputTypes] } @@ -498,7 +503,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi var input: Seq[Expression] = Seq.empty var result: Expression = null for (_ <- 1 to 10) { - input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary) + input = input :+ generateLiterals(StringTypeWithCaseAccentSensitivity, Utf8Binary) try { val tempResult = method.invoke(null, f.getClassName, input) if (result == null) result = tempResult.asInstanceOf[Expression] @@ -609,7 +614,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi var input: Seq[Expression] = Seq.empty var result: Expression = null for (_ <- 1 to 10) { - input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary) + input = input :+ generateLiterals(StringTypeWithCaseAccentSensitivity, Utf8Binary) try { val tempResult = method.invoke(null, f.getClassName, input) if (result == null) result = tempResult.asInstanceOf[Expression]