Skip to content

Commit

Permalink
[SPARK-46966][PYTHON] Add UDTF API for 'analyze' method to indicate s…
Browse files Browse the repository at this point in the history
…ubset of input table columns to select

### What changes were proposed in this pull request?

This PR adds a UDTF API for the 'analyze' method to indicate subset of input table columns to select.

For example, this UDTF populates this 'select' list to indicate that Spark should only return two input columns from the input table: 'input' and 'partition_col':

```
from pyspark.sql.functions import AnalyzeResult, OrderingColumn, PartitioningColumn, SelectedColumn
from pyspark.sql.types import IntegerType, Row, StructType
class Udtf:
    def __init__(self):
        self._partition_col = None
        self._count = 0
        self._sum = 0
        self._last = None

    staticmethod
    def analyze(row: Row):
        return AnalyzeResult(
            schema=StructType()
                .add("user_id", IntegerType())
                .add("count", IntegerType())
                .add("total", IntegerType())
                .add("last", IntegerType()),
            partitionBy=[
                PartitioningColumn("user_id")
            ],
            orderBy=[
                OrderingColumn("timestamp")
            ],
            select=[
                SelectedColumn("input"),
                SelectedColumn("partition_col")
            ])

    def eval(self, row: Row):
        self._partition_col = row["partition_col"]
        self._count += 1
        self._last = row["input"]
        self._sum += row["input"]

    def terminate(self):
        yield self._partition_col, self._count, self._sum, self._last
```

### Why are the changes needed?

This can reduce the amount of data sent between the JVM and Python interpreter, improving performance.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

This PR adds test coverage.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#45007 from dtenedor/udtf-select-cols.

Authored-by: Daniel Tenedorio <daniel.tenedorio@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
  • Loading branch information
dtenedor authored and ueshin committed Feb 7, 2024
1 parent 5789316 commit 31f85e5
Show file tree
Hide file tree
Showing 13 changed files with 367 additions and 13 deletions.
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -3456,6 +3456,12 @@
],
"sqlState" : "42802"
},
"UDTF_INVALID_REQUESTED_SELECTED_EXPRESSION_FROM_ANALYZE_METHOD_REQUIRES_ALIAS" : {
"message" : [
"Failed to evaluate the user-defined table function because its 'analyze' method returned a requested 'select' expression (<expression>) that does not include a corresponding alias; please update the UDTF to specify an alias there and then try the query again."
],
"sqlState" : "42802"
},
"UNABLE_TO_ACQUIRE_MEMORY" : {
"message" : [
"Unable to acquire <requestedBytes> bytes of memory, got <receivedBytes>."
Expand Down
6 changes: 6 additions & 0 deletions docs/sql-error-conditions.md
Original file line number Diff line number Diff line change
Expand Up @@ -2215,6 +2215,12 @@ Please ensure that the number of aliases provided matches the number of columns

Failed to evaluate the user-defined table function because its 'analyze' method returned a requested OrderingColumn whose column name expression included an unnecessary alias `<aliasName>`; please remove this alias and then try the query again.

### UDTF_INVALID_REQUESTED_SELECTED_EXPRESSION_FROM_ANALYZE_METHOD_REQUIRES_ALIAS

[SQLSTATE: 42802](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)

Failed to evaluate the user-defined table function because its 'analyze' method returned a requested 'select' expression (`<expression>`) that does not include a corresponding alias; please update the UDTF to specify an alias there and then try the query again.

### UNABLE_TO_ACQUIRE_MEMORY

[SQLSTATE: 53200](sql-error-conditions-sqlstates.html#class-53-insufficient-resources)
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409
from pyspark.sql.udf import UserDefinedFunction, _create_py_udf # noqa: F401
from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult # noqa: F401
from pyspark.sql.udtf import OrderingColumn, PartitioningColumn # noqa: F401
from pyspark.sql.udtf import OrderingColumn, PartitioningColumn, SelectedColumn # noqa: F401
from pyspark.sql.udtf import SkipRestOfInputTableException # noqa: F401
from pyspark.sql.udtf import UserDefinedTableFunction, _create_py_udtf

Expand Down
32 changes: 29 additions & 3 deletions python/pyspark/sql/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"AnalyzeResult",
"PartitioningColumn",
"OrderingColumn",
"SelectedColumn",
"SkipRestOfInputTableException",
"UDTFRegistration",
]
Expand Down Expand Up @@ -108,6 +109,26 @@ class OrderingColumn:
overrideNullsFirst: Optional[bool] = None


@dataclass(frozen=True)
class SelectedColumn:
"""
Represents an expression that the UDTF is specifying for Catalyst to evaluate against the
columns in the input TABLE argument. The UDTF then receives one input column for each expression
in the list, in the order they are listed.
Parameters
----------
name : str
The contents of the selected column name or expression represented as a SQL string.
alias : str, default ''
If non-empty, this is the alias for the column or expression as visible from the UDTF's
'eval' method. This is required if the expression is not a simple column reference.
"""

name: str
alias: str = ""


# Note: this class is a "dataclass" for purposes of convenience, but it is not marked "frozen"
# because the intention is that users may create subclasses of it for purposes of returning custom
# information from the "analyze" method.
Expand All @@ -118,13 +139,13 @@ class AnalyzeResult:
Parameters
----------
schema : :class:`StructType`
schema: :class:`StructType`
The schema that the Python UDTF will return.
withSinglePartition : bool
withSinglePartition: bool
If true, the UDTF is specifying for Catalyst to repartition all rows of the input TABLE
argument to one collection for consumption by exactly one instance of the correpsonding
UDTF class.
partitionBy : sequence of :class:`PartitioningColumn`
partitionBy: sequence of :class:`PartitioningColumn`
If non-empty, this is a sequence of expressions that the UDTF is specifying for Catalyst to
partition the input TABLE argument by. In this case, calls to the UDTF may not include any
explicit PARTITION BY clause, in which case Catalyst will return an error. This option is
Expand All @@ -133,12 +154,17 @@ class AnalyzeResult:
If non-empty, this is a sequence of expressions that the UDTF is specifying for Catalyst to
sort the input TABLE argument by. Note that the 'partitionBy' list must also be non-empty
in this case.
select: sequence of :class:`SelectedColumn`
If non-empty, this is a sequence of expressions that the UDTF is specifying for Catalyst to
evaluate against the columns in the input TABLE argument. The UDTF then receives one input
attribute for each name in the list, in the order they are listed.
"""

schema: StructType
withSinglePartition: bool = False
partitionBy: Sequence[PartitioningColumn] = field(default_factory=tuple)
orderBy: Sequence[OrderingColumn] = field(default_factory=tuple)
select: Sequence[SelectedColumn] = field(default_factory=tuple)


class SkipRestOfInputTableException(Exception):
Expand Down
20 changes: 19 additions & 1 deletion python/pyspark/sql/worker/analyze_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
write_with_length,
SpecialLengths,
)
from pyspark.sql.functions import PartitioningColumn
from pyspark.sql.functions import PartitioningColumn, SelectedColumn
from pyspark.sql.types import _parse_datatype_json_string, StructType
from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult
from pyspark.util import handle_worker_exception
Expand Down Expand Up @@ -203,6 +203,19 @@ def format_error(msg: str) -> str:
and then try the query again."""
)
)
elif isinstance(result.select, (list, tuple)) and (
len(result.select) > 0
and not all([isinstance(val, SelectedColumn) for val in result.select])
):
raise PySparkValueError(
format_error(
f"""
{error_prefix} because the static 'analyze' method returned an
'AnalyzeResult' object with the 'select' field set to a value besides a
list or tuple of 'SelectedColumn' objects. Please update the table function
and then try the query again."""
)
)

# Return the analyzed schema.
write_with_length(result.schema.json().encode("utf-8"), outfile)
Expand All @@ -225,6 +238,11 @@ def format_error(msg: str) -> str:
write_int(1, outfile)
else:
write_int(2, outfile)
# Return the requested selected input table columns, if specified.
write_int(len(result.select), outfile)
for col in result.select:
write_with_length(col.name.encode("utf-8"), outfile)
write_with_length(col.alias.encode("utf-8"), outfile)

except BaseException as e:
handle_worker_exception(e, outfile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Project, Repartition, RepartitionByExpression, Sort}
import org.apache.spark.sql.catalyst.trees.TreePattern.{FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION, TreePattern}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.DataType

/**
Expand Down Expand Up @@ -58,14 +59,19 @@ import org.apache.spark.sql.types.DataType
* @param orderByExpressions if non-empty, the TABLE argument included the ORDER BY clause to
* indicate that the rows within each partition of the table function are
* to arrive in the provided order.
* @param selectedInputExpressions If non-empty, this is a sequence of expressions that the UDTF is
* specifying for Catalyst to evaluate against the columns in the
* input TABLE argument. The UDTF then receives one input attribute
* for each name in the list, in the order they are listed.
*/
case class FunctionTableSubqueryArgumentExpression(
plan: LogicalPlan,
outerAttrs: Seq[Expression] = Seq.empty,
exprId: ExprId = NamedExpression.newExprId,
partitionByExpressions: Seq[Expression] = Seq.empty,
withSinglePartition: Boolean = false,
orderByExpressions: Seq[SortOrder] = Seq.empty)
orderByExpressions: Seq[SortOrder] = Seq.empty,
selectedInputExpressions: Seq[PythonUDTFSelectedExpression] = Seq.empty)
extends SubqueryExpression(plan, outerAttrs, exprId, Seq.empty, None) with Unevaluable {

assert(!(withSinglePartition && partitionByExpressions.nonEmpty),
Expand Down Expand Up @@ -134,6 +140,19 @@ case class FunctionTableSubqueryArgumentExpression(
child = subquery)
}
}
// If instructed, add a projection to compute the specified input expressions.
if (selectedInputExpressions.nonEmpty) {
val projectList: Seq[NamedExpression] = selectedInputExpressions.map {
case PythonUDTFSelectedExpression(expression: Expression, Some(alias: String)) =>
Alias(expression, alias)()
case PythonUDTFSelectedExpression(a: Attribute, None) =>
a
case PythonUDTFSelectedExpression(other: Expression, None) =>
throw QueryCompilationErrors
.invalidUDTFSelectExpressionFromAnalyzeMethodNeedsAlias(other.sql)
} ++ extraProjectedPartitioningExpressions
subquery = Project(projectList, subquery)
}
Project(Seq(Alias(CreateStruct(subquery.output), "c")()), subquery)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,12 @@ case class UnresolvedPolymorphicPythonUDTF(
* @param orderByExpressions if non-empty, this contains the list of ordering items that the
* 'analyze' method explicitly indicated that the UDTF call should consume
* the input table rows by
* @param selectedInputExpressions If non-empty, this is a list of expressions that the UDTF is
* specifying for Catalyst to evaluate against the columns in the
* input TABLE argument. In this case, Catalyst will insert a
* projection to evaluate these expressions and return the result to
* the UDTF. The UDTF then receives one input column for each
* expression in the list, in the order they are listed.
* @param pickledAnalyzeResult this is the pickled 'AnalyzeResult' instance from the UDTF, which
* contains all metadata returned by the Python UDTF 'analyze' method
* including the result schema of the function call as well as optional
Expand All @@ -256,6 +262,7 @@ case class PythonUDTFAnalyzeResult(
withSinglePartition: Boolean,
partitionByExpressions: Seq[Expression],
orderByExpressions: Seq[SortOrder],
selectedInputExpressions: Seq[PythonUDTFSelectedExpression],
pickledAnalyzeResult: Array[Byte]) {
/**
* Applies the requested properties from this analysis result to the target TABLE argument
Expand Down Expand Up @@ -291,6 +298,7 @@ case class PythonUDTFAnalyzeResult(
var newWithSinglePartition = t.withSinglePartition
var newPartitionByExpressions = t.partitionByExpressions
var newOrderByExpressions = t.orderByExpressions
var newSelectedInputExpressions = t.selectedInputExpressions
if (withSinglePartition) {
newWithSinglePartition = true
}
Expand All @@ -300,13 +308,30 @@ case class PythonUDTFAnalyzeResult(
if (orderByExpressions.nonEmpty) {
newOrderByExpressions = orderByExpressions
}
if (selectedInputExpressions.nonEmpty) {
newSelectedInputExpressions = selectedInputExpressions
}
t.copy(
withSinglePartition = newWithSinglePartition,
partitionByExpressions = newPartitionByExpressions,
orderByExpressions = newOrderByExpressions)
orderByExpressions = newOrderByExpressions,
selectedInputExpressions = newSelectedInputExpressions)
}
}

/**
* Represents an expression that the UDTF is specifying for Catalyst to evaluate against the
* columns in the input TABLE argument. The UDTF then receives one input column for each expression
* in the list, in the order they are listed.
*
* @param expression the expression that the UDTF is specifying for Catalyst to evaluate against the
* columns in the input TABLE argument
* @param alias If present, this is the alias for the column or expression as visible from the
* UDTF's 'eval' method. This is required if the expression is not a simple column
* reference.
*/
case class PythonUDTFSelectedExpression(expression: Expression, alias: Option[String])

/**
* A place holder used when printing expressions without debugging information such as the
* result id.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,12 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
messageParameters = Map("aliasName" -> aliasName))
}

def invalidUDTFSelectExpressionFromAnalyzeMethodNeedsAlias(expression: String): Throwable = {
new AnalysisException(
errorClass = "UDTF_INVALID_REQUESTED_SELECTED_EXPRESSION_FROM_ANALYZE_METHOD_REQUIRES_ALIAS",
messageParameters = Map("expression" -> expression))
}

def windowAggregateFunctionWithFilterNotSupportedError(): Throwable = {
new AnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_1030",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import net.razorvine.pickle.Pickler

import org.apache.spark.api.python.{PythonEvalType, PythonFunction, PythonWorkerUtils, SpecialLengths}
import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Expression, FunctionTableSubqueryArgumentExpression, NamedArgumentExpression, NullsFirst, NullsLast, PythonUDAF, PythonUDF, PythonUDTF, PythonUDTFAnalyzeResult, SortOrder, UnresolvedPolymorphicPythonUDTF}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Expression, FunctionTableSubqueryArgumentExpression, NamedArgumentExpression, NullsFirst, NullsLast, PythonUDAF, PythonUDF, PythonUDTF, PythonUDTFAnalyzeResult, PythonUDTFSelectedExpression, SortOrder, UnresolvedPolymorphicPythonUDTF}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, NamedParametersSupport, OneRowRelation}
import org.apache.spark.sql.errors.QueryCompilationErrors
Expand Down Expand Up @@ -276,11 +276,24 @@ class UserDefinedPythonTableFunctionAnalyzeRunner(
case 2 => orderBy.append(SortOrder(parsed, direction, NullsLast, Seq.empty))
}
}
// Receive the list of requested input columns to select, if specified.
val numSelectedInputExpressions = dataIn.readInt()
val selectedInputExpressions = ArrayBuffer.empty[PythonUDTFSelectedExpression]
for (_ <- 0 until numSelectedInputExpressions) {
val expressionSql: String = PythonWorkerUtils.readUTF(dataIn)
val parsed: Expression = parser.parseExpression(expressionSql)
val alias: String = PythonWorkerUtils.readUTF(dataIn)
selectedInputExpressions.append(
PythonUDTFSelectedExpression(
parsed,
if (alias.nonEmpty) Some(alias) else None))
}
PythonUDTFAnalyzeResult(
schema = schema,
withSinglePartition = withSinglePartition,
partitionByExpressions = partitionByExpressions.toSeq,
orderByExpressions = orderBy.toSeq,
selectedInputExpressions = selectedInputExpressions.toSeq,
pickledAnalyzeResult = pickledAnalyzeResult)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,85 @@ SELECT * FROM UDTFPartitionByOrderByComplexExpr(TABLE(t2))
[Analyzer test output redacted due to nondeterminism]


-- !query
SELECT * FROM UDTFPartitionByOrderBySelectExpr(TABLE(t2))
-- !query analysis
[Analyzer test output redacted due to nondeterminism]


-- !query
SELECT * FROM UDTFPartitionByOrderBySelectComplexExpr(TABLE(t2))
-- !query analysis
[Analyzer test output redacted due to nondeterminism]


-- !query
SELECT * FROM UDTFPartitionByOrderBySelectExprOnlyPartitionColumn(TABLE(t2))
-- !query analysis
[Analyzer test output redacted due to nondeterminism]


-- !query
SELECT * FROM UDTFInvalidSelectExprParseError(TABLE(t2))
-- !query analysis
org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
"errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION",
"sqlState" : "42703",
"messageParameters" : {
"objectName" : "`unparsable`",
"proposal" : "`t2`.`input`, `partition_by_0`, `t2`.`partition_col`"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 1,
"stopIndex" : 10,
"fragment" : "unparsable"
} ]
}


-- !query
SELECT * FROM UDTFInvalidSelectExprStringValue(TABLE(t2))
-- !query analysis
org.apache.spark.sql.AnalysisException
{
"errorClass" : "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON",
"sqlState" : "38000",
"messageParameters" : {
"msg" : "Failed to evaluate the user-defined table function 'UDTFInvalidSelectExprStringValue' because the static 'analyze' method returned an 'AnalyzeResult' object with the 'select' field set to a value besides a list or tuple of 'SelectedColumn' objects. Please update the table function and then try the query again."
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 15,
"stopIndex" : 57,
"fragment" : "UDTFInvalidSelectExprStringValue(TABLE(t2))"
} ]
}


-- !query
SELECT * FROM UDTFInvalidComplexSelectExprMissingAlias(TABLE(t2))
-- !query analysis
org.apache.spark.sql.AnalysisException
{
"errorClass" : "UDTF_INVALID_REQUESTED_SELECTED_EXPRESSION_FROM_ANALYZE_METHOD_REQUIRES_ALIAS",
"sqlState" : "42802",
"messageParameters" : {
"expression" : "(input + 1)"
},
"queryContext" : [ {
"objectType" : "",
"objectName" : "",
"startIndex" : 15,
"stopIndex" : 65,
"fragment" : "UDTFInvalidComplexSelectExprMissingAlias(TABLE(t2))"
} ]
}


-- !query
SELECT * FROM UDTFInvalidOrderByAscKeyword(TABLE(t2))
-- !query analysis
Expand Down
Loading

0 comments on commit 31f85e5

Please sign in to comment.