Skip to content

Commit

Permalink
Added the first version of collectNumberOrderedElements
Browse files Browse the repository at this point in the history
  • Loading branch information
Rahamim, Ben committed Oct 9, 2024
1 parent ef44b6e commit 17e59dd
Showing 1 changed file with 53 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.aggregate.{Collect, DeclarativeAggregate, ImperativeAggregate}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryComparison, ExpectsInputTypes, Expression, GreaterThan, If, IsNull, LessThan, Literal}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryComparison, Concat, CreateArray, ExpectsInputTypes, Expression, GreaterThan, If, IsNull, LessThan, Literal, Size, Slice, SortArray}
import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils}
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, DataType}
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, DataType}

import scala.collection.generic.Growable
import scala.collection.mutable
Expand All @@ -39,10 +40,16 @@ In order to support Spark 3.1.x as well as Spark 3.2.0 and up, the methods withN
object SparkOverwriteUDAFs {
def minValueByKey(key: Column, value: Column): Column =
Column(MinValueByKey(key.expr, value.expr).toAggregateExpression(false))

def maxValueByKey(key: Column, value: Column): Column =
Column(MaxValueByKey(key.expr, value.expr).toAggregateExpression(false))

def collectLimitedList(e: Column, maxSize: Int): Column =
Column(CollectLimitedList(e.expr, howMuchToTake = maxSize).toAggregateExpression(false))

def collectNumberOrderedElements(col: Column, howManyToTake: Int, ascending: Boolean = false) =
Column(CollectNumberOrderedElements(col.expr, lit(howManyToTake).expr, lit(ascending).expr).toAggregateExpression(false))

}

case class MinValueByKey(child1: Expression, child2: Expression)
Expand Down Expand Up @@ -86,7 +93,7 @@ abstract class ExtramumValueByKey(
private lazy val data = AttributeReference("data", child2.dataType)()

override lazy val aggBufferAttributes
: Seq[AttributeReference] = minmax :: data :: Nil
: Seq[AttributeReference] = minmax :: data :: Nil

override lazy val initialValues: Seq[Expression] = Seq(
Literal.create(null, child1.dataType),
Expand Down Expand Up @@ -174,3 +181,46 @@ abstract class LimitedCollect[T <: Growable[Any] with Iterable[Any]](howMuchToTa
}
}
}

case class CollectNumberOrderedElements(child: Expression, howManyToTake: Expression, ascending: Expression) extends DeclarativeAggregate with ExpectsInputTypes {

override def children: Seq[Expression] = Seq(child)

override def nullable: Boolean = true

// Return data type.
override def dataType: DataType = ArrayType(child.dataType, containsNull = false)

// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)

override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForOrderingExpr(child.dataType, "function TakeFirstValues")

private lazy val data = AttributeReference("data", ArrayType(child.dataType, containsNull = false))()

override lazy val aggBufferAttributes: Seq[AttributeReference] = data :: Nil

override lazy val initialValues: Seq[Expression] = Seq(
Literal.create(Array(), ArrayType(child.dataType, containsNull = false))
)

// Change to array_append after Spark 3.4.0
override lazy val updateExpressions: Seq[Expression] = sortAndSliceArray(data, CreateArray(Seq(child)))

override lazy val mergeExpressions: Seq[Expression] = sortAndSliceArray(data.right, data.left)

private def sortAndSliceArray(firstArray: Expression, secondArray: Expression) = {
val unifiedArray = Concat(Seq(firstArray, secondArray))
Seq(
If(GreaterThan(Size(unifiedArray), howManyToTake),
Slice(SortArray(unifiedArray, ascending), Literal(1), howManyToTake),
unifiedArray))
}

override lazy val evaluateExpression: AttributeReference = data

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = {
copy(child = newChildren.head)
}
}

0 comments on commit 17e59dd

Please sign in to comment.