Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-31454][ML] An optimized K-Means based on DenseMatrix and GEMM #28229

Closed
wants to merge 10 commits into from
321 changes: 300 additions & 21 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.annotation.Since
import org.apache.spark.ml.{Estimator, Model, PipelineStage}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.feature.{Instance, InstanceBlock}
import org.apache.spark.ml.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel, VectorWithNorm}
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.rdd.RDD
Expand All @@ -42,7 +43,8 @@ import org.apache.spark.util.VersionUtils.majorVersion
* Common params for KMeans and KMeansModel
*/
private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol
with HasSeed with HasPredictionCol with HasTol with HasDistanceMeasure with HasWeightCol {
with HasSeed with HasPredictionCol with HasTol with HasDistanceMeasure with HasWeightCol
with HasBlockSize {

/**
* The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than
Expand Down Expand Up @@ -276,7 +278,9 @@ class KMeans @Since("1.5.0") (
initMode -> MLlibKMeans.K_MEANS_PARALLEL,
initSteps -> 2,
tol -> 1e-4,
distanceMeasure -> DistanceMeasure.EUCLIDEAN)
distanceMeasure -> DistanceMeasure.EUCLIDEAN,
blockSize -> 1
)

@Since("1.5.0")
override def copy(extra: ParamMap): KMeans = defaultCopy(extra)
Expand Down Expand Up @@ -330,6 +334,25 @@ class KMeans @Since("1.5.0") (
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

/**
* Set block size for stacking input data in matrices.
* If blockSize == 1, then stacking will be skipped, and each vector is treated individually;
* If blockSize > 1, then vectors will be stacked to blocks, and high-level BLAS routines
* will be used if possible (for example, GEMV instead of DOT, GEMM instead of GEMV).
* Recommended size is between 10 and 1000. An appropriate choice of the block size depends
* on the sparsity and dim of input datasets, the underlying BLAS implementation (for example,
* f2jBLAS, OpenBLAS, intel MKL) and its configuration (for example, number of threads).
* Note that existing BLAS implementations are mainly optimized for dense matrices, if the
* input dataset is sparse, stacking may bring no performance gain, the worse is possible
* performance regression.
* Default is 1.
*
* @group expertSetParam
*/
@Since("3.1.0")
def setBlockSize(value: Int): this.type = set(blockSize, value)
setDefault(blockSize -> 1)

@Since("2.0.0")
override def fit(dataset: Dataset[_]): KMeansModel = instrumented { instr =>
transformSchema(dataset.schema, logging = true)
Expand All @@ -341,9 +364,9 @@ class KMeans @Since("1.5.0") (
lit(1.0)
}

val instances: RDD[(OldVector, Double)] = dataset
val instances: RDD[(Vector, Double)] = dataset
.select(DatasetUtils.columnToVector(dataset, getFeaturesCol), w).rdd.map {
case Row(point: Vector, weight: Double) => (OldVectors.fromML(point), weight)
case Row(point: Vector, weight: Double) => (point, weight)
}

if (handlePersistence) {
Expand All @@ -353,33 +376,289 @@ class KMeans @Since("1.5.0") (
instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logParams(this, featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure,
maxIter, seed, tol, weightCol)
val algo = new MLlibKMeans()
.setK($(k))
.setInitializationMode($(initMode))
.setInitializationSteps($(initSteps))
.setMaxIterations($(maxIter))
.setSeed($(seed))
.setEpsilon($(tol))
.setDistanceMeasure($(distanceMeasure))
val parentModel = algo.runWithWeight(instances, Option(instr))
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
maxIter, seed, tol, weightCol, blockSize)

val model = if ($(blockSize) == 1) {
trainOnRows(instances)
} else {
trainOnBlocks(instances)
}

val summary = new KMeansSummary(
model.transform(dataset),
$(predictionCol),
$(featuresCol),
$(k),
parentModel.numIter,
parentModel.trainingCost)

model.parentModel.numIter,
model.parentModel.trainingCost)
model.setSummary(Some(summary))
instr.logNamedValue("clusterSizes", summary.clusterSizes)

instr.logNamedValue("clusterSizes", model.summary.clusterSizes)
if (handlePersistence) {
instances.unpersist()
}
model
}

private def trainOnRows(instances: RDD[(Vector, Double)]): KMeansModel =
instrumented { instr =>
val oldVectorInstances = instances.map {
case (point: Vector, weight: Double) => (OldVectors.fromML(point), weight)
}
val algo = new MLlibKMeans()
.setK($(k))
.setInitializationMode($(initMode))
.setInitializationSteps($(initSteps))
.setMaxIterations($(maxIter))
.setSeed($(seed))
.setEpsilon($(tol))
.setDistanceMeasure($(distanceMeasure))
val parentModel = algo.runWithWeight(oldVectorInstances, Option(instr))
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
model
}

private def trainOnBlocks(instances: RDD[(Vector, Double)]): KMeansModel =
instrumented { instr =>
val instanceRDD: RDD[Instance] = instances.map {
case (point: Vector, weight: Double) => Instance(0.0, weight, point)
}

val blocks = InstanceBlock.blokify(instanceRDD, $(blockSize))
.persist(StorageLevel.MEMORY_AND_DISK)
.setName(s"training dataset (blockSize=${$(blockSize)})")

val sc = instances.sparkContext

val initStartTime = System.nanoTime()

val distanceMeasureInstance = DistanceMeasure.decodeFromString($(distanceMeasure))

// Use MLlibKMeans to initialize centers
val mllibKMeans = new MLlibKMeans()
.setK($(k))
.setInitializationMode($(initMode))
.setInitializationSteps($(initSteps))
.setMaxIterations($(maxIter))
.setSeed($(seed))
.setEpsilon($(tol))
.setDistanceMeasure($(distanceMeasure))
val centers = if (initMode == "random") {
mllibKMeans.initBlocksRandom(blocks)
} else {
mllibKMeans.initBlocksKMeansParallel(blocks, distanceMeasureInstance)
}

val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
logInfo(f"Initialization with $initMode took $initTimeInSeconds%.3f seconds.")

var converged = false
var cost = 0.0
var iteration = 0

val iterationStartTime = System.nanoTime()

instr.logNumFeatures(centers.head.size)

// Execute iterations of Lloyd's algorithm until converged
while (iteration < $(maxIter) && !converged) {
// Convert center vectors to dense matrix
val centers_matrix = Matrices.fromVectors(centers).toDense

val costAccum = sc.doubleAccumulator
val bcCenters = sc.broadcast(centers_matrix)

val centers_num = centers_matrix.numRows
val centers_dim = centers_matrix.numCols

// Compute squared sums for points
val data_square_sums: RDD[DenseMatrix] = blocks.mapPartitions { p =>
p.map { block =>
computePointsSquareSum(block.matrix.toDense, centers_num) }
}

// Find the new centers
val collected = blocks.zip(data_square_sums).flatMap {
case (block, points_square_sums) =>
val centers_matrix = bcCenters.value
val points_num = block.matrix.numRows

val sums = Array.fill(centers_num)(Vectors.zeros(centers_dim))
val counts = Array.fill(centers_num)(0L)

// Compute squared sums for centers
val centers_square_sums = computeCentersSquareSum(centers_matrix, points_num)

// Compute squared distances
val distances = computeSquaredDistances(
block.matrix.toDense, points_square_sums,
centers_matrix, centers_square_sums)

val (bestCenters, weightedCosts) = findClosest(
distances, block.weights)

for (cost <- weightedCosts)
costAccum.add(cost)

// sums points around best center
// for ((row, index) <- block.matrix.rowIter.zipWithIndex) {
// val bestCenter = bestCenters(index)
// if (block.weights.nonEmpty) {
// BLAS.axpy(block.weights(index), row, sums(bestCenter))
// } else {
// BLAS.axpy(1, row, sums(bestCenter))
// }
// counts(bestCenter) += 1
// }

// sums points around best center, adding values directly without copying array
if (block.weights.nonEmpty) {
for (rowIndex <- 0 until block.matrix.numRows) {
val bestCenter = bestCenters(rowIndex)
for (i <- 0 until centers_dim)
sums(bestCenter).toArray(i) += block.weights(rowIndex) + block.matrix(rowIndex, i)
counts(bestCenter) += 1
}
} else {
for (rowIndex <- 0 until block.matrix.numRows) {
val bestCenter = bestCenters(rowIndex)
for (i <- 0 until centers_dim)
sums(bestCenter).toArray(i) += block.matrix(rowIndex, i)
counts(bestCenter) += 1
}
}


counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator
}.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
BLAS.axpy(1.0, sum2, sum1)
(sum1, count1 + count2)
}.collectAsMap()

if (iteration == 0) {
instr.logNumExamples(collected.values.map(_._2).sum)
}

val newCenters = collected.mapValues { case (sum, count) =>
distanceMeasureInstance.centroid(sum, count)
}

bcCenters.destroy()

// Update the cluster centers and costs
converged = true
newCenters.foreach { case (j, newCenter) =>
if (converged &&
!distanceMeasureInstance.isCenterConverged(
new VectorWithNorm(centers(j)), newCenter, ${tol})) {
converged = false
}
centers(j) = newCenter.vector
}

cost = costAccum.value
iteration += 1
}

val iterationTimeInSeconds = (System.nanoTime() - iterationStartTime) / 1e9
logInfo(f"Iterations took $iterationTimeInSeconds%.3f seconds.")

if (iteration == ${maxIter}) {
logInfo(s"KMeans reached the max number of iterations: $maxIter.")
} else {
logInfo(s"KMeans converged in $iteration iterations.")
}

logInfo(s"The cost is $cost.")

val parentModel = new MLlibKMeansModel(centers.map(OldVectors.fromML(_)),
${distanceMeasure}, cost, iteration)

val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
model
}

private def computeSquaredDistances(points_matrix: DenseMatrix,
points_square_sums: DenseMatrix,
centers_matrix: DenseMatrix,
centers_square_sums: DenseMatrix): DenseMatrix = {
// (x - y)^2 = x^2 + y^2 - 2 * x * y

// Add up squared sums of points and centers (x^2 + y^2)
val ret: DenseMatrix = computeMatrixSum(points_square_sums, centers_square_sums)

// use GEMM to compute squared distances, (2*x*y) can be decomposed to matrix multiply
val alpha = -2.0
val beta = 1.0
BLAS.gemm(alpha, points_matrix, centers_matrix.transpose, beta, ret)

ret
}

private def computePointsSquareSum(points_matrix: DenseMatrix,
centers_num: Int): DenseMatrix = {
val points_num = points_matrix.numRows
val ret = DenseMatrix.zeros(points_num, centers_num)
for ((row, index) <- points_matrix.rowIter.zipWithIndex) {
val square = BLAS.dot(row, row)
for (i <- 0 until centers_num)
ret(index, i) = square
}
ret
}

private def computeCentersSquareSum(centers_matrix: DenseMatrix,
points_num: Int): DenseMatrix = {
val centers_num = centers_matrix.numRows
val ret = DenseMatrix.zeros(points_num, centers_num)
for ((row, index) <- centers_matrix.rowIter.zipWithIndex) {
val square = BLAS.dot(row, row)
for (i <- 0 until points_num)
ret(i, index) = square
}
ret
}

// use GEMM to compute matrix sum
private def computeMatrixSum(matrix1: DenseMatrix,
matrix2: DenseMatrix): DenseMatrix = {
val column_num = matrix1.numCols
val eye = DenseMatrix.eye(column_num)
val alpha = 1.0
val beta = 1.0
BLAS.gemm(alpha, matrix1, eye, beta, matrix2)
matrix2
}

private def findClosest(distances: DenseMatrix,
weights: Array[Double]): (Array[Int], Array[Double]) = {
val points_num = distances.numRows
val ret_closest = new Array[Int](points_num)
val ret_cost = new Array[Double](points_num)

for ((row, index) <- distances.rowIter.zipWithIndex) {
var closest = 0
var cost = row(0)
for (i <- 1 until row.size) {
if (row(i) < cost) {
closest = i
cost = row(i)
}
}
ret_closest(index) = closest

// use weighted squared distance as cost
if (weights.nonEmpty) {
ret_cost(index) = cost * weights(index)
} else {
ret_cost(index) = cost
}
}

(ret_closest, ret_cost)

}

@Since("1.5.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
Expand Down
Loading