-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-36553][ML] KMeans avoid compute auxiliary statistics for large K #35457
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -117,6 +117,24 @@ private[spark] abstract class DistanceMeasure extends Serializable { | |
packedValues | ||
} | ||
|
||
/** | ||
* @param centers the clustering centers | ||
* @param statistics optional statistics to accelerate the computation, which should not | ||
* change the result. | ||
* @param point given point | ||
* @return the index of the closest center to the given point, as well as the cost. | ||
*/ | ||
def findClosest( | ||
centers: Array[VectorWithNorm], | ||
statistics: Option[Array[Double]], | ||
point: VectorWithNorm): (Int, Double) = { | ||
if (statistics.nonEmpty) { | ||
findClosest(centers, statistics.get, point) | ||
} else { | ||
findClosest(centers, point) | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shall we add the function description like the other existing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good point. I will update this PR |
||
|
||
/** | ||
* @return the index of the closest center to the given point, as well as the cost. | ||
*/ | ||
|
@@ -253,6 +271,11 @@ object DistanceMeasure { | |
case _ => false | ||
} | ||
} | ||
|
||
private[clustering] def shouldComputeStatistics(k: Int): Boolean = k < 1000 | ||
|
||
private[clustering] def shouldComputeStatisticsLocally(k: Int, numFeatures: Int): Boolean = | ||
k.toLong * k * numFeatures < 1000000 | ||
} | ||
|
||
private[spark] class EuclideanDistanceMeasure extends DistanceMeasure { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -269,15 +269,22 @@ class KMeans private ( | |
|
||
instr.foreach(_.logNumFeatures(numFeatures)) | ||
|
||
val shouldDistributed = centers.length * centers.length * numFeatures.toLong > 1000000L | ||
val shouldComputeStats = | ||
DistanceMeasure.shouldComputeStatistics(centers.length) | ||
val shouldComputeStatsLocally = | ||
DistanceMeasure.shouldComputeStatisticsLocally(centers.length, numFeatures) | ||
|
||
// Execute iterations of Lloyd's algorithm until converged | ||
while (iteration < maxIterations && !converged) { | ||
val bcCenters = sc.broadcast(centers) | ||
val stats = if (shouldDistributed) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. previous stats is a Array[Double] |
||
distanceMeasureInstance.computeStatisticsDistributedly(sc, bcCenters) | ||
val stats = if (shouldComputeStats) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now, it is a Option[Array[Double]] There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the following |
||
if (shouldComputeStatsLocally) { | ||
Some(distanceMeasureInstance.computeStatistics(centers)) | ||
} else { | ||
Some(distanceMeasureInstance.computeStatisticsDistributedly(sc, bcCenters)) | ||
} | ||
} else { | ||
distanceMeasureInstance.computeStatistics(centers) | ||
None | ||
} | ||
val bcStats = sc.broadcast(stats) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this overload used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is used in both training and prediction.
statistics
is optional in it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK this is a new method but I don't see it called, maybe I'm missing something