Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to support KMeans with large feature space #10739

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.Logging
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.{axpy, scal}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
Expand All @@ -45,7 +45,9 @@ class KMeans private (
private var initializationMode: String,
private var initializationSteps: Int,
private var epsilon: Double,
private var seed: Long) extends Serializable with Logging {
private var seed: Long,
private var vectorFactory: VectorFactory = DenseVectorFactory.instance
) extends Serializable with Logging {

/**
* Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1,
Expand Down Expand Up @@ -176,6 +178,13 @@ class KMeans private (
this
}

def getVectorFactory: VectorFactory = vectorFactory

def setVectorFactory(vectorFactory: VectorFactory): this.type = {
this.vectorFactory = vectorFactory
this
}

// Initial cluster centers can be provided as a KMeansModel object rather than using the
// random or k-means|| initializationMode
private var initialModel: Option[KMeansModel] = None
Expand Down Expand Up @@ -282,7 +291,8 @@ class KMeans private (
val k = thisActiveCenters(0).length
val dims = thisActiveCenters(0)(0).vector.size

val sums = Array.fill(runs, k)(Vectors.zeros(dims))
// val sums = Array.fill(runs, k)(Vectors.zeros(dims))
val sums = Array.fill(runs, k)(vectorFactory.zeros(dims))
val counts = Array.fill(runs, k)(0L)

points.foreach { point =>
Expand Down Expand Up @@ -376,7 +386,8 @@ class KMeans private (
// Initialize each run's first center to a random point.
val seed = new XORShiftRandom(this.seed).nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
// val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).compact(vectorFactory)))

/** Merges new centers to centers. */
def mergeNewCenters(): Unit = {
Expand Down Expand Up @@ -436,7 +447,8 @@ class KMeans private (
}.collect()
mergeNewCenters()
chosen.foreach { case (p, rs) =>
rs.foreach(newCenters(_) += p.toDense)
// rs.foreach(newCenters(_) += p.toDense)
rs.foreach(newCenters(_) += p)
}
step += 1
}
Expand All @@ -459,7 +471,7 @@ class KMeans private (
val finalCenters = (0 until runs).par.map { r =>
val myCenters = centers(r).toArray
val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray
LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30)
LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30, vectorFactory)
}

finalCenters.toArray
Expand Down Expand Up @@ -488,6 +500,7 @@ object KMeans {
* @param runs number of parallel runs, defaults to 1. The best model is returned.
* @param initializationMode initialization model, either "random" or "k-means||" (default).
* @param seed random seed value for cluster initialization
* @param vectorFactory provide factory to use for creating the vectors representing the centroids
*/
@Since("1.3.0")
def train(
Expand All @@ -496,12 +509,14 @@ object KMeans {
maxIterations: Int,
runs: Int,
initializationMode: String,
seed: Long): KMeansModel = {
seed: Long,
vectorFactory: VectorFactory = DenseVectorFactory.instance): KMeansModel = {
new KMeans().setK(k)
.setMaxIterations(maxIterations)
.setRuns(runs)
.setInitializationMode(initializationMode)
.setSeed(seed)
.setVectorFactory(vectorFactory)
.run(data)
}

Expand Down Expand Up @@ -617,5 +632,33 @@ class VectorWithNorm(val vector: Vector, val norm: Double) extends Serializable
def this(array: Array[Double]) = this(Vectors.dense(array))

/** Converts the vector to a dense vector. */
def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm)
// def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm)

def compact(fact: VectorFactory): VectorWithNorm = new VectorWithNorm(fact.compact(vector), norm)
}

trait VectorFactory extends Serializable {
def zeros(size: Int): Vector

def compact(vec: Vector): Vector
}

class DenseVectorFactory private() extends VectorFactory {
override def zeros(size: Int): Vector = Vectors.zeros(size)

override def compact(vec: Vector): Vector = vec.toDense
}

object DenseVectorFactory {
val instance = new DenseVectorFactory
}

class SmartVectorFactory private() extends VectorFactory {
override def zeros(size: Int): Vector = new SparseVector(size, Array.empty, Array.empty)

override def compact(vec: Vector): Vector = vec.compressed
}

