-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-33518][ML] Improve performance of ML ALS recommendForAll by GEMV #30468
Changes from 4 commits
b923b56
b645968
7861b7b
7dd2b91
543a41f
8ca7d56
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ import scala.util.{Sorting, Try} | |
import scala.util.hashing.byteswap64 | ||
|
||
import com.github.fommil.netlib.BLAS.{getInstance => blas} | ||
import com.google.common.collect.{Ordering => GuavaOrdering} | ||
import org.apache.hadoop.fs.Path | ||
import org.json4s.DefaultFormats | ||
import org.json4s.JsonDSL._ | ||
|
@@ -47,7 +48,7 @@ import org.apache.spark.sql.{DataFrame, Dataset} | |
import org.apache.spark.sql.functions._ | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.storage.StorageLevel | ||
import org.apache.spark.util.{BoundedPriorityQueue, Utils} | ||
import org.apache.spark.util.Utils | ||
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} | ||
import org.apache.spark.util.random.XORShiftRandom | ||
|
||
|
@@ -456,34 +457,42 @@ class ALSModel private[ml] ( | |
num: Int, | ||
blockSize: Int): DataFrame = { | ||
import srcFactors.sparkSession.implicits._ | ||
import ALSModel.TopSelector | ||
|
||
val srcFactorsBlocked = blockify(srcFactors.as[(Int, Array[Float])], blockSize) | ||
val dstFactorsBlocked = blockify(dstFactors.as[(Int, Array[Float])], blockSize) | ||
val ratings = srcFactorsBlocked.crossJoin(dstFactorsBlocked) | ||
.as[(Seq[(Int, Array[Float])], Seq[(Int, Array[Float])])] | ||
.flatMap { case (srcIter, dstIter) => | ||
val m = srcIter.size | ||
val n = math.min(dstIter.size, num) | ||
val output = new Array[(Int, Int, Float)](m * n) | ||
var i = 0 | ||
val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2)) | ||
srcIter.foreach { case (srcId, srcFactor) => | ||
dstIter.foreach { case (dstId, dstFactor) => | ||
// We use F2jBLAS which is faster than a call to native BLAS for vector dot product | ||
val score = BLAS.f2jBLAS.sdot(rank, srcFactor, 1, dstFactor, 1) | ||
pq += dstId -> score | ||
val partialRecs = srcFactorsBlocked.crossJoin(dstFactorsBlocked) | ||
.as[(Array[Int], Array[Float], Array[Int], Array[Float])] | ||
.mapPartitions { iter => | ||
var buffer: Array[Float] = null | ||
var selector: TopSelector = null | ||
iter.flatMap { case (srcIds, srcMat, dstIds, dstMat) => | ||
require(srcMat.length == srcIds.length * rank) | ||
require(dstMat.length == dstIds.length * rank) | ||
val m = srcIds.length | ||
val n = dstIds.length | ||
if (buffer == null || buffer.length < n) { | ||
buffer = Array.ofDim[Float](n) | ||
selector = new TopSelector(buffer) | ||
} | ||
pq.foreach { case (dstId, score) => | ||
output(i) = (srcId, dstId, score) | ||
i += 1 | ||
|
||
Iterator.tabulate(m) { i => | ||
// buffer = i-th vec in srcMat * dstMat | ||
BLAS.f2jBLAS.sgemv("T", rank, n, 1.0F, dstMat, 0, rank, | ||
srcMat, i * rank, 1, 0.0F, buffer, 0, 1) | ||
val indices = selector.selectTopKIndices(Iterator.range(0, n), num) | ||
(srcIds(i), indices.map(dstIds), indices.map(buffer)) | ||
} | ||
pq.clear() | ||
} ++ { | ||
buffer = null | ||
selector = null | ||
Iterator.empty | ||
} | ||
output.toSeq | ||
} | ||
// We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output. | ||
val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2)) | ||
val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn) | ||
|
||
val aggregator = new TopKArrayAggregator(num) | ||
val recs = partialRecs.as[(Int, Array[Int], Array[Float])] | ||
.groupByKey(_._1).agg(aggregator.toColumn) | ||
.toDF("id", "recommendations") | ||
|
||
val arrayType = ArrayType( | ||
|
@@ -499,9 +508,12 @@ class ALSModel private[ml] ( | |
*/ | ||
private def blockify( | ||
factors: Dataset[(Int, Array[Float])], | ||
blockSize: Int): Dataset[Seq[(Int, Array[Float])]] = { | ||
blockSize: Int): Dataset[(Array[Int], Array[Float])] = { | ||
import factors.sparkSession.implicits._ | ||
factors.mapPartitions(_.grouped(blockSize)) | ||
factors.mapPartitions { iter => | ||
iter.grouped(blockSize) | ||
.map(block => (block.map(_._1).toArray, block.flatMap(_._2).toArray)) | ||
} | ||
} | ||
|
||
} | ||
|
@@ -551,6 +563,21 @@ object ALSModel extends MLReadable[ALSModel] { | |
model | ||
} | ||
} | ||
|
||
/** select top indices based on values. */ | ||
private[recommendation] class TopSelector(val values: Array[Float]) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this need to be a class? looks like this code is called once. May be less code/indirection There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will consider whether a class is needed. |
||
import scala.collection.JavaConverters._ | ||
|
||
private val indexOrdering = new GuavaOrdering[Int] { | ||
override def compare(left: Int, right: Int): Int = { | ||
Ordering[Float].compare(values(left), values(right)) | ||
} | ||
} | ||
|
||
def selectTopKIndices(iterator: Iterator[Int], k: Int): Array[Int] = { | ||
indexOrdering.greatestOf(iterator.asJava, k).asScala.toArray | ||
} | ||
} | ||
} | ||
|
||
/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,3 +57,57 @@ private[recommendation] class TopByKeyAggregator[K1: TypeTag, K2: TypeTag, V: Ty | |
|
||
override def outputEncoder: Encoder[Array[(K2, V)]] = ExpressionEncoder[Array[(K2, V)]]() | ||
} | ||
|
||
|
||
/** | ||
* Works on rows of the form (ScrId, DstIds, Scores). | ||
* Finds the top `num` DstIds and Scores. | ||
*/ | ||
private[recommendation] class TopKArrayAggregator(num: Int) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't love the copy-paste - are the other uses of the class above able to use your new idea? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since I just find that aggregating on array bring little improvement, so I will remove this class |
||
extends Aggregator[ | ||
(Int, Array[Int], Array[Float]), | ||
(Array[Int], Array[Float]), | ||
Array[(Int, Float)]] { | ||
|
||
override def zero: (Array[Int], Array[Float]) = { | ||
(Array.emptyIntArray, Array.emptyFloatArray) | ||
} | ||
|
||
override def reduce( | ||
b: (Array[Int], Array[Float]), | ||
a: (Int, Array[Int], Array[Float])): (Array[Int], Array[Float]) = { | ||
merge(b, (a._2, a._3)) | ||
} | ||
|
||
def merge( | ||
b1: (Array[Int], Array[Float]), | ||
b2: (Array[Int], Array[Float])): (Array[Int], Array[Float]) = { | ||
val (ids1, scores1) = b1 | ||
val (ids2, socres2) = b2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: typo in socres2 |
||
if (ids1.isEmpty) { | ||
b2 | ||
} else if (ids2.isEmpty) { | ||
b1 | ||
} else { | ||
val len1 = ids1.length | ||
val len2 = ids2.length | ||
val indices = Array.range(0, len1 + len2) | ||
.sortBy(i => if (i < len1) -scores1(i) else -socres2(i - len1)) | ||
.take(num) | ||
(indices.map(i => if (i < len1) ids1(i) else ids2(i - len1)), | ||
indices.map(i => if (i < len1) scores1(i) else socres2(i - len1))) | ||
} | ||
} | ||
|
||
override def finish(reduction: (Array[Int], Array[Float])): Array[(Int, Float)] = { | ||
reduction._1.zip(reduction._2) | ||
} | ||
|
||
override def bufferEncoder: Encoder[(Array[Int], Array[Float])] = { | ||
Encoders.kryo[(Array[Int], Array[Float])] | ||
} | ||
|
||
override def outputEncoder: Encoder[Array[(Int, Float)]] = { | ||
ExpressionEncoder[Array[(Int, Float)]]() | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a little confused by this part - why null these out? the flatMap in which they are declared is done here. Maybe I misread
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is to make sure
buffer
is marked ready for GC, but it doesn't matter, I will remove it.