Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-36553][ML] KMeans avoid compute auxiliary statistics for large K #35457

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,24 @@ private[spark] abstract class DistanceMeasure extends Serializable {
packedValues
}

/**
* @param centers the clustering centers
* @param statistics optional statistics to accelerate the computation, which should not
* change the result.
* @param point given point
* @return the index of the closest center to the given point, as well as the cost.
*/
def findClosest(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this overload used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is used in both training and prediction. statistics is optional in it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK this is a new method but I don't see it called, maybe I'm missing something

centers: Array[VectorWithNorm],
statistics: Option[Array[Double]],
point: VectorWithNorm): (Int, Double) = {
if (statistics.nonEmpty) {
findClosest(centers, statistics.get, point)
} else {
findClosest(centers, point)
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add the function description like the other existing def findClosest functions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. I will update this PR


/**
* @return the index of the closest center to the given point, as well as the cost.
*/
Expand Down Expand Up @@ -253,6 +271,11 @@ object DistanceMeasure {
case _ => false
}
}

private[clustering] def shouldComputeStatistics(k: Int): Boolean = k < 1000

private[clustering] def shouldComputeStatisticsLocally(k: Int, numFeatures: Int): Boolean =
k.toLong * k * numFeatures < 1000000
}

private[spark] class EuclideanDistanceMeasure extends DistanceMeasure {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,22 @@ class KMeans private (

instr.foreach(_.logNumFeatures(numFeatures))

val shouldDistributed = centers.length * centers.length * numFeatures.toLong > 1000000L
val shouldComputeStats =
DistanceMeasure.shouldComputeStatistics(centers.length)
val shouldComputeStatsLocally =
DistanceMeasure.shouldComputeStatisticsLocally(centers.length, numFeatures)

// Execute iterations of Lloyd's algorithm until converged
while (iteration < maxIterations && !converged) {
val bcCenters = sc.broadcast(centers)
val stats = if (shouldDistributed) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

previous stats is a Array[Double]

distanceMeasureInstance.computeStatisticsDistributedly(sc, bcCenters)
val stats = if (shouldComputeStats) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now, it is a Option[Array[Double]]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the following val (bestCenter, cost) = distanceMeasureInstance.findClosest(centers, stats, point) will call this new method, without code change in the call sites.

if (shouldComputeStatsLocally) {
Some(distanceMeasureInstance.computeStatistics(centers))
} else {
Some(distanceMeasureInstance.computeStatisticsDistributedly(sc, bcCenters))
}
} else {
distanceMeasureInstance.computeStatistics(centers)
None
}
val bcStats = sc.broadcast(stats)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,16 @@ class KMeansModel (@Since("1.0.0") val clusterCenters: Array[Vector],

// TODO: computation of statistics may take seconds, so save it to KMeansModel in training
@transient private lazy val statistics = if (clusterCenters == null) {
null
None
} else {
distanceMeasureInstance.computeStatistics(clusterCentersWithNorm)
val k = clusterCenters.length
val numFeatures = clusterCenters.head.size
if (DistanceMeasure.shouldComputeStatistics(k) &&
DistanceMeasure.shouldComputeStatisticsLocally(k, numFeatures)) {
Some(distanceMeasureInstance.computeStatistics(clusterCentersWithNorm))
} else {
None
}
}

@Since("2.4.0")
Expand Down