diff --git a/datafu-spark/src/main/scala/spark/utils/overwrites/SparkOverwriteUDAFs.scala b/datafu-spark/src/main/scala/spark/utils/overwrites/SparkOverwriteUDAFs.scala index d57a6ef9..fb1a26cb 100644 --- a/datafu-spark/src/main/scala/spark/utils/overwrites/SparkOverwriteUDAFs.scala +++ b/datafu-spark/src/main/scala/spark/utils/overwrites/SparkOverwriteUDAFs.scala @@ -22,9 +22,10 @@ import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.aggregate.{Collect, DeclarativeAggregate, ImperativeAggregate} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryComparison, ExpectsInputTypes, Expression, GreaterThan, If, IsNull, LessThan, Literal} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryComparison, Concat, CreateArray, ExpectsInputTypes, Expression, GreaterThan, If, IsNull, LessThan, Literal, Size, Slice, SortArray} import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils} -import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, DataType} import scala.collection.generic.Growable import scala.collection.mutable @@ -39,10 +40,16 @@ In order to support Spark 3.1.x as well as Spark 3.2.0 and up, the methods withN object SparkOverwriteUDAFs { def minValueByKey(key: Column, value: Column): Column = Column(MinValueByKey(key.expr, value.expr).toAggregateExpression(false)) + def maxValueByKey(key: Column, value: Column): Column = Column(MaxValueByKey(key.expr, value.expr).toAggregateExpression(false)) + def collectLimitedList(e: Column, maxSize: Int): Column = Column(CollectLimitedList(e.expr, howMuchToTake = maxSize).toAggregateExpression(false)) + + def collectNumberOrderedElements(col: Column, howManyToTake: Int, ascending: Boolean = false) = + Column(CollectNumberOrderedElements(col.expr, lit(howManyToTake).expr, lit(ascending).expr).toAggregateExpression(false)) + } case class MinValueByKey(child1: Expression, child2: Expression) @@ -86,7 +93,7 @@ abstract class ExtramumValueByKey( private lazy val data = AttributeReference("data", child2.dataType)() override lazy val aggBufferAttributes - : Seq[AttributeReference] = minmax :: data :: Nil + : Seq[AttributeReference] = minmax :: data :: Nil override lazy val initialValues: Seq[Expression] = Seq( Literal.create(null, child1.dataType), @@ -174,3 +181,46 @@ abstract class LimitedCollect[T <: Growable[Any] with Iterable[Any]](howMuchToTa } } } + +case class CollectNumberOrderedElements(child: Expression, howManyToTake: Expression, ascending: Expression) extends DeclarativeAggregate with ExpectsInputTypes { + + override def children: Seq[Expression] = Seq(child) + + override def nullable: Boolean = true + + // Return data type. + override def dataType: DataType = ArrayType(child.dataType, containsNull = false) + + // Expected input data type. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function TakeFirstValues") + + private lazy val data = AttributeReference("data", ArrayType(child.dataType, containsNull = false))() + + override lazy val aggBufferAttributes: Seq[AttributeReference] = data :: Nil + + override lazy val initialValues: Seq[Expression] = Seq( + Literal.create(Array(), ArrayType(child.dataType, containsNull = false)) + ) + + // Change to array_append after Spark 3.4.0 + override lazy val updateExpressions: Seq[Expression] = sortAndSliceArray(data, CreateArray(Seq(child))) + + override lazy val mergeExpressions: Seq[Expression] = sortAndSliceArray(data.right, data.left) + + private def sortAndSliceArray(firstArray: Expression, secondArray: Expression) = { + val unifiedArray = Concat(Seq(firstArray, secondArray)) + Seq( + If(GreaterThan(Size(unifiedArray), howManyToTake), + Slice(SortArray(unifiedArray, ascending), Literal(1), howManyToTake), + unifiedArray)) + } + + override lazy val evaluateExpression: AttributeReference = data + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = { + copy(child = newChildren.head) + } +}