Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-33518][ML] Improve performance of ML ALS recommendForAll by GEMV #30468

Closed
wants to merge 6 commits into from

Conversation

zhengruifeng
Copy link
Contributor

@zhengruifeng zhengruifeng commented Nov 23, 2020

What changes were proposed in this pull request?

There were a lot of works on improving ALS's recommendForAll

For now, I found that it maybe futhermore optimized by

1, using GEMV and sharing a pre-allocated buffer per task;

2, using guava.ordering instead of BoundedPriorityQueue;

Why are the changes needed?

In my test, using f2jBLAS.sgemv, it is about 2.3X faster than existing impl.

Impl Master GEMM GEMV GEMV + array aggregator GEMV + guava ordering + array aggregator GEMV + guava ordering
Duration 341229 363741 191201 189790 148417 147222

Does this PR introduce any user-facing change?

No

How was this patch tested?

existing testsuites

@github-actions github-actions bot added the ML label Nov 23, 2020
@zhengruifeng
Copy link
Contributor Author

dataset: ml-latest/ratings.csv
number of users: 283228
number of items: 53889
number of ratings: 27753444

env: Ubuntu 20.04
blas: f2jBLAS
cmd: bin/spark-shell --driver-memory=64G --conf spark.driver.maxResultSize=10g

train (in 2.4.7):

import org.apache.spark.ml.recommendation._
sc.setLogLevel("OFF")

val df = spark.read.option("header", true).option("inferSchema", "true").csv("/d1/Datasets/ml-latest/ratings.csv")

df.select(countDistinct("userId"), countDistinct("movieId"), count("rating")).head
org.apache.spark.sql.Row = [283228,53889,27753444]

val als = new ALS().setMaxIter(1).setUserCol("userId").setItemCol("movieId").setRatingCol("rating")

val model = als.fit(df)

model.save("/d0/tmp/ml-latest/als-model")

test code:

import org.apache.spark.ml.recommendation._
sc.setLogLevel("OFF")

val model = ALSModel.load("/d0/tmp/ml-latest/als-model")


val start = System.currentTimeMillis;
model.recommendForAllUsers(10).count
model.recommendForAllItems(10).count
val end = System.currentTimeMillis;
end - start

@zhengruifeng
Copy link
Contributor Author

master:
als_master_jobs_2020_11_23_17_37_03

als_master_exe_2020_11_23_17_37_46

@zhengruifeng
Copy link
Contributor Author

zhengruifeng commented Nov 23, 2020

GEMM:

    val ratings = srcFactorsBlocked.crossJoin(dstFactorsBlocked)
      .as[(Array[Int], Array[Float], Array[Int], Array[Float])]
      .mapPartitions { iter =>
        var buffer: Array[Float] = null
        val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2))
        iter.flatMap { case (srcIds, srcMat, dstIds, dstMat) =>
          require(srcMat.length == srcIds.length * rank)
          require(dstMat.length == dstIds.length * rank)
          val m = srcIds.length
          val n = dstIds.length
          if (buffer == null || buffer.length < m * n) {
            buffer = Array.ofDim[Float](m * n)
          }

          BLAS.f2jBLAS.sgemm("T", "N", m, n, rank, 1.0F,
            srcMat, rank, dstMat, rank, 0.0F, buffer, m)

          Iterator.range(0, m).flatMap { i =>
            val srcId = srcIds(i)
            pq.clear()
            var j = 0
            while (j < n) { pq += dstIds(j) -> buffer(i + j * m); j += 1 }
            pq.iterator.map { case (dstId, value) => (srcId, dstId, value) }
          }
        } ++ {
          buffer = null
          pq.clear()
          Iterator.empty
        }
      }

