Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22119][FOLLOWUP][ML] Use spherical KMeans with cosine distance #20518

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -310,19 +310,17 @@ class KMeans private (
points.foreach { point =>
val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)
costAccum.add(cost)
val sum = sums(bestCenter)
axpy(1.0, point.vector, sum)
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter))
counts(bestCenter) += 1
}

counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator
}.reduceByKey { case ((sum1, count1), (sum2, count2)) =>
axpy(1.0, sum2, sum1)
(sum1, count1 + count2)
}.mapValues { case (sum, count) =>
scal(1.0 / count, sum)
new VectorWithNorm(sum)
}.collectAsMap()
}.collectAsMap().mapValues { case (sum, count) =>
distanceMeasureInstance.centroid(sum, count)
}

bcCenters.destroy(blocking = false)

Expand Down Expand Up @@ -657,6 +655,26 @@ private[spark] abstract class DistanceMeasure extends Serializable {
v1: VectorWithNorm,
v2: VectorWithNorm): Double

/**
* Updates the value of `sum` adding the `point` vector.
* @param point a `VectorWithNorm` to be added to `sum` of a cluster
* @param sum the `sum` for a cluster to be updated
*/
def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
axpy(1.0, point.vector, sum)
}

/**
* Returns a centroid for a cluster given its `sum` vector and its `count` of points.
*
* @param sum the `sum` for a cluster
* @param count the number of points in the cluster
* @return the centroid of the cluster
*/
def centroid(sum: Vector, count: Long): VectorWithNorm = {
scal(1.0 / count, sum)
new VectorWithNorm(sum)
}
}

@Since("2.4.0")
Expand Down Expand Up @@ -743,6 +761,30 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure {
* @return the cosine distance between the two input vectors
*/
override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = {
assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.")
1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm
}

/**
* Updates the value of `sum` adding the `point` vector.
* @param point a `VectorWithNorm` to be added to `sum` of a cluster
* @param sum the `sum` for a cluster to be updated
*/
override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
axpy(1.0 / point.norm, point.vector, sum)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to ignore zero points here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the cosine similarity/distance is not defined for zero points: if there were 0 points we would have earlier failures while computing any cosine distance involving them.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In scala, 1.0 / 0.0 generate Infinity, what about directly throw an exception instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I agree. I added an assertion before computing the cosine distance and a test case for this situation. Thank you for your comment.

}

/**
* Returns a centroid for a cluster given its `sum` vector and its `count` of points.
*
* @param sum the `sum` for a cluster
* @param count the number of points in the cluster
* @return the centroid of the cluster
*/
override def centroid(sum: Vector, count: Long): VectorWithNorm = {
scal(1.0 / count, sum)
val norm = Vectors.norm(sum, 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than scale sum twice, can you just compute its normal and then scale by 1 / (norm * count * count)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think that the performance improvement would be significant since we are doing it only on k vectors per run? I think the code is clearer in this way, do you agree?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel strongly about it, yeah. It won't matter much either way.

scal(1.0 / norm, sum)
new VectorWithNorm(sum, 1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.ml.clustering

import scala.util.Random

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
Expand Down Expand Up @@ -179,6 +179,19 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(predictionsMap(Vectors.dense(-1.0, 1.0)) ==
predictionsMap(Vectors.dense(-100.0, 90.0)))

model.clusterCenters.forall(Vectors.norm(_, 2) == 1.0)
}

test("KMeans with cosine distance is not supported for 0-length vectors") {
val model = new KMeans().setDistanceMeasure(DistanceMeasure.COSINE).setK(2)
val df = spark.createDataFrame(spark.sparkContext.parallelize(Array(
Vectors.dense(0.0, 0.0),
Vectors.dense(10.0, 10.0),
Vectors.dense(1.0, 0.5)
)).map(v => TestRow(v)))
val e = intercept[SparkException](model.fit(df))
assert(e.getCause.isInstanceOf[AssertionError])
assert(e.getCause.getMessage.contains("Cosine distance is not defined"))
}

test("read/write") {
Expand Down