Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-31007][ML] KMeans optimization based on triangle-inequality #27758

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.ml.impl


private[ml] object Utils {
private[spark] object Utils {

lazy val EPSILON = {
var eps = 1.0
Expand All @@ -27,4 +27,55 @@ private[ml] object Utils {
}
eps
}

/**
* Convert an n * (n + 1) / 2 dimension array representing the upper triangular part of a matrix
* into an n * n array representing the full symmetric matrix (column major).
*
* @param n The order of the n by n matrix.
* @param triangularValues The upper triangular part of the matrix packed in an array
* (column major).
* @return A dense matrix which represents the symmetric matrix in column major.
*/
def unpackUpperTriangular(
n: Int,
triangularValues: Array[Double]): Array[Double] = {
val symmetricValues = new Array[Double](n * n)
var r = 0
var i = 0
while (i < n) {
var j = 0
while (j <= i) {
symmetricValues(i * n + j) = triangularValues(r)
symmetricValues(j * n + i) = triangularValues(r)
r += 1
j += 1
}
i += 1
}
symmetricValues
}

/**
* Indexing in an array representing the upper triangular part of a matrix
* into an n * n array representing the full symmetric matrix (column major).
* val symmetricValues = unpackUpperTriangularMatrix(n, triangularValues)
* val matrix = new DenseMatrix(n, n, symmetricValues)
* val index = indexUpperTriangularMatrix(n, i, j)
* then: symmetricValues(index) == matrix(i, j)
*
* @param n The order of the n by n matrix.
*/
def indexUpperTriangular(
n: Int,
i: Int,
j: Int): Int = {
require(i >= 0 && i < n, s"Expected 0 <= i < $n, got i = $i.")
require(j >= 0 && j < n, s"Expected 0 <= j < $n, got j = $j.")
if (i <= j) {
j * (j + 1) / 2 + i
} else {
i * (i + 1) / 2 + j
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Since
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.impl.Utils.EPSILON
import org.apache.spark.ml.impl.Utils.{unpackUpperTriangular, EPSILON}
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
Expand Down Expand Up @@ -583,19 +583,7 @@ object GaussianMixture extends DefaultParamsReadable[GaussianMixture] {
private[clustering] def unpackUpperTriangularMatrix(
n: Int,
triangularValues: Array[Double]): DenseMatrix = {
val symmetricValues = new Array[Double](n * n)
var r = 0
var i = 0
while (i < n) {
var j = 0
while (j <= i) {
symmetricValues(i * n + j) = triangularValues(r)
symmetricValues(j * n + i) = triangularValues(r)
r += 1
j += 1
}
i += 1
}
val symmetricValues = unpackUpperTriangular(n, triangularValues)
new DenseMatrix(n, n, symmetricValues)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,125 @@

package org.apache.spark.mllib.clustering

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.impl.Utils.indexUpperTriangular
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
import org.apache.spark.mllib.util.MLUtils

private[spark] abstract class DistanceMeasure extends Serializable {

/**
* Statistics used in triangle inequality to obtain useful bounds to find closest centers.
* @param distance distance between two centers
*/
def computeStatistics(distance: Double): Double

/**
* Statistics used in triangle inequality to obtain useful bounds to find closest centers.
*
* @return The packed upper triangular part of a symmetric matrix containing statistics,
* matrix(i,j) represents:
* 1, if i != j: a bound r = matrix(i,j) to help avoiding unnecessary distance
* computation. Given point x, let i be current closest center, and d be current best
* distance, if d < f(r), then we no longer need to compute the distance to center j;
* 2, if i == j: a bound r = matrix(i,i) = min_k{maxtrix(i,k)|k!=i}. If distance
* between point x and center i is less than f(r), then center i is the closest center
* to point x.
*/
def computeStatistics(centers: Array[VectorWithNorm]): Array[Double] = {
val k = centers.length
if (k == 1) return Array(Double.NaN)

val packedValues = Array.ofDim[Double](k * (k + 1) / 2)
val diagValues = Array.fill(k)(Double.PositiveInfinity)
var i = 0
while (i < k) {
var j = i + 1
while (j < k) {
val d = distance(centers(i), centers(j))
val s = computeStatistics(d)
val index = indexUpperTriangular(k, i, j)
packedValues(index) = s
if (s < diagValues(i)) diagValues(i) = s
if (s < diagValues(j)) diagValues(j) = s
j += 1
}
i += 1
}

i = 0
while (i < k) {
val index = indexUpperTriangular(k, i, i)
packedValues(index) = diagValues(i)
i += 1
}
packedValues
}

/**
* Compute distance between centers in a distributed way.
*/
def computeStatisticsDistributedly(
sc: SparkContext,
bcCenters: Broadcast[Array[VectorWithNorm]]): Array[Double] = {
val k = bcCenters.value.length
if (k == 1) return Array(Double.NaN)

val packedValues = Array.ofDim[Double](k * (k + 1) / 2)
val diagValues = Array.fill(k)(Double.PositiveInfinity)

val numParts = math.min(k, 1024)
sc.range(0, numParts, 1, numParts)
.mapPartitionsWithIndex { case (pid, _) =>
val centers = bcCenters.value
Iterator.range(0, k).flatMap { i =>
Iterator.range(i + 1, k).flatMap { j =>
val hash = (i, j).hashCode.abs
if (hash % numParts == pid) {
val d = distance(centers(i), centers(j))
val s = computeStatistics(d)
Iterator.single((i, j, s))
} else Iterator.empty
}
}
}.collect.foreach { case (i, j, s) =>
val index = indexUpperTriangular(k, i, j)
packedValues(index) = s
if (s < diagValues(i)) diagValues(i) = s
if (s < diagValues(j)) diagValues(j) = s
}

var i = 0
while (i < k) {
val index = indexUpperTriangular(k, i, i)
packedValues(index) = diagValues(i)
i += 1
}
packedValues
}

/**
* @return the index of the closest center to the given point, as well as the cost.
*/
def findClosest(
centers: TraversableOnce[VectorWithNorm],
centers: Array[VectorWithNorm],
statistics: Array[Double],
point: VectorWithNorm): (Int, Double)

/**
* @return the index of the closest center to the given point, as well as the cost.
*/
def findClosest(
centers: Array[VectorWithNorm],
point: VectorWithNorm): (Int, Double) = {
var bestDistance = Double.PositiveInfinity
var bestIndex = 0
var i = 0
centers.foreach { center =>
while (i < centers.length) {
val center = centers(i)
val currentDistance = distance(center, point)
if (currentDistance < bestDistance) {
bestDistance = currentDistance
Expand All @@ -48,7 +150,7 @@ private[spark] abstract class DistanceMeasure extends Serializable {
* @return the K-means cost of a given point against the given cluster centers.
*/
def pointCost(
centers: TraversableOnce[VectorWithNorm],
centers: Array[VectorWithNorm],
point: VectorWithNorm): Double = {
findClosest(centers, point)._2
}
Expand Down Expand Up @@ -154,22 +256,79 @@ object DistanceMeasure {
}

private[spark] class EuclideanDistanceMeasure extends DistanceMeasure {

/**
* Statistics used in triangle inequality to obtain useful bounds to find closest centers.
* @see <a href="https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf">Charles Elkan,
* Using the Triangle Inequality to Accelerate k-Means</a>
*
* @return One element used in statistics matrix to make matrix(i,j) represents:
* 1, if i != j: a bound r = matrix(i,j) to help avoiding unnecessary distance
* computation. Given point x, let i be current closest center, and d be current best
* squared distance, if d < r, then we no longer need to compute the distance to center
* j. matrix(i,j) equals to squared of half of Euclidean distance between centers i
* and j;
* 2, if i == j: a bound r = matrix(i,i) = min_k{maxtrix(i,k)|k!=i}. If squared
* distance between point x and center i is less than r, then center i is the closest
* center to point x.
*/
override def computeStatistics(distance: Double): Double = {
0.25 * distance * distance
}

/**
* @return the index of the closest center to the given point, as well as the cost.
*/
override def findClosest(
centers: Array[VectorWithNorm],
statistics: Array[Double],
point: VectorWithNorm): (Int, Double) = {
var bestDistance = EuclideanDistanceMeasure.fastSquaredDistance(centers(0), point)
if (bestDistance < statistics(0)) return (0, bestDistance)

val k = centers.length
var bestIndex = 0
var i = 1
while (i < k) {
val center = centers(i)
// Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
// distance computation.
val normDiff = center.norm - point.norm
val lowerBound = normDiff * normDiff
if (lowerBound < bestDistance) {
val index1 = indexUpperTriangular(k, i, bestIndex)
if (statistics(index1) < bestDistance) {
val d = EuclideanDistanceMeasure.fastSquaredDistance(center, point)
val index2 = indexUpperTriangular(k, i, i)
if (d < statistics(index2)) return (i, d)
if (d < bestDistance) {
bestDistance = d
bestIndex = i
}
}
}
i += 1
}
(bestIndex, bestDistance)
}

/**
* @return the index of the closest center to the given point, as well as the squared distance.
*/
override def findClosest(
centers: TraversableOnce[VectorWithNorm],
centers: Array[VectorWithNorm],
point: VectorWithNorm): (Int, Double) = {
var bestDistance = Double.PositiveInfinity
var bestIndex = 0
var i = 0
centers.foreach { center =>
while (i < centers.length) {
val center = centers(i)
// Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary
// distance computation.
var lowerBoundOfSqDist = center.norm - point.norm
lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist
if (lowerBoundOfSqDist < bestDistance) {
val distance: Double = EuclideanDistanceMeasure.fastSquaredDistance(center, point)
val distance = EuclideanDistanceMeasure.fastSquaredDistance(center, point)
if (distance < bestDistance) {
bestDistance = distance
bestIndex = i
Expand Down Expand Up @@ -234,6 +393,58 @@ private[spark] object EuclideanDistanceMeasure {
}

private[spark] class CosineDistanceMeasure extends DistanceMeasure {

/**
* Statistics used in triangle inequality to obtain useful bounds to find closest centers.
*
* @return One element used in statistics matrix to make matrix(i,j) represents:
* 1, if i != j: a bound r = matrix(i,j) to help avoiding unnecessary distance
* computation. Given point x, let i be current closest center, and d be current best
* squared distance, if d < r, then we no longer need to compute the distance to center
* j. For Cosine distance, it is similar to Euclidean distance. However, radian/angle
* is used instead of Cosine distance to compute matrix(i,j): for centers i and j,
* compute the radian/angle between them, halving it, and converting it back to Cosine
* distance at the end;
* 2, if i == j: a bound r = matrix(i,i) = min_k{maxtrix(i,k)|k!=i}. If Cosine
* distance between point x and center i is less than r, then center i is the closest
* center to point x.
*/
override def computeStatistics(distance: Double): Double = {
// d = 1 - cos(x)
// r = 1 - cos(x/2) = 1 - sqrt((cos(x) + 1) / 2) = 1 - sqrt(1 - d/2)
1 - math.sqrt(1 - distance / 2)
}

/**
* @return the index of the closest center to the given point, as well as the cost.
*/
def findClosest(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any clean way to avoid duplicating most of this code? maybe not. It looks almost identical to the above though

centers: Array[VectorWithNorm],
statistics: Array[Double],
point: VectorWithNorm): (Int, Double) = {
var bestDistance = distance(centers(0), point)
if (bestDistance < statistics(0)) return (0, bestDistance)

val k = centers.length
var bestIndex = 0
var i = 1
while (i < k) {
val index1 = indexUpperTriangular(k, i, bestIndex)
if (statistics(index1) < bestDistance) {
val center = centers(i)
val d = distance(center, point)
val index2 = indexUpperTriangular(k, i, i)
if (d < statistics(index2)) return (i, d)
if (d < bestDistance) {
bestDistance = d
bestIndex = i
}
}
i += 1
}
(bestIndex, bestDistance)
}

/**
* @param v1: first vector
* @param v2: second vector
Expand Down
Loading