different from previous impl (like #18624) which use flatMap, it use mapPartition to reuse buffer (of size m*n) in a task.

als_gemm_jobs_2020_11_23_17_08_49

als_gemm_exe_2020_11_23_17_09_43

However, it is still slower than existing impl.

@zhengruifeng
Copy link
Contributor Author

zhengruifeng commented Nov 23, 2020

GEMV:

val ratings = srcFactorsBlocked.crossJoin(dstFactorsBlocked)
      .as[(Array[Int], Array[Float], Array[Int], Array[Float])]
      .mapPartitions { iter =>
        var buffer: Array[Float] = null
        val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2))
        iter.flatMap { case (srcIds, srcMat, dstIds, dstMat) =>
          require(srcMat.length == srcIds.length * rank)
          require(dstMat.length == dstIds.length * rank)
          val m = srcIds.length
          val n = dstIds.length
          if (buffer == null || buffer.length < n) {
            buffer = Array.ofDim[Float](n)
          }

          Iterator.range(0, m).flatMap { i =>
            BLAS.f2jBLAS.sgemv("T", rank, n, 1.0F, dstMat, 0, rank,
              srcMat, i * rank, 1, 0.0F, buffer, 0, 1)

            pq.clear()
            var j = 0
            while (j < n) { pq += dstIds(j) -> buffer(j); j += 1 }
            val srcId = srcIds(i)
            pq.iterator.map { case (dstId, value) => (srcId, dstId, value) }
          }
        } ++ {
          buffer = null
          pq.clear()
          Iterator.empty
        }
      }

Then I switch to GEMV, which brings siginificent speedup. The size of buffer is reduce to n.

als_gemv_jobs_2020_11_23_16_55_10

als_gemv_exe_2020_11_23_16_57_13

@zhengruifeng
Copy link
Contributor Author

zhengruifeng commented Nov 23, 2020

Then I try to directly aggregate on arrays (containing topK elements) obtained in each gemv. But it does not bring visible improvement.

als_gemv_array_pq_jobs_2020_11_23_18_46_30

als_gemv_array_pq_exe_2020_11_23_18_48_18

@zhengruifeng
Copy link
Contributor Author

zhengruifeng commented Nov 23, 2020

Finally, I found that Guava.Ordering seems much more efficient than BoundedPriorityQueue. (see Selecting top k items from a list efficiently in Java / Groovy).

  /** select top indices based on values. */
  private[recommendation] class TopSelector(val values: Array[Float]) {
    import scala.collection.JavaConverters._

    private val indexOrdering = new GuavaOrdering[Int] {
      override def compare(left: Int, right: Int): Int = {
        Ordering[Float].compare(values(left), values(right))
      }
    }

    def selectTopKIndices(iterator: Iterator[Int], k: Int): Array[Int] = {
      indexOrdering.greatestOf(iterator.asJava, k).asScala.toArray
    }
  }

Compared to BoundedPriorityQueue, we do not need to create many object references like Tuple2 here.

als_gemv_array_jobs_2020_11_23_16_44_39

als_gemv_array_exe_2020_11_23_16_44_15

@zhengruifeng
Copy link
Contributor Author

zhengruifeng commented Nov 23, 2020

It looks like that:
1, GEMM is only about 7% slower than master, I guess it can be furthermore accelerated via native blas impl. But it need a big buffer (m*n), I think it somewhat dangerous; maybe we can split a block (whose size is optimized for crossJoin) into sub-blocks (whose size is optimized for gemm) to reduce this buffer, but I think it will be too convoluted;
2, Compared with DOT based impls, GEMV should be a nice choice. It is much more faster (even with f2jBLAS), and the buffer size is relative small (n);
3, Guava.Ordering is much faster than BoundedPriorityQueue. With Guava.Ordering, we do not need to create Tuple2 objects.

Above tests are done locally, since I do not have a clean cluster for now.
And only f2jBLAS is used, since after upgrading to Ubuntu 20.04, I fail to link netlib-java to native impls for now.

