diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala b/mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala index 112de982e463b..ee3e99c0a8a55 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/impl/Utils.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.impl -private[ml] object Utils { +private[spark] object Utils { lazy val EPSILON = { var eps = 1.0 @@ -27,4 +27,55 @@ private[ml] object Utils { } eps } + + /** + * Convert an n * (n + 1) / 2 dimension array representing the upper triangular part of a matrix + * into an n * n array representing the full symmetric matrix (column major). + * + * @param n The order of the n by n matrix. + * @param triangularValues The upper triangular part of the matrix packed in an array + * (column major). + * @return A dense matrix which represents the symmetric matrix in column major. + */ + def unpackUpperTriangular( + n: Int, + triangularValues: Array[Double]): Array[Double] = { + val symmetricValues = new Array[Double](n * n) + var r = 0 + var i = 0 + while (i < n) { + var j = 0 + while (j <= i) { + symmetricValues(i * n + j) = triangularValues(r) + symmetricValues(j * n + i) = triangularValues(r) + r += 1 + j += 1 + } + i += 1 + } + symmetricValues + } + + /** + * Indexing in an array representing the upper triangular part of a matrix + * into an n * n array representing the full symmetric matrix (column major). + * val symmetricValues = unpackUpperTriangularMatrix(n, triangularValues) + * val matrix = new DenseMatrix(n, n, symmetricValues) + * val index = indexUpperTriangularMatrix(n, i, j) + * then: symmetricValues(index) == matrix(i, j) + * + * @param n The order of the n by n matrix. + */ + def indexUpperTriangular( + n: Int, + i: Int, + j: Int): Int = { + require(i >= 0 && i < n, s"Expected 0 <= i < $n, got i = $i.") + require(j >= 0 && j < n, s"Expected 0 <= j < $n, got j = $j.") + if (i <= j) { + j * (j + 1) / 2 + i + } else { + i * (i + 1) / 2 + j + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index f490faf084d2c..1c4560aa5fdd7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -22,7 +22,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.Since import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.impl.Utils.EPSILON +import org.apache.spark.ml.impl.Utils.{unpackUpperTriangular, EPSILON} import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -583,19 +583,7 @@ object GaussianMixture extends DefaultParamsReadable[GaussianMixture] { private[clustering] def unpackUpperTriangularMatrix( n: Int, triangularValues: Array[Double]): DenseMatrix = { - val symmetricValues = new Array[Double](n * n) - var r = 0 - var i = 0 - while (i < n) { - var j = 0 - while (j <= i) { - symmetricValues(i * n + j) = triangularValues(r) - symmetricValues(j * n + i) = triangularValues(r) - r += 1 - j += 1 - } - i += 1 - } + val symmetricValues = unpackUpperTriangular(n, triangularValues) new DenseMatrix(n, n, symmetricValues) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala index 4ca91259772bc..7729e2a0f8219 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/DistanceMeasure.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.impl.Utils.indexUpperTriangular import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal} import org.apache.spark.mllib.util.MLUtils @@ -35,38 +36,42 @@ private[spark] abstract class DistanceMeasure extends Serializable { /** * Statistics used in triangle inequality to obtain useful bounds to find closest centers. * - * @return A symmetric matrix containing statistics, matrix(i)(j) represents: + * @return The upper triangular part of a symmetric matrix containing statistics, matrix(i)(j) + * represents: * 1, a lower bound r of the center i, if i==j. If distance between point x and center i * is less than f(r), then center i is the closest center to point x. * 2, a lower bound r=matrix(i)(j) to help avoiding unnecessary distance computation. * Given point x, let i be current closest center, and d be current best distance, * if d < f(r), then we no longer need to compute the distance to center j. */ - def computeStatistics(centers: Array[VectorWithNorm]): Array[Array[Double]] = { + def computeStatistics(centers: Array[VectorWithNorm]): Array[Double] = { val k = centers.length - if (k == 1) return Array(Array(Double.NaN)) + if (k == 1) return Array(Double.NaN) - val stats = Array.ofDim[Double](k, k) + val packedValues = Array.ofDim[Double](k * (k + 1) / 2) + val diagValues = Array.fill(k)(Double.PositiveInfinity) var i = 0 - while (i < k) { - stats(i)(i) = Double.PositiveInfinity - i += 1 - } - i = 0 while (i < k) { var j = i + 1 while (j < k) { val d = distance(centers(i), centers(j)) val s = computeStatistics(d) - stats(i)(j) = s - stats(j)(i) = s - if (s < stats(i)(i)) stats(i)(i) = s - if (s < stats(j)(j)) stats(j)(j) = s + val index = indexUpperTriangular(k, i, j) + packedValues(index) = s + if (s < diagValues(i)) diagValues(i) = s + if (s < diagValues(j)) diagValues(j) = s j += 1 } i += 1 } - stats + + i = 0 + while (i < k) { + val index = indexUpperTriangular(k, i, i) + packedValues(index) = diagValues(i) + i += 1 + } + packedValues } /** @@ -74,12 +79,15 @@ private[spark] abstract class DistanceMeasure extends Serializable { */ def computeStatisticsDistributedly( sc: SparkContext, - bcCenters: Broadcast[Array[VectorWithNorm]]): Array[Array[Double]] = { + bcCenters: Broadcast[Array[VectorWithNorm]]): Array[Double] = { val k = bcCenters.value.length - if (k == 1) return Array(Array(Double.NaN)) + if (k == 1) return Array(Double.NaN) + + val packedValues = Array.ofDim[Double](k * (k + 1) / 2) + val diagValues = Array.fill(k)(Double.PositiveInfinity) val numParts = math.min(k, 1024) - val collected = sc.range(0, numParts, 1, numParts) + sc.range(0, numParts, 1, numParts) .mapPartitionsWithIndex { case (pid, _) => val centers = bcCenters.value Iterator.range(0, k).flatMap { i => @@ -88,32 +96,24 @@ private[spark] abstract class DistanceMeasure extends Serializable { if (hash % numParts == pid) { val d = distance(centers(i), centers(j)) val s = computeStatistics(d) - Iterator.single(((i, j), s)) + Iterator.single((i, j, s)) } else Iterator.empty } }.filterNot(_._2 == 0) - }.collectAsMap() + }.foreach { case (i, j, s) => + val index = indexUpperTriangular(k, i, j) + packedValues(index) = s + if (s < diagValues(i)) diagValues(i) = s + if (s < diagValues(j)) diagValues(j) = s + } - val stats = Array.ofDim[Double](k, k) var i = 0 while (i < k) { - stats(i)(i) = Double.PositiveInfinity - i += 1 - } - i = 0 - while (i < k) { - var j = i + 1 - while (j < k) { - val s = collected.getOrElse((i, j), 0.0) - stats(i)(j) = s - stats(j)(i) = s - if (s < stats(i)(i)) stats(i)(i) = s - if (s < stats(j)(j)) stats(j)(j) = s - j += 1 - } + val index = indexUpperTriangular(k, i, i) + packedValues(index) = diagValues(i) i += 1 } - stats + packedValues } /** @@ -121,7 +121,7 @@ private[spark] abstract class DistanceMeasure extends Serializable { */ def findClosest( centers: Array[VectorWithNorm], - statistics: Array[Array[Double]], + statistics: Array[Double], point: VectorWithNorm): (Int, Double) /** @@ -279,28 +279,33 @@ private[spark] class EuclideanDistanceMeasure extends DistanceMeasure { */ override def findClosest( centers: Array[VectorWithNorm], - statistics: Array[Array[Double]], + statistics: Array[Double], point: VectorWithNorm): (Int, Double) = { var bestDistance = EuclideanDistanceMeasure.fastSquaredDistance(centers(0), point) - if (bestDistance < statistics(0)(0)) { + if (bestDistance < statistics(0)) { return (0, bestDistance) } + val k = centers.length var bestIndex = 0 var i = 1 - while (i < centers.length) { + while (i < k) { val center = centers(i) // Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary // distance computation. val normDiff = center.norm - point.norm val lowerBound = normDiff * normDiff - if (lowerBound < bestDistance && statistics(i)(bestIndex) < bestDistance) { - val d = EuclideanDistanceMeasure.fastSquaredDistance(center, point) - if (d < statistics(i)(i)) { - return (i, d) - } else if (d < bestDistance) { - bestDistance = d - bestIndex = i + if (lowerBound < bestDistance) { + val index1 = indexUpperTriangular(k, i, bestIndex) + if (statistics(index1) < bestDistance) { + val d = EuclideanDistanceMeasure.fastSquaredDistance(center, point) + val index2 = indexUpperTriangular(k, i, i) + if (d < statistics(index2)) { + return (i, d) + } else if (d < bestDistance) { + bestDistance = d + bestIndex = i + } } } i += 1 @@ -415,20 +420,23 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure { */ def findClosest( centers: Array[VectorWithNorm], - statistics: Array[Array[Double]], + statistics: Array[Double], point: VectorWithNorm): (Int, Double) = { var bestDistance = distance(centers(0), point) - if (bestDistance < statistics(0)(0)) { + if (bestDistance < statistics(0)) { return (0, bestDistance) } + val k = centers.length var bestIndex = 0 var i = 1 - while (i < centers.length) { - if (statistics(i)(bestIndex) < bestDistance) { + while (i < k) { + val index1 = indexUpperTriangular(k, i, bestIndex) + if (statistics(index1) < bestDistance) { val center = centers(i) val d = distance(center, point) - if (d < statistics(i)(i)) { + val index2 = indexUpperTriangular(k, i, i) + if (d < statistics(index2)) { return (i, d) } else if (d < bestDistance) { bestDistance = d