object SmartVectorFactory {
val instance = new SmartVectorFactory
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ private[mllib] object LocalKMeans extends Logging {
points: Array[VectorWithNorm],
weights: Array[Double],
k: Int,
maxIterations: Int
maxIterations: Int,
vectorFactory: VectorFactory
): Array[VectorWithNorm] = {
val rand = new Random(seed)
val dimensions = points(0).vector.size
val centers = new Array[VectorWithNorm](k)

// Initialize centers by sampling using the k-means++ procedure.
centers(0) = pickWeighted(rand, points, weights).toDense
centers(0) = pickWeighted(rand, points, weights).compact(vectorFactory)
for (i <- 1 until k) {
// Pick the next center with a probability proportional to cost under current centers
val curCenters = centers.view.take(i)
Expand All @@ -62,9 +63,9 @@ private[mllib] object LocalKMeans extends Logging {
if (j == 0) {
logWarning("kMeansPlusPlus initialization ran out of distinct points for centers." +
s" Using duplicate point for center k = $i.")
centers(i) = points(0).toDense
centers(i) = points(0).compact(vectorFactory)
} else {
centers(i) = points(j - 1).toDense
centers(i) = points(j - 1).compact(vectorFactory)
}
}

Expand Down Expand Up @@ -93,7 +94,7 @@ private[mllib] object LocalKMeans extends Logging {
while (j < k) {
if (counts(j) == 0.0) {
// Assign center to a random point
centers(j) = points(rand.nextInt(points.length)).toDense
centers(j) = points(rand.nextInt(points.length)).compact(vectorFactory)
} else {
scal(1.0 / counts(j), sums(j))
centers(j) = new VectorWithNorm(sums(j))
Expand Down
78 changes: 78 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ private[spark] object BLAS extends Serializable with Logging {
throw new UnsupportedOperationException(
s"axpy doesn't support x type ${x.getClass}.")
}
case sy: SparseVector =>
x match {
case sx: SparseVector =>
axpy(a, sx, sy)
case dx: DenseVector =>
axpy(a, dx, sy)
case _ =>
throw new UnsupportedOperationException(
s"axpy doesn't support x type ${x.getClass}.")
}
case _ =>
throw new IllegalArgumentException(
s"axpy only supports adding to a dense vector but got type ${y.getClass}.")
Expand Down Expand Up @@ -92,6 +102,74 @@ private[spark] object BLAS extends Serializable with Logging {
}
}

/**
* y += a * x
*/
private def axpy(a: Double, x: DenseVector, y: SparseVector): Unit = {
require(x.size == y.size)

val xIndices = (0 until x.size).filter(i => x(i) != 0.0).toArray
val xValues = xIndices.map(i => x(i))

axpy(a, Vectors.sparse(x.size, xIndices, xValues), y)
}

/**
* y += a * x
*/
private def axpy(a: Double, x: SparseVector, y: SparseVector): Unit = {
val xSortedIndices = x.indices
val xValues = x.values

val ySortedIndices = y.indices
val yValues = y.values

val newIndices = new Array[Int](xSortedIndices.length + ySortedIndices.length)
val newValues = new Array[Double](xValues.length + yValues.length)

assert(newIndices.length == newValues.length)

var xj = 0
var yj = 0
var j = 0
var previ = Int.MinValue

def getAt(indices: Array[Int], j: Int): Int =
if (j < indices.length) indices(j) else Int.MaxValue

while (xj < xSortedIndices.length || yj < ySortedIndices.length) {
val xi = getAt(xSortedIndices, xj)
val yi = getAt(ySortedIndices, yj)

val (i, value) = if (xi <= yi) {
val vv = a*xValues(xj)
xj += 1
(xi, vv)
}
else {
val vv = yValues(yj)
yj += 1
(yi, vv)
}

assert(i >= previ)

if (previ != i) {
newIndices(j) = i
newValues(j) = value
j += 1
}
else {
assert(newIndices(j - 1) == i)
newValues(j - 1) += value
}

previ = i
}

y.reassign(newIndices.slice(0, j), newValues.slice(0, j))
}

/** Y += a * x */
private[spark] def axpy(a: Double, X: DenseMatrix, Y: DenseMatrix): Unit = {
require(X.numRows == Y.numRows && X.numCols == Y.numCols, "Dimension mismatch: " +
Expand Down
34 changes: 25 additions & 9 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -700,21 +700,37 @@ object DenseVector {
* A sparse vector represented by an index array and an value array.
*
* @param size size of the vector.
* @param indices index array, assume to be strictly increasing.
* @param values value array, must have the same length as the index array.
* @param sortedIndices index array, assume to be strictly increasing.
* @param sortedValues value array, must have the same length as the index array.
*/
@Since("1.0.0")
@SQLUserDefinedType(udt = classOf[VectorUDT])
class SparseVector @Since("1.0.0") (
@Since("1.0.0") override val size: Int,
@Since("1.0.0") val indices: Array[Int],
@Since("1.0.0") val values: Array[Double]) extends Vector {
@Since("1.0.0") private var sortedIndices: Array[Int],
@Since("1.0.0") private var sortedValues: Array[Double]) extends Vector {

require(allRequirements())

def allRequirements(): Boolean = {
require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
s" indices match the dimension of the values. You provided ${indices.length} indices and " +
s" ${values.length} values.")
require(indices.length <= size, s"You provided ${indices.length} indices and values, " +
s"which exceeds the specified vector size ${size}.")

true
}

def reassign(newSortedIndices: Array[Int], newValues: Array[Double]): Unit = {
sortedIndices = newSortedIndices
sortedValues = newValues
require(allRequirements())
}

def indices: Array[Int] = sortedIndices

require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
s" indices match the dimension of the values. You provided ${indices.length} indices and " +
s" ${values.length} values.")
require(indices.length <= size, s"You provided ${indices.length} indices and values, " +
s"which exceeds the specified vector size ${size}.")
def values: Array[Double] = sortedValues

override def toString: String =
s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})"
Expand Down
Loading