friendly ping @srowen @MLnick @mpjlu @jkbradley @mengxr @WeichenXu123, because of your comments in previous prs (#17742, #17845, #18624)

@SparkQA
Copy link

SparkQA commented Nov 23, 2020

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/36149/

@SparkQA
Copy link

SparkQA commented Nov 23, 2020

Test build #131546 has finished for PR 30468 at commit 7dd2b91.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Nov 23, 2020

Kubernetes integration test status success
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/36149/

b1: (Array[Int], Array[Float]),
b2: (Array[Int], Array[Float])): (Array[Int], Array[Float]) = {
val (ids1, scores1) = b1
val (ids2, socres2) = b2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: typo in socres2

}
pq.clear()
} ++ {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused by this part - why null these out? the flatMap in which they are declared is done here. Maybe I misread

Copy link
Contributor Author

@zhengruifeng zhengruifeng Nov 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is to make sure buffer is marked ready for GC, but it doesn't matter, I will remove it.

@@ -551,6 +563,21 @@ object ALSModel extends MLReadable[ALSModel] {
model
}
}

/** select top indices based on values. */
private[recommendation] class TopSelector(val values: Array[Float]) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be a class? looks like this code is called once. May be less code/indirection

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will consider whether a class is needed.

* Works on rows of the form (ScrId, DstIds, Scores).
* Finds the top `num` DstIds and Scores.
*/
private[recommendation] class TopKArrayAggregator(num: Int)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love the copy-paste - are the other uses of the class above able to use your new idea?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since I just find that aggregating on array bring little improvement, so I will remove this class

@SparkQA
Copy link

SparkQA commented Nov 24, 2020

Test build #131585 has finished for PR 30468 at commit 7dd2b91.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

back to TopByKeyAggregator
@zhengruifeng
Copy link
Contributor Author

commit 543a41f:

scala> val start = System.currentTimeMillis;
start: Long = 1606185191500

scala> model.recommendForAllUsers(10).count
res1: Long = 283228

scala> model.recommendForAllItems(10).count
res2: Long = 53889

scala> val end = System.currentTimeMillis;
end: Long = 1606185338722

scala> end - start
res3: Long = 147222

I also try using BoundedPriorityQueue[Int] instead of new BoundedPriorityQueue[(Int, Float)], it is faster than commit b645968, but still slower than using GuavaOrdering.

val ratings = srcFactorsBlocked.crossJoin(dstFactorsBlocked)
      .as[(Array[Int], Array[Float], Array[Int], Array[Float])]
      .mapPartitions { iter =>
        var buffer: Array[Float] = null
        var pq: BoundedPriorityQueue[Int] = null
        iter.flatMap { case (srcIds, srcMat, dstIds, dstMat) =>
          require(srcMat.length == srcIds.length * rank)
          require(dstMat.length == dstIds.length * rank)
          val m = srcIds.length
          val n = dstIds.length
          if (buffer == null || buffer.length < n) {
            buffer = Array.ofDim[Float](n)
            pq = new BoundedPriorityQueue[Int](num)(Ordering.by(buffer.apply))
          }

          Iterator.range(0, m).flatMap { i =>
            // buffer = i-th vec in srcMat * dstMat
            BLAS.f2jBLAS.sgemv("T", rank, n, 1.0F, dstMat, 0, rank,
              srcMat, i * rank, 1, 0.0F, buffer, 0, 1)

            pq.clear()
            pq ++= Iterator.range(0, n)
            val srcId = srcIds(i)
            pq.iterator.map { j => (srcId, dstIds(j), buffer(j)) }
          }
        }
      }

scala> val start = System.currentTimeMillis;
start: Long = 1606187052784

scala> model.recommendForAllUsers(10).count
res1: Long = 283228

scala> model.recommendForAllItems(10).count
res2: Long = 53889

scala> val end = System.currentTimeMillis;
end: Long = 1606187220213

scala> end - start
res3: Long = 167429

@SparkQA
Copy link

SparkQA commented Nov 24, 2020

Test build #131623 has finished for PR 30468 at commit 543a41f.

  • This patch fails SparkR unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@zhengruifeng
Copy link
Contributor Author

retest this please

@SparkQA
Copy link

SparkQA commented Nov 24, 2020

Test build #131626 has finished for PR 30468 at commit 543a41f.

  • This patch fails due to an unknown error code, -9.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Nov 24, 2020

Test build #131629 has finished for PR 30468 at commit 8ca7d56.

  • This patch fails SparkR unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Member

