diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala index dc4ee013fd189..6feb662632763 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} /** - * StringTypeCollated is an abstract class for StringType with collation support. + * AbstractStringType is an abstract class for StringType with collation support. */ abstract class AbstractStringType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType @@ -46,9 +46,10 @@ case object StringTypeBinaryLcase extends AbstractStringType { } /** - * Use StringTypeAnyCollation for expressions supporting all possible collation types. + * Use StringTypeWithCaseAccentSensitivity for expressions supporting all collation types (binary + * and ICU) but limited to using case and accent sensitivity specifiers. */ -case object StringTypeAnyCollation extends AbstractStringType { +case object StringTypeWithCaseAccentSensitivity extends AbstractStringType { override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 5983346ff1e27..e0298b19931c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -32,7 +32,8 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType, StringTypeAnyCollation} +import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType, + StringTypeWithCaseAccentSensitivity} import org.apache.spark.sql.types._ import org.apache.spark.sql.types.UpCastRule.numericPrecedence @@ -438,7 +439,7 @@ abstract class TypeCoercionBase { } case aj @ ArrayJoin(arr, d, nr) - if !AbstractArrayType(StringTypeAnyCollation).acceptsType(arr.dataType) && + if !AbstractArrayType(StringTypeWithCaseAccentSensitivity).acceptsType(arr.dataType) && ArrayType.acceptsType(arr.dataType) => val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull implicitCast(arr, ArrayType(StringType, containsNull)) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index 13ea8c77c41b4..6aa11b6fd16df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -84,7 +84,7 @@ case class CallMethodViaReflection( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("class"), - "inputType" -> toSQLType(StringTypeAnyCollation), + "inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity), "inputExpr" -> toSQLExpr(children.head) ) ) @@ -97,7 +97,7 @@ case class CallMethodViaReflection( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("method"), - "inputType" -> toSQLType(StringTypeAnyCollation), + "inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity), "inputExpr" -> toSQLExpr(children(1)) ) ) @@ -114,7 +114,8 @@ case class CallMethodViaReflection( "paramIndex" -> ordinalNumber(idx), "requiredType" -> toSQLType( TypeCollection(BooleanType, ByteType, ShortType, - IntegerType, LongType, FloatType, DoubleType, StringTypeAnyCollation)), + IntegerType, LongType, FloatType, DoubleType, + StringTypeWithCaseAccentSensitivity)), "inputSql" -> toSQLExpr(e), "inputType" -> toSQLType(e.dataType)) ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala index 6e400d026e0ee..28ec8482e5cdd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.CollationFactory -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def dataType: DataType = BinaryType final lazy val collationId: Int = expr.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 749152f135e92..08cb03edb78b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors} -import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeAnyCollation} +import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCaseAccentSensitivity} import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType, VariantType} import org.apache.spark.unsafe.types.UTF8String @@ -61,7 +61,8 @@ object ExprUtils extends QueryErrorsBase { def convertToMapData(exp: Expression): Map[String, String] = exp match { case m: CreateMap - if AbstractMapType(StringTypeAnyCollation, StringTypeAnyCollation).acceptsType(m.dataType) => + if AbstractMapType(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) + .acceptsType(m.dataType) => val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData] ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) => key.toString -> value.toString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala index 2102428131f64..78bd02d5703cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, DataType, IntegerType, LongType, StringType, TypeCollection} import org.apache.spark.unsafe.types.UTF8String @@ -105,7 +105,9 @@ case class HllSketchAgg( override def prettyName: String = "hll_sketch_agg" override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(IntegerType, LongType, StringTypeAnyCollation, BinaryType), IntegerType) + Seq( + TypeCollection(IntegerType, LongType, StringTypeWithCaseAccentSensitivity, BinaryType), + IntegerType) override def dataType: DataType = BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index d45ca533f9392..0cff70436db7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ // scalastyle:off line.contains.tab @@ -73,7 +73,7 @@ case class Collate(child: Expression, collationName: String) extends UnaryExpression with ExpectsInputTypes { private val collationId = CollationFactory.collationNameToId(collationName) override def dataType: DataType = StringType(collationId) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override protected def withNewChildInternal( newChild: Expression): Expression = copy(newChild) @@ -111,5 +111,5 @@ case class Collation(child: Expression) val collationName = CollationFactory.fetchCollation(collationId).collationName Literal.create(collationName, SQLConf.get.defaultStringType) } - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5cdd3c7eb62d1..c091d51fc177f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation} +import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeWithCaseAccentSensitivity} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SQLOpenHashSet import org.apache.spark.unsafe.UTF8StringBuilder @@ -1348,7 +1348,7 @@ case class Reverse(child: Expression) // Input types are utilized by type coercion in ImplicitTypeCasts. override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeAnyCollation, ArrayType)) + Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, ArrayType)) override def dataType: DataType = child.dataType @@ -2134,9 +2134,12 @@ case class ArrayJoin( this(array, delimiter, Some(nullReplacement)) override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) { - Seq(AbstractArrayType(StringTypeAnyCollation), StringTypeAnyCollation, StringTypeAnyCollation) + Seq(AbstractArrayType(StringTypeWithCaseAccentSensitivity), + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity) } else { - Seq(AbstractArrayType(StringTypeAnyCollation), StringTypeAnyCollation) + Seq(AbstractArrayType(StringTypeWithCaseAccentSensitivity), + StringTypeWithCaseAccentSensitivity) } override def children: Seq[Expression] = if (nullReplacement.isDefined) { @@ -2857,7 +2860,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio with QueryErrorsBase { private def allowedTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, BinaryType, ArrayType) + Seq(StringTypeWithCaseAccentSensitivity, BinaryType, ArrayType) final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index cb10440c48328..2f4462c0664f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.TypeUtils._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -147,7 +147,7 @@ case class CsvToStructs( converter(parser.parse(csv)) } - override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil override def prettyName: String = "from_csv" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 36bd53001594e..b166d235557fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.catalyst.util.LegacyDateFormats.SIMPLE_DATE_FORMAT import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DayTimeIntervalType.DAY import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -961,7 +961,8 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti override def dataType: DataType = SQLConf.get.defaultStringType - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(TimestampType, StringTypeWithCaseAccentSensitivity) override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) @@ -1269,8 +1270,10 @@ abstract class ToTimestamp override def forTimestampNTZ: Boolean = left.dataType == TimestampNTZType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType), - StringTypeAnyCollation) + Seq(TypeCollection( + StringTypeWithCaseAccentSensitivity, DateType, TimestampType, TimestampNTZType + ), + StringTypeWithCaseAccentSensitivity) override def dataType: DataType = LongType override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true @@ -1441,7 +1444,8 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(LongType, StringTypeWithCaseAccentSensitivity) override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) @@ -1549,7 +1553,8 @@ case class NextDay( def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(DateType, StringTypeWithCaseAccentSensitivity) override def dataType: DataType = DateType override def nullable: Boolean = true @@ -1760,7 +1765,8 @@ sealed trait UTCTimestamp extends BinaryExpression with ImplicitCastInputTypes w val func: (Long, String) => Long val funcName: String - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(TimestampType, StringTypeWithCaseAccentSensitivity) override def dataType: DataType = TimestampType override def nullSafeEval(time: Any, timezone: Any): Any = { @@ -2100,8 +2106,9 @@ case class ParseToDate( override def inputTypes: Seq[AbstractDataType] = { // Note: ideally this function should only take string input, but we allow more types here to // be backward compatible. - TypeCollection(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType) +: - format.map(_ => StringTypeAnyCollation).toSeq + TypeCollection( + StringTypeWithCaseAccentSensitivity, DateType, TimestampType, TimestampNTZType) +: + format.map(_ => StringTypeWithCaseAccentSensitivity).toSeq } override protected def withNewChildrenInternal( @@ -2172,10 +2179,10 @@ case class ParseToTimestamp( override def inputTypes: Seq[AbstractDataType] = { // Note: ideally this function should only take string input, but we allow more types here to // be backward compatible. - val types = Seq(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType) + val types = Seq(StringTypeWithCaseAccentSensitivity, DateType, TimestampType, TimestampNTZType) TypeCollection( (if (dataType.isInstanceOf[TimestampType]) types :+ NumericType else types): _* - ) +: format.map(_ => StringTypeAnyCollation).toSeq + ) +: format.map(_ => StringTypeWithCaseAccentSensitivity).toSeq } override protected def withNewChildrenInternal( @@ -2305,7 +2312,8 @@ case class TruncDate(date: Expression, format: Expression) override def left: Expression = date override def right: Expression = format - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(DateType, StringTypeWithCaseAccentSensitivity) override def dataType: DataType = DateType override def prettyName: String = "trunc" override val instant = date @@ -2374,7 +2382,8 @@ case class TruncTimestamp( override def left: Expression = format override def right: Expression = timestamp - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, TimestampType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCaseAccentSensitivity, TimestampType) override def dataType: TimestampType = TimestampType override def prettyName: String = "date_trunc" override val instant = timestamp @@ -2675,7 +2684,7 @@ case class MakeTimestamp( // casted into decimal safely, we use DecimalType(16, 6) which is wider than DecimalType(10, 0). override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, IntegerType, IntegerType, IntegerType, IntegerType, DecimalType(16, 6)) ++ - timezone.map(_ => StringTypeAnyCollation) + timezone.map(_ => StringTypeWithCaseAccentSensitivity) override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = @@ -3122,8 +3131,8 @@ case class ConvertTimezone( override def second: Expression = targetTz override def third: Expression = sourceTs - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, - StringTypeAnyCollation, TimestampNTZType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity, TimestampNTZType) override def dataType: DataType = TimestampNTZType override def nullSafeEval(srcTz: Any, tgtTz: Any, micros: Any): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 2037eb22fede6..bdcf3f0c1eeab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, TreePatt import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{UTF8String, VariantVal} import org.apache.spark.util.Utils @@ -134,7 +134,7 @@ case class GetJsonObject(json: Expression, path: Expression) override def left: Expression = json override def right: Expression = path override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true override def prettyName: String = "get_json_object" @@ -489,7 +489,9 @@ case class JsonTuple(children: Seq[Expression]) throw QueryCompilationErrors.wrongNumArgsError( toSQLId(prettyName), Seq("> 1"), children.length ) - } else if (children.forall(child => StringTypeAnyCollation.acceptsType(child.dataType))) { + } else if ( + children.forall( + child => StringTypeWithCaseAccentSensitivity.acceptsType(child.dataType))) { TypeCheckResult.TypeCheckSuccess } else { DataTypeMismatch( @@ -726,7 +728,7 @@ case class JsonToStructs( converter(parser.parse(json.asInstanceOf[UTF8String])) } - override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil override def sql: String = schema match { case _: MapType => "entries" @@ -968,7 +970,7 @@ case class SchemaOfJson( case class LengthOfJsonArray(child: Expression) extends UnaryExpression with CodegenFallback with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def dataType: DataType = IntegerType override def nullable: Boolean = true override def prettyName: String = "json_array_length" @@ -1041,7 +1043,7 @@ case class LengthOfJsonArray(child: Expression) extends UnaryExpression case class JsonObjectKeys(child: Expression) extends UnaryExpression with CodegenFallback with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType) override def nullable: Boolean = true override def prettyName: String = "json_object_keys" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala index c11357352c79a..cb62fa2cc3bd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter} import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types.{AbstractDataType, DataType} import org.apache.spark.unsafe.types.UTF8String @@ -192,8 +192,12 @@ case class Mask( * NumericType, IntegralType, FractionalType. */ override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation, - StringTypeAnyCollation, StringTypeAnyCollation) + Seq( + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity) override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index ddba820414ae4..e46acf467db22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -453,7 +453,7 @@ case class Conv( override def second: Expression = fromBaseExpr override def third: Expression = toBaseExpr override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, IntegerType, IntegerType) + Seq(StringTypeWithCaseAccentSensitivity, IntegerType, IntegerType) override def dataType: DataType = first.dataType override def nullable: Boolean = true @@ -1114,7 +1114,7 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, BinaryType, StringTypeAnyCollation)) + Seq(TypeCollection(LongType, BinaryType, StringTypeWithCaseAccentSensitivity)) override def dataType: DataType = child.dataType match { case st: StringType => st @@ -1158,7 +1158,7 @@ case class Unhex(child: Expression, failOnError: Boolean = false) def this(expr: Expression) = this(expr, false) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def nullable: Boolean = true override def dataType: DataType = BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 6629f724c4dda..cb846f606632b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, RandomUUIDGenerator} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -85,7 +85,7 @@ case class RaiseError(errorClass: Expression, errorParms: Expression, dataType: override def foldable: Boolean = false override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, MapType(StringType, StringType)) + Seq(StringTypeWithCaseAccentSensitivity, MapType(StringType, StringType)) override def left: Expression = errorClass override def right: Expression = errorParms @@ -415,7 +415,9 @@ case class AesEncrypt( override def prettyName: String = "aes_encrypt" override def inputTypes: Seq[AbstractDataType] = - Seq(BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, + Seq(BinaryType, BinaryType, + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, BinaryType, BinaryType) override def children: Seq[Expression] = Seq(input, key, mode, padding, iv, aad) @@ -489,7 +491,10 @@ case class AesDecrypt( this(input, key, Literal("GCM")) override def inputTypes: Seq[AbstractDataType] = { - Seq(BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType) + Seq(BinaryType, + BinaryType, + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, BinaryType) } override def prettyName: String = "aes_decrypt" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala index e914190c06456..5bd2ab6035e10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.util.ToNumberParser import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, DatetimeType, Decimal, DecimalType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -50,7 +50,7 @@ abstract class ToNumberBase(left: Expression, right: Expression, errorOnFail: Bo } override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() @@ -284,7 +284,8 @@ case class ToCharacter(left: Expression, right: Expression) } override def dataType: DataType = SQLConf.get.defaultStringType - override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(DecimalType, StringTypeWithCaseAccentSensitivity) override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() if (inputTypeCheck.isSuccess) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 970397c76a1cd..fdc3c27890469 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -35,7 +35,8 @@ import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern} import org.apache.spark.sql.catalyst.util.{CollationSupport, GenericArrayData, StringUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.internal.types.{StringTypeAnyCollation, StringTypeBinaryLcase} +import org.apache.spark.sql.internal.types.{ + StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -46,7 +47,7 @@ abstract class StringRegexExpression extends BinaryExpression def matches(regex: Pattern, str: String): Boolean override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation) + Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity) final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId final lazy val collationRegexFlags: Int = CollationSupport.collationAwareRegexFlags(collationId) @@ -278,7 +279,7 @@ case class ILike( this(left, right, '\\') override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation) + Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): Expression = { @@ -567,7 +568,7 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) override def dataType: DataType = ArrayType(str.dataType, containsNull = false) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation, IntegerType) + Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity, IntegerType) override def first: Expression = str override def second: Expression = regex override def third: Expression = limit @@ -711,7 +712,8 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def dataType: DataType = subject.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation, StringTypeBinaryLcase, IntegerType) + Seq(StringTypeBinaryLcase, + StringTypeWithCaseAccentSensitivity, StringTypeBinaryLcase, IntegerType) final lazy val collationId: Int = subject.dataType.asInstanceOf[StringType].collationId override def prettyName: String = "regexp_replace" @@ -799,7 +801,7 @@ abstract class RegExpExtractBase final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_EXTRACT_FAMILY) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation, IntegerType) + Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity, IntegerType) override def first: Expression = subject override def second: Expression = regexp override def third: Expression = idx @@ -1052,7 +1054,7 @@ case class RegExpCount(left: Expression, right: Expression) override def children: Seq[Expression] = Seq(left, right) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation) + Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): RegExpCount = @@ -1092,7 +1094,7 @@ case class RegExpSubStr(left: Expression, right: Expression) override def children: Seq[Expression] = Seq(left, right) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation) + Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): RegExpSubStr = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 786c3968be0fe..c91c57ee1eb3e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -38,7 +38,8 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LO import org.apache.spark.sql.catalyst.util.{ArrayData, CharsetProvider, CollationFactory, CollationSupport, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeNonCSAICollation} +import org.apache.spark.sql.internal.types.{AbstractArrayType, + StringTypeNonCSAICollation, StringTypeWithCaseAccentSensitivity} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.array.ByteArrayMethods @@ -81,8 +82,10 @@ case class ConcatWs(children: Seq[Expression]) /** The 1st child (separator) is str, and rest are either str or array of str. */ override def inputTypes: Seq[AbstractDataType] = { val arrayOrStr = - TypeCollection(AbstractArrayType(StringTypeAnyCollation), StringTypeAnyCollation) - StringTypeAnyCollation +: Seq.fill(children.size - 1)(arrayOrStr) + TypeCollection(AbstractArrayType(StringTypeWithCaseAccentSensitivity), + StringTypeWithCaseAccentSensitivity + ) + StringTypeWithCaseAccentSensitivity +: Seq.fill(children.size - 1)(arrayOrStr) } override def dataType: DataType = children.head.dataType @@ -433,7 +436,7 @@ trait String2StringExpression extends ImplicitCastInputTypes { def convert(v: UTF8String): UTF8String override def dataType: DataType = child.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) protected override def nullSafeEval(input: Any): Any = convert(input.asInstanceOf[UTF8String]) @@ -515,7 +518,7 @@ abstract class StringPredicate extends BinaryExpression def compare(l: UTF8String, r: UTF8String): Boolean override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) protected override def nullSafeEval(input1: Any, input2: Any): Any = compare(input1.asInstanceOf[UTF8String], input2.asInstanceOf[UTF8String]) @@ -732,7 +735,7 @@ case class IsValidUTF8(input: Expression) extends RuntimeReplaceable with Implic override lazy val replacement: Expression = Invoke(input, "isValid", BooleanType) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def nodeName: String = "is_valid_utf8" @@ -779,7 +782,7 @@ case class MakeValidUTF8(input: Expression) extends RuntimeReplaceable with Impl override lazy val replacement: Expression = Invoke(input, "makeValid", input.dataType) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def nodeName: String = "make_valid_utf8" @@ -824,7 +827,7 @@ case class ValidateUTF8(input: Expression) extends RuntimeReplaceable with Impli Seq(input), inputTypes) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def nodeName: String = "validate_utf8" @@ -873,7 +876,7 @@ case class TryValidateUTF8(input: Expression) extends RuntimeReplaceable with Im Seq(input), inputTypes) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def nodeName: String = "try_validate_utf8" @@ -1008,8 +1011,8 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len: override def dataType: DataType = input.dataType override def inputTypes: Seq[AbstractDataType] = Seq( - TypeCollection(StringTypeAnyCollation, BinaryType), - TypeCollection(StringTypeAnyCollation, BinaryType), IntegerType, IntegerType) + TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType), + TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType), IntegerType, IntegerType) override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() @@ -1213,7 +1216,7 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) override protected def nullSafeEval(word: Any, set: Any): Any = { CollationSupport.FindInSet. @@ -1241,7 +1244,8 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { override def children: Seq[Expression] = srcStr +: trimStr.toSeq override def dataType: DataType = srcStr.dataType - override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq.fill(children.size)(StringTypeWithCaseAccentSensitivity) final lazy val collationId: Int = srcStr.dataType.asInstanceOf[StringType].collationId @@ -1846,7 +1850,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) override def dataType: DataType = str.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, IntegerType, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, IntegerType, StringTypeWithCaseAccentSensitivity) override def nullSafeEval(string: Any, len: Any, pad: Any): Any = { string.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) @@ -1926,7 +1930,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression = Litera override def dataType: DataType = str.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, IntegerType, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, IntegerType, StringTypeWithCaseAccentSensitivity) override def nullSafeEval(string: Any, len: Any, pad: Any): Any = { string.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) @@ -1971,7 +1975,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC override def dataType: DataType = children(0).dataType override def inputTypes: Seq[AbstractDataType] = - StringTypeAnyCollation :: List.fill(children.size - 1)(AnyDataType) + StringTypeWithCaseAccentSensitivity :: List.fill(children.size - 1)(AnyDataType) override def checkInputDataTypes(): TypeCheckResult = { if (children.isEmpty) { @@ -2082,7 +2086,7 @@ case class InitCap(child: Expression) // Flag to indicate whether to use ICU instead of JVM case mappings for UTF8_BINARY collation. private final lazy val useICU = SQLConf.get.getConf(SQLConf.ICU_CASE_MAPPINGS_ENABLED) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def dataType: DataType = child.dataType override def nullSafeEval(string: Any): Any = { @@ -2114,7 +2118,8 @@ case class StringRepeat(str: Expression, times: Expression) override def left: Expression = str override def right: Expression = times override def dataType: DataType = str.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, IntegerType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCaseAccentSensitivity, IntegerType) override def nullSafeEval(string: Any, n: Any): Any = { string.asInstanceOf[UTF8String].repeat(n.asInstanceOf[Integer]) @@ -2207,7 +2212,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) override def dataType: DataType = str.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeAnyCollation, BinaryType), IntegerType, IntegerType) + Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType), IntegerType, IntegerType) override def first: Expression = str override def second: Expression = pos @@ -2265,7 +2270,8 @@ case class Right(str: Expression, len: Expression) extends RuntimeReplaceable ) ) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, IntegerType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCaseAccentSensitivity, IntegerType) override def left: Expression = str override def right: Expression = len override protected def withNewChildrenInternal( @@ -2296,7 +2302,7 @@ case class Left(str: Expression, len: Expression) extends RuntimeReplaceable override lazy val replacement: Expression = Substring(str, Literal(1), len) override def inputTypes: Seq[AbstractDataType] = { - Seq(TypeCollection(StringTypeAnyCollation, BinaryType), IntegerType) + Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType), IntegerType) } override def left: Expression = str @@ -2332,7 +2338,7 @@ case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeAnyCollation, BinaryType)) + Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType)) protected override def nullSafeEval(value: Any): Any = child.dataType match { case _: StringType => value.asInstanceOf[UTF8String].numChars @@ -2367,7 +2373,7 @@ case class BitLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeAnyCollation, BinaryType)) + Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType)) protected override def nullSafeEval(value: Any): Any = child.dataType match { case _: StringType => value.asInstanceOf[UTF8String].numBytes * 8 @@ -2406,7 +2412,7 @@ case class OctetLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeAnyCollation, BinaryType)) + Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType)) protected override def nullSafeEval(value: Any): Any = child.dataType match { case _: StringType => value.asInstanceOf[UTF8String].numBytes @@ -2466,8 +2472,9 @@ case class Levenshtein( } override def inputTypes: Seq[AbstractDataType] = threshold match { - case Some(_) => Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType) - case _ => Seq(StringTypeAnyCollation, StringTypeAnyCollation) + case Some(_) => + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity, IntegerType) + case _ => Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) } override def children: Seq[Expression] = threshold match { @@ -2592,7 +2599,7 @@ case class SoundEx(child: Expression) override def dataType: DataType = SQLConf.get.defaultStringType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def nullSafeEval(input: Any): Any = input.asInstanceOf[UTF8String].soundex() @@ -2622,7 +2629,7 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) protected override def nullSafeEval(string: Any): Any = { // only pick the first character to reduce the `toString` cost @@ -2767,7 +2774,7 @@ case class UnBase64(child: Expression, failOnError: Boolean = false) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = BinaryType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) def this(expr: Expression) = this(expr, false) @@ -2946,7 +2953,8 @@ case class StringDecode( this(bin, charset, SQLConf.get.legacyJavaCharsets, SQLConf.get.legacyCodingErrorAction) override val dataType: DataType = SQLConf.get.defaultStringType - override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(BinaryType, StringTypeWithCaseAccentSensitivity) override def prettyName: String = "decode" override def toString: String = s"$prettyName($bin, $charset)" @@ -2955,7 +2963,7 @@ case class StringDecode( SQLConf.get.defaultStringType, "decode", Seq(bin, charset, Literal(legacyCharsets), Literal(legacyErrorAction)), - Seq(BinaryType, StringTypeAnyCollation, BooleanType, BooleanType)) + Seq(BinaryType, StringTypeWithCaseAccentSensitivity, BooleanType, BooleanType)) override def children: Seq[Expression] = Seq(bin, charset) override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = @@ -3012,15 +3020,20 @@ case class Encode( override def dataType: DataType = BinaryType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) override lazy val replacement: Expression = StaticInvoke( classOf[Encode], BinaryType, "encode", Seq( - str, charset, Literal(legacyCharsets, BooleanType), Literal(legacyErrorAction, BooleanType)), - Seq(StringTypeAnyCollation, StringTypeAnyCollation, BooleanType, BooleanType)) + str, charset, Literal(legacyCharsets, BooleanType), Literal(legacyErrorAction, BooleanType) + ), + Seq( + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, + BooleanType, + BooleanType)) override def toString: String = s"$prettyName($str, $charset)" @@ -3104,7 +3117,8 @@ case class ToBinary( override def children: Seq[Expression] = expr +: format.toSeq - override def inputTypes: Seq[AbstractDataType] = children.map(_ => StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + children.map(_ => StringTypeWithCaseAccentSensitivity) override def checkInputDataTypes(): TypeCheckResult = { def isValidFormat: Boolean = { @@ -3120,7 +3134,8 @@ case class ToBinary( errorSubClass = "INVALID_ARG_VALUE", messageParameters = Map( "inputName" -> "fmt", - "requireType" -> s"case-insensitive ${toSQLType(StringTypeAnyCollation)}", + "requireType" -> + s"case-insensitive ${toSQLType(StringTypeWithCaseAccentSensitivity)}", "validValues" -> "'hex', 'utf-8', 'utf8', or 'base64'", "inputValue" -> toSQLValue(fmt, f.dataType) ) @@ -3131,7 +3146,7 @@ case class ToBinary( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("fmt"), - "inputType" -> toSQLType(StringTypeAnyCollation), + "inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity), "inputExpr" -> toSQLExpr(f) ) ) @@ -3140,7 +3155,8 @@ case class ToBinary( errorSubClass = "INVALID_ARG_VALUE", messageParameters = Map( "inputName" -> "fmt", - "requireType" -> s"case-insensitive ${toSQLType(StringTypeAnyCollation)}", + "requireType" -> + s"case-insensitive ${toSQLType(StringTypeWithCaseAccentSensitivity)}", "validValues" -> "'hex', 'utf-8', 'utf8', or 'base64'", "inputValue" -> toSQLValue(f.eval(), f.dataType) ) @@ -3189,7 +3205,7 @@ case class FormatNumber(x: Expression, d: Expression) override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(NumericType, TypeCollection(IntegerType, StringTypeAnyCollation)) + Seq(NumericType, TypeCollection(IntegerType, StringTypeWithCaseAccentSensitivity)) private val defaultFormat = "#,###,###,###,###,###,##0" @@ -3394,7 +3410,9 @@ case class Sentences( override def dataType: DataType = ArrayType(ArrayType(str.dataType, containsNull = false), containsNull = false) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation) + Seq( + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) override def first: Expression = str override def second: Expression = language override def third: Expression = country @@ -3540,10 +3558,9 @@ case class Luhncheck(input: Expression) extends RuntimeReplaceable with Implicit classOf[ExpressionImplUtils], BooleanType, "isLuhnNumber", - Seq(input), - inputTypes) + Seq(input), inputTypes) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def prettyName: String = "luhn_check" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala index 3e4e4f992002a..09e91da65484f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType} import org.apache.spark.unsafe.types.UTF8String @@ -59,13 +59,13 @@ case class UrlEncode(child: Expression) SQLConf.get.defaultStringType, "encode", Seq(child), - Seq(StringTypeAnyCollation)) + Seq(StringTypeWithCaseAccentSensitivity)) override protected def withNewChildInternal(newChild: Expression): Expression = { copy(child = newChild) } - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def prettyName: String = "url_encode" } @@ -98,13 +98,13 @@ case class UrlDecode(child: Expression, failOnError: Boolean = true) SQLConf.get.defaultStringType, "decode", Seq(child, Literal(failOnError)), - Seq(StringTypeAnyCollation, BooleanType)) + Seq(StringTypeWithCaseAccentSensitivity, BooleanType)) override protected def withNewChildInternal(newChild: Expression): Expression = { copy(child = newChild) } - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def prettyName: String = "url_decode" } @@ -190,7 +190,8 @@ case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.ge def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled) override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq.fill(children.size)(StringTypeWithCaseAccentSensitivity) override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "parse_url" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index 2c8ca1e8bb2bb..323f6e42f3e50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.types.variant._ import org.apache.spark.types.variant.VariantUtil.{IntervalFields, Type} @@ -66,7 +66,7 @@ case class ParseJson(child: Expression, failOnError: Boolean = true) inputTypes :+ BooleanType :+ BooleanType, returnNullable = !failOnError) - override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil override def dataType: DataType = VariantType @@ -271,7 +271,8 @@ case class VariantGet( final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET) - override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(VariantType, StringTypeWithCaseAccentSensitivity) override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index 31e65cf0abc95..6c38bd88144b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -42,7 +42,7 @@ abstract class XPathExtract override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) override def checkInputDataTypes(): TypeCheckResult = { if (!path.foldable) { @@ -50,7 +50,7 @@ abstract class XPathExtract errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("path"), - "inputType" -> toSQLType(StringTypeAnyCollation), + "inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity), "inputExpr" -> toSQLExpr(path) ) ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala index 48a87db291a8d..6f1430b04ed67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils._ import org.apache.spark.sql.catalyst.xml.{StaxXmlGenerator, StaxXmlParser, ValidatorUtil, XmlInferSchema, XmlOptions} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -124,7 +124,7 @@ case class XmlToStructs( defineCodeGen(ctx, ev, input => s"(InternalRow) $expr.nullSafeEval($input)") } - override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil override def prettyName: String = "from_xml" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala index de600d881b624..342dcbd8e6b6d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation} +import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeWithCaseAccentSensitivity} import org.apache.spark.sql.types._ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { @@ -1057,11 +1057,11 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { ArrayType(IntegerType)) shouldCast( ArrayType(StringType), - AbstractArrayType(StringTypeAnyCollation), + AbstractArrayType(StringTypeWithCaseAccentSensitivity), ArrayType(StringType)) shouldCast( ArrayType(IntegerType), - AbstractArrayType(StringTypeAnyCollation), + AbstractArrayType(StringTypeWithCaseAccentSensitivity), ArrayType(StringType)) shouldCast( ArrayType(StringType), @@ -1075,11 +1075,11 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { ArrayType(ArrayType(IntegerType))) shouldCast( ArrayType(ArrayType(StringType)), - AbstractArrayType(AbstractArrayType(StringTypeAnyCollation)), + AbstractArrayType(AbstractArrayType(StringTypeWithCaseAccentSensitivity)), ArrayType(ArrayType(StringType))) shouldCast( ArrayType(ArrayType(IntegerType)), - AbstractArrayType(AbstractArrayType(StringTypeAnyCollation)), + AbstractArrayType(AbstractArrayType(StringTypeWithCaseAccentSensitivity)), ArrayType(ArrayType(StringType))) shouldCast( ArrayType(ArrayType(StringType)), @@ -1088,14 +1088,16 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { // Invalid casts involving casting arrays into non-complex types. shouldNotCast(ArrayType(IntegerType), IntegerType) - shouldNotCast(ArrayType(StringType), StringTypeAnyCollation) + shouldNotCast(ArrayType(StringType), StringTypeWithCaseAccentSensitivity) shouldNotCast(ArrayType(StringType), IntegerType) - shouldNotCast(ArrayType(IntegerType), StringTypeAnyCollation) + shouldNotCast(ArrayType(IntegerType), StringTypeWithCaseAccentSensitivity) // Invalid casts involving casting arrays of arrays into arrays of non-complex types. shouldNotCast(ArrayType(ArrayType(IntegerType)), AbstractArrayType(IntegerType)) - shouldNotCast(ArrayType(ArrayType(StringType)), AbstractArrayType(StringTypeAnyCollation)) + shouldNotCast(ArrayType(ArrayType(StringType)), + AbstractArrayType(StringTypeWithCaseAccentSensitivity)) shouldNotCast(ArrayType(ArrayType(StringType)), AbstractArrayType(IntegerType)) - shouldNotCast(ArrayType(ArrayType(IntegerType)), AbstractArrayType(StringTypeAnyCollation)) + shouldNotCast(ArrayType(ArrayType(IntegerType)), + AbstractArrayType(StringTypeWithCaseAccentSensitivity)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 9b454ba764f92..1aae2f10b7326 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjectio import org.apache.spark.sql.catalyst.util.CharsetProvider import org.apache.spark.sql.errors.QueryExecutionErrors.toSQLId import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1466,7 +1466,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("fmt"), - "inputType" -> toSQLType(StringTypeAnyCollation), + "inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity), "inputExpr" -> toSQLExpr(wrongFmt) ) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 1d23774a51692..879c0c480943d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -66,10 +66,10 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi collationType: CollationType): Any = inputEntry match { case e: Class[_] if e.isAssignableFrom(classOf[Expression]) => - generateLiterals(StringTypeAnyCollation, collationType) + generateLiterals(StringTypeWithCaseAccentSensitivity, collationType) case se: Class[_] if se.isAssignableFrom(classOf[Seq[Expression]]) => - CreateArray(Seq(generateLiterals(StringTypeAnyCollation, collationType), - generateLiterals(StringTypeAnyCollation, collationType))) + CreateArray(Seq(generateLiterals(StringTypeWithCaseAccentSensitivity, collationType), + generateLiterals(StringTypeWithCaseAccentSensitivity, collationType))) case oe: Class[_] if oe.isAssignableFrom(classOf[Option[Any]]) => None case b: Class[_] if b.isAssignableFrom(classOf[Boolean]) => false case dt: Class[_] if dt.isAssignableFrom(classOf[DataType]) => StringType @@ -142,12 +142,12 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) ).head case ArrayType => - generateLiterals(StringTypeAnyCollation, collationType).map( + generateLiterals(StringTypeWithCaseAccentSensitivity, collationType).map( lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) ).head case MapType => - val key = generateLiterals(StringTypeAnyCollation, collationType) - val value = generateLiterals(StringTypeAnyCollation, collationType) + val key = generateLiterals(StringTypeWithCaseAccentSensitivity, collationType) + val value = generateLiterals(StringTypeWithCaseAccentSensitivity, collationType) CreateMap(Seq(key, value)) case MapType(keyType, valueType, _) => val key = generateLiterals(keyType, collationType) @@ -159,8 +159,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi CreateMap(Seq(key, value)) case StructType => CreateNamedStruct( - Seq(Literal("start"), generateLiterals(StringTypeAnyCollation, collationType), - Literal("end"), generateLiterals(StringTypeAnyCollation, collationType))) + Seq(Literal("start"), + generateLiterals(StringTypeWithCaseAccentSensitivity, collationType), + Literal("end"), generateLiterals(StringTypeWithCaseAccentSensitivity, collationType))) } /** @@ -209,10 +210,10 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case ArrayType(elementType, _) => "array(" + generateInputAsString(elementType, collationType) + ")" case ArrayType => - "array(" + generateInputAsString(StringTypeAnyCollation, collationType) + ")" + "array(" + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ")" case MapType => - "map(" + generateInputAsString(StringTypeAnyCollation, collationType) + ", " + - generateInputAsString(StringTypeAnyCollation, collationType) + ")" + "map(" + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ", " + + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ")" case MapType(keyType, valueType, _) => "map(" + generateInputAsString(keyType, collationType) + ", " + generateInputAsString(valueType, collationType) + ")" @@ -220,8 +221,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi "map(" + generateInputAsString(keyType, collationType) + ", " + generateInputAsString(valueType, collationType) + ")" case StructType => - "named_struct( 'start', " + generateInputAsString(StringTypeAnyCollation, collationType) + - ", 'end', " + generateInputAsString(StringTypeAnyCollation, collationType) + ")" + "named_struct( 'start', " + + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ", 'end', " + + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ")" case StructType(fields) => "named_struct(" + fields.map(f => "'" + f.name + "', " + generateInputAsString(f.dataType, collationType)).mkString(", ") + ")" @@ -267,10 +269,12 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case ArrayType(elementType, _) => "array<" + generateInputTypeAsStrings(elementType, collationType) + ">" case ArrayType => - "array<" + generateInputTypeAsStrings(StringTypeAnyCollation, collationType) + ">" + "array<" + generateInputTypeAsStrings(StringTypeWithCaseAccentSensitivity, collationType) + + ">" case MapType => - "map<" + generateInputTypeAsStrings(StringTypeAnyCollation, collationType) + ", " + - generateInputTypeAsStrings(StringTypeAnyCollation, collationType) + ">" + "map<" + generateInputTypeAsStrings(StringTypeWithCaseAccentSensitivity, collationType) + + ", " + + generateInputTypeAsStrings(StringTypeWithCaseAccentSensitivity, collationType) + ">" case MapType(keyType, valueType, _) => "map<" + generateInputTypeAsStrings(keyType, collationType) + ", " + generateInputTypeAsStrings(valueType, collationType) + ">" @@ -278,9 +282,10 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi "map<" + generateInputTypeAsStrings(keyType, collationType) + ", " + generateInputTypeAsStrings(valueType, collationType) + ">" case StructType => - "struct" + generateInputTypeAsStrings(StringTypeWithCaseAccentSensitivity, collationType) + ">" case StructType(fields) => "named_struct<" + fields.map(f => "'" + f.name + "', " + generateInputTypeAsStrings(f.dataType, collationType)).mkString(", ") + ">" @@ -293,8 +298,8 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi */ def hasStringType(inputType: AbstractDataType): Boolean = { inputType match { - case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase | AnyDataType => - true + case _: StringType | StringTypeWithCaseAccentSensitivity | StringTypeBinaryLcase | AnyDataType + => true case ArrayType => true case MapType => true case MapType(keyType, valueType, _) => hasStringType(keyType) || hasStringType(valueType) @@ -408,7 +413,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi var input: Seq[Expression] = Seq.empty var i = 0 for (_ <- 1 to 10) { - input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary) + input = input :+ generateLiterals(StringTypeWithCaseAccentSensitivity, Utf8Binary) try { method.invoke(null, funInfo.getClassName, input).asInstanceOf[ExpectsInputTypes] } @@ -498,7 +503,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi var input: Seq[Expression] = Seq.empty var result: Expression = null for (_ <- 1 to 10) { - input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary) + input = input :+ generateLiterals(StringTypeWithCaseAccentSensitivity, Utf8Binary) try { val tempResult = method.invoke(null, f.getClassName, input) if (result == null) result = tempResult.asInstanceOf[Expression] @@ -609,7 +614,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi var input: Seq[Expression] = Seq.empty var result: Expression = null for (_ <- 1 to 10) { - input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary) + input = input :+ generateLiterals(StringTypeWithCaseAccentSensitivity, Utf8Binary) try { val tempResult = method.invoke(null, f.getClassName, input) if (result == null) result = tempResult.asInstanceOf[Expression]