Skip to content

Commit

Permalink
add threshold param to ALS
Browse files Browse the repository at this point in the history
  • Loading branch information
hqzizania committed Oct 24, 2016
1 parent dc4f4ba commit d29fd67
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 10 deletions.
39 changes: 30 additions & 9 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,24 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
/** @group expertGetParam */
def getFinalStorageLevel: String = $(finalStorageLevel)

/**
* Param for threshold in computation of dst factors to decide
* if stacking factors to speed up the computation.(>= 1).
* Default: 1024
* @group expertParam
*/
val threshold = new IntParam(this, "threshold", "threshold in computation of dst factors " +
"to decide if stacking factors to speed up the computation.",
ParamValidators.gtEq(1))

/** @group expertGetParam */
def getThreshold: Int = $(threshold)

setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10,
intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK")
intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK",
threshold -> 1024)

/**
* Validates and transforms the input schema.
Expand Down Expand Up @@ -436,6 +450,10 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
@Since("2.0.0")
def setFinalStorageLevel(value: String): this.type = set(finalStorageLevel, value)

/** @group expertSetParam */
@Since("2.1.0")
def setThreshold(value: Int): this.type = set(threshold, value)

/**
* Sets both numUserBlocks and numItemBlocks to the specific value.
*
Expand Down Expand Up @@ -464,14 +482,15 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
val instrLog = Instrumentation.create(this, ratings)
instrLog.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha,
userCol, itemCol, ratingCol, predictionCol, maxIter,
regParam, nonnegative, checkpointInterval, seed)
regParam, nonnegative, threshold, checkpointInterval, seed)
val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank),
numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks),
maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
alpha = $(alpha), nonnegative = $(nonnegative),
intermediateRDDStorageLevel = StorageLevel.fromString($(intermediateStorageLevel)),
finalRDDStorageLevel = StorageLevel.fromString($(finalStorageLevel)),
checkpointInterval = $(checkpointInterval), seed = $(seed))
threshold = $(threshold), checkpointInterval = $(checkpointInterval),
seed = $(seed))
val userDF = userFactors.toDF("id", "features")
val itemDF = itemFactors.toDF("id", "features")
val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this)
Expand Down Expand Up @@ -706,6 +725,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
nonnegative: Boolean = false,
intermediateRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
finalRDDStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
threshold: Int = 1024,
checkpointInterval: Int = 10,
seed: Long = 0L)(
implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = {
Expand Down Expand Up @@ -752,7 +772,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
userFactors.setName(s"userFactors-$iter").persist(intermediateRDDStorageLevel)
val previousItemFactors = itemFactors
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, implicitPrefs, alpha, solver)
userLocalIndexEncoder, implicitPrefs, alpha, solver, threshold)
previousItemFactors.unpersist()
itemFactors.setName(s"itemFactors-$iter").persist(intermediateRDDStorageLevel)
// TODO: Generalize PeriodicGraphCheckpointer and use it here.
Expand All @@ -762,7 +782,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
}
val previousUserFactors = userFactors
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, implicitPrefs, alpha, solver)
itemLocalIndexEncoder, implicitPrefs, alpha, solver, threshold)
if (shouldCheckpoint(iter)) {
ALS.cleanShuffleDependencies(sc, deps)
deletePreviousCheckpointFile()
Expand All @@ -773,7 +793,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
} else {
for (iter <- 0 until maxIter) {
itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
userLocalIndexEncoder, solver = solver)
userLocalIndexEncoder, solver = solver, threshold = threshold)
if (shouldCheckpoint(iter)) {
val deps = itemFactors.dependencies
itemFactors.checkpoint()
Expand All @@ -783,7 +803,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
previousCheckpointFile = itemFactors.getCheckpointFile
}
userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
itemLocalIndexEncoder, solver = solver)
itemLocalIndexEncoder, solver = solver, threshold = threshold)
}
}
val userIdAndFactors = userInBlocks
Expand Down Expand Up @@ -1297,7 +1317,8 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
srcEncoder: LocalIndexEncoder,
implicitPrefs: Boolean = false,
alpha: Double = 1.0,
solver: LeastSquaresNESolver): RDD[(Int, FactorBlock)] = {
solver: LeastSquaresNESolver,
threshold: Int): RDD[(Int, FactorBlock)] = {
val numSrcBlocks = srcFactorBlocks.partitions.length
val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None
val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
Expand Down Expand Up @@ -1325,7 +1346,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging {
var numExplicits = 0
// Stacking factors(vectors) in matrices to speed up the computation,
// when the number of factors and the rank is large enough.
val doStack = srcPtrs(j + 1) - srcPtrs(j) > 1024 && rank > 1024
val doStack = srcPtrs(j + 1) - srcPtrs(j) > threshold && rank > threshold
val srcFactorBuffer = mutable.ArrayBuilder.make[Double]
val bBuffer = mutable.ArrayBuilder.make[Double]
while (i < srcPtrs(j + 1)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ class ALSSuite
implicitPrefs: Boolean = false,
numUserBlocks: Int = 2,
numItemBlocks: Int = 3,
targetRMSE: Double = 0.05): Unit = {
targetRMSE: Double = 0.05,
threshold: Int = 1024): Unit = {
val spark = this.spark
import spark.implicits._
val als = new ALS()
Expand All @@ -311,6 +312,7 @@ class ALSSuite
.setNumUserBlocks(numUserBlocks)
.setNumItemBlocks(numItemBlocks)
.setSeed(0)
.setThreshold(threshold)
val alpha = als.getAlpha
val model = als.fit(training.toDF())
val predictions = model.transform(test.toDF()).select("rating", "prediction").rdd.map {
Expand Down Expand Up @@ -382,6 +384,12 @@ class ALSSuite
numItemBlocks = 5, numUserBlocks = 5)
}

test("do stacking factors in matrices") {
val (training, test) = genExplicitTestData(numUsers = 200, numItems = 20, rank = 1)
testALS(training, test, maxIter = 1, rank = 129, regParam = 0.01, targetRMSE = 0.02,
threshold = 128)
}

test("implicit feedback") {
val (training, test) =
genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
Expand Down

0 comments on commit d29fd67

Please sign in to comment.