@srowen srowen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks OK pending tests

@srowen
Copy link
Member

srowen commented Nov 24, 2020

Jenkins retest this please

@SparkQA
Copy link

SparkQA commented Nov 24, 2020

Test build #131684 has finished for PR 30468 at commit 8ca7d56.

  • This patch fails SparkR unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@zhengruifeng
Copy link
Contributor Author

retest this please

@SparkQA
Copy link

SparkQA commented Nov 25, 2020

Test build #131711 has finished for PR 30468 at commit 8ca7d56.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@srowen
Copy link
Member

srowen commented Nov 29, 2020

I think it's OK to merge when you're ready

@zhengruifeng zhengruifeng changed the title [SPARK-33518][ML][WIP] Improve performance of ML ALS recommendForAll by GEMV [SPARK-33518][ML] Improve performance of ML ALS recommendForAll by GEMV Nov 30, 2020
@zhengruifeng
Copy link
Contributor Author

performance test between GuavaOrdering and BoundedPriorityQueue

test("BoundedPriorityQueue vs GuavaOrdering") {
    import com.google.common.collect.{Ordering => GuavaOrdering}
    import org.apache.spark.util.BoundedPriorityQueue

    for (n <- Seq(512, 1024, 2048, 4096, 8192); k <- Seq(5, 10, 20, 40, 80)) {
      val rng = new Random(123)
      val indices = Array.range(0, n)
      val values = Array.fill(n)(rng.nextFloat)
      val zipped = indices.zip(values)

      val pq0 = new BoundedPriorityQueue[(Int, Float)](k)(Ordering.by(_._2))
      val pq1 = new BoundedPriorityQueue[Int](k)(Ordering.by(values.apply))
      val ord0 = GuavaOrdering.from(Ordering[(Int, Float)])
      val ord1 = new GuavaOrdering[Int] {
        override def compare(left: Int, right: Int): Int = {
          Ordering[Float].compare(values(left), values(right))
        }
      }

      val tic0 = System.currentTimeMillis()
      (0 until 100000).foreach { i =>
        pq0.clear()
        var j = 0
        while (j < n) { pq0 += indices(j) -> values(j); j += 1 }
        val res0 = pq0.iterator.size
      }
      val toc0 = System.currentTimeMillis()

      val tic1 = System.currentTimeMillis()
      (0 until 100000).foreach { i =>
        pq1.clear()
        pq1 ++= Iterator.range(0, n)
        val res1 = pq1.iterator.map(j => (indices(j), values(j))).size
      }
      val toc1 = System.currentTimeMillis()

      val tic2 = System.currentTimeMillis()
      (0 until 100000).foreach { i =>
        val res2 = ord0.greatestOf(zipped.iterator.asJava, k).asScala.iterator.size
      }
      val toc2 = System.currentTimeMillis()

      val tic3 = System.currentTimeMillis()
      (0 until 100000).foreach { i =>
        val res3 = ord1.greatestOf(Iterator.range(0, n).asJava, k).asScala
          .iterator.map(j => (indices(j), values(j))).size
      }
      val toc3 = System.currentTimeMillis()

      println(s"n=$n, k=$k:" +
        s" pq0=${toc0 - tic0}," +
        s" pq1=${toc1 - tic1}," +
        s" ord0=${toc2 - tic2}," +
        s" ord1=${toc3 - tic3}")
    }
  }

results:

