Skip to content

Commit

Permalink
back to TopByKeyAggregator
Browse files Browse the repository at this point in the history
back to TopByKeyAggregator
  • Loading branch information
zhengruifeng committed Nov 24, 2020
1 parent 7dd2b91 commit 543a41f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 66 deletions.
21 changes: 9 additions & 12 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ class ALSModel private[ml] (

val srcFactorsBlocked = blockify(srcFactors.as[(Int, Array[Float])], blockSize)
val dstFactorsBlocked = blockify(dstFactors.as[(Int, Array[Float])], blockSize)
val partialRecs = srcFactorsBlocked.crossJoin(dstFactorsBlocked)
val ratings = srcFactorsBlocked.crossJoin(dstFactorsBlocked)
.as[(Array[Int], Array[Float], Array[Int], Array[Float])]
.mapPartitions { iter =>
var buffer: Array[Float] = null
Expand All @@ -476,23 +476,20 @@ class ALSModel private[ml] (
selector = new TopSelector(buffer)
}

Iterator.tabulate(m) { i =>
Iterator.range(0, m).flatMap { i =>
// buffer = i-th vec in srcMat * dstMat
BLAS.f2jBLAS.sgemv("T", rank, n, 1.0F, dstMat, 0, rank,
srcMat, i * rank, 1, 0.0F, buffer, 0, 1)
val indices = selector.selectTopKIndices(Iterator.range(0, n), num)
(srcIds(i), indices.map(dstIds), indices.map(buffer))

val srcId = srcIds(i)
selector.selectTopKIndices(Iterator.range(0, n), num)
.iterator.map { j => (srcId, dstIds(j), buffer(j)) }
}
} ++ {
buffer = null
selector = null
Iterator.empty
}
}

val aggregator = new TopKArrayAggregator(num)
val recs = partialRecs.as[(Int, Array[Int], Array[Float])]
.groupByKey(_._1).agg(aggregator.toColumn)
// We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output.
val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2))
val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn)
.toDF("id", "recommendations")

val arrayType = ArrayType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,57 +57,3 @@ private[recommendation] class TopByKeyAggregator[K1: TypeTag, K2: TypeTag, V: Ty

override def outputEncoder: Encoder[Array[(K2, V)]] = ExpressionEncoder[Array[(K2, V)]]()
}


/**
* Works on rows of the form (ScrId, DstIds, Scores).
* Finds the top `num` DstIds and Scores.
*/
private[recommendation] class TopKArrayAggregator(num: Int)
extends Aggregator[
(Int, Array[Int], Array[Float]),
(Array[Int], Array[Float]),
Array[(Int, Float)]] {

override def zero: (Array[Int], Array[Float]) = {
(Array.emptyIntArray, Array.emptyFloatArray)
}

override def reduce(
b: (Array[Int], Array[Float]),
a: (Int, Array[Int], Array[Float])): (Array[Int], Array[Float]) = {
merge(b, (a._2, a._3))
}

def merge(
b1: (Array[Int], Array[Float]),
b2: (Array[Int], Array[Float])): (Array[Int], Array[Float]) = {
val (ids1, scores1) = b1
val (ids2, socres2) = b2
if (ids1.isEmpty) {
b2
} else if (ids2.isEmpty) {
b1
} else {
val len1 = ids1.length
val len2 = ids2.length
val indices = Array.range(0, len1 + len2)
.sortBy(i => if (i < len1) -scores1(i) else -socres2(i - len1))
.take(num)
(indices.map(i => if (i < len1) ids1(i) else ids2(i - len1)),
indices.map(i => if (i < len1) scores1(i) else socres2(i - len1)))
}
}

override def finish(reduction: (Array[Int], Array[Float])): Array[(Int, Float)] = {
reduction._1.zip(reduction._2)
}

override def bufferEncoder: Encoder[(Array[Int], Array[Float])] = {
Encoders.kryo[(Array[Int], Array[Float])]
}

override def outputEncoder: Encoder[Array[(Int, Float)]] = {
ExpressionEncoder[Array[(Int, Float)]]()
}
}

0 comments on commit 543a41f

Please sign in to comment.