Skip to content

Commit

Permalink
use guava ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengruifeng committed Nov 23, 2020
1 parent 7861b7b commit 7dd2b91
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.util.{Sorting, Try}
import scala.util.hashing.byteswap64

import com.github.fommil.netlib.BLAS.{getInstance => blas}
import com.google.common.collect.{Ordering => GuavaOrdering}
import org.apache.hadoop.fs.Path
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
Expand All @@ -47,7 +48,7 @@ import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
import org.apache.spark.util.random.XORShiftRandom

Expand Down Expand Up @@ -456,37 +457,35 @@ class ALSModel private[ml] (
num: Int,
blockSize: Int): DataFrame = {
import srcFactors.sparkSession.implicits._
import ALSModel.TopSelector

val srcFactorsBlocked = blockify(srcFactors.as[(Int, Array[Float])], blockSize)
val dstFactorsBlocked = blockify(dstFactors.as[(Int, Array[Float])], blockSize)
val partialRecs = srcFactorsBlocked.crossJoin(dstFactorsBlocked)
.as[(Array[Int], Array[Float], Array[Int], Array[Float])]
.mapPartitions { iter =>
var buffer: Array[Float] = null
val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2))
var selector: TopSelector = null
iter.flatMap { case (srcIds, srcMat, dstIds, dstMat) =>
require(srcMat.length == srcIds.length * rank)
require(dstMat.length == dstIds.length * rank)
val m = srcIds.length
val n = dstIds.length
if (buffer == null || buffer.length < n) {
buffer = Array.ofDim[Float](n)
selector = new TopSelector(buffer)
}

Iterator.tabulate(m) { i =>
// buffer = i-th vec in srcMat * dstMat
BLAS.f2jBLAS.sgemv("T", rank, n, 1.0F, dstMat, 0, rank,
srcMat, i * rank, 1, 0.0F, buffer, 0, 1)

pq.clear()
var j = 0
while (j < n) { pq += dstIds(j) -> buffer(j); j += 1 }
val (kDstIds, kScores) = pq.toArray.sortBy(-_._2).unzip
(srcIds(i), kDstIds, kScores)
val indices = selector.selectTopKIndices(Iterator.range(0, n), num)
(srcIds(i), indices.map(dstIds), indices.map(buffer))
}
} ++ {
buffer = null
pq.clear()
selector = null
Iterator.empty
}
}
Expand Down Expand Up @@ -564,6 +563,21 @@ object ALSModel extends MLReadable[ALSModel] {
model
}
}

/** select top indices based on values. */
private[recommendation] class TopSelector(val values: Array[Float]) {
import scala.collection.JavaConverters._

private val indexOrdering = new GuavaOrdering[Int] {
override def compare(left: Int, right: Int): Int = {
Ordering[Float].compare(values(left), values(right))
}
}

def selectTopKIndices(iterator: Iterator[Int], k: Int): Array[Int] = {
indexOrdering.greatestOf(iterator.asJava, k).asScala.toArray
}
}
}

/**
Expand Down

0 comments on commit 7dd2b91

Please sign in to comment.