Skip to content

Commit

Permalink
compute stats distributedly
Browse files Browse the repository at this point in the history
compute stats distributedly

nit

nit
  • Loading branch information
zhengruifeng committed Apr 14, 2020
1 parent ac83e54 commit b4fabb1
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.apache.spark.mllib.clustering

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
import org.apache.spark.mllib.util.MLUtils
Expand All @@ -26,8 +28,93 @@ private[spark] abstract class DistanceMeasure extends Serializable {

/**
* Statistics used in triangle inequality to obtain useful bounds to find closest centers.
* @param distance distance between two centers
*/
def computeStatistics(centers: Array[VectorWithNorm]): Array[Array[Double]]
def computeStatistics(distance: Double): Double

/**
* Statistics used in triangle inequality to obtain useful bounds to find closest centers.
*
* @return A symmetric matrix containing statistics, matrix(i)(j) represents:
* 1, a lower bound r of the center i, if i==j. If distance between point x and center i
* is less than f(r), then center i is the closest center to point x.
* 2, a lower bound r=matrix(i)(j) to help avoiding unnecessary distance computation.
* Given point x, let i be current closest center, and d be current best distance,
* if d < f(r), then we no longer need to compute the distance to center j.
*/
def computeStatistics(centers: Array[VectorWithNorm]): Array[Array[Double]] = {
val k = centers.length
if (k == 1) return Array(Array(Double.NaN))

val stats = Array.ofDim[Double](k, k)
var i = 0
while (i < k) {
stats(i)(i) = Double.PositiveInfinity
i += 1
}
i = 0
while (i < k) {
var j = i + 1
while (j < k) {
val d = distance(centers(i), centers(j))
val s = computeStatistics(d)
stats(i)(j) = s
stats(j)(i) = s
if (s < stats(i)(i)) stats(i)(i) = s
if (s < stats(j)(j)) stats(j)(j) = s
j += 1
}
i += 1
}
stats
}

/**
* Compute distance between centers in a distributed way.
*/
def computeStatisticsDistributedly(
sc: SparkContext,
bcCenters: Broadcast[Array[VectorWithNorm]]): Array[Array[Double]] = {
val k = bcCenters.value.length
if (k == 1) return Array(Array(Double.NaN))

val numParts = math.min(k, 1024)
val collected = sc.range(0, numParts, 1, numParts)
.mapPartitionsWithIndex { case (pid, _) =>
val centers = bcCenters.value
Iterator.range(0, k).flatMap { i =>
Iterator.range(i + 1, k).flatMap { j =>
val hash = (i, j).hashCode.abs
if (hash % numParts == pid) {
val d = distance(centers(i), centers(j))
val s = computeStatistics(d)
Iterator.single(((i, j), s))
} else Iterator.empty
}
}.filterNot(_._2 == 0)
}.collectAsMap()

val stats = Array.ofDim[Double](k, k)
var i = 0
while (i < k) {
stats(i)(i) = Double.PositiveInfinity
i += 1
}
i = 0
while (i < k) {
var j = i + 1
while (j < k) {
val s = collected.getOrElse((i, j), 0.0)
stats(i)(j) = s
stats(j)(i) = s
if (s < stats(i)(i)) stats(i)(i) = s
if (s < stats(j)(j)) stats(j)(j) = s
j += 1
}
i += 1
}
stats
}

/**
* @return the index of the closest center to the given point, as well as the cost.
Expand Down Expand Up @@ -174,7 +261,7 @@ private[spark] class EuclideanDistanceMeasure extends DistanceMeasure {
* @see <a href="https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf">Charles Elkan,
* Using the Triangle Inequality to Accelerate k-Means</a>
*
* @return A symmetric matrix containing statistics, matrix(i)(j) represents:
* @return One element used in statistics matrix to make matrix(i)(j) represents:
* 1, squared radii of the center i, if i==j. If distance between point x and center i
* is less than the radius of center i, then center i is the closest center to point x.
* For Euclidean distance, radius of center i is half of the distance between center i
Expand All @@ -183,33 +270,8 @@ private[spark] class EuclideanDistanceMeasure extends DistanceMeasure {
* Given point x, let i be current closest center, and d be current best squared
* distance, if d < r, then we no longer need to compute the distance to center j.
*/
override def computeStatistics(centers: Array[VectorWithNorm]): Array[Array[Double]] = {
val k = centers.length
if (k == 1) {
Array(Array(Double.NaN))
} else {
val matrix = Array.ofDim[Double](k, k)
var i = 0
while (i < k) {
matrix(i)(i) = Double.PositiveInfinity
i += 1
}
i = 0
while (i < k) {
var j = i + 1
while (j < k) {
val d = distance(centers(i), centers(j))
val r = 0.25 * d * d
matrix(i)(j) = r
matrix(j)(i) = r
if (r < matrix(i)(i)) matrix(i)(i) = r
if (r < matrix(j)(j)) matrix(j)(j) = r
j += 1
}
i += 1
}
matrix
}
override def computeStatistics(distance: Double): Double = {
0.25 * distance * distance
}

/**
Expand Down Expand Up @@ -330,7 +392,8 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure {

/**
* Statistics used in triangle inequality to obtain useful bounds to find closest centers.
* @return A symmetric matrix containing statistics, matrix(i)(j) represents:
*
* @return One element used in statistics matrix to make matrix(i)(j) represents:
* 1, squared radii of the center i, if i==j. If distance between point x and center i
* is less than the radius of center i, then center i is the closest center to point x.
* For Cosine distance, it is similar to Euclidean distance. However, here radian/angle
Expand All @@ -341,35 +404,10 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure {
* Given point x, let i be current closest center, and d be current best squared
* distance, if d < r, then we no longer need to compute the distance to center j.
*/
override def computeStatistics(centers: Array[VectorWithNorm]): Array[Array[Double]] = {
val k = centers.length
if (k == 1) {
Array(Array(Double.NaN))
} else {
val matrix = Array.ofDim[Double](k, k)
var i = 0
while (i < k) {
matrix(i)(i) = Double.PositiveInfinity
i += 1
}
i = 0
while (i < k) {
var j = i + 1
while (j < k) {
// d = 1 - cos(x)
// r = 1 - cos(x/2) = 1 - sqrt((cos(x) + 1) / 2) = 1 - sqrt(1 - d/2)
val d = distance(centers(i), centers(j))
val r = 1 - math.sqrt(1 - d / 2)
matrix(i)(j) = r
matrix(j)(i) = r
if (r < matrix(i)(i)) matrix(i)(i) = r
if (r < matrix(j)(j)) matrix(j)(j) = r
j += 1
}
i += 1
}
matrix
}
override def computeStatistics(distance: Double): Double = {
// d = 1 - cos(x)
// r = 1 - cos(x/2) = 1 - sqrt((cos(x) + 1) / 2) = 1 - sqrt(1 - d/2)
1 - math.sqrt(1 - distance / 2)
}

/**
Expand Down
28 changes: 18 additions & 10 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,7 @@ class KMeans private (
*/
@Since("0.8.0")
def run(data: RDD[Vector]): KMeansModel = {
val instances: RDD[(Vector, Double)] = data.map {
case (point) => (point, 1.0)
}
val instances = data.map(point => (point, 1.0))
runWithWeight(instances, None)
}

Expand Down Expand Up @@ -260,6 +258,7 @@ class KMeans private (
initKMeansParallel(data, distanceMeasureInstance)
}
}
val numFeatures = centers.head.vector.size
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
logInfo(f"Initialization with $initializationMode took $initTimeInSeconds%.3f seconds.")

Expand All @@ -269,18 +268,26 @@ class KMeans private (

val iterationStartTime = System.nanoTime()

instr.foreach(_.logNumFeatures(centers.head.vector.size))
instr.foreach(_.logNumFeatures(numFeatures))

val shouldDistributed = centers.length * centers.length * numFeatures.toLong > 1000000L

// Execute iterations of Lloyd's algorithm until converged
while (iteration < maxIterations && !converged) {
val statistics = distanceMeasureInstance.computeStatistics(centers)
val bcCenters = sc.broadcast(centers)
val stats = if (shouldDistributed) {
distanceMeasureInstance.computeStatisticsDistributedly(sc, bcCenters)
} else {
distanceMeasureInstance.computeStatistics(centers)
}
val bcStats = sc.broadcast(stats)

val costAccum = sc.doubleAccumulator
val bcCenters = sc.broadcast((centers, statistics))

// Find the new centers
val collected = data.mapPartitions { points =>
val (centers, radii) = bcCenters.value
val centers = bcCenters.value
val stats = bcStats.value
val dims = centers.head.vector.size

val sums = Array.fill(centers.length)(Vectors.zeros(dims))
Expand All @@ -291,14 +298,14 @@ class KMeans private (
val clusterWeightSum = Array.ofDim[Double](centers.length)

points.foreach { point =>
val (bestCenter, cost) = distanceMeasureInstance.findClosest(centers, radii, point)
val (bestCenter, cost) = distanceMeasureInstance.findClosest(centers, stats, point)
costAccum.add(cost * point.weight)
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter))
clusterWeightSum(bestCenter) += point.weight
}

clusterWeightSum.indices.filter(clusterWeightSum(_) > 0)
.map(j => (j, (sums(j), clusterWeightSum(j)))).iterator
Iterator.tabulate(centers.length)(j => (j, (sums(j), clusterWeightSum(j))))
.filter(_._2._2 > 0)
}.reduceByKey { (sumweight1, sumweight2) =>
axpy(1.0, sumweight2._1, sumweight1._1)
(sumweight1._1, sumweight1._2 + sumweight2._2)
Expand All @@ -310,6 +317,7 @@ class KMeans private (
}

bcCenters.destroy()
bcStats.destroy()

// Update the cluster centers and costs
converged = true
Expand Down

0 comments on commit b4fabb1

Please sign in to comment.