-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-22119][FOLLOWUP][ML] Use spherical KMeans with cosine distance #20518
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -310,19 +310,17 @@ class KMeans private ( | |
points.foreach { point => | ||
val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point) | ||
costAccum.add(cost) | ||
val sum = sums(bestCenter) | ||
axpy(1.0, point.vector, sum) | ||
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter)) | ||
counts(bestCenter) += 1 | ||
} | ||
|
||
counts.indices.filter(counts(_) > 0).map(j => (j, (sums(j), counts(j)))).iterator | ||
}.reduceByKey { case ((sum1, count1), (sum2, count2)) => | ||
axpy(1.0, sum2, sum1) | ||
(sum1, count1 + count2) | ||
}.mapValues { case (sum, count) => | ||
scal(1.0 / count, sum) | ||
new VectorWithNorm(sum) | ||
}.collectAsMap() | ||
}.collectAsMap().mapValues { case (sum, count) => | ||
distanceMeasureInstance.centroid(sum, count) | ||
} | ||
|
||
bcCenters.destroy(blocking = false) | ||
|
||
|
@@ -657,6 +655,26 @@ private[spark] abstract class DistanceMeasure extends Serializable { | |
v1: VectorWithNorm, | ||
v2: VectorWithNorm): Double | ||
|
||
/** | ||
* Updates the value of `sum` adding the `point` vector. | ||
* @param point a `VectorWithNorm` to be added to `sum` of a cluster | ||
* @param sum the `sum` for a cluster to be updated | ||
*/ | ||
def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { | ||
axpy(1.0, point.vector, sum) | ||
} | ||
|
||
/** | ||
* Returns a centroid for a cluster given its `sum` vector and its `count` of points. | ||
* | ||
* @param sum the `sum` for a cluster | ||
* @param count the number of points in the cluster | ||
* @return the centroid of the cluster | ||
*/ | ||
def centroid(sum: Vector, count: Long): VectorWithNorm = { | ||
scal(1.0 / count, sum) | ||
new VectorWithNorm(sum) | ||
} | ||
} | ||
|
||
@Since("2.4.0") | ||
|
@@ -743,6 +761,30 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure { | |
* @return the cosine distance between the two input vectors | ||
*/ | ||
override def distance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { | ||
assert(v1.norm > 0 && v2.norm > 0, "Cosine distance is not defined for zero-length vectors.") | ||
1 - dot(v1.vector, v2.vector) / v1.norm / v2.norm | ||
} | ||
|
||
/** | ||
* Updates the value of `sum` adding the `point` vector. | ||
* @param point a `VectorWithNorm` to be added to `sum` of a cluster | ||
* @param sum the `sum` for a cluster to be updated | ||
*/ | ||
override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = { | ||
axpy(1.0 / point.norm, point.vector, sum) | ||
} | ||
|
||
/** | ||
* Returns a centroid for a cluster given its `sum` vector and its `count` of points. | ||
* | ||
* @param sum the `sum` for a cluster | ||
* @param count the number of points in the cluster | ||
* @return the centroid of the cluster | ||
*/ | ||
override def centroid(sum: Vector, count: Long): VectorWithNorm = { | ||
scal(1.0 / count, sum) | ||
val norm = Vectors.norm(sum, 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather than scale There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you think that the performance improvement would be significant since we are doing it only on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't feel strongly about it, yeah. It won't matter much either way. |
||
scal(1.0 / norm, sum) | ||
new VectorWithNorm(sum, 1) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to ignore zero points here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the cosine similarity/distance is not defined for zero points: if there were 0 points we would have earlier failures while computing any cosine distance involving them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In scala,
1.0 / 0.0
generateInfinity
, what about directly throw an exception instead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I agree. I added an assertion before computing the cosine distance and a test case for this situation. Thank you for your comment.