Skip to content

Commit

Permalink
[SPARK-49811][SQL] Rename StringTypeAnyCollation
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Rename StringTypeAnyCollation to StringTypeWithCaseAccentSensitivity. Name StringTypeAnyCollation is unfortunate, with adding new type of collations it requires ren

### Why are the changes needed?
Name StringTypeAnyCollation is unfortunate, with adding new specifier (for example trim specifier) it requires always renaming it to (something like AllCollationExeptTrimCollation) until new collation is implemented in all functions. It gets even more confusing if multiple collations are not supported for some functions.
Instead of this naming convention should be only specifiers that are supported and avoid using all.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Just renaming all tests passing.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48265 from jovanpavl-db/rename-string-type-collations.

Authored-by: Jovan Pavlovic <jovan.pavlovic@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
jovanpavl-db authored and cloud-fan committed Sep 30, 2024
1 parent 885c3fa commit d85e7bc
Show file tree
Hide file tree
Showing 24 changed files with 218 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}

/**
* StringTypeCollated is an abstract class for StringType with collation support.
* AbstractStringType is an abstract class for StringType with collation support.
*/
abstract class AbstractStringType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType
Expand All @@ -46,9 +46,10 @@ case object StringTypeBinaryLcase extends AbstractStringType {
}

/**
* Use StringTypeAnyCollation for expressions supporting all possible collation types.
* Use StringTypeWithCaseAccentSensitivity for expressions supporting all collation types (binary
* and ICU) but limited to using case and accent sensitivity specifiers.
*/
case object StringTypeAnyCollation extends AbstractStringType {
case object StringTypeWithCaseAccentSensitivity extends AbstractStringType {
override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType, StringTypeAnyCollation}
import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType,
StringTypeWithCaseAccentSensitivity}
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.UpCastRule.numericPrecedence

Expand Down Expand Up @@ -438,7 +439,7 @@ abstract class TypeCoercionBase {
}

case aj @ ArrayJoin(arr, d, nr)
if !AbstractArrayType(StringTypeAnyCollation).acceptsType(arr.dataType) &&
if !AbstractArrayType(StringTypeWithCaseAccentSensitivity).acceptsType(arr.dataType) &&
ArrayType.acceptsType(arr.dataType) =>
val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull
implicitCast(arr, ArrayType(StringType, containsNull)) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch,
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeAnyCollation
import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ArrayImplicits._
Expand Down Expand Up @@ -84,7 +84,7 @@ case class CallMethodViaReflection(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> toSQLId("class"),
"inputType" -> toSQLType(StringTypeAnyCollation),
"inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity),
"inputExpr" -> toSQLExpr(children.head)
)
)
Expand All @@ -97,7 +97,7 @@ case class CallMethodViaReflection(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> toSQLId("method"),
"inputType" -> toSQLType(StringTypeAnyCollation),
"inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity),
"inputExpr" -> toSQLExpr(children(1))
)
)
Expand All @@ -114,7 +114,8 @@ case class CallMethodViaReflection(
"paramIndex" -> ordinalNumber(idx),
"requiredType" -> toSQLType(
TypeCollection(BooleanType, ByteType, ShortType,
IntegerType, LongType, FloatType, DoubleType, StringTypeAnyCollation)),
IntegerType, LongType, FloatType, DoubleType,
StringTypeWithCaseAccentSensitivity)),
"inputSql" -> toSQLExpr(e),
"inputType" -> toSQLType(e.dataType))
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.internal.types.StringTypeAnyCollation
import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation)
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity)
override def dataType: DataType = BinaryType

final lazy val collationId: Int = expr.dataType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeAnyCollation}
import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCaseAccentSensitivity}
import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType, VariantType}
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -61,7 +61,8 @@ object ExprUtils extends QueryErrorsBase {

def convertToMapData(exp: Expression): Map[String, String] = exp match {
case m: CreateMap
if AbstractMapType(StringTypeAnyCollation, StringTypeAnyCollation).acceptsType(m.dataType) =>
if AbstractMapType(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity)
.acceptsType(m.dataType) =>
val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData]
ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) =>
key.toString -> value.toString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression,
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.types.StringTypeAnyCollation
import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity
import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, DataType, IntegerType, LongType, StringType, TypeCollection}
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -105,7 +105,9 @@ case class HllSketchAgg(
override def prettyName: String = "hll_sketch_agg"

override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(IntegerType, LongType, StringTypeAnyCollation, BinaryType), IntegerType)
Seq(
TypeCollection(IntegerType, LongType, StringTypeWithCaseAccentSensitivity, BinaryType),
IntegerType)

override def dataType: DataType = BinaryType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeAnyCollation
import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity
import org.apache.spark.sql.types._

// scalastyle:off line.contains.tab
Expand Down Expand Up @@ -73,7 +73,7 @@ case class Collate(child: Expression, collationName: String)
extends UnaryExpression with ExpectsInputTypes {
private val collationId = CollationFactory.collationNameToId(collationName)
override def dataType: DataType = StringType(collationId)
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation)
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity)

override protected def withNewChildInternal(
newChild: Expression): Expression = copy(newChild)
Expand Down Expand Up @@ -111,5 +111,5 @@ case class Collation(child: Expression)
val collationName = CollationFactory.fetchCollation(collationId).collationName
Literal.create(collationName, SQLConf.get.defaultStringType)
}
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation)
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity)
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation}
import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeWithCaseAccentSensitivity}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SQLOpenHashSet
import org.apache.spark.unsafe.UTF8StringBuilder
Expand Down Expand Up @@ -1348,7 +1348,7 @@ case class Reverse(child: Expression)

// Input types are utilized by type coercion in ImplicitTypeCasts.
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(StringTypeAnyCollation, ArrayType))
Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, ArrayType))

override def dataType: DataType = child.dataType

Expand Down Expand Up @@ -2134,9 +2134,12 @@ case class ArrayJoin(
this(array, delimiter, Some(nullReplacement))

override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) {
Seq(AbstractArrayType(StringTypeAnyCollation), StringTypeAnyCollation, StringTypeAnyCollation)
Seq(AbstractArrayType(StringTypeWithCaseAccentSensitivity),
StringTypeWithCaseAccentSensitivity,
StringTypeWithCaseAccentSensitivity)
} else {
Seq(AbstractArrayType(StringTypeAnyCollation), StringTypeAnyCollation)
Seq(AbstractArrayType(StringTypeWithCaseAccentSensitivity),
StringTypeWithCaseAccentSensitivity)
}

override def children: Seq[Expression] = if (nullReplacement.isDefined) {
Expand Down Expand Up @@ -2857,7 +2860,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
with QueryErrorsBase {

private def allowedTypes: Seq[AbstractDataType] =
Seq(StringTypeAnyCollation, BinaryType, ArrayType)
Seq(StringTypeWithCaseAccentSensitivity, BinaryType, ArrayType)

final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.TypeUtils._
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeAnyCollation
import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -147,7 +147,7 @@ case class CsvToStructs(
converter(parser.parse(csv))
}

override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil
override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil

override def prettyName: String = "from_csv"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.catalyst.util.LegacyDateFormats.SIMPLE_DATE_FORMAT
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeAnyCollation
import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DayTimeIntervalType.DAY
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
Expand Down Expand Up @@ -961,7 +961,8 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti

override def dataType: DataType = SQLConf.get.defaultStringType

override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringTypeAnyCollation)
override def inputTypes: Seq[AbstractDataType] =
Seq(TimestampType, StringTypeWithCaseAccentSensitivity)

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
Expand Down Expand Up @@ -1269,8 +1270,10 @@ abstract class ToTimestamp
override def forTimestampNTZ: Boolean = left.dataType == TimestampNTZType

override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType),
StringTypeAnyCollation)
Seq(TypeCollection(
StringTypeWithCaseAccentSensitivity, DateType, TimestampType, TimestampNTZType
),
StringTypeWithCaseAccentSensitivity)

