Skip to content

Commit

Permalink
[SPARK-50032][SQL] Allow use of fully qualified collation name
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
In this PR collations can now be identified by their fully qualified name, as per the collation project plan. The `Collation` expression has been changed to always return fully qualified name. Currently we only support predefined collations.

### Why are the changes needed?
Make collation names behave as per the project spec.

### Does this PR introduce _any_ user-facing change?
Yes. Two user-facing changes are made:
1. Collation expression now returns fully qualified name:
```sql
select collation('a' collate utf8_lcase) -- returns `SYSTEM.BUILTIN.UTF8_LCASE`
```
2. Collations can now be identified by their full qualified name:
```sql
select contains('a' collate system.builtin.utf8_lcase, 'A') -- returns true
```

### How was this patch tested?
New tests in this PR.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#48546 from stevomitric/stevomitric/fully-qualified-name.

Lead-authored-by: Stevo Mitric <stevo.mitric@databricks.com>
Co-authored-by: Wenchen Fan <cloud0fan@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
stevomitric and cloud-fan committed Dec 1, 2024
1 parent 3fab712 commit faf74ad
Show file tree
Hide file tree
Showing 20 changed files with 347 additions and 165 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -415,18 +415,6 @@ private static Collation fetchCollation(int collationId) {
}
}

/**
* Method for constructing errors thrown on providing invalid collation name.
*/
protected static SparkException collationInvalidNameException(String collationName) {
Map<String, String> params = new HashMap<>();
final int maxSuggestions = 3;
params.put("collationName", collationName);
params.put("proposals", getClosestSuggestionsOnInvalidName(collationName, maxSuggestions));
return new SparkException("COLLATION_INVALID_NAME",
SparkException.constructMessageParams(params), null);
}

private static int collationNameToId(String collationName) throws SparkException {
// Collation names provided by user are treated as case-insensitive.
String collationNameUpper = collationName.toUpperCase();
Expand Down Expand Up @@ -1185,6 +1173,52 @@ public static int collationNameToId(String collationName) throws SparkException
return Collation.CollationSpec.collationNameToId(collationName);
}

/**
* Returns the resolved fully qualified collation name.
*/
public static String resolveFullyQualifiedName(String[] collationName) throws SparkException {
// If collation name has only one part, then we don't need to do any name resolution.
if (collationName.length == 1) return collationName[0];
else {
// Currently we only support builtin collation names with fixed catalog `SYSTEM` and
// schema `BUILTIN`.
if (collationName.length != 3 ||
!CollationFactory.CATALOG.equalsIgnoreCase(collationName[0]) ||
!CollationFactory.SCHEMA.equalsIgnoreCase(collationName[1])) {
// Throw exception with original (before case conversion) collation name.
throw CollationFactory.collationInvalidNameException(
collationName.length != 0 ? collationName[collationName.length - 1] : "");
}
return collationName[2];
}
}

/**
* Method for constructing errors thrown on providing invalid collation name.
*/
public static SparkException collationInvalidNameException(String collationName) {
Map<String, String> params = new HashMap<>();
final int maxSuggestions = 3;
params.put("collationName", collationName);
params.put("proposals", getClosestSuggestionsOnInvalidName(collationName, maxSuggestions));
return new SparkException("COLLATION_INVALID_NAME",
SparkException.constructMessageParams(params), null);
}



/**
* Returns the fully qualified collation name for the given collation ID.
*/
public static String fullyQualifiedName(int collationId) {
Collation.CollationSpec.DefinitionOrigin definitionOrigin =
Collation.CollationSpec.getDefinitionOrigin(collationId);
// Currently only predefined collations are supported.
assert definitionOrigin == Collation.CollationSpec.DefinitionOrigin.PREDEFINED;
return String.format("%s.%s.%s", CATALOG, SCHEMA,
Collation.CollationSpec.fetchCollation(collationId).collationName);
}

public static boolean isCaseInsensitive(int collationId) {
return Collation.CollationSpecICU.fromCollationId(collationId).caseSensitivity ==
Collation.CollationSpecICU.CaseSensitivity.CI;
Expand Down
12 changes: 6 additions & 6 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16756,12 +16756,12 @@ def collation(col: "ColumnOrName") -> Column:
Examples
--------
>>> df = spark.createDataFrame([('name',)], ['dt'])
>>> df.select(collation('dt').alias('collation')).show()
+-----------+
| collation|
+-----------+
|UTF8_BINARY|
+-----------+
>>> df.select(collation('dt').alias('collation')).show(truncate=False)
+--------------------------+
|collation |
+--------------------------+
|SYSTEM.BUILTIN.UTF8_BINARY|
+--------------------------+
"""
return _invoke_function_over_columns("collation", col)

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def test_string_functions(self):
def test_collation(self):
df = self.spark.createDataFrame([("a",), ("b",)], ["name"])
actual = df.select(F.collation(F.collate("name", "UNICODE"))).distinct().collect()
self.assertEqual([Row("UNICODE")], actual)
self.assertEqual([Row("SYSTEM.BUILTIN.UNICODE")], actual)

def test_try_make_interval(self):
df = self.spark.createDataFrame([(2147483647,)], ["num"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1233,7 +1233,7 @@ colPosition
;

collateClause
: COLLATE collationName=identifier
: COLLATE collationName=multipartIdentifier
;

type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
}
}

/**
* Create a multi-part identifier.
*/
override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] =
withOrigin(ctx) {
ctx.parts.asScala.map(_.getText).toSeq
}

/**
* Resolve/create a primitive type.
*/
Expand All @@ -78,8 +86,9 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
typeCtx.children.asScala.toSeq match {
case Seq(_) => StringType
case Seq(_, ctx: CollateClauseContext) =>
val collationName = visitCollateClause(ctx)
val collationId = CollationFactory.collationNameToId(collationName)
val collationNameParts = visitCollateClause(ctx).toArray
val collationId = CollationFactory.collationNameToId(
CollationFactory.resolveFullyQualifiedName(collationNameParts))
StringType(collationId)
}
case (CHARACTER | CHAR, length :: Nil) => CharType(length.getText.toInt)
Expand Down Expand Up @@ -219,8 +228,8 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
/**
* Returns a collation name.
*/
override def visitCollateClause(ctx: CollateClauseContext): String = withOrigin(ctx) {
ctx.identifier.getText
override def visitCollateClause(ctx: CollateClauseContext): Seq[String] = withOrigin(ctx) {
visitMultipartIdentifier(ctx.collationName)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
ResolveFieldNameAndPosition ::
AddMetadataColumns ::
DeduplicateRelations ::
ResolveCollationName ::
new ResolveReferences(catalogManager) ::
// Please do not insert any other rules in between. See the TODO comments in rule
// ResolveLateralColumnAliasReference for more details.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_COLLATION
import org.apache.spark.sql.catalyst.util.CollationFactory

/**
* Resolves fully qualified collation name and replaces [[UnresolvedCollation]] with
* [[ResolvedCollation]].
*/
object ResolveCollationName extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveExpressionsWithPruning(_.containsPattern(UNRESOLVED_COLLATION), ruleId) {
case UnresolvedCollation(collationName) =>
ResolvedCollation(CollationFactory.resolveFullyQualifiedName(collationName.toArray))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, UnresolvedException}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_COLLATION}
import org.apache.spark.sql.catalyst.util.{AttributeNameParser, CollationFactory}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeWithCollation
Expand All @@ -37,7 +39,7 @@ import org.apache.spark.sql.types._
examples = """
Examples:
> SELECT COLLATION('Spark SQL' _FUNC_ UTF8_LCASE);
UTF8_LCASE
SYSTEM.BUILTIN.UTF8_LCASE
""",
since = "4.0.0",
group = "string_funcs")
Expand All @@ -56,7 +58,8 @@ object CollateExpressionBuilder extends ExpressionBuilder {
evalCollation.toString.toUpperCase().contains("TRIM")) {
throw QueryCompilationErrors.trimCollationNotEnabledError()
}
Collate(e, evalCollation.toString)
Collate(e, UnresolvedCollation(
AttributeNameParser.parseAttributeName(evalCollation.toString)))
}
case (_: StringType, false) => throw QueryCompilationErrors.nonFoldableArgumentError(
funcName, "collationName", StringType)
Expand All @@ -73,24 +76,63 @@ object CollateExpressionBuilder extends ExpressionBuilder {
* This function is pass-through, it will not modify the input data.
* Only type metadata will be updated.
*/
case class Collate(child: Expression, collationName: String)
extends UnaryExpression with ExpectsInputTypes {
private val collationId = CollationFactory.collationNameToId(collationName)
override def dataType: DataType = StringType(collationId)
case class Collate(child: Expression, collation: Expression)
extends BinaryExpression with ExpectsInputTypes {
override def left: Expression = child
override def right: Expression = collation
override def dataType: DataType = collation.dataType
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation(supportsTrimCollation = true))

override protected def withNewChildInternal(
newChild: Expression): Expression = copy(newChild)
Seq(StringTypeWithCollation(supportsTrimCollation = true), AnyDataType)

override def eval(row: InternalRow): Any = child.eval(row)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
defineCodeGen(ctx, ev, (in) => in)
/** Just a simple passthrough for code generation. */
override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx)
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
throw SparkException.internalError("Collate.doGenCode should not be called.")
}

override def sql: String = s"$prettyName(${child.sql}, $collation)"

override def toString: String =
s"$prettyName($child, $collation)"

override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): Expression =
copy(child = newLeft, collation = newRight)

override def foldable: Boolean = child.foldable
}

/**
* An expression that marks an unresolved collation name.
*
* This class is used to represent a collation name that has not yet been resolved from a fully
* qualified collation name. It is used during the analysis phase, where the collation name is
* specified but not yet validated or resolved.
*/
case class UnresolvedCollation(collationName: Seq[String])
extends LeafExpression with Unevaluable {
override def dataType: DataType = throw new UnresolvedException("dataType")

override def nullable: Boolean = false

override lazy val resolved: Boolean = false

final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_COLLATION)
}

/**
* An expression that represents a resolved collation name.
*/
case class ResolvedCollation(collationName: String) extends LeafExpression with Unevaluable {
override def nullable: Boolean = false

override def dataType: DataType = StringType(CollationFactory.collationNameToId(collationName))

override def sql: String = s"$prettyName(${child.sql}, $collationName)"
override def toString: String = collationName

override def toString: String = s"$prettyName($child, $collationName)"
override def sql: String = collationName
}

// scalastyle:off line.contains.tab
Expand All @@ -103,7 +145,7 @@ case class Collate(child: Expression, collationName: String)
examples = """
Examples:
> SELECT _FUNC_('Spark SQL');
UTF8_BINARY
SYSTEM.BUILTIN.UTF8_BINARY
""",
since = "4.0.0",
group = "string_funcs")
Expand All @@ -113,8 +155,8 @@ case class Collation(child: Expression)
override protected def withNewChildInternal(newChild: Expression): Collation = copy(newChild)
override lazy val replacement: Expression = {
val collationId = child.dataType.asInstanceOf[StringType].collationId
val collationName = CollationFactory.fetchCollation(collationId).collationName
Literal.create(collationName, SQLConf.get.defaultStringType)
val fullyQualifiedCollationName = CollationFactory.fullyQualifiedName(collationId)
Literal.create(fullyQualifiedCollationName, SQLConf.get.defaultStringType)
}
override def inputTypes: Seq[AbstractDataType] =
Seq(StringTypeWithCollation(supportsTrimCollation = true))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2286,14 +2286,6 @@ class AstBuilder extends DataTypeAstBuilder
FunctionIdentifier(ctx.function.getText, Option(ctx.db).map(_.getText))
}

/**
* Create a multi-part identifier.
*/
override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] =
withOrigin(ctx) {
ctx.parts.asScala.map(_.getText).toSeq
}

/* ********************************************************************************************
* Expression parsing
* ******************************************************************************************** */
Expand Down Expand Up @@ -2706,15 +2698,16 @@ class AstBuilder extends DataTypeAstBuilder
*/
override def visitCollate(ctx: CollateContext): Expression = withOrigin(ctx) {
val collationName = visitCollateClause(ctx.collateClause())
Collate(expression(ctx.primaryExpression), collationName)

Collate(expression(ctx.primaryExpression), UnresolvedCollation(collationName))
}

override def visitCollateClause(ctx: CollateClauseContext): String = withOrigin(ctx) {
val collationName = ctx.collationName.getText
if (!SQLConf.get.trimCollationEnabled && collationName.toUpperCase().contains("TRIM")) {
override def visitCollateClause(ctx: CollateClauseContext): Seq[String] = withOrigin(ctx) {
val collationName = visitMultipartIdentifier(ctx.collationName)
if (!SQLConf.get.trimCollationEnabled && collationName.last.toUpperCase().contains("TRIM")) {
throw QueryCompilationErrors.trimCollationNotEnabledError()
}
ctx.identifier.getText
collationName
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveBinaryArithmetic" ::
"org.apache.spark.sql.catalyst.analysis.ResolveCollationName" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveDeserializer" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveEncodersInUDF" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions" ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ object TreePattern extends Enumeration {
// Unresolved expression patterns (Alphabetically ordered)
val UNRESOLVED_ALIAS: Value = Value
val UNRESOLVED_ATTRIBUTE: Value = Value
val UNRESOLVED_COLLATION: Value = Value
val UNRESOLVED_DESERIALIZER: Value = Value
val UNRESOLVED_DF_STAR: Value = Value
val UNRESOLVED_HAVING: Value = Value
Expand Down
Loading

0 comments on commit faf74ad

Please sign in to comment.