n=512, k=5: pq0=1040, pq1=823, ord0=1236, ord1=726
n=512, k=10: pq0=1929, pq1=1076, ord0=2548, ord1=1066
n=512, k=20: pq0=1506, pq1=1366, ord0=1763, ord1=968
n=512, k=40: pq0=2257, pq1=1940, ord0=1658, ord1=1261
n=512, k=80: pq0=3583, pq1=3161, ord0=1673, ord1=2087
n=1024, k=5: pq0=1573, pq1=1545, ord0=3019, ord1=1110
n=1024, k=10: pq0=1787, pq1=1714, ord0=3248, ord1=1298
n=1024, k=20: pq0=2369, pq1=2282, ord0=3547, ord1=1580
n=1024, k=40: pq0=3480, pq1=2978, ord0=3230, ord1=1961
n=1024, k=80: pq0=5526, pq1=4695, ord0=3292, ord1=3305
n=2048, k=5: pq0=2962, pq1=2983, ord0=5968, ord1=2098
n=2048, k=10: pq0=3195, pq1=3129, ord0=6568, ord1=2317
n=2048, k=20: pq0=4016, pq1=3879, ord0=6590, ord1=2743
n=2048, k=40: pq0=5231, pq1=4729, ord0=6524, ord1=3148
n=2048, k=80: pq0=7875, pq1=6909, ord0=6375, ord1=4417
n=4096, k=5: pq0=5819, pq1=5868, ord0=11899, ord1=4077
n=4096, k=10: pq0=6066, pq1=6051, ord0=12956, ord1=4336
n=4096, k=20: pq0=6958, pq1=6777, ord0=13209, ord1=4755
n=4096, k=40: pq0=8632, pq1=8163, ord0=13133, ord1=5439
n=4096, k=80: pq0=12020, pq1=10876, ord0=12785, ord1=7215
n=8192, k=5: pq0=11553, pq1=11491, ord0=23747, ord1=8105
n=8192, k=10: pq0=11761, pq1=11926, ord0=25972, ord1=8291
n=8192, k=20: pq0=12752, pq1=12602, ord0=26304, ord1=8766
n=8192, k=40: pq0=14854, pq1=14401, ord0=25708, ord1=9817
n=8192, k=80: pq0=18953, pq1=17644, ord0=25601, ord1=11764

It seems that ord1(using guava.ordering to select topK by indices, based on quickselect algorithm) is stably faster than existing impl pq0 by a factor of about 40%~100%.

@srowen
Copy link
Member

srowen commented Nov 30, 2020

Seems fine

@zhengruifeng
Copy link
Contributor Author

I will leave this PR open now, and will merge it to 3.2.0 if no more comments.

@srowen
Copy link
Member

srowen commented Dec 17, 2020

I think it's fine to merge - should I?

@zhengruifeng
Copy link
Contributor Author

@srowen I think it is ready to merge.

@srowen
Copy link
Member

srowen commented Dec 18, 2020

Jenkins retest this please

@SparkQA
Copy link

SparkQA commented Dec 18, 2020

Kubernetes integration test starting
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/37580/

@SparkQA
Copy link

SparkQA commented Dec 18, 2020

Test build #132979 has finished for PR 30468 at commit 8ca7d56.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Dec 18, 2020

Kubernetes integration test status failure
URL: https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder-K8s/37580/

@srowen
Copy link
Member

srowen commented Dec 19, 2020

Merged to master (3.2.0)

@zhengruifeng zhengruifeng deleted the als_rec_opt branch December 21, 2020 01:30
zhengruifeng added a commit that referenced this pull request Jan 27, 2021
### What changes were proposed in this pull request?
1, update related doc;
2, MatrixFactorizationModel use GEMV;

### Why are the changes needed?
see performance gain in #30468

### Does this PR introduce _any_ user-facing change?
NO

### How was this patch tested?
existing testsuites

Closes #31279 from zhengruifeng/als_follow_up.

Authored-by: Ruifeng Zheng <ruifengz@foxmail.com>
Signed-off-by: Ruifeng Zheng <ruifengz@foxmail.com>
skestle pushed a commit to skestle/spark that referenced this pull request Feb 3, 2021
### What changes were proposed in this pull request?
1, update related doc;
2, MatrixFactorizationModel use GEMV;

### Why are the changes needed?
see performance gain in apache#30468

### Does this PR introduce _any_ user-facing change?
NO

### How was this patch tested?
existing testsuites

Closes apache#31279 from zhengruifeng/als_follow_up.

Authored-by: Ruifeng Zheng <ruifengz@foxmail.com>
Signed-off-by: Ruifeng Zheng <ruifengz@foxmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants