Skip to content

Commit

Permalink
[SPARK-42746][SQL] Implement LISTAGG function
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Implement new aggregation function `listagg([ALL | DISTINCT] expr[, sep]) [WITHIN GROUP (ORDER BY key [ASC | DESC] [,...])]`

### Why are the changes needed?

Listagg is a popular function implemented by many other vendors. For now, users have to use workarounds like [this](https://kb.databricks.com/sql/recreate-listagg-functionality-with-spark-sql). PR will close the gap.

### Does this PR introduce _any_ user-facing change?

Yes, the new `listagg` function. BigQuery and PostgreSQL have the same function but with `string_agg` name so I added it as an alias.

### How was this patch tested?

With new unit tests

### Was this patch authored or co-authored using generative AI tooling?

Generated-by: GitHub Copilot

Closes #48748 from mikhailnik-db/SPARK-42746-add-listagg.

Lead-authored-by: Mikhail Nikoliukin <mikhail.nikoliukin@databricks.com>
Co-authored-by: Jia Fan <fanjiaeminem@qq.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
2 people authored and cloud-fan committed Nov 29, 2024
1 parent d5b534d commit 4b97e11
Show file tree
Hide file tree
Showing 29 changed files with 1,655 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,27 +135,57 @@ public static byte[] subStringSQL(byte[] bytes, int pos, int len) {
return Arrays.copyOfRange(bytes, start, end);
}

/**
* Concatenate multiple byte arrays into one.
* If one of the inputs is null then null will be returned.
*
* @param inputs byte arrays to concatenate
* @return the concatenated byte array or null if one of the arguments is null
*/
public static byte[] concat(byte[]... inputs) {
return concatWS(EMPTY_BYTE, inputs);
}

/**
* Concatenate multiple byte arrays with a given delimiter.
* If the delimiter or one of the inputs is null then null will be returned.
*
* @param delimiter byte array to be placed between each input
* @param inputs byte arrays to concatenate
* @return the concatenated byte array or null if one of the arguments is null
*/
public static byte[] concatWS(byte[] delimiter, byte[]... inputs) {
if (delimiter == null) {
return null;
}
// Compute the total length of the result
long totalLength = 0;
for (byte[] input : inputs) {
if (input != null) {
totalLength += input.length;
totalLength += input.length + delimiter.length;
} else {
return null;
}
}

if (totalLength > 0) totalLength -= delimiter.length;
// Allocate a new byte array, and copy the inputs one by one into it
final byte[] result = new byte[Ints.checkedCast(totalLength)];
int offset = 0;
for (byte[] input : inputs) {
for (int i = 0; i < inputs.length; i++) {
byte[] input = inputs[i];
int len = input.length;
Platform.copyMemory(
input, Platform.BYTE_ARRAY_OFFSET,
result, Platform.BYTE_ARRAY_OFFSET + offset,
len);
offset += len;
if (delimiter.length > 0 && i < inputs.length - 1) {
Platform.copyMemory(
delimiter, Platform.BYTE_ARRAY_OFFSET,
result, Platform.BYTE_ARRAY_OFFSET + offset,
delimiter.length);
offset += delimiter.length;
}
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,59 @@ public void testCompareBinary() {
byte[] y4 = new byte[]{(byte) 100, (byte) 200};
Assertions.assertEquals(0, ByteArray.compareBinary(x4, y4));
}

@Test
public void testConcat() {
byte[] x1 = new byte[]{(byte) 1, (byte) 2, (byte) 3};
byte[] y1 = new byte[]{(byte) 4, (byte) 5, (byte) 6};
byte[] result1 = ByteArray.concat(x1, y1);
byte[] expected1 = new byte[]{(byte) 1, (byte) 2, (byte) 3, (byte) 4, (byte) 5, (byte) 6};
Assertions.assertArrayEquals(expected1, result1);

byte[] x2 = new byte[]{(byte) 1, (byte) 2, (byte) 3};
byte[] y2 = new byte[0];
byte[] result2 = ByteArray.concat(x2, y2);
byte[] expected2 = new byte[]{(byte) 1, (byte) 2, (byte) 3};
Assertions.assertArrayEquals(expected2, result2);

byte[] x3 = new byte[0];
byte[] y3 = new byte[]{(byte) 4, (byte) 5, (byte) 6};
byte[] result3 = ByteArray.concat(x3, y3);
byte[] expected3 = new byte[]{(byte) 4, (byte) 5, (byte) 6};
Assertions.assertArrayEquals(expected3, result3);

byte[] x4 = new byte[]{(byte) 1, (byte) 2, (byte) 3};
byte[] y4 = null;
byte[] result4 = ByteArray.concat(x4, y4);
Assertions.assertArrayEquals(null, result4);
}

@Test
public void testConcatWS() {
byte[] separator = new byte[]{(byte) 42};

byte[] x1 = new byte[]{(byte) 1, (byte) 2, (byte) 3};
byte[] y1 = new byte[]{(byte) 4, (byte) 5, (byte) 6};
byte[] result1 = ByteArray.concatWS(separator, x1, y1);
byte[] expected1 = new byte[]{(byte) 1, (byte) 2, (byte) 3, (byte) 42,
(byte) 4, (byte) 5, (byte) 6};
Assertions.assertArrayEquals(expected1, result1);

byte[] x2 = new byte[]{(byte) 1, (byte) 2, (byte) 3};
byte[] y2 = new byte[0];
byte[] result2 = ByteArray.concatWS(separator, x2, y2);
byte[] expected2 = new byte[]{(byte) 1, (byte) 2, (byte) 3, (byte) 42};
Assertions.assertArrayEquals(expected2, result2);

byte[] x3 = new byte[0];
byte[] y3 = new byte[]{(byte) 4, (byte) 5, (byte) 6};
byte[] result3 = ByteArray.concatWS(separator, x3, y3);
byte[] expected3 = new byte[]{(byte) 42, (byte) 4, (byte) 5, (byte) 6};
Assertions.assertArrayEquals(expected3, result3);

byte[] x4 = new byte[]{(byte) 1, (byte) 2, (byte) 3};
byte[] y4 = null;
byte[] result4 = ByteArray.concatWS(separator, x4, y4);
Assertions.assertArrayEquals(null, result4);
}
}
51 changes: 28 additions & 23 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -2627,29 +2627,6 @@
],
"sqlState" : "22006"
},
"INVALID_INVERSE_DISTRIBUTION_FUNCTION" : {
"message" : [
"Invalid inverse distribution function <funcName>."
],
"subClass" : {
"DISTINCT_UNSUPPORTED" : {
"message" : [
"Cannot use DISTINCT with WITHIN GROUP."
]
},
"WITHIN_GROUP_MISSING" : {
"message" : [
"WITHIN GROUP is required for inverse distribution function."
]
},
"WRONG_NUM_ORDERINGS" : {
"message" : [
"Requires <expectedNum> orderings in WITHIN GROUP but got <actualNum>."
]
}
},
"sqlState" : "42K0K"
},
"INVALID_JAVA_IDENTIFIER_AS_FIELD_NAME" : {
"message" : [
"<fieldName> is not a valid identifier of Java and cannot be used as field name",
Expand Down Expand Up @@ -3364,6 +3341,34 @@
],
"sqlState" : "42601"
},
"INVALID_WITHIN_GROUP_EXPRESSION" : {
"message" : [
"Invalid function <funcName> with WITHIN GROUP."
],
"subClass" : {
"DISTINCT_UNSUPPORTED" : {
"message" : [
"The function does not support DISTINCT with WITHIN GROUP."
]
},
"MISMATCH_WITH_DISTINCT_INPUT" : {
"message" : [
"The function is invoked with DISTINCT and WITHIN GROUP but expressions <funcArg> and <orderingExpr> do not match. The WITHIN GROUP ordering expression must be picked from the function inputs."
]
},
"WITHIN_GROUP_MISSING" : {
"message" : [
"WITHIN GROUP is required for the function."
]
},
"WRONG_NUM_ORDERINGS" : {
"message" : [
"The function requires <expectedNum> orderings in WITHIN GROUP but got <actualNum>."
]
}
},
"sqlState" : "42K0K"
},
"INVALID_WRITER_COMMIT_MESSAGE" : {
"message" : [
"The data source writer has generated an invalid number of commit messages. Expected exactly one writer commit message from each task, but received <detail>."
Expand Down
8 changes: 7 additions & 1 deletion python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,13 @@ def test_function_parity(self):
missing_in_py = jvm_fn_set.difference(py_fn_set)

# Functions that we expect to be missing in python until they are added to pyspark
expected_missing_in_py = set()
expected_missing_in_py = {
# TODO(SPARK-50220): listagg functions will soon be added and removed from this list
"listagg_distinct",
"listagg",
"string_agg",
"string_agg_distinct",
}

self.assertEqual(
expected_missing_in_py, missing_in_py, "Missing functions in pyspark not as expected"
Expand Down
71 changes: 71 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,77 @@ object functions {
*/
def sum_distinct(e: Column): Column = Column.fn("sum", isDistinct = true, e)

/**
* Aggregate function: returns the concatenation of non-null input values.
*
* @group agg_funcs
* @since 4.0.0
*/
def listagg(e: Column): Column = Column.fn("listagg", e)

/**
* Aggregate function: returns the concatenation of non-null input values, separated by the
* delimiter.
*
* @group agg_funcs
* @since 4.0.0
*/
def listagg(e: Column, delimiter: Column): Column = Column.fn("listagg", e, delimiter)

/**
* Aggregate function: returns the concatenation of distinct non-null input values.
*
* @group agg_funcs
* @since 4.0.0
*/
def listagg_distinct(e: Column): Column = Column.fn("listagg", isDistinct = true, e)

/**
* Aggregate function: returns the concatenation of distinct non-null input values, separated by
* the delimiter.
*
* @group agg_funcs
* @since 4.0.0
*/
def listagg_distinct(e: Column, delimiter: Column): Column =
Column.fn("listagg", isDistinct = true, e, delimiter)

/**
* Aggregate function: returns the concatenation of non-null input values. Alias for `listagg`.
*
* @group agg_funcs
* @since 4.0.0
*/
def string_agg(e: Column): Column = Column.fn("string_agg", e)

/**
* Aggregate function: returns the concatenation of non-null input values, separated by the
* delimiter. Alias for `listagg`.
*
* @group agg_funcs
* @since 4.0.0
*/
def string_agg(e: Column, delimiter: Column): Column = Column.fn("string_agg", e, delimiter)

/**
* Aggregate function: returns the concatenation of distinct non-null input values. Alias for
* `listagg`.
*
* @group agg_funcs
* @since 4.0.0
*/
def string_agg_distinct(e: Column): Column = Column.fn("string_agg", isDistinct = true, e)

/**
* Aggregate function: returns the concatenation of distinct non-null input values, separated by
* the delimiter. Alias for `listagg`.
*
* @group agg_funcs
* @since 4.0.0
*/
def string_agg_distinct(e: Column, delimiter: Column): Column =
Column.fn("string_agg", isDistinct = true, e, delimiter)

/**
* Aggregate function: alias for `var_samp`.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2772,6 +2772,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
ne
case e: Expression if e.foldable =>
e // No need to create an attribute reference if it will be evaluated as a Literal.
case e: SortOrder =>
// For SortOder just recursively extract the from child expression.
e.copy(child = extractExpr(e.child))
case e: NamedArgumentExpression =>
// For NamedArgumentExpression, we extract the value and replace it with
// an AttributeReference (with an internal column name, e.g. "_w0").
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Median, PercentileCont, PercentileDisc}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ListAgg, Median, PercentileCont, PercentileDisc}
import org.apache.spark.sql.catalyst.optimizer.{BooleanSimplification, DecorrelateInnerQuery, InlineCTE}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -423,10 +423,23 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
"funcName" -> toSQLExpr(wf),
"windowExpr" -> toSQLExpr(w)))

case agg @ AggregateExpression(listAgg: ListAgg, _, _, _, _)
if agg.isDistinct && listAgg.needSaveOrderValue =>
throw QueryCompilationErrors.functionAndOrderExpressionMismatchError(
listAgg.prettyName, listAgg.child, listAgg.orderExpressions)

case w: WindowExpression =>
// Only allow window functions with an aggregate expression or an offset window
// function or a Pandas window UDF.
w.windowFunction match {
case agg @ AggregateExpression(fun: ListAgg, _, _, _, _)
// listagg(...) WITHIN GROUP (ORDER BY ...) OVER (ORDER BY ...) is unsupported
if fun.orderingFilled && (w.windowSpec.orderSpec.nonEmpty ||
w.windowSpec.frameSpecification !=
SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing)) =>
agg.failAnalysis(
errorClass = "INVALID_WINDOW_SPEC_FOR_AGGREGATION_FUNC",
messageParameters = Map("aggFunc" -> toSQLExpr(agg.aggregateFunction)))
case agg @ AggregateExpression(
_: PercentileCont | _: PercentileDisc | _: Median, _, _, _, _)
if w.windowSpec.orderSpec.nonEmpty || w.windowSpec.frameSpecification !=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,8 @@ object FunctionRegistry {
expression[CollectList]("collect_list"),
expression[CollectList]("array_agg", true, Some("3.3.0")),
expression[CollectSet]("collect_set"),
expression[ListAgg]("listagg"),
expression[ListAgg]("string_agg", setAlias = true),
expressionBuilder("count_min_sketch", CountMinSketchAggExpressionBuilder),
expression[BoolAnd]("every", true),
expression[BoolAnd]("bool_and"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,18 +128,15 @@ class FunctionResolution(
numArgs: Int,
u: UnresolvedFunction): Expression = {
func match {
case owg: SupportsOrderingWithinGroup if u.isDistinct =>
throw QueryCompilationErrors.distinctInverseDistributionFunctionUnsupportedError(
owg.prettyName
)
case owg: SupportsOrderingWithinGroup if !owg.isDistinctSupported && u.isDistinct =>
throw QueryCompilationErrors.distinctWithOrderingFunctionUnsupportedError(owg.prettyName)
case owg: SupportsOrderingWithinGroup
if !owg.orderingFilled && u.orderingWithinGroup.isEmpty =>
throw QueryCompilationErrors.inverseDistributionFunctionMissingWithinGroupError(
owg.prettyName
)
if owg.isOrderingMandatory && !owg.orderingFilled && u.orderingWithinGroup.isEmpty =>
throw QueryCompilationErrors.functionMissingWithinGroupError(owg.prettyName)
case owg: SupportsOrderingWithinGroup
if owg.orderingFilled && u.orderingWithinGroup.nonEmpty =>
throw QueryCompilationErrors.wrongNumOrderingsForInverseDistributionFunctionError(
// e.g mode(expr1) within group (order by expr2) is not supported
throw QueryCompilationErrors.wrongNumOrderingsForFunctionError(
owg.prettyName,
0,
u.orderingWithinGroup.length
Expand Down Expand Up @@ -198,7 +195,7 @@ class FunctionResolution(
case agg: AggregateFunction =>
// Note: PythonUDAF does not support these advanced clauses.
if (agg.isInstanceOf[PythonUDAF]) checkUnsupportedAggregateClause(agg, u)
// After parse, the inverse distribution functions not set the ordering within group yet.
// After parse, the functions not set the ordering within group yet.
val newAgg = agg match {
case owg: SupportsOrderingWithinGroup
if !owg.orderingFilled && u.orderingWithinGroup.nonEmpty =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,16 @@ case class Mode(
}

override def orderingFilled: Boolean = child != UnresolvedWithinGroup
override def isOrderingMandatory: Boolean = true
override def isDistinctSupported: Boolean = false

assert(orderingFilled || (!orderingFilled && reverseOpt.isEmpty))

override def withOrderingWithinGroup(orderingWithinGroup: Seq[SortOrder]): AggregateFunction = {
child match {
case UnresolvedWithinGroup =>
if (orderingWithinGroup.length != 1) {
throw QueryCompilationErrors.wrongNumOrderingsForInverseDistributionFunctionError(
throw QueryCompilationErrors.wrongNumOrderingsForFunctionError(
nodeName, 1, orderingWithinGroup.length)
}
orderingWithinGroup.head match {
Expand Down
Loading

0 comments on commit 4b97e11

Please sign in to comment.