Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-33518][ML] Improve performance of ML ALS recommendForAll by GEMV #30468

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 51 additions & 24 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.util.{Sorting, Try}
import scala.util.hashing.byteswap64

import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.google.common.collect.{Ordering => GuavaOrdering}
import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
Expand All @@ -47,7 +48,7 @@ import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
import org.apache.spark.util.random.XORShiftRandom

Expand Down Expand Up @@ -456,34 +457,42 @@ class ALSModel private[ml] (
num: Int,
blockSize: Int): DataFrame = {
import srcFactors.sparkSession.implicits._
import ALSModel.TopSelector

val srcFactorsBlocked = blockify(srcFactors.as[(Int, Array[Float])], blockSize)
val dstFactorsBlocked = blockify(dstFactors.as[(Int, Array[Float])], blockSize)
val ratings = srcFactorsBlocked.crossJoin(dstFactorsBlocked)
.as[(Seq[(Int, Array[Float])], Seq[(Int, Array[Float])])]
.flatMap { case (srcIter, dstIter) =>
val m = srcIter.size
val n = math.min(dstIter.size, num)
val output = new Array[(Int, Int, Float)](m * n)
var i = 0
val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2))
srcIter.foreach { case (srcId, srcFactor) =>
dstIter.foreach { case (dstId, dstFactor) =>
// We use F2jBLAS which is faster than a call to native BLAS for vector dot product
val score = BLAS.f2jBLAS.sdot(rank, srcFactor, 1, dstFactor, 1)
pq += dstId -> score
val partialRecs = srcFactorsBlocked.crossJoin(dstFactorsBlocked)
.as[(Array[Int], Array[Float], Array[Int], Array[Float])]
.mapPartitions { iter =>
var buffer: Array[Float] = null
var selector: TopSelector = null
iter.flatMap { case (srcIds, srcMat, dstIds, dstMat) =>
require(srcMat.length == srcIds.length * rank)
require(dstMat.length == dstIds.length * rank)
val m = srcIds.length
val n = dstIds.length
if (buffer == null || buffer.length < n) {
buffer = Array.ofDim[Float](n)
selector = new TopSelector(buffer)
}
pq.foreach { case (dstId, score) =>
output(i) = (srcId, dstId, score)
i += 1

Iterator.tabulate(m) { i =>
// buffer = i-th vec in srcMat * dstMat
BLAS.f2jBLAS.sgemv("T", rank, n, 1.0F, dstMat, 0, rank,
srcMat, i * rank, 1, 0.0F, buffer, 0, 1)
val indices = selector.selectTopKIndices(Iterator.range(0, n), num)
(srcIds(i), indices.map(dstIds), indices.map(buffer))
}
pq.clear()
} ++ {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused by this part - why null these out? the flatMap in which they are declared is done here. Maybe I misread

Copy link
Contributor Author

@zhengruifeng zhengruifeng Nov 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is to make sure buffer is marked ready for GC, but it doesn't matter, I will remove it.

buffer = null
selector = null
Iterator.empty
}
output.toSeq
}
// We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output.
val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2))
val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn)

val aggregator = new TopKArrayAggregator(num)
val recs = partialRecs.as[(Int, Array[Int], Array[Float])]
.groupByKey(_._1).agg(aggregator.toColumn)
.toDF("id", "recommendations")

val arrayType = ArrayType(
Expand All @@ -499,9 +508,12 @@ class ALSModel private[ml] (
*/
private def blockify(
factors: Dataset[(Int, Array[Float])],
blockSize: Int): Dataset[Seq[(Int, Array[Float])]] = {
blockSize: Int): Dataset[(Array[Int], Array[Float])] = {
import factors.sparkSession.implicits._
factors.mapPartitions(_.grouped(blockSize))
factors.mapPartitions { iter =>
iter.grouped(blockSize)
.map(block => (block.map(_._1).toArray, block.flatMap(_._2).toArray))
}
}

}
Expand Down Expand Up @@ -551,6 +563,21 @@ object ALSModel extends MLReadable[ALSModel] {
model
}
}

/** select top indices based on values. */
private[recommendation] class TopSelector(val values: Array[Float]) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be a class? looks like this code is called once. May be less code/indirection

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will consider whether a class is needed.

import scala.collection.JavaConverters._

private val indexOrdering = new GuavaOrdering[Int] {
override def compare(left: Int, right: Int): Int = {
Ordering[Float].compare(values(left), values(right))
}
}

def selectTopKIndices(iterator: Iterator[Int], k: Int): Array[Int] = {
indexOrdering.greatestOf(iterator.asJava, k).asScala.toArray
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,57 @@ private[recommendation] class TopByKeyAggregator[K1: TypeTag, K2: TypeTag, V: Ty

override def outputEncoder: Encoder[Array[(K2, V)]] = ExpressionEncoder[Array[(K2, V)]]()
}


/**
* Works on rows of the form (ScrId, DstIds, Scores).
* Finds the top `num` DstIds and Scores.
*/
private[recommendation] class TopKArrayAggregator(num: Int)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love the copy-paste - are the other uses of the class above able to use your new idea?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since I just find that aggregating on array bring little improvement, so I will remove this class

extends Aggregator[
(Int, Array[Int], Array[Float]),
(Array[Int], Array[Float]),
Array[(Int, Float)]] {

override def zero: (Array[Int], Array[Float]) = {
(Array.emptyIntArray, Array.emptyFloatArray)
}

override def reduce(
b: (Array[Int], Array[Float]),
a: (Int, Array[Int], Array[Float])): (Array[Int], Array[Float]) = {
merge(b, (a._2, a._3))
}

def merge(
b1: (Array[Int], Array[Float]),
b2: (Array[Int], Array[Float])): (Array[Int], Array[Float]) = {
val (ids1, scores1) = b1
val (ids2, socres2) = b2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: typo in socres2

if (ids1.isEmpty) {
b2
} else if (ids2.isEmpty) {
b1
} else {
val len1 = ids1.length
val len2 = ids2.length
val indices = Array.range(0, len1 + len2)
.sortBy(i => if (i < len1) -scores1(i) else -socres2(i - len1))
.take(num)
(indices.map(i => if (i < len1) ids1(i) else ids2(i - len1)),
indices.map(i => if (i < len1) scores1(i) else socres2(i - len1)))
}
}

override def finish(reduction: (Array[Int], Array[Float])): Array[(Int, Float)] = {
reduction._1.zip(reduction._2)
}

override def bufferEncoder: Encoder[(Array[Int], Array[Float])] = {
Encoders.kryo[(Array[Int], Array[Float])]
}

override def outputEncoder: Encoder[Array[(Int, Float)]] = {
ExpressionEncoder[Array[(Int, Float)]]()
}
}