Skip to content

Commit

Permalink
[SPARK-49022] Use Column Node API in Column
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR makes the org.apache.spark.sql.Column and friends use the recently introduced ColumnNode API. This is a stepping stone towards making the Column API implementation agnostic.

Most of the changes are fairly mechanical, and they are mostly caused by the removal of the Column(Expression) constructor.

### Why are the changes needed?
We want to create unified Scala interface for Classic and Connect. A language agnostic Column API implementation is part of this.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Existing tests.

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#47688 from hvanhovell/SPARK-49022.

Authored-by: Herman van Hovell <herman@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
hvanhovell authored and attilapiros committed Oct 4, 2024
1 parent fa9f2c9 commit b039741
Show file tree
Hide file tree
Showing 39 changed files with 473 additions and 508 deletions.
75 changes: 30 additions & 45 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -3965,19 +3965,11 @@ setMethod("row_number",
#' yields unresolved \code{a.b.c}
#' @return Column object wrapping JVM UnresolvedNamedLambdaVariable
#' @keywords internal
unresolved_named_lambda_var <- function(...) {
jc <- newJObject(
"org.apache.spark.sql.Column",
newJObject(
"org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable",
lapply(list(...), function(x) {
handledCallJStatic(
"org.apache.spark.sql.catalyst.expressions.UnresolvedNamedLambdaVariable",
"freshVarName",
x)
})
)
)
unresolved_named_lambda_var <- function(name) {
jc <- handledCallJStatic(
"org.apache.spark.sql.api.python.PythonSQLUtils",
"unresolvedNamedLambdaVariable",
name)
column(jc)
}

Expand All @@ -3990,7 +3982,6 @@ unresolved_named_lambda_var <- function(...) {
#' @return JVM \code{LambdaFunction} object
#' @keywords internal
create_lambda <- function(fun) {
as_jexpr <- function(x) callJMethod(x@jc, "expr")

# Process function arguments
parameters <- formals(fun)
Expand All @@ -4011,22 +4002,18 @@ create_lambda <- function(fun) {
stopifnot(class(result) == "Column")

# Convert both Columns to Scala expressions
jexpr <- as_jexpr(result)

jargs <- handledCallJStatic(
"org.apache.spark.api.python.PythonUtils",
"toSeq",
handledCallJStatic(
"java.util.Arrays", "asList", lapply(args, as_jexpr)
)
handledCallJStatic("java.util.Arrays", "asList", lapply(args, function(x) { x@jc }))
)

# Create Scala LambdaFunction
newJObject(
"org.apache.spark.sql.catalyst.expressions.LambdaFunction",
jexpr,
jargs,
FALSE
handledCallJStatic(
"org.apache.spark.sql.api.python.PythonSQLUtils",
"lambdaFunction",
result@jc,
jargs
)
}

Expand All @@ -4039,20 +4026,18 @@ create_lambda <- function(fun) {
#' @return a \code{Column} representing name applied to cols with funs
#' @keywords internal
invoke_higher_order_function <- function(name, cols, funs) {
as_jexpr <- function(x) {
as_col <- function(x) {
if (class(x) == "character") {
x <- column(x)
}
callJMethod(x@jc, "expr")
x@jc
}

jexpr <- do.call(newJObject, c(
paste("org.apache.spark.sql.catalyst.expressions", name, sep = "."),
lapply(cols, as_jexpr),
lapply(funs, create_lambda)
))

column(newJObject("org.apache.spark.sql.Column", jexpr))
jcol <- handledCallJStatic(
"org.apache.spark.sql.api.python.PythonSQLUtils",
"fn",
name,
c(lapply(cols, as_col), lapply(funs, create_lambda))) # check varargs invocation
column(jcol)
}

#' @details
Expand All @@ -4068,7 +4053,7 @@ setMethod("array_aggregate",
signature(x = "characterOrColumn", initialValue = "Column", merge = "function"),
function(x, initialValue, merge, finish = NULL) {
invoke_higher_order_function(
"ArrayAggregate",
"aggregate",
cols = list(x, initialValue),
funs = if (is.null(finish)) {
list(merge)
Expand Down Expand Up @@ -4129,7 +4114,7 @@ setMethod("array_exists",
signature(x = "characterOrColumn", f = "function"),
function(x, f) {
invoke_higher_order_function(
"ArrayExists",
"exists",
cols = list(x),
funs = list(f)
)
Expand All @@ -4145,7 +4130,7 @@ setMethod("array_filter",
signature(x = "characterOrColumn", f = "function"),
function(x, f) {
invoke_higher_order_function(
"ArrayFilter",
"filter",
cols = list(x),
funs = list(f)
)
Expand All @@ -4161,7 +4146,7 @@ setMethod("array_forall",
signature(x = "characterOrColumn", f = "function"),
function(x, f) {
invoke_higher_order_function(
"ArrayForAll",
"forall",
cols = list(x),
funs = list(f)
)
Expand Down Expand Up @@ -4291,7 +4276,7 @@ setMethod("array_sort",
column(callJStatic("org.apache.spark.sql.functions", "array_sort", x@jc))
} else {
invoke_higher_order_function(
"ArraySort",
"array_sort",
cols = list(x),
funs = list(comparator)
)
Expand All @@ -4309,7 +4294,7 @@ setMethod("array_transform",
signature(x = "characterOrColumn", f = "function"),
function(x, f) {
invoke_higher_order_function(
"ArrayTransform",
"transform",
cols = list(x),
funs = list(f)
)
Expand Down Expand Up @@ -4374,7 +4359,7 @@ setMethod("arrays_zip_with",
signature(x = "characterOrColumn", y = "characterOrColumn", f = "function"),
function(x, y, f) {
invoke_higher_order_function(
"ZipWith",
"zip_with",
cols = list(x, y),
funs = list(f)
)
Expand Down Expand Up @@ -4447,7 +4432,7 @@ setMethod("map_filter",
signature(x = "characterOrColumn", f = "function"),
function(x, f) {
invoke_higher_order_function(
"MapFilter",
"map_filter",
cols = list(x),
funs = list(f))
})
Expand Down Expand Up @@ -4504,7 +4489,7 @@ setMethod("transform_keys",
signature(x = "characterOrColumn", f = "function"),
function(x, f) {
invoke_higher_order_function(
"TransformKeys",
"transform_keys",
cols = list(x),
funs = list(f)
)
Expand All @@ -4521,7 +4506,7 @@ setMethod("transform_values",
signature(x = "characterOrColumn", f = "function"),
function(x, f) {
invoke_higher_order_function(
"TransformValues",
"transform_values",
cols = list(x),
funs = list(f)
)
Expand Down Expand Up @@ -4552,7 +4537,7 @@ setMethod("map_zip_with",
signature(x = "characterOrColumn", y = "characterOrColumn", f = "function"),
function(x, y, f) {
invoke_higher_order_function(
"MapZipWith",
"map_zip_with",
cols = list(x, y),
funs = list(f)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object functions {
def from_avro(
data: Column,
jsonFormatSchema: String): Column = {
new Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, Map.empty))
Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, Map.empty))
}

/**
Expand All @@ -62,7 +62,7 @@ object functions {
data: Column,
jsonFormatSchema: String,
options: java.util.Map[String, String]): Column = {
new Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, options.asScala.toMap))
Column(AvroDataToCatalyst(data.expr, jsonFormatSchema, options.asScala.toMap))
}

/**
Expand All @@ -74,7 +74,7 @@ object functions {
*/
@Experimental
def to_avro(data: Column): Column = {
new Column(CatalystDataToAvro(data.expr, None))
Column(CatalystDataToAvro(data.expr, None))
}

/**
Expand All @@ -87,6 +87,6 @@ object functions {
*/
@Experimental
def to_avro(data: Column, jsonFormatSchema: String): Column = {
new Column(CatalystDataToAvro(data.expr, Some(jsonFormatSchema)))
Column(CatalystDataToAvro(data.expr, Some(jsonFormatSchema)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,12 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.register"),

// Typed Column
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.TypedColumn.*"),
ProblemFilters.exclude[IncompatibleResultTypeProblem](
"org.apache.spark.sql.TypedColumn.expr"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.TypedColumn$"),

// Datasource V2 partition transforms
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.PartitionTransform$"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ object functions {
messageName: String,
binaryFileDescriptorSet: Array[Byte],
options: java.util.Map[String, String]): Column = {
new Column(
Column(
ProtobufDataToCatalyst(
data.expr, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap
)
Expand All @@ -93,7 +93,7 @@ object functions {
@Experimental
def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = {
val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
new Column(ProtobufDataToCatalyst(data.expr, messageName, Some(fileContent)))
Column(ProtobufDataToCatalyst(data.expr, messageName, Some(fileContent)))
}

/**
Expand All @@ -112,7 +112,7 @@ object functions {
@Experimental
def from_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte])
: Column = {
new Column(ProtobufDataToCatalyst(data.expr, messageName, Some(binaryFileDescriptorSet)))
Column(ProtobufDataToCatalyst(data.expr, messageName, Some(binaryFileDescriptorSet)))
}

/**
Expand All @@ -132,7 +132,7 @@ object functions {
*/
@Experimental
def from_protobuf(data: Column, messageClassName: String): Column = {
new Column(ProtobufDataToCatalyst(data.expr, messageClassName))
Column(ProtobufDataToCatalyst(data.expr, messageClassName))
}

/**
Expand All @@ -156,7 +156,7 @@ object functions {
data: Column,
messageClassName: String,
options: java.util.Map[String, String]): Column = {
new Column(ProtobufDataToCatalyst(data.expr, messageClassName, None, options.asScala.toMap))
Column(ProtobufDataToCatalyst(data.expr, messageClassName, None, options.asScala.toMap))
}

/**
Expand Down Expand Up @@ -194,7 +194,7 @@ object functions {
@Experimental
def to_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte])
: Column = {
new Column(CatalystDataToProtobuf(data.expr, messageName, Some(binaryFileDescriptorSet)))
Column(CatalystDataToProtobuf(data.expr, messageName, Some(binaryFileDescriptorSet)))
}
/**
* Converts a column into binary of protobuf format. The Protobuf definition is provided
Expand All @@ -216,7 +216,7 @@ object functions {
descFilePath: String,
options: java.util.Map[String, String]): Column = {
val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath)
new Column(
Column(
CatalystDataToProtobuf(data.expr, messageName, Some(fileContent), options.asScala.toMap)
)
}
Expand All @@ -242,7 +242,7 @@ object functions {
binaryFileDescriptorSet: Array[Byte],
options: java.util.Map[String, String]
): Column = {
new Column(
Column(
CatalystDataToProtobuf(
data.expr, messageName, Some(binaryFileDescriptorSet), options.asScala.toMap
)
Expand All @@ -266,7 +266,7 @@ object functions {
*/
@Experimental
def to_protobuf(data: Column, messageClassName: String): Column = {
new Column(CatalystDataToProtobuf(data.expr, messageClassName))
Column(CatalystDataToProtobuf(data.expr, messageClassName))
}

/**
Expand All @@ -288,6 +288,6 @@ object functions {
@Experimental
def to_protobuf(data: Column, messageClassName: String, options: java.util.Map[String, String])
: Column = {
new Column(CatalystDataToProtobuf(data.expr, messageClassName, None, options.asScala.toMap))
Column(CatalystDataToProtobuf(data.expr, messageClassName, None, options.asScala.toMap))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset, Encoder, Encoders, Row}
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -195,7 +194,7 @@ class StringIndexer @Since("1.4.0") (
} else {
// We don't count for NaN values. Because `StringIndexerAggregator` only processes strings,
// we replace NaNs with null in advance.
new Column(If(col.isNaN.expr, Literal(null), col.expr)).cast(StringType)
when(!isnan(col), col).cast(StringType)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ private[ml] class SummaryBuilderImpl(
mutableAggBufferOffset = 0,
inputAggBufferOffset = 0)

new Column(agg.toAggregateExpression())
Column(agg.toAggregateExpression())
}
}

Expand Down
5 changes: 4 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,10 @@ object MimaExcludes {
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.scheduler.JobWaiter.cancel"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.cancel"),
// SPARK-48901: Add clusterBy() to DataStreamWriter.
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.DataStreamWriter.clusterBy")
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.DataStreamWriter.clusterBy"),
// SPARK-49022: Use Column API
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.TypedColumn.this"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.expressions.WindowSpec.this")
)

// Default exclude rules
Expand Down
6 changes: 2 additions & 4 deletions python/pyspark/pandas/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,10 +915,8 @@ def attach_distributed_column(sdf: PySparkDataFrame, column_name: str) -> PySpar
if is_remote():
return sdf.select(F.monotonically_increasing_id().alias(column_name), *scols)
jvm = sdf.sparkSession._jvm
tag = jvm.org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FUNC_ALIAS()
jexpr = F.monotonically_increasing_id()._jc.expr()
jexpr.setTagValue(tag, "distributed_index")
return sdf.select(PySparkColumn(jvm.Column(jexpr)).alias(column_name), *scols)
jcol = jvm.PythonSQLUtils.distributedIndex()
return sdf.select(PySparkColumn(jcol).alias(column_name), *scols)

@staticmethod
def attach_distributed_sequence_column(
Expand Down
Loading

0 comments on commit b039741

Please sign in to comment.