From a767f5cb1704075ee249169e8faf2ab3610b9dbc Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sat, 17 Aug 2024 09:45:05 +0900 Subject: [PATCH] [SPARK-49022] Use Column Node API in Column ### What changes were proposed in this pull request? This PR makes the org.apache.spark.sql.Column and friends use the recently introduced ColumnNode API. This is a stepping stone towards making the Column API implementation agnostic. Most of the changes are fairly mechanical, and they are mostly caused by the removal of the Column(Expression) constructor. ### Why are the changes needed? We want to create unified Scala interface for Classic and Connect. A language agnostic Column API implementation is part of this. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #47688 from hvanhovell/SPARK-49022. Authored-by: Herman van Hovell Signed-off-by: Hyukjin Kwon --- R/pkg/R/functions.R | 75 +++---- .../org/apache/spark/sql/avro/functions.scala | 8 +- .../CheckConnectJvmClientCompatibility.scala | 6 + .../apache/spark/sql/protobuf/functions.scala | 20 +- .../spark/ml/feature/StringIndexer.scala | 3 +- .../org/apache/spark/ml/stat/Summarizer.scala | 2 +- project/MimaExcludes.scala | 5 +- python/pyspark/pandas/internal.py | 6 +- python/pyspark/sql/functions/builtin.py | 62 ++---- python/pyspark/sql/tests/test_dataframe.py | 1 - python/pyspark/sql/udtf.py | 8 +- .../sql/catalyst/analysis/Analyzer.scala | 14 +- .../connect/planner/SparkConnectPlanner.scala | 12 +- .../scala/org/apache/spark/sql/Column.scala | 196 +++++++++--------- .../spark/sql/DataFrameNaFunctions.scala | 2 +- .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../spark/sql/KeyValueGroupedDataset.scala | 2 +- .../apache/spark/sql/UDFRegistration.scala | 70 ++++--- .../spark/sql/api/python/PythonSQLUtils.scala | 25 ++- .../spark/sql/execution/aggregate/udaf.scala | 19 +- .../sql/execution/stat/StatFunctions.scala | 6 +- .../spark/sql/expressions/Aggregator.scala | 13 +- .../sql/expressions/UserDefinedFunction.scala | 40 +--- .../apache/spark/sql/expressions/Window.scala | 3 +- .../spark/sql/expressions/WindowSpec.scala | 63 +++--- .../apache/spark/sql/expressions/udaf.scala | 11 +- .../org/apache/spark/sql/functions.scala | 104 ++++------ .../internal/UserDefinedFunctionUtils.scala | 44 ++++ .../sql/internal/columnNodeSupport.scala | 20 +- .../spark/sql/ColumnExpressionSuite.scala | 8 +- .../spark/sql/DataFrameComplexTypeSuite.scala | 4 +- .../org/apache/spark/sql/DataFrameSuite.scala | 6 +- .../sql/DataFrameWindowFramesSuite.scala | 35 ++-- .../spark/sql/IntegratedUDFTestUtils.scala | 48 ++--- .../scala/org/apache/spark/sql/UDFSuite.scala | 3 +- .../errors/QueryExecutionErrorsSuite.scala | 2 +- .../spark/sql/internal/ColumnNodeSuite.scala | 6 +- ...ColumnNodeToExpressionConverterSuite.scala | 16 +- .../sql/streaming/StreamingQuerySuite.scala | 9 +- 39 files changed, 473 insertions(+), 508 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/internal/UserDefinedFunctionUtils.scala diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index b91124f96a6fa..9c825a99be180 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3965,19 +3965,11 @@ setMethod("row_number", #' yields unresolved \code{a.b.c} #' @return Column object wrapping JVM UnresolvedNamedLambdaVariable #' @keywords internal -unresolved_named_lambda_var <- function(...) { - jc <- newJObject( - "org.apache.spark.sql.Column", - newJObject( - "org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable", - lapply(list(...), function(x) { - handledCallJStatic( - "org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable", - "freshVarName", - x) - }) - ) - ) +unresolved_named_lambda_var <- function(name) { + jc <- handledCallJStatic( + "org.apache.spark.sql.api.python.PythonSQLUtils", + "unresolvedNamedLambdaVariable", + name) column(jc) } @@ -3990,7 +3982,6 @@ unresolved_named_lambda_var <- function(...) { #' @return JVM \code{LambdaFunction} object #' @keywords internal create_lambda <- function(fun) { - as_jexpr <- function(x) callJMethod(x@jc, "expr") # Process function arguments parameters <- formals(fun) @@ -4011,22 +4002,18 @@ create_lambda <- function(fun) { stopifnot(class(result) == "Column") # Convert both Columns to Scala expressions - jexpr <- as_jexpr(result) - jargs <- handledCallJStatic( "org.apache.spark.api.python.PythonUtils", "toSeq", - handledCallJStatic( - "java.util.Arrays", "asList", lapply(args, as_jexpr) - ) + handledCallJStatic("java.util.Arrays", "asList", lapply(args, function(x) { x@jc })) ) # Create Scala LambdaFunction - newJObject( - "org.apache.spark.sql.catalyst.expressions.LambdaFunction", - jexpr, - jargs, - FALSE + handledCallJStatic( + "org.apache.spark.sql.api.python.PythonSQLUtils", + "lambdaFunction", + result@jc, + jargs ) } @@ -4039,20 +4026,18 @@ create_lambda <- function(fun) { #' @return a \code{Column} representing name applied to cols with funs #' @keywords internal invoke_higher_order_function <- function(name, cols, funs) { - as_jexpr <- function(x) { + as_col <- function(x) { if (class(x) == "character") { x <- column(x) } - callJMethod(x@jc, "expr") + x@jc } - - jexpr <- do.call(newJObject, c( - paste("org.apache.spark.sql.catalyst.expressions", name, sep = "."), - lapply(cols, as_jexpr), - lapply(funs, create_lambda) - )) - - column(newJObject("org.apache.spark.sql.Column", jexpr)) + jcol <- handledCallJStatic( + "org.apache.spark.sql.api.python.PythonSQLUtils", + "fn", + name, + c(lapply(cols, as_col), lapply(funs, create_lambda))) # check varargs invocation + column(jcol) } #' @details @@ -4068,7 +4053,7 @@ setMethod("array_aggregate", signature(x = "characterOrColumn", initialValue = "Column", merge = "function"), function(x, initialValue, merge, finish = NULL) { invoke_higher_order_function( - "ArrayAggregate", + "aggregate", cols = list(x, initialValue), funs = if (is.null(finish)) { list(merge) @@ -4129,7 +4114,7 @@ setMethod("array_exists", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "ArrayExists", + "exists", cols = list(x), funs = list(f) ) @@ -4145,7 +4130,7 @@ setMethod("array_filter", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "ArrayFilter", + "filter", cols = list(x), funs = list(f) ) @@ -4161,7 +4146,7 @@ setMethod("array_forall", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "ArrayForAll", + "forall", cols = list(x), funs = list(f) ) @@ -4291,7 +4276,7 @@ setMethod("array_sort", column(callJStatic("org.apache.spark.sql.functions", "array_sort", x@jc)) } else { invoke_higher_order_function( - "ArraySort", + "array_sort", cols = list(x), funs = list(comparator) ) @@ -4309,7 +4294,7 @@ setMethod("array_transform", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "ArrayTransform", + "transform", cols = list(x), funs = list(f) ) @@ -4374,7 +4359,7 @@ setMethod("arrays_zip_with", signature(x = "characterOrColumn", y = "characterOrColumn", f = "function"), function(x, y, f) { invoke_higher_order_function( - "ZipWith", + "zip_with", cols = list(x, y), funs = list(f) ) @@ -4447,7 +4432,7 @@ setMethod("map_filter", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "MapFilter", + "map_filter", cols = list(x), funs = list(f)) }) @@ -4504,7 +4489,7 @@ setMethod("transform_keys", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "TransformKeys", + "transform_keys", cols = list(x), funs = list(f) ) @@ -4521,7 +4506,7 @@ setMethod("transform_values", signature(x = "characterOrColumn", f = "function"), function(x, f) { invoke_higher_order_function( - "TransformValues", + "transform_values", cols = list(x), funs = list(f) ) @@ -4552,7 +4537,7 @@ setMethod("map_zip_with", signature(x = "characterOrColumn", y = "characterOrColumn", f = "function"), function(x, y, f) { invoke_higher_order_function( - "MapZipWith", + "map_zip_with", cols = list(x, y), funs = list(f) ) diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala index 5830b2ec42383..1af7558200de3 100755 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala @@ -41,7 +41,7 @@ object functions { def from_avro( data: Column, jsonFormatSchema: String): Column = { - new Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, Map.empty)) + Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, Map.empty)) } /** @@ -62,7 +62,7 @@ object functions { data: Column, jsonFormatSchema: String, options: java.util.Map[String, String]): Column = { - new Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, options.asScala.toMap)) + Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, options.asScala.toMap)) } /** @@ -74,7 +74,7 @@ object functions { */ @Experimental def to_avro(data: Column): Column = { - new Column(CatalystDataToAvro(data.expr, None)) + Column(CatalystDataToAvro(data.expr, None)) } /** @@ -87,6 +87,6 @@ object functions { */ @Experimental def to_avro(data: Column, jsonFormatSchema: String): Column = { - new Column(CatalystDataToAvro(data.expr, Some(jsonFormatSchema))) + Column(CatalystDataToAvro(data.expr, Some(jsonFormatSchema))) } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 59b399da9a5c6..07c9e5190da00 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -300,6 +300,12 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.UDFRegistration.register"), + // Typed Column + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.TypedColumn.*"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.sql.TypedColumn.expr"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.TypedColumn$"), + // Datasource V2 partition transforms ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform$"), diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala index 91e87dee50482..2700764399606 100644 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala @@ -70,7 +70,7 @@ object functions { messageName: String, binaryFileDescriptorSet: Array[Byte], options: java.util.Map[String, String]): Column = { - new Column( + Column( ProtobufDataToCatalyst( data.expr, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap ) @@ -93,7 +93,7 @@ object functions { @Experimental def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = { val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) - new Column(ProtobufDataToCatalyst(data.expr, messageName, Some(fileContent))) + Column(ProtobufDataToCatalyst(data.expr, messageName, Some(fileContent))) } /** @@ -112,7 +112,7 @@ object functions { @Experimental def from_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte]) : Column = { - new Column(ProtobufDataToCatalyst(data.expr, messageName, Some(binaryFileDescriptorSet))) + Column(ProtobufDataToCatalyst(data.expr, messageName, Some(binaryFileDescriptorSet))) } /** @@ -132,7 +132,7 @@ object functions { */ @Experimental def from_protobuf(data: Column, messageClassName: String): Column = { - new Column(ProtobufDataToCatalyst(data.expr, messageClassName)) + Column(ProtobufDataToCatalyst(data.expr, messageClassName)) } /** @@ -156,7 +156,7 @@ object functions { data: Column, messageClassName: String, options: java.util.Map[String, String]): Column = { - new Column(ProtobufDataToCatalyst(data.expr, messageClassName, None, options.asScala.toMap)) + Column(ProtobufDataToCatalyst(data.expr, messageClassName, None, options.asScala.toMap)) } /** @@ -194,7 +194,7 @@ object functions { @Experimental def to_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte]) : Column = { - new Column(CatalystDataToProtobuf(data.expr, messageName, Some(binaryFileDescriptorSet))) + Column(CatalystDataToProtobuf(data.expr, messageName, Some(binaryFileDescriptorSet))) } /** * Converts a column into binary of protobuf format. The Protobuf definition is provided @@ -216,7 +216,7 @@ object functions { descFilePath: String, options: java.util.Map[String, String]): Column = { val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) - new Column( + Column( CatalystDataToProtobuf(data.expr, messageName, Some(fileContent), options.asScala.toMap) ) } @@ -242,7 +242,7 @@ object functions { binaryFileDescriptorSet: Array[Byte], options: java.util.Map[String, String] ): Column = { - new Column( + Column( CatalystDataToProtobuf( data.expr, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap ) @@ -266,7 +266,7 @@ object functions { */ @Experimental def to_protobuf(data: Column, messageClassName: String): Column = { - new Column(CatalystDataToProtobuf(data.expr, messageClassName)) + Column(CatalystDataToProtobuf(data.expr, messageClassName)) } /** @@ -288,6 +288,6 @@ object functions { @Experimental def to_protobuf(data: Column, messageClassName: String, options: java.util.Map[String, String]) : Column = { - new Column(CatalystDataToProtobuf(data.expr, messageClassName, None, options.asScala.toMap)) + Column(CatalystDataToProtobuf(data.expr, messageClassName, None, options.asScala.toMap)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 838869a2b3952..7cf39d4750314 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -28,7 +28,6 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset, Encoder, Encoders, Row} -import org.apache.spark.sql.catalyst.expressions.{If, Literal} import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -195,7 +194,7 @@ class StringIndexer @Since("1.4.0") ( } else { // We don't count for NaN values. Because `StringIndexerAggregator` only processes strings, // we replace NaNs with null in advance. - new Column(If(col.isNaN.expr, Literal(null), col.expr)).cast(StringType) + when(!isnan(col), col).cast(StringType) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala index 7a27b32aa24c5..9388205a751ec 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala @@ -257,7 +257,7 @@ private[ml] class SummaryBuilderImpl( mutableAggBufferOffset = 0, inputAggBufferOffset = 0) - new Column(agg.toAggregateExpression()) + Column(agg.toAggregateExpression()) } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 20e50469e8568..03bf9c89aa2dc 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -107,7 +107,10 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.scheduler.JobWaiter.cancel"), ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.cancel"), // SPARK-48901: Add clusterBy() to DataStreamWriter. - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.DataStreamWriter.clusterBy") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.DataStreamWriter.clusterBy"), + // SPARK-49022: Use Column API + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.TypedColumn.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.expressions.WindowSpec.this") ) // Default exclude rules diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py index c5fef3b138254..92d4a3357319f 100644 --- a/python/pyspark/pandas/internal.py +++ b/python/pyspark/pandas/internal.py @@ -915,10 +915,8 @@ def attach_distributed_column(sdf: PySparkDataFrame, column_name: str) -> PySpar if is_remote(): return sdf.select(F.monotonically_increasing_id().alias(column_name), *scols) jvm = sdf.sparkSession._jvm - tag = jvm.org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FUNC_ALIAS() - jexpr = F.monotonically_increasing_id()._jc.expr() - jexpr.setTagValue(tag, "distributed_index") - return sdf.select(PySparkColumn(jvm.Column(jexpr)).alias(column_name), *scols) + jcol = jvm.PythonSQLUtils.distributedIndex() + return sdf.select(PySparkColumn(jcol).alias(column_name), *scols) @staticmethod def attach_distributed_sequence_column( diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index fd684ddcb0f49..24b8ae82e99ad 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -17558,7 +17558,7 @@ def array_sort( if comparator is None: return _invoke_function_over_columns("array_sort", col) else: - return _invoke_higher_order_function("ArraySort", [col], [comparator]) + return _invoke_higher_order_function("array_sort", [col], [comparator]) @_try_remote_functions @@ -18559,7 +18559,7 @@ def from_csv( ) -def _unresolved_named_lambda_variable(*name_parts: Any) -> Column: +def _unresolved_named_lambda_variable(name: str) -> Column: """ Create `o.a.s.sql.expressions.UnresolvedNamedLambdaVariable`, convert it to o.s.sql.Column and wrap in Python `Column` @@ -18572,14 +18572,9 @@ def _unresolved_named_lambda_variable(*name_parts: Any) -> Column: name_parts : str """ from py4j.java_gateway import JVMView - from pyspark.sql.classic.column import _to_seq sc = _get_active_spark_context() - name_parts_seq = _to_seq(sc, name_parts) - expressions = cast(JVMView, sc._jvm).org.apache.spark.sql.catalyst.expressions - return Column( - cast(JVMView, sc._jvm).Column(expressions.UnresolvedNamedLambdaVariable(name_parts_seq)) - ) + return Column(cast(JVMView, sc._jvm).PythonSQLUtils.unresolvedNamedLambdaVariable(name)) def _get_lambda_parameters(f: Callable) -> ValuesView[inspect.Parameter]: @@ -18628,15 +18623,9 @@ def _create_lambda(f: Callable) -> Callable: parameters = _get_lambda_parameters(f) sc = _get_active_spark_context() - expressions = cast(JVMView, sc._jvm).org.apache.spark.sql.catalyst.expressions argnames = ["x", "y", "z"] - args = [ - _unresolved_named_lambda_variable( - expressions.UnresolvedNamedLambdaVariable.freshVarName(arg) - ) - for arg in argnames[: len(parameters)] - ] + args = [_unresolved_named_lambda_variable(arg) for arg in argnames[: len(parameters)]] result = f(*args) @@ -18646,10 +18635,9 @@ def _create_lambda(f: Callable) -> Callable: messageParameters={"func_name": f.__name__, "return_type": type(result).__name__}, ) - jexpr = result._jc.expr() - jargs = _to_seq(sc, [arg._jc.expr() for arg in args]) - - return expressions.LambdaFunction(jexpr, jargs, False) + jexpr = result._jc + jargs = _to_seq(sc, [arg._jc for arg in args]) + return cast(JVMView, sc._jvm).PythonSQLUtils.lambdaFunction(jexpr, jargs) def _invoke_higher_order_function( @@ -18669,16 +18657,12 @@ def _invoke_higher_order_function( :return: a Column """ from py4j.java_gateway import JVMView - from pyspark.sql.classic.column import _to_java_column + from pyspark.sql.classic.column import _to_seq, _to_java_column sc = _get_active_spark_context() - expressions = cast(JVMView, sc._jvm).org.apache.spark.sql.catalyst.expressions - expr = getattr(expressions, name) - - jcols = [_to_java_column(col).expr() for col in cols] jfuns = [_create_lambda(f) for f in funs] - - return Column(cast(JVMView, sc._jvm).Column(expr(*jcols + jfuns))) + jcols = [_to_java_column(c) for c in cols] + return Column(cast(JVMView, sc._jvm).PythonSQLUtils.fn(name, _to_seq(sc, jcols + jfuns))) @overload @@ -18746,7 +18730,7 @@ def transform( |[1, -2, 3, -4]| +--------------+ """ - return _invoke_higher_order_function("ArrayTransform", [col], [f]) + return _invoke_higher_order_function("transform", [col], [f]) @_try_remote_functions @@ -18787,7 +18771,7 @@ def exists(col: "ColumnOrName", f: Callable[[Column], Column]) -> Column: | true| +------------+ """ - return _invoke_higher_order_function("ArrayExists", [col], [f]) + return _invoke_higher_order_function("exists", [col], [f]) @_try_remote_functions @@ -18832,7 +18816,7 @@ def forall(col: "ColumnOrName", f: Callable[[Column], Column]) -> Column: | true| +-------+ """ - return _invoke_higher_order_function("ArrayForAll", [col], [f]) + return _invoke_higher_order_function("forall", [col], [f]) @overload @@ -18899,7 +18883,7 @@ def filter( |[2018-09-20, 2019-07-01]| +------------------------+ """ - return _invoke_higher_order_function("ArrayFilter", [col], [f]) + return _invoke_higher_order_function("filter", [col], [f]) @_try_remote_functions @@ -18972,10 +18956,10 @@ def aggregate( +----+ """ if finish is not None: - return _invoke_higher_order_function("ArrayAggregate", [col, initialValue], [merge, finish]) + return _invoke_higher_order_function("aggregate", [col, initialValue], [merge, finish]) else: - return _invoke_higher_order_function("ArrayAggregate", [col, initialValue], [merge]) + return _invoke_higher_order_function("aggregate", [col, initialValue], [merge]) @_try_remote_functions @@ -19045,10 +19029,10 @@ def reduce( +----+ """ if finish is not None: - return _invoke_higher_order_function("ArrayAggregate", [col, initialValue], [merge, finish]) + return _invoke_higher_order_function("reduce", [col, initialValue], [merge, finish]) else: - return _invoke_higher_order_function("ArrayAggregate", [col, initialValue], [merge]) + return _invoke_higher_order_function("reduce", [col, initialValue], [merge]) @_try_remote_functions @@ -19103,7 +19087,7 @@ def zip_with( |[foo_1, bar_2, 3]| +-----------------+ """ - return _invoke_higher_order_function("ZipWith", [left, right], [f]) + return _invoke_higher_order_function("zip_with", [left, right], [f]) @_try_remote_functions @@ -19143,7 +19127,7 @@ def transform_keys(col: "ColumnOrName", f: Callable[[Column, Column], Column]) - >>> sorted(row["data_upper"].items()) [('BAR', 2.0), ('FOO', -2.0)] """ - return _invoke_higher_order_function("TransformKeys", [col], [f]) + return _invoke_higher_order_function("transform_keys", [col], [f]) @_try_remote_functions @@ -19183,7 +19167,7 @@ def transform_values(col: "ColumnOrName", f: Callable[[Column, Column], Column]) >>> sorted(row["new_data"].items()) [('IT', 20.0), ('OPS', 34.0), ('SALES', 2.0)] """ - return _invoke_higher_order_function("TransformValues", [col], [f]) + return _invoke_higher_order_function("transform_values", [col], [f]) @_try_remote_functions @@ -19246,7 +19230,7 @@ def map_filter(col: "ColumnOrName", f: Callable[[Column, Column], Column]) -> Co >>> sorted(row["data_filtered"].items()) [('baz', 32.0)] """ - return _invoke_higher_order_function("MapFilter", [col], [f]) + return _invoke_higher_order_function("map_filter", [col], [f]) @_try_remote_functions @@ -19327,7 +19311,7 @@ def map_zip_with( >>> sorted(row["updated_data"].items()) [('A', 1), ('B', 5), ('C', None)] """ - return _invoke_higher_order_function("MapZipWith", [col1, col2], [f]) + return _invoke_higher_order_function("map_zip_with", [col1, col2], [f]) @_try_remote_functions diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 2feba0b3b345f..7dd42eecde7f8 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -114,7 +114,6 @@ def test_count_star(self): self.assertEqual(df3.select(count(df3["*"])).columns, ["count(1)"]) self.assertEqual(df3.select(count(col("*"))).columns, ["count(1)"]) - self.assertEqual(df3.select(count(col("s.*"))).columns, ["count(1)"]) def test_self_join(self): df1 = self.spark.range(10).withColumn("a", lit(0)) diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index f56b8358699d3..5ce3e2dfd2a9e 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -373,7 +373,7 @@ def _create_judtf(self, func: Type) -> "JavaObject": return judtf def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> "DataFrame": - from pyspark.sql.classic.column import _to_java_column, _to_java_expr, _to_seq + from pyspark.sql.classic.column import _to_java_column, _to_seq from pyspark.sql import DataFrame, SparkSession @@ -382,11 +382,7 @@ def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> "DataFram assert sc._jvm is not None jcols = [_to_java_column(arg) for arg in args] + [ - sc._jvm.Column( - sc._jvm.org.apache.spark.sql.catalyst.expressions.NamedArgumentExpression( - key, _to_java_expr(value) - ) - ) + sc._jvm.PythonSQLUtils.namedArgumentExpression(key, _to_java_column(value)) for key, value in kwargs.items() ] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 73caac4cb1d3e..20546c5c5be3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1600,6 +1600,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } else { a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) } + case c: CollectMetrics if containsStar(c.metrics) => + c.copy(metrics = buildExpandedProjectList(c.metrics, c.child)) case g: Generate if containsStar(g.generator.children) => throw QueryCompilationErrors.invalidStarUsageError("explode/json_tuple/UDTF", extractStar(g.generator.children)) @@ -1908,8 +1910,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } catch { case e: AnalysisException => AnalysisContext.get.outerPlan.map { - // Only Project and Aggregate can host star expressions. - case u @ (_: Project | _: Aggregate) => + // Only Project, Aggregate, CollectMetrics can host star expressions. + case u @ (_: Project | _: Aggregate | _: CollectMetrics) => Try(s.expand(u.children.head, resolver)) match { case Success(expanded) => expanded.map(wrapOuterReference) case Failure(_) => throw e @@ -1947,6 +1949,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor private def extractStar(exprs: Seq[Expression]): Seq[Star] = exprs.flatMap(_.collect { case s: Star => s }) + private def isCountStarExpansionAllowed(arguments: Seq[Expression]): Boolean = arguments match { + case Seq(UnresolvedStar(None)) => true + case Seq(_: ResolvedStar) => true + case _ => false + } + /** * Expands the matching attribute.*'s in `child`'s output. */ @@ -1954,7 +1962,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor expr.transformUp { case f0: UnresolvedFunction if !f0.isDistinct && f0.nameParts.map(_.toLowerCase(Locale.ROOT)) == Seq("count") && - f0.arguments == Seq(UnresolvedStar(None)) => + isCountStarExpansionAllowed(f0.arguments) => // Transform COUNT(*) into COUNT(1). f0.copy(nameParts = Seq("count"), arguments = Seq(Literal(1))) case f1: UnresolvedFunction if containsStar(f1.arguments) => diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 9e86430792079..c8aba5d19fe7f 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -67,7 +67,7 @@ import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionHolder, Spark import org.apache.spark.sql.connect.utils.MetricGenerator import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, TypedAggregateExpression} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.command.CreateViewCommand import org.apache.spark.sql.execution.datasources.LogicalRelation @@ -78,7 +78,7 @@ import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeoutFromString import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} -import org.apache.spark.sql.internal.{CatalogImpl, TypedAggUtils} +import org.apache.spark.sql.internal.{CatalogImpl, TypedAggUtils, UserDefinedFunctionUtils} import org.apache.spark.sql.protobuf.{CatalystDataToProtobuf, ProtobufDataToCatalyst} import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger} import org.apache.spark.sql.types._ @@ -1718,9 +1718,9 @@ class SparkConnectPlanner( val udf = fun.getScalarScalaUdf val udfPacket = unpackUdf(fun) if (udf.getAggregate) { - transformScalaFunction(fun) - .asInstanceOf[UserDefinedAggregator[Any, Any, Any]] - .scalaAggregator(fun.getArgumentsList.asScala.map(transformExpression).toSeq) + ScalaAggregator( + transformScalaFunction(fun).asInstanceOf[UserDefinedAggregator[Any, Any, Any]], + fun.getArgumentsList.asScala.map(transformExpression).toSeq) .toAggregateExpression() } else { ScalaUDF( @@ -1899,7 +1899,7 @@ class SparkConnectPlanner( fun: org.apache.spark.sql.expressions.UserDefinedFunction, exprs: Seq[Expression]): ScalaUDF = { val f = fun.asInstanceOf[org.apache.spark.sql.expressions.SparkUserDefinedFunction] - f.createScalaUDF(exprs) + UserDefinedFunctionUtils.toScalaUDF(f, exprs) } private def extractProtobufArgs(children: Seq[Expression]) = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 4a59f8eccba54..26df8cd9294b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -27,22 +27,21 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} +import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.internal.TypedAggUtils +import org.apache.spark.sql.internal.{ColumnNode, ExpressionColumnNode, TypedAggUtils} import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ -private[sql] object Column { +private[spark] object Column { def apply(colName: String): Column = new Column(colName) - def apply(expr: Expression): Column = new Column(expr) + def apply(expr: Expression): Column = Column(ExpressionColumnNode(expr)) - def unapply(col: Column): Option[Expression] = Some(col.expr) + def apply(node: => ColumnNode): Column = withOrigin(new Column(node)) private[sql] def generateAlias(e: Expression): String = { e match { @@ -69,9 +68,9 @@ private[sql] object Column { isDistinct: Boolean, isInternal: Boolean, inputs: Seq[Column]): Column = withOrigin { - Column(UnresolvedFunction( - name :: Nil, - inputs.map(_.expr), + Column(internal.UnresolvedFunction( + name, + inputs.map(_.node), isDistinct = isDistinct, isInternal = isInternal)) } @@ -89,9 +88,20 @@ private[sql] object Column { */ @Stable class TypedColumn[-T, U]( - expr: Expression, - private[sql] val encoder: ExpressionEncoder[U]) - extends Column(expr) { + node: ColumnNode, + private[sql] val encoder: Encoder[U], + private[sql] val inputType: Option[(ExpressionEncoder[_], Seq[Attribute])] = None) + extends Column(node) { + + override lazy val expr: Expression = { + val expression = internal.ColumnNodeToExpressionConverter(node) + inputType match { + case Some((inputEncoder, inputAttributes)) => + TypedAggUtils.withInputType(expression, inputEncoder, inputAttributes) + case None => + expression + } + } /** * Inserts the specific input type and schema into any expressions that are expected to operate @@ -100,8 +110,7 @@ class TypedColumn[-T, U]( private[sql] def withInputType( inputEncoder: ExpressionEncoder[_], inputAttributes: Seq[Attribute]): TypedColumn[T, U] = { - val newExpr = TypedAggUtils.withInputType(expr, inputEncoder, inputAttributes) - new TypedColumn[T, U](newExpr, encoder) + new TypedColumn[T, U](node, encoder, Option((inputEncoder, inputAttributes))) } /** @@ -113,7 +122,7 @@ class TypedColumn[-T, U]( * @since 2.0.0 */ override def name(alias: String): TypedColumn[T, U] = - new TypedColumn[T, U](super.name(alias).expr, encoder) + new TypedColumn[T, U](super.name(alias).node, encoder) } @@ -137,9 +146,6 @@ class TypedColumn[-T, U]( * $"a" === $"b" * }}} * - * @note The internal Catalyst expression can be accessed via [[expr]], but this method is for - * debugging purposes only and can change in any future Spark releases. - * * @groupname java_expr_ops Java-specific expression operators * @groupname expr_ops Expression operators * @groupname df_ops DataFrame functions @@ -148,15 +154,14 @@ class TypedColumn[-T, U]( * @since 1.3.0 */ @Stable -class Column(val expr: Expression) extends Logging { +class Column(val node: ColumnNode) extends Logging { + lazy val expr: Expression = internal.ColumnNodeToExpressionConverter(node) def this(name: String) = this(withOrigin { name match { - case "*" => UnresolvedStar(None) - case _ if name.endsWith(".*") => - val parts = UnresolvedAttribute.parseAttributeName(name.dropRight(2)) - UnresolvedStar(Some(parts)) - case _ => UnresolvedAttribute.quotedString(name) + case "*" => internal.UnresolvedStar(None) + case _ if name.endsWith(".*") => internal.UnresolvedStar(Option(name.dropRight(2))) + case _ => internal.UnresolvedAttribute(name) } }) @@ -170,23 +175,14 @@ class Column(val expr: Expression) extends Logging { Column.fn(name, this, lit(other)) } - override def toString: String = toPrettySQL(expr) + override def toString: String = node.sql override def equals(that: Any): Boolean = that match { - case that: Column => that.normalizedExpr() == this.normalizedExpr() + case that: Column => that.node.normalized == this.node.normalized case _ => false } - override def hashCode: Int = this.normalizedExpr().hashCode() - - private def normalizedExpr(): Expression = expr transform { - case a: AttributeReference => DetectAmbiguousSelfJoin.stripColumnReferenceMetadata(a) - } - - /** Creates a column based on the given expression. */ - private def withExpr(newExpr: => Expression): Column = withOrigin { - new Column(newExpr) - } + override def hashCode: Int = this.node.normalized.hashCode() /** * Returns the expression for this column either with an existing or auto assigned name. @@ -217,7 +213,7 @@ class Column(val expr: Expression) extends Logging { * results into the correct JVM types. * @since 1.6.0 */ - def as[U : Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](expr, encoderFor[U]) + def as[U : Encoder]: TypedColumn[Any, U] = new TypedColumn[Any, U](node, encoderFor[U]) /** * Extracts a value or values from a complex type. @@ -232,8 +228,8 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def apply(extraction: Any): Column = withExpr { - UnresolvedExtractValue(expr, lit(extraction).expr) + def apply(extraction: Any): Column = Column { + internal.UnresolvedExtractValue(node, lit(extraction).node) } /** @@ -283,14 +279,18 @@ class Column(val expr: Expression) extends Logging { * @since 1.3.0 */ def ===(other: Any): Column = { - val right = lit(other).expr - if (this.expr == right) { + val right = lit(other) + checkTrivialPredicate(right) + fn("=", other) + } + + private def checkTrivialPredicate(right: Column): Unit = { + if (this == right) { logWarning( log"Constructing trivially true equals predicate, " + - log"'${MDC(LEFT_EXPR, this.expr)} = ${MDC(RIGHT_EXPR, right)}'. " + + log"'${MDC(LEFT_EXPR, this)} == ${MDC(RIGHT_EXPR, right)}'. " + log"Perhaps you need to use aliases.") } - fn("=", other) } /** @@ -490,14 +490,9 @@ class Column(val expr: Expression) extends Logging { * @since 1.3.0 */ def <=>(other: Any): Column = { - val right = lit(other).expr - if (this.expr == right) { - logWarning( - log"Constructing trivially true equals predicate, " + - log"'${MDC(LEFT_EXPR, this.expr)} <=> ${MDC(RIGHT_EXPR, right)}'. " + - log"Perhaps you need to use aliases.") - } - fn("<=>", other) + val right = lit(other) + checkTrivialPredicate(right) + fn("<=>", right) } /** @@ -529,11 +524,11 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def when(condition: Column, value: Any): Column = withExpr { - this.expr match { - case CaseWhen(branches, None) => - CaseWhen(branches :+ ((condition.expr, lit(value).expr))) - case CaseWhen(_, Some(_)) => + def when(condition: Column, value: Any): Column = Column { + node match { + case internal.CaseWhenOtherwise(branches, None, _) => + internal.CaseWhenOtherwise(branches :+ ((condition.node, lit(value).node)), None) + case internal.CaseWhenOtherwise(_, Some(_), _) => throw new IllegalArgumentException( "when() cannot be applied once otherwise() is applied") case _ => @@ -563,11 +558,11 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def otherwise(value: Any): Column = withExpr { - this.expr match { - case CaseWhen(branches, None) => - CaseWhen(branches, Option(lit(value).expr)) - case CaseWhen(_, Some(_)) => + def otherwise(value: Any): Column = Column { + node match { + case internal.CaseWhenOtherwise(branches, None, _) => + internal.CaseWhenOtherwise(branches, Option(lit(value).node)) + case internal.CaseWhenOtherwise(_, Some(_), _) => throw new IllegalArgumentException( "otherwise() can only be applied once on a Column previously generated by when()") case _ => @@ -943,10 +938,10 @@ class Column(val expr: Expression) extends Logging { * @since 3.1.0 */ // scalastyle:on line.size.limit - def withField(fieldName: String, col: Column): Column = withExpr { + def withField(fieldName: String, col: Column): Column = { require(fieldName != null, "fieldName cannot be null") require(col != null, "col cannot be null") - UpdateFields(expr, fieldName, col.expr) + Column(internal.UpdateFields(node, fieldName, Option(col.node))) } // scalastyle:off line.size.limit @@ -1009,9 +1004,9 @@ class Column(val expr: Expression) extends Logging { * @since 3.1.0 */ // scalastyle:on line.size.limit - def dropFields(fieldNames: String*): Column = withExpr { - fieldNames.tail.foldLeft(UpdateFields(expr, fieldNames.head)) { - (resExpr, fieldName) => UpdateFields(resExpr, fieldName) + def dropFields(fieldNames: String*): Column = Column { + fieldNames.tail.foldLeft(internal.UpdateFields(node, fieldNames.head)) { + (resExpr, fieldName) => internal.UpdateFields(resExpr, fieldName) } } @@ -1121,7 +1116,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def as(aliases: Seq[String]): Column = withExpr { MultiAlias(expr, aliases) } + def as(aliases: Seq[String]): Column = Column(internal.Alias(node, aliases)) /** * Assigns the given aliases to the results of a table generating function. @@ -1133,9 +1128,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.4.0 */ - def as(aliases: Array[String]): Column = withExpr { - MultiAlias(expr, aliases.toImmutableArraySeq) - } + def as(aliases: Array[String]): Column = as(aliases.toImmutableArraySeq) /** * Gives the column an alias. @@ -1163,9 +1156,8 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def as(alias: String, metadata: Metadata): Column = withExpr { - Alias(expr, alias)(explicitMetadata = Some(metadata)) - } + def as(alias: String, metadata: Metadata): Column = + Column(internal.Alias(node, alias :: Nil, metadata = Option(metadata))) /** * Gives the column a name (alias). @@ -1181,13 +1173,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.0.0 */ - def name(alias: String): Column = withExpr { - // SPARK-33536: an alias is no longer a column reference. Therefore, - // we should not inherit the column reference related metadata in an alias - // so that it is not caught as a column reference in DetectAmbiguousSelfJoin. - Alias(expr, alias)( - nonInheritableMetadataKeys = Seq(Dataset.DATASET_ID_KEY, Dataset.COL_POS_KEY)) - } + def name(alias: String): Column = Column(internal.Alias(node, alias :: Nil)) /** * Casts the column to a different data type. @@ -1203,11 +1189,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def cast(to: DataType): Column = withExpr { - val cast = Cast(expr, CharVarcharUtils.replaceCharVarcharWithStringForCast(to)) - cast.setTagValue(Cast.USER_SPECIFIED_CAST, ()) - cast - } + def cast(to: DataType): Column = Column(internal.Cast(node, to)) /** * Casts the column to a different data type, using the canonical string representation @@ -1237,13 +1219,8 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 4.0.0 */ - def try_cast(to: DataType): Column = withExpr { - val cast = Cast( - child = expr, - dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(to), - evalMode = EvalMode.TRY) - cast.setTagValue(Cast.USER_SPECIFIED_CAST, ()) - cast + def try_cast(to: DataType): Column = { + Column(internal.Cast(node, to, Option(internal.Cast.Try))) } /** @@ -1260,6 +1237,17 @@ class Column(val expr: Expression) extends Logging { try_cast(CatalystSqlParser.parseDataType(to)) } + private def sortOrder( + sortDirection: internal.SortOrder.SortDirection, + nullOrdering: internal.SortOrder.NullOrdering): Column = { + Column(internal.SortOrder(node, sortDirection, nullOrdering)) + } + + private[sql] def sortOrder: internal.SortOrder = node match { + case order: internal.SortOrder => order + case _ => asc.node.asInstanceOf[internal.SortOrder] + } + /** * Returns a sort expression based on the descending order of the column. * {{{ @@ -1273,7 +1261,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def desc: Column = withExpr { SortOrder(expr, Descending) } + def desc: Column = desc_nulls_last /** * Returns a sort expression based on the descending order of the column, @@ -1289,7 +1277,9 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst, Seq.empty) } + def desc_nulls_first: Column = sortOrder( + internal.SortOrder.Descending, + internal.SortOrder.NullsFirst) /** * Returns a sort expression based on the descending order of the column, @@ -1305,7 +1295,9 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast, Seq.empty) } + def desc_nulls_last: Column = sortOrder( + internal.SortOrder.Descending, + internal.SortOrder.NullsLast) /** * Returns a sort expression based on ascending order of the column. @@ -1320,7 +1312,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def asc: Column = withExpr { SortOrder(expr, Ascending) } + def asc: Column = asc_nulls_first /** * Returns a sort expression based on ascending order of the column, @@ -1336,7 +1328,9 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst, Seq.empty) } + def asc_nulls_first: Column = sortOrder( + internal.SortOrder.Ascending, + internal.SortOrder.NullsFirst) /** * Returns a sort expression based on ascending order of the column, @@ -1352,7 +1346,9 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast, Seq.empty) } + def asc_nulls_last: Column = sortOrder( + internal.SortOrder.Ascending, + internal.SortOrder.NullsLast) /** * Prints the expression to the console for debugging purposes. @@ -1363,9 +1359,9 @@ class Column(val expr: Expression) extends Logging { def explain(extended: Boolean): Unit = { // scalastyle:off println if (extended) { - println(expr) + println(node) } else { - println(expr.sql) + println(node.sql) } // scalastyle:on println } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index c083ee89db6f2..231d361810f84 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -468,7 +468,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val branches = replacementMap.flatMap { case (source, target) => Seq(Literal(source), buildExpr(target)) }.toSeq - new Column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name) + Column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name) } private def convertToDouble(v: Any): Double = v match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c7511737b2b3f..94129d2e8b58b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1640,7 +1640,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { - implicit val encoder: ExpressionEncoder[U1] = c1.encoder + implicit val encoder: ExpressionEncoder[U1] = encoderFor(c1.encoder) val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) if (!encoder.isSerializedAsStructForTopLevel) { @@ -1657,7 +1657,7 @@ class Dataset[T] private[sql]( * that cast appropriately for the user facing interface. */ protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - val encoders = columns.map(_.encoder) + val encoders = columns.map(c => encoderFor(c.encoder)) val namedColumns = columns.map(_.withInputType(exprEnc, logicalPlan.output).named) val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 52ab633cd75a7..a672f29966df7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -968,7 +968,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * that cast appropriately for the user facing interface. */ protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - val encoders = columns.map(_.encoder) + val encoders = columns.map(c => encoderFor(c.encoder)) val namedColumns = columns.map(_.withInputType(vExprEnc, dataAttributes).named) val keyColumn = TypedAggUtils.aggKeyColumn(kExprEnc, groupingAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index e5999355133e3..4fdb84836b0fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -30,10 +30,11 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.sql.execution.aggregate.{ScalaAggregator, ScalaUDAF} import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction} import org.apache.spark.sql.internal.ToScalaUDF +import org.apache.spark.sql.internal.UserDefinedFunctionUtils.toScalaUDF import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -107,51 +108,52 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @since 2.2.0 */ def register(name: String, udf: UserDefinedFunction): UserDefinedFunction = { - udf.withName(name) match { - case udaf: UserDefinedAggregator[_, _, _] => - def builder(children: Seq[Expression]) = udaf.scalaAggregator(children) - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - udaf - case other => - def builder(children: Seq[Expression]) = other.apply(children.map(Column.apply) : _*).expr - functionRegistry.createOrReplaceTempFunction(name, builder, "scala_udf") - other - } + register(name, udf, "scala_udf", validateParameterCount = false) } private def registerScalaUDF( name: String, - f: AnyRef, + func: AnyRef, returnTypeTag: TypeTag[_], inputTypeTags: TypeTag[_]*): UserDefinedFunction = { - register(name, SparkUserDefinedFunction(f, returnTypeTag, inputTypeTags: _*), "scala_udf") + val udf = SparkUserDefinedFunction(func, returnTypeTag, inputTypeTags: _*) + register(name, udf, "scala_udf", validateParameterCount = true) } private def registerJavaUDF( name: String, - f: AnyRef, - returnType: DataType, + func: AnyRef, + returnDataType: DataType, cardinality: Int): UserDefinedFunction = { - val validatedReturnType = CharVarcharUtils.failIfHasCharVarchar(returnType) - register(name, SparkUserDefinedFunction(f, validatedReturnType, cardinality), "java_udf") + val validatedReturnDataType = CharVarcharUtils.failIfHasCharVarchar(returnDataType) + val udf = SparkUserDefinedFunction(func, validatedReturnDataType, cardinality) + register(name, udf, "java_udf", validateParameterCount = true) } private def register( name: String, - udf: SparkUserDefinedFunction, - source: String): UserDefinedFunction = { + udf: UserDefinedFunction, + source: String, + validateParameterCount: Boolean): UserDefinedFunction = { val named = udf.withName(name) - val expectedParameterCount = named.inputEncoders.size - val builder: Seq[Expression] => Expression = { children => - val actualParameterCount = children.length - if (expectedParameterCount == actualParameterCount) { - named.createScalaUDF(children) - } else { - throw QueryCompilationErrors.wrongNumArgsError( - name, - expectedParameterCount.toString, - actualParameterCount) - } + val builder: Seq[Expression] => Expression = named match { + case udaf: UserDefinedAggregator[_, _, _] => + ScalaAggregator(udaf, _) + case udf: SparkUserDefinedFunction if validateParameterCount => + val expectedParameterCount = udf.inputEncoders.size + children => { + val actualParameterCount = children.length + if (expectedParameterCount == actualParameterCount) { + toScalaUDF(udf, children) + } else { + throw QueryCompilationErrors.wrongNumArgsError( + name, + expectedParameterCount.toString, + actualParameterCount) + } + } + case udf: SparkUserDefinedFunction => + toScalaUDF(udf, _) } functionRegistry.createOrReplaceTempFunction(name, builder, source) named @@ -458,12 +460,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends throw QueryCompilationErrors.udfClassWithTooManyTypeArgumentsError(n) } } catch { - case e @ (_: InstantiationException | _: IllegalArgumentException) => + case _: InstantiationException | _: IllegalArgumentException => throw QueryCompilationErrors.classWithoutPublicNonArgumentConstructorError(className) } } } catch { - case e: ClassNotFoundException => throw QueryCompilationErrors.cannotLoadClassNotOnClassPathError(className) + case _: ClassNotFoundException => throw QueryCompilationErrors.cannotLoadClassNotOnClassPathError(className) } } @@ -483,8 +485,8 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val udaf = clazz.getConstructor().newInstance().asInstanceOf[UserDefinedAggregateFunction] register(name, udaf) } catch { - case e: ClassNotFoundException => throw QueryCompilationErrors.cannotLoadClassNotOnClassPathError(className) - case e @ (_: InstantiationException | _: IllegalArgumentException) => + case _: ClassNotFoundException => throw QueryCompilationErrors.cannotLoadClassNotOnClassPathError(className) + case _: InstantiationException | _: IllegalArgumentException => throw QueryCompilationErrors.classWithoutPublicNonArgumentConstructorError(className) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 84418a0ecc65f..dbb3a333bfb11 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.api.python.DechunkedInputStream import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.CLASS_LOADER import org.apache.spark.security.SocketAuthServer -import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession} +import org.apache.spark.sql.{internal, Column, DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.execution.{ExplainMode, QueryExecution} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.python.EvaluatePython -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{ExpressionColumnNode, SQLConf} import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{MutableURLClassLoader, Utils} @@ -174,6 +174,27 @@ private[sql] object PythonSQLUtils extends Logging { def pandasCovar(col1: Column, col2: Column, ddof: Int): Column = { Column(PandasCovar(col1.expr, col2.expr, ddof).toAggregateExpression(false)) } + + def unresolvedNamedLambdaVariable(name: String): Column = + Column(internal.UnresolvedNamedLambdaVariable.apply(name)) + + @scala.annotation.varargs + def lambdaFunction(function: Column, variables: Column*): Column = { + val arguments = variables.map(_.node.asInstanceOf[internal.UnresolvedNamedLambdaVariable]) + Column(internal.LambdaFunction(function.node, arguments)) + } + + def namedArgumentExpression(name: String, e: Column): Column = + Column(ExpressionColumnNode(NamedArgumentExpression(name, e.expr))) + + def distributedIndex(): Column = { + val expr = MonotonicallyIncreasingID() + expr.setTagValue(FunctionRegistry.FUNC_ALIAS, "distributed_index") + Column(ExpressionColumnNode(expr)) + } + + @scala.annotation.varargs + def fn(name: String, arguments: Column*): Column = Column.fn(name, arguments: _*) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index e517376bc5fc0..ffef4996fe052 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.internal.Logging import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, _} import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes -import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction} +import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction, UserDefinedAggregator} import org.apache.spark.sql.types._ /** @@ -554,6 +554,21 @@ case class ScalaAggregator[IN, BUF, OUT]( copy(children = newChildren) } +object ScalaAggregator { + def apply[IN, BUF, OUT]( + uda: UserDefinedAggregator[IN, BUF, OUT], + children: Seq[Expression]): ScalaAggregator[IN, BUF, OUT] = { + new ScalaAggregator( + children = children, + agg = uda.aggregator, + inputEncoder = encoderFor(uda.inputEncoder), + bufferEncoder = encoderFor(uda.aggregator.bufferEncoder), + nullable = uda.nullable, + isDeterministic = uda.deterministic, + aggregatorName = Option(uda.name)) + } +} + /** * An extension rule to resolve encoder expressions from a [[ScalaAggregator]] */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 6b3b374ae9ad9..d059f5ada576b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -21,7 +21,7 @@ import java.util.Locale import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.expressions.{Cast, ElementAt, EvalMode} +import org.apache.spark.sql.catalyst.expressions.{Cast, ElementAt} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.util.QuantileSummaries import org.apache.spark.sql.errors.QueryExecutionErrors @@ -205,7 +205,7 @@ object StatFunctions extends Logging { val column = col(field.name) var casted = column if (field.dataType.isInstanceOf[StringType]) { - casted = new Column(Cast(column.expr, DoubleType, evalMode = EvalMode.TRY)) + casted = column.try_cast(DoubleType) } val percentilesCol = if (percentiles.nonEmpty) { @@ -252,7 +252,7 @@ object StatFunctions extends Logging { .withColumnRenamed("_1", "summary") } else { val valueColumns = columnNames.map { columnName => - new Column(ElementAt(col(columnName).expr, col("summary").expr)).as(columnName) + Column(ElementAt(col(columnName).expr, col("summary").expr)).as(columnName) } import org.apache.spark.util.ArrayImplicits._ ds.select(mapColumns: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 725580ceaf17f..1a2fbdc1fd116 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -18,9 +18,7 @@ package org.apache.spark.sql.expressions import org.apache.spark.sql.{Encoder, TypedColumn} -import org.apache.spark.sql.catalyst.encoders.encoderFor -import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.internal.UserDefinedFunctionLike +import org.apache.spark.sql.internal.{InvokeInlineUserDefinedFunction, UserDefinedFunctionLike} /** * A base class for user-defined aggregations, which can be used in `Dataset` operations to take @@ -95,11 +93,8 @@ abstract class Aggregator[-IN, BUF, OUT] extends Serializable with UserDefinedFu * @since 1.6.0 */ def toColumn: TypedColumn[IN, OUT] = { - implicit val bEncoder = bufferEncoder - implicit val cEncoder = outputEncoder - - val expr = TypedAggregateExpression(this).toAggregateExpression() - - new TypedColumn[IN, OUT](expr, encoderFor[OUT]) + new TypedColumn[IN, OUT]( + InvokeInlineUserDefinedFunction(this, Nil), + outputEncoder) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 97ec8e37d26d6..403eccfddffad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -23,11 +23,7 @@ import scala.util.Try import org.apache.spark.annotation.Stable import org.apache.spark.sql.{Column, Encoder} import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder -import org.apache.spark.sql.catalyst.encoders.encoderFor -import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} -import org.apache.spark.sql.execution.aggregate.ScalaAggregator -import org.apache.spark.sql.internal.UserDefinedFunctionLike +import org.apache.spark.sql.internal.{InvokeInlineUserDefinedFunction, UserDefinedFunctionLike} import org.apache.spark.sql.types.DataType /** @@ -68,7 +64,9 @@ sealed abstract class UserDefinedFunction extends UserDefinedFunctionLike { * @since 1.3.0 */ @scala.annotation.varargs - def apply(exprs: Column*): Column + def apply(exprs: Column*): Column = { + Column(InvokeInlineUserDefinedFunction(this, exprs.map(_.node))) + } /** * Updates UserDefinedFunction with a given name. @@ -101,23 +99,6 @@ private[spark] case class SparkUserDefinedFunction( nullable: Boolean = true, deterministic: Boolean = true) extends UserDefinedFunction { - @scala.annotation.varargs - override def apply(exprs: Column*): Column = { - Column(createScalaUDF(exprs.map(_.expr))) - } - - private[sql] def createScalaUDF(exprs: Seq[Expression]): ScalaUDF = { - ScalaUDF( - f, - dataType, - exprs, - inputEncoders.map(_.filter(_ != UnboundRowEncoder).map(e => encoderFor(e))), - outputEncoder.map(e => encoderFor(e)), - udfName = givenName, - nullable = nullable, - udfDeterministic = deterministic) - } - override def withName(name: String): SparkUserDefinedFunction = { copy(givenName = Option(name)) } @@ -177,19 +158,6 @@ private[sql] case class UserDefinedAggregator[IN, BUF, OUT]( nullable: Boolean = true, deterministic: Boolean = true) extends UserDefinedFunction { - @scala.annotation.varargs - def apply(exprs: Column*): Column = { - Column(scalaAggregator(exprs.map(_.expr)).toAggregateExpression()) - } - - // This is also used by udf.register(...) when it detects a UserDefinedAggregator - def scalaAggregator(exprs: Seq[Expression]): ScalaAggregator[IN, BUF, OUT] = { - val iEncoder = encoderFor(inputEncoder) - val bEncoder = encoderFor(aggregator.bufferEncoder) - ScalaAggregator( - exprs, aggregator, iEncoder, bEncoder, nullable, deterministic, aggregatorName = givenName) - } - override def withName(name: String): UserDefinedAggregator[IN, BUF, OUT] = { copy(givenName = Option(name)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index 93bf738a53daf..9c4499ee243f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Stable import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.{WindowSpec => _, _} /** * Utility functions for defining window in DataFrames. @@ -215,7 +214,7 @@ object Window { } private[sql] def spec: WindowSpec = { - new WindowSpec(Seq.empty, Seq.empty, UnspecifiedFrame) + new WindowSpec(Seq.empty, Seq.empty, None) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 32aa13a29cec3..7da8b8dbd4b9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Stable import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.{ColumnNode, SortOrder, Window => EvalWindow, WindowFrame, WindowSpec => InternalWindowSpec} /** * A window specification that defines the partitioning, ordering, and frame boundaries. @@ -31,9 +31,9 @@ import org.apache.spark.sql.errors.QueryCompilationErrors */ @Stable class WindowSpec private[sql]( - partitionSpec: Seq[Expression], + partitionSpec: Seq[ColumnNode], orderSpec: Seq[SortOrder], - frame: WindowFrame) { + frame: Option[WindowFrame]) { /** * Defines the partitioning columns in a [[WindowSpec]]. @@ -50,7 +50,7 @@ class WindowSpec private[sql]( */ @scala.annotation.varargs def partitionBy(cols: Column*): WindowSpec = { - new WindowSpec(cols.map(_.expr), orderSpec, frame) + new WindowSpec(cols.map(_.node), orderSpec, frame) } /** @@ -68,15 +68,7 @@ class WindowSpec private[sql]( */ @scala.annotation.varargs def orderBy(cols: Column*): WindowSpec = { - val sortOrder: Seq[SortOrder] = cols.map { col => - col.expr match { - case expr: SortOrder => - expr - case expr: Expression => - SortOrder(expr, Ascending) - } - } - new WindowSpec(partitionSpec, sortOrder, frame) + new WindowSpec(partitionSpec, cols.map(_.sortOrder), frame) } /** @@ -125,23 +117,20 @@ class WindowSpec private[sql]( // Note: when updating the doc for this method, also update Window.rowsBetween. def rowsBetween(start: Long, end: Long): WindowSpec = { val boundaryStart = start match { - case 0 => CurrentRow - case Long.MinValue => UnboundedPreceding - case x if Int.MinValue <= x && x <= Int.MaxValue => Literal(x.toInt) + case 0 => WindowFrame.CurrentRow + case Long.MinValue => WindowFrame.UnboundedPreceding + case x if Int.MinValue <= x && x <= Int.MaxValue => WindowFrame.value(x.toInt) case x => throw QueryCompilationErrors.invalidBoundaryStartError(x) } val boundaryEnd = end match { - case 0 => CurrentRow - case Long.MaxValue => UnboundedFollowing - case x if Int.MinValue <= x && x <= Int.MaxValue => Literal(x.toInt) + case 0 => WindowFrame.CurrentRow + case Long.MaxValue => WindowFrame.UnboundedFollowing + case x if Int.MinValue <= x && x <= Int.MaxValue => WindowFrame.value(x.toInt) case x => throw QueryCompilationErrors.invalidBoundaryEndError(x) } - new WindowSpec( - partitionSpec, - orderSpec, - SpecifiedWindowFrame(RowFrame, boundaryStart, boundaryEnd)) + withFrame(WindowFrame.Row, boundaryStart, boundaryEnd) } /** @@ -193,28 +182,32 @@ class WindowSpec private[sql]( // Note: when updating the doc for this method, also update Window.rangeBetween. def rangeBetween(start: Long, end: Long): WindowSpec = { val boundaryStart = start match { - case 0 => CurrentRow - case Long.MinValue => UnboundedPreceding - case x => Literal(x) + case 0 => WindowFrame.CurrentRow + case Long.MinValue => WindowFrame.UnboundedPreceding + case x => WindowFrame.value(x) } val boundaryEnd = end match { - case 0 => CurrentRow - case Long.MaxValue => UnboundedFollowing - case x => Literal(x) + case 0 => WindowFrame.CurrentRow + case Long.MaxValue => WindowFrame.UnboundedFollowing + case x => WindowFrame.value(x) } + withFrame(WindowFrame.Range, boundaryStart, boundaryEnd) + } - new WindowSpec( - partitionSpec, - orderSpec, - SpecifiedWindowFrame(RangeFrame, boundaryStart, boundaryEnd)) + private[sql] def withFrame( + frameType: WindowFrame.FrameType, + lower: WindowFrame.FrameBoundary, + uppper: WindowFrame.FrameBoundary): WindowSpec = { + val frame = WindowFrame(frameType, lower, uppper) + new WindowSpec(partitionSpec, orderSpec, Some(frame)) } /** * Converts this [[WindowSpec]] into a [[Column]] with an aggregate expression. */ private[sql] def withAggregate(aggregate: Column): Column = { - val spec = WindowSpecDefinition(partitionSpec, orderSpec, frame) - new Column(WindowExpression(aggregate.expr, spec)) + val spec = InternalWindowSpec(partitionSpec, sortColumns = orderSpec, frame = frame) + Column(EvalWindow(aggregate.node, spec)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index b387695ef2379..a4aa9c312aff2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Stable import org.apache.spark.sql.{Column, Row} -import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.sql.internal.{InvokeInlineUserDefinedFunction, UserDefinedFunctionLike} import org.apache.spark.sql.types._ /** @@ -32,7 +32,7 @@ import org.apache.spark.sql.types._ @Stable @deprecated("Aggregator[IN, BUF, OUT] should now be registered as a UDF" + " via the functions.udaf(agg) method.", "3.0.0") -abstract class UserDefinedAggregateFunction extends Serializable { +abstract class UserDefinedAggregateFunction extends Serializable with UserDefinedFunctionLike { /** * A `StructType` represents data types of input arguments of this aggregate function. @@ -130,8 +130,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { */ @scala.annotation.varargs def apply(exprs: Column*): Column = { - val aggregateExpression = ScalaUDAF(exprs.map(_.expr), this).toAggregateExpression() - Column(aggregateExpression) + Column(InvokeInlineUserDefinedFunction(this, exprs.map(_.node))) } /** @@ -142,9 +141,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { */ @scala.annotation.varargs def distinct(exprs: Column*): Column = { - val aggregateExpression = - ScalaUDAF(exprs.map(_.expr), this).toAggregateExpression(isDistinct = true) - Column(aggregateExpression) + Column(InvokeInlineUserDefinedFunction(this, exprs.map(_.node), isDistinct = true)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index fc9e5f3e8f72a..be83444a8fd33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -24,13 +24,9 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.annotation.Stable import org.apache.spark.sql.api.java._ -import org.apache.spark.sql.catalyst.ScalaReflection.encoderFor -import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, ResolvedHint} +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.PrimitiveLongEncoder import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.{Aggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} import org.apache.spark.sql.internal.{SQLConf, ToScalaUDF} import org.apache.spark.sql.types._ @@ -88,10 +84,6 @@ import org.apache.spark.util.Utils object functions { // scalastyle:on - private def withExpr(expr: => Expression): Column = withOrigin { - Column(expr) - } - /** * Returns a [[Column]] based on the given column name. * @@ -128,7 +120,7 @@ object functions { // method, `typedLit[Any](literal)` will always fail and fallback to `Literal.apply`. Hence, // we can just manually call `Literal.apply` to skip the expensive `ScalaReflection` code. // This is significantly better when there are many threads calling `lit` concurrently. - Column(Literal(literal)) + Column(internal.Literal(literal)) } } @@ -140,7 +132,7 @@ object functions { * @group normal_funcs * @since 2.2.0 */ - def typedLit[T : TypeTag](literal: T): Column = withOrigin { + def typedLit[T : TypeTag](literal: T): Column = { typedlit(literal) } @@ -159,11 +151,13 @@ object functions { * @group normal_funcs * @since 3.2.0 */ - def typedlit[T : TypeTag](literal: T): Column = withOrigin { + def typedlit[T : TypeTag](literal: T): Column = { literal match { case c: Column => c case s: Symbol => new ColumnName(s.name) - case _ => Column(Literal.create(literal)) + case _ => + val dataType = ScalaReflection.schemaFor[T].dataType + Column(internal.Literal(literal, Option(dataType))) } } @@ -415,14 +409,8 @@ object functions { * @group agg_funcs * @since 1.3.0 */ - def count(e: Column): Column = { - val withoutStar = e.expr match { - // Turn count(*) into count(1) - case _: Star => Column(Literal(1)) - case _ => e - } - Column.fn("count", withoutStar) - } + def count(e: Column): Column = + Column.fn("count", e) /** * Aggregate function: returns the number of items in a group. @@ -431,7 +419,7 @@ object functions { * @since 1.3.0 */ def count(columnName: String): TypedColumn[Any, Long] = - count(Column(columnName)).as(AgnosticEncoders.PrimitiveLongEncoder) + count(Column(columnName)).as(PrimitiveLongEncoder) /** * Aggregate function: returns the number of distinct items in a group. @@ -1713,10 +1701,7 @@ object functions { * @group normal_funcs * @since 1.5.0 */ - def broadcast[T](df: Dataset[T]): Dataset[T] = { - Dataset[T](df.sparkSession, - ResolvedHint(df.logicalPlan, HintInfo(strategy = Some(BROADCAST))))(df.exprEnc) - } + def broadcast[T](df: Dataset[T]): Dataset[T] = df.hint("broadcast") /** * Returns the first column that is not null, or null if all inputs are null. @@ -2012,9 +1997,8 @@ object functions { * @group conditional_funcs * @since 1.4.0 */ - def when(condition: Column, value: Any): Column = withExpr { - CaseWhen(Seq((condition.expr, lit(value).expr))) - } + def when(condition: Column, value: Any): Column = + Column(internal.CaseWhenOtherwise(Seq(condition.node -> lit(value).node))) /** * Computes bitwise NOT (~) of a number. @@ -2072,12 +2056,7 @@ object functions { * * @group normal_funcs */ - def expr(expr: String): Column = withExpr { - val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse { - new SparkSqlParser() - } - parser.parseExpression(expr) - } + def expr(expr: String): Column = Column(internal.SqlExpression(expr)) ////////////////////////////////////////////////////////////////////////////////////////////// // Math Functions @@ -6065,31 +6044,26 @@ object functions { def array_except(col1: Column, col2: Column): Column = Column.fn("array_except", col1, col2) - private def createLambda(f: Column => Column) = withOrigin { - Column { - val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) - val function = f(Column(x)).expr - LambdaFunction(function, Seq(x)) - } + + private def createLambda(f: Column => Column) = { + val x = internal.UnresolvedNamedLambdaVariable("x") + val function = f(Column(x)).node + Column(internal.LambdaFunction(function, Seq(x))) } - private def createLambda(f: (Column, Column) => Column) = withOrigin { - Column { - val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) - val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) - val function = f(Column(x), Column(y)).expr - LambdaFunction(function, Seq(x, y)) - } + private def createLambda(f: (Column, Column) => Column) = { + val x = internal.UnresolvedNamedLambdaVariable("x") + val y = internal.UnresolvedNamedLambdaVariable("y") + val function = f(Column(x), Column(y)).node + Column(internal.LambdaFunction(function, Seq(x, y))) } - private def createLambda(f: (Column, Column, Column) => Column) = withOrigin { - Column { - val x = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("x"))) - val y = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("y"))) - val z = UnresolvedNamedLambdaVariable(Seq(UnresolvedNamedLambdaVariable.freshVarName("z"))) - val function = f(Column(x), Column(y), Column(z)).expr - LambdaFunction(function, Seq(x, y, z)) - } + private def createLambda(f: (Column, Column, Column) => Column) = { + val x = internal.UnresolvedNamedLambdaVariable("x") + val y = internal.UnresolvedNamedLambdaVariable("y") + val z = internal.UnresolvedNamedLambdaVariable("z") + val function = f(Column(x), Column(y), Column(z)).node + Column(internal.LambdaFunction(function, Seq(x, y, z))) } /** @@ -7966,7 +7940,7 @@ object functions { * @note The input encoder is inferred from the input type IN. */ def udaf[IN: TypeTag, BUF, OUT](agg: Aggregator[IN, BUF, OUT]): UserDefinedFunction = { - udaf(agg, encoderFor[IN]) + udaf(agg, ScalaReflection.encoderFor[IN]) } /** @@ -8332,7 +8306,7 @@ object functions { */ @scala.annotation.varargs @deprecated("Use call_udf") - def callUDF(udfName: String, cols: Column*): Column = call_udf(udfName, cols: _*) + def callUDF(udfName: String, cols: Column*): Column = call_function(udfName, cols: _*) /** * Call an user-defined function. @@ -8350,7 +8324,7 @@ object functions { * @since 3.2.0 */ @scala.annotation.varargs - def call_udf(udfName: String, cols: Column*): Column = Column.fn(udfName, cols: _*) + def call_udf(udfName: String, cols: Column*): Column = call_function(udfName, cols: _*) /** * Call a SQL function. @@ -8363,15 +8337,7 @@ object functions { */ @scala.annotation.varargs def call_function(funcName: String, cols: Column*): Column = { - val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse { - new SparkSqlParser() - } - val nameParts = parser.parseMultipartIdentifier(funcName) - call_function(nameParts, cols: _*) - } - - private def call_function(nameParts: Seq[String], cols: Column*): Column = withExpr { - UnresolvedFunction(nameParts, cols.map(_.expr), false) + Column(internal.UnresolvedFunction(funcName, cols.map(_.node), isUserDefinedFunction = true)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/UserDefinedFunctionUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/UserDefinedFunctionUtils.scala new file mode 100644 index 0000000000000..bd8735d15be13 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/UserDefinedFunctionUtils.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal + +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder +import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} +import org.apache.spark.sql.expressions.SparkUserDefinedFunction + +private[sql] object UserDefinedFunctionUtils { + /** + * Convert a UDF into an (executable) ScalaUDF expressions. + * + * This function should be moved to ScalaUDF when we move SparkUserDefinedFunction to sql/api. + */ + def toScalaUDF(udf: SparkUserDefinedFunction, children: Seq[Expression]): ScalaUDF = { + ScalaUDF( + udf.f, + udf.dataType, + children, + udf.inputEncoders.map(_.collect { + // At some point it would be nice if were to support this. + case e if e != UnboundRowEncoder => encoderFor(e) + }), + udf.outputEncoder.map(encoderFor(_)), + udfName = udf.givenName, + nullable = udf.nullable, + udfDeterministic = udf.deterministic) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala index 4d4960d24d010..ea6e36680da45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala @@ -16,10 +16,11 @@ */ package org.apache.spark.sql.internal +import UserDefinedFunctionUtils.toScalaUDF + import org.apache.spark.SparkException import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.{analysis, expressions, CatalystTypeConverters} -import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.parser.{ParserInterface, ParserUtils} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -162,23 +163,16 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres case InvokeInlineUserDefinedFunction( a: UserDefinedAggregator[Any @unchecked, Any @unchecked, Any @unchecked], arguments, isDistinct, _) => - ScalaAggregator( - agg = a.aggregator, - children = arguments.map(apply), - inputEncoder = encoderFor(a.inputEncoder), - bufferEncoder = encoderFor(a.aggregator.bufferEncoder), - aggregatorName = a.givenName, - nullable = a.nullable, - isDeterministic = a.deterministic).toAggregateExpression(isDistinct) + ScalaAggregator(a, arguments.map(apply)).toAggregateExpression(isDistinct) case InvokeInlineUserDefinedFunction( a: UserDefinedAggregateFunction, arguments, isDistinct, _) => ScalaUDAF(udaf = a, children = arguments.map(apply)).toAggregateExpression(isDistinct) case InvokeInlineUserDefinedFunction(udf: SparkUserDefinedFunction, arguments, _, _) => - udf.createScalaUDF(arguments.map(apply)) + toScalaUDF(udf, arguments.map(apply)) - case Wrapper(expression, _) => + case ExpressionColumnNode(expression, _) => expression case node => @@ -240,10 +234,10 @@ private[sql] object ColumnNodeToExpressionConverter extends ColumnNodeToExpressi /** * [[ColumnNode]] wrapper for an [[Expression]]. */ -private[sql] case class Wrapper( +private[sql] case class ExpressionColumnNode( expression: Expression, override val origin: Origin = CurrentOrigin.get) extends ColumnNode { - override def normalize(): Wrapper = { + override def normalize(): ExpressionColumnNode = { val updated = expression.transform { case a: AttributeReference => DetectAmbiguousSelfJoin.stripColumnReferenceMetadata(a) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 936bcc21b763d..64b5128872610 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -978,15 +978,15 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("SPARK-37646: lit") { assert(lit($"foo") == $"foo") assert(lit($"foo") == $"foo") - assert(lit(1) == Column(Literal(1))) - assert(lit(null) == Column(Literal(null, NullType))) + assert(lit(1).expr == Column(Literal(1)).expr) + assert(lit(null).expr == Column(Literal(null, NullType)).expr) } test("typedLit") { assert(typedLit($"foo") == $"foo") assert(typedLit($"foo") == $"foo") - assert(typedLit(1) == Column(Literal(1))) - assert(typedLit[String](null) == Column(Literal(null, StringType))) + assert(typedLit(1).expr == Column(Literal(1)).expr) + assert(typedLit[String](null).expr == Column(Literal(null, StringType)).expr) val df = Seq(Tuple1(0)).toDF("a") // Only check the types `lit` cannot handle diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index d982a000ad374..48ac2cc5d4044 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -82,8 +82,8 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSparkSession { // items: Seq[Int] => items.map { item => Seq(Struct(item)) } val result = df.select( - new Column(MapObjects( - (item: Expression) => array(struct(new Column(item))).expr, + Column(MapObjects( + (item: Expression) => array(struct(Column(item))).expr, $"items".expr, df.schema("items").dataType.asInstanceOf[ArrayType].elementType )) as "items" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 155acc98cb33b..301ab28b9124b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Cast, CreateMap, EqualTo, ExpressionSet, GreaterThan, Literal, PythonUDF, ScalarSubquery, Uuid} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Cast, CreateMap, EqualTo, ExpressionSet, GreaterThan, Literal, PythonUDF, ScalarSubquery} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LocalRelation, LogicalPlan, OneRowRelation} @@ -1566,7 +1566,7 @@ class DataFrameSuite extends QueryTest test("SPARK-46794: exclude subqueries from LogicalRDD constraints") { withTempDir { checkpointDir => val subquery = - new Column(ScalarSubquery(spark.range(10).selectExpr("max(id)").logicalPlan)) + Column(ScalarSubquery(spark.range(10).selectExpr("max(id)").logicalPlan)) val df = spark.range(1000).filter($"id" === subquery) assert(df.logicalPlan.constraints.exists(_.exists(_.isInstanceOf[ScalarSubquery]))) @@ -1839,7 +1839,7 @@ class DataFrameSuite extends QueryTest } test("Uuid expressions should produce same results at retries in the same DataFrame") { - val df = spark.range(1).select($"id", new Column(Uuid())) + val df = spark.range(1).select($"id", uuid()) checkAnswer(df, df.collect()) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala index 95f4cc78d1564..c03c5e878427f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFramesSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.{Ascending, Literal, NonFoldableLiteral, RangeFrame, SortOrder, SpecifiedWindowFrame, UnaryMinus, UnspecifiedFrame} +import org.apache.spark.sql.catalyst.expressions.{Literal, NonFoldableLiteral} import org.apache.spark.sql.catalyst.optimizer.EliminateWindowPartitions import org.apache.spark.sql.catalyst.plans.logical.{Window => WindowNode} -import org.apache.spark.sql.expressions.{Window, WindowSpec} +import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{ExpressionColumnNode, SQLConf} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.CalendarIntervalType @@ -503,11 +503,11 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { test("Window frame bounds lower and upper do not have the same type") { val df = Seq((1L, "1"), (1L, "1")).toDF("key", "value") - val windowSpec = new WindowSpec( - Seq(Column("value").expr), - Seq(SortOrder(Column("key").expr, Ascending)), - SpecifiedWindowFrame(RangeFrame, Literal.create(null, CalendarIntervalType), Literal(2)) - ) + + val windowSpec = Window.partitionBy($"value").orderBy($"key".asc).withFrame( + internal.WindowFrame.Range, + internal.WindowFrame.Value(ExpressionColumnNode(Literal.create(null, CalendarIntervalType))), + internal.WindowFrame.Value(lit(2).node)) checkError( exception = intercept[AnalysisException] { df.select($"key", count("key").over(windowSpec)).collect() @@ -526,11 +526,10 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { test("Window frame lower bound is not a literal") { val df = Seq((1L, "1"), (1L, "1")).toDF("key", "value") - val windowSpec = new WindowSpec( - Seq(Column("value").expr), - Seq(SortOrder(Column("key").expr, Ascending)), - SpecifiedWindowFrame(RangeFrame, NonFoldableLiteral(1), Literal(2)) - ) + val windowSpec = Window.partitionBy($"value").orderBy($"key".asc).withFrame( + internal.WindowFrame.Range, + internal.WindowFrame.Value(ExpressionColumnNode(NonFoldableLiteral(1))), + internal.WindowFrame.Value(lit(2).node)) checkError( exception = intercept[AnalysisException] { df.select($"key", count("key").over(windowSpec)).collect() @@ -546,8 +545,7 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { test("SPARK-41805: Reuse expressions in WindowSpecDefinition") { val ds = Seq((1, 1), (1, 2), (1, 3), (2, 1), (2, 2)).toDF("n", "i") - val sortOrder = SortOrder($"n".cast("string").expr, Ascending) - val window = new WindowSpec(Seq($"n".expr), Seq(sortOrder), UnspecifiedFrame) + val window = Window.partitionBy($"n").orderBy($"n".cast("string").asc) val df = ds.select(sum("i").over(window), avg("i").over(window)) val ws = df.queryExecution.analyzed.collect { case w: WindowNode => w } assert(ws.size === 1) @@ -557,9 +555,10 @@ class DataFrameWindowFramesSuite extends QueryTest with SharedSparkSession { test("SPARK-41793: Incorrect result for window frames defined by a range clause on large " + "decimals") { - val window = new WindowSpec(Seq($"a".expr), Seq(SortOrder($"b".expr, Ascending)), - SpecifiedWindowFrame(RangeFrame, - UnaryMinus(Literal(BigDecimal(10.2345))), Literal(BigDecimal(6.7890)))) + val window = Window.partitionBy($"a").orderBy($"b".asc).withFrame( + internal.WindowFrame.Range, + internal.WindowFrame.Value((-lit(BigDecimal(10.2345))).node), + internal.WindowFrame.Value(lit(BigDecimal(10.2345)).node)) val df = Seq( 1 -> "11342371013783243717493546650944543.47", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 2a51e5113e14d..44709fd309cfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -28,11 +28,13 @@ import org.scalatest.Assertions._ import org.apache.spark.TestUtils import org.apache.spark.api.python.{PythonBroadcast, PythonEvalType, PythonFunction, PythonUtils, SimplePythonFunction} import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExprId, PythonUDF} import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.execution.datasources.v2.python.UserDefinedPythonDataSource import org.apache.spark.sql.execution.python.{UserDefinedPythonFunction, UserDefinedPythonTableFunction} import org.apache.spark.sql.expressions.SparkUserDefinedFunction +import org.apache.spark.sql.internal.UserDefinedFunctionUtils.toScalaUDF import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType, StructType, VariantType} import org.apache.spark.util.ArrayImplicits._ @@ -1566,40 +1568,30 @@ object IntegratedUDFTestUtils extends SQLHelper { * casted_col.cast(df.schema("col").dataType) * }}} */ - class TestInternalScalaUDF( - name: String, - returnType: Option[DataType] = None) extends SparkUserDefinedFunction( - (input: Any) => if (input == null) { - null - } else { - input.toString - }, - StringType, - inputEncoders = Seq.fill(1)(None), - givenName = Some(name)) { + case class TestScalaUDF(name: String, returnType: Option[DataType] = None) extends TestUDF { + private val udf: SparkUserDefinedFunction = { + val unnamed = functions.udf { (input: Any) => + if (input == null) { + null + } else { + input.toString + } + } + unnamed.withName(name).asInstanceOf[SparkUserDefinedFunction] + } - override def apply(exprs: Column*): Column = { + val builder: FunctionRegistry.FunctionBuilder = { exprs => assert(exprs.length == 1, "Defined UDF only has one column") - val expr = exprs.head.expr + val expr = exprs.head val rt = returnType.getOrElse { assert(expr.resolved, "column should be resolved to use the same type " + - "as input. Try df(name) or df.col(name)") + "as input. Try df(name) or df.col(name)") expr.dataType } - Column(Cast(createScalaUDF(Cast(expr, StringType) :: Nil), rt)) + Cast(toScalaUDF(udf, Cast(expr, StringType) :: Nil), rt) } - override def withName(name: String): TestInternalScalaUDF = { - // "withName" should overridden to return TestInternalScalaUDF. Otherwise, the current object - // is sliced and the overridden "apply" is not invoked. - new TestInternalScalaUDF(name) - } - } - - case class TestScalaUDF(name: String, returnType: Option[DataType] = None) extends TestUDF { - private[IntegratedUDFTestUtils] lazy val udf = new TestInternalScalaUDF(name, returnType) - - def apply(exprs: Column*): Column = udf(exprs: _*) + def apply(exprs: Column*): Column = Column(builder(exprs.map(_.expr))) val prettyName: String = "Scala UDF" } @@ -1611,7 +1603,9 @@ object IntegratedUDFTestUtils extends SQLHelper { case udf: TestPythonUDF => session.udf.registerPython(udf.name, udf.udf) case udf: TestScalarPandasUDF => session.udf.registerPython(udf.name, udf.udf) case udf: TestGroupedAggPandasUDF => session.udf.registerPython(udf.name, udf.udf) - case udf: TestScalaUDF => session.udf.register(udf.name, udf.udf) + case udf: TestScalaUDF => + val registry = session.sessionState.functionRegistry + registry.createOrReplaceTempFunction(udf.name, udf.builder, "scala_udf") case other => throw new RuntimeException(s"Unknown UDF class [${other.getClass}]") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 7e940252430f8..36552d5c5487c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -775,7 +775,8 @@ class UDFSuite extends QueryTest with SharedSparkSession { errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", parameters = Map( "objectName" -> "`b`", - "proposal" -> "`a`")) + "proposal" -> "`a`"), + context = ExpectedContext("apply", ".*")) } test("wrong order of input fields for case class") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index 4a748d590feb1..b9f4e82cdd3c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -211,7 +211,7 @@ class QueryExecutionErrorsSuite test("UNSUPPORTED_FEATURE: unsupported types (map and struct) in lit()") { def checkUnsupportedTypeInLiteral(v: Any, literal: String, dataType: String): Unit = { checkError( - exception = intercept[SparkRuntimeException] { lit(v) }, + exception = intercept[SparkRuntimeException] { lit(v).expr }, errorClass = "UNSUPPORTED_FEATURE.LITERAL_TYPE", parameters = Map("value" -> literal, "type" -> dataType), sqlState = "0A000") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeSuite.scala index 052f220d97a50..7bf70695a9854 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeSuite.scala @@ -229,6 +229,10 @@ class ColumnNodeSuite extends SparkFunSuite { } else { Metadata.empty } - Wrapper(AttributeReference(name, LongType, metadata = metadata)(exprId = ExprId(id))) + ExpressionColumnNode(AttributeReference( + name, + LongType, + metadata = metadata)( + exprId = ExprId(id))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala index 20232ac12ec0f..0fbfe762df918 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.internal import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.sql.{Dataset, Encoders} +import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.{analysis, expressions, InternalRow} -import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder, AgnosticEncoders} +import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId} import org.apache.spark.sql.catalyst.parser.ParserInterface @@ -350,7 +350,7 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { InvokeInlineUserDefinedFunction( UserDefinedAggregator( aggregator = int2LongSum, - inputEncoder = AgnosticEncoders.PrimitiveIntEncoder, + inputEncoder = PrimitiveIntEncoder, nullable = false, givenName = Option("int2LongSum")), UnresolvedAttribute("i_col") :: Nil), @@ -368,8 +368,8 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { InvokeInlineUserDefinedFunction( SparkUserDefinedFunction( f = concat, - inputEncoders = None :: Option(encoderFor(StringEncoder)) :: Nil, - outputEncoder = Option(encoderFor(StringEncoder)), + inputEncoders = None :: Option(toAny(StringEncoder)) :: Nil, + outputEncoder = Option(toAny(StringEncoder)), dataType = StringType, nullable = false, deterministic = false), @@ -378,8 +378,8 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { function = concat, dataType = StringType, children = Seq(analysis.UnresolvedAttribute("a"), analysis.UnresolvedAttribute("b")), - inputEncoders = Seq(None, Option(encoderFor(Encoders.STRING))), - outputEncoder = Option(encoderFor(Encoders.STRING)), + inputEncoders = Seq(None, Option(encoderFor(StringEncoder))), + outputEncoder = Option(encoderFor(StringEncoder)), udfName = None, nullable = false, udfDeterministic = false)) @@ -387,7 +387,7 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { test("extension") { testConversion( - Wrapper(analysis.UnresolvedAttribute("bar")), + ExpressionColumnNode(analysis.UnresolvedAttribute("bar")), analysis.UnresolvedAttribute("bar")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 9282c0d0e3034..2767f2dd46b2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -36,9 +36,8 @@ import org.scalatestplus.mockito.MockitoSugar import org.apache.spark.{SparkException, SparkUnsupportedOperationException, TestUtils} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset, Row, SaveMode} +import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Row, SaveMode} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid} import org.apache.spark.sql.catalyst.plans.logical.{CTERelationDef, CTERelationRef, LocalRelation} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Complete import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes @@ -1002,7 +1001,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } val stream = MemoryStream[Int] - val df = stream.toDF().select(new Column(Uuid())) + val df = stream.toDF().select(uuid()) testStream(df)( AddData(stream, 1), CheckAnswer(collectUuid), @@ -1022,7 +1021,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } val stream = MemoryStream[Int] - val df = stream.toDF().select(new Column(new Rand()), new Column(new Randn())) + val df = stream.toDF().select(rand(), randn()) testStream(df)( AddData(stream, 1), CheckAnswer(collectRand), @@ -1041,7 +1040,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } val stream = MemoryStream[Int] - val df = stream.toDF().select(new Column(new Shuffle(Literal.create[Seq[Int]](0 until 100)))) + val df = stream.toDF().select(shuffle(typedLit[Seq[Int]](0 until 100))) testStream(df)( AddData(stream, 1), CheckAnswer(collectShuffle),