override def dataType: DataType = LongType
override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true
Expand Down Expand Up @@ -1441,7 +1444,8 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[
override def dataType: DataType = SQLConf.get.defaultStringType
override def nullable: Boolean = true

override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringTypeAnyCollation)
override def inputTypes: Seq[AbstractDataType] =
Seq(LongType, StringTypeWithCaseAccentSensitivity)

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Option(timeZoneId))
Expand Down Expand Up @@ -1549,7 +1553,8 @@ case class NextDay(

def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled)

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringTypeAnyCollation)
override def inputTypes: Seq[AbstractDataType] =
Seq(DateType, StringTypeWithCaseAccentSensitivity)

override def dataType: DataType = DateType
override def nullable: Boolean = true
Expand Down Expand Up @@ -1760,7 +1765,8 @@ sealed trait UTCTimestamp extends BinaryExpression with ImplicitCastInputTypes w
val func: (Long, String) => Long
val funcName: String

override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringTypeAnyCollation)
override def inputTypes: Seq[AbstractDataType] =
Seq(TimestampType, StringTypeWithCaseAccentSensitivity)
override def dataType: DataType = TimestampType

override def nullSafeEval(time: Any, timezone: Any): Any = {
Expand Down Expand Up @@ -2100,8 +2106,9 @@ case class ParseToDate(
override def inputTypes: Seq[AbstractDataType] = {
// Note: ideally this function should only take string input, but we allow more types here to
// be backward compatible.
TypeCollection(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType) +:
format.map(_ => StringTypeAnyCollation).toSeq
TypeCollection(
StringTypeWithCaseAccentSensitivity, DateType, TimestampType, TimestampNTZType) +:
format.map(_ => StringTypeWithCaseAccentSensitivity).toSeq
}

override protected def withNewChildrenInternal(
Expand Down Expand Up @@ -2172,10 +2179,10 @@ case class ParseToTimestamp(
override def inputTypes: Seq[AbstractDataType] = {
// Note: ideally this function should only take string input, but we allow more types here to
// be backward compatible.
val types = Seq(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType)
val types = Seq(StringTypeWithCaseAccentSensitivity, DateType, TimestampType, TimestampNTZType)
TypeCollection(
(if (dataType.isInstanceOf[TimestampType]) types :+ NumericType else types): _*
) +: format.map(_ => StringTypeAnyCollation).toSeq
) +: format.map(_ => StringTypeWithCaseAccentSensitivity).toSeq
}

override protected def withNewChildrenInternal(
Expand Down Expand Up @@ -2305,7 +2312,8 @@ case class TruncDate(date: Expression, format: Expression)
override def left: Expression = date
override def right: Expression = format

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringTypeAnyCollation)
override def inputTypes: Seq[AbstractDataType] =
Seq(DateType, StringTypeWithCaseAccentSensitivity)
override def dataType: DataType = DateType
override def prettyName: String = "trunc"
override val instant = date
Expand Down Expand Up @@ -2374,7 +2382,8 @@ case class TruncTimestamp(
override def left: Expression = format
override def right: Expression = timestamp

override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, TimestampType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCaseAccentSensitivity, TimestampType)
override def dataType: TimestampType = TimestampType
override def prettyName: String = "date_trunc"
override val instant = timestamp
Expand Down Expand Up @@ -2675,7 +2684,7 @@ case class MakeTimestamp(
// casted into decimal safely, we use DecimalType(16, 6) which is wider than DecimalType(10, 0).
override def inputTypes: Seq[AbstractDataType] =
Seq(IntegerType, IntegerType, IntegerType, IntegerType, IntegerType, DecimalType(16, 6)) ++
timezone.map(_ => StringTypeAnyCollation)
timezone.map(_ => StringTypeWithCaseAccentSensitivity)
override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true

override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
Expand Down Expand Up @@ -3122,8 +3131,8 @@ case class ConvertTimezone(
override def second: Expression = targetTz
override def third: Expression = sourceTs

override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation,
StringTypeAnyCollation, TimestampNTZType)
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity, TimestampNTZType)
override def dataType: DataType = TimestampNTZType

override def nullSafeEval(srcTz: Any, tgtTz: Any, micros: Any): Any = {
Expand Down
Loading

0 comments on commit d85e7bc

Please sign in to comment.