-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-22119][ML] Add cosine distance to KMeans #19340
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a bad idea but I'm not sure about the design direction here.
@@ -260,7 +269,8 @@ class KMeans @Since("1.5.0") ( | |||
maxIter -> 20, | |||
initMode -> MLlibKMeans.K_MEANS_PARALLEL, | |||
initSteps -> 2, | |||
tol -> 1e-4) | |||
tol -> 1e-4, | |||
distanceMeasure -> DistanceSuite.EUCLIDEAN) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"DistanceSuite" sounds like a test case, which you can't use here, but, looks like it's an object you added in non-test code. That's confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you have any suggestion about which might be an appropriate name? Thanks.
@@ -71,6 +71,15 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe | |||
@Since("1.5.0") | |||
def getInitMode: String = $(initMode) | |||
|
|||
@Since("2.3.0") | |||
final val distanceMeasure = new Param[String](this, "distanceMeasure", "The distance measure. " + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting question here -- what about supplying a function as an argument, for full generality? but then that doesn't translate to Pyspark I guess, and, probably only 2-3 distance functions that ever make sense here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be hard for two main reasons:
1 - as I will explain later, even though theoretically we would need only a function, in practice this is not true for performance reasons;
2 - saving and loading a function would be much harder (I'm not sure it would even be feasible).
@@ -40,20 +40,29 @@ import org.apache.spark.util.random.XORShiftRandom | |||
* to it should be cached by the user. | |||
*/ | |||
@Since("0.8.0") | |||
class KMeans private ( | |||
class KMeans @Since("2.3.0") private ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's a private constructor, it shouldn't have API "Since" annotations. You don't need to preserve it for compatibility.
} | ||
private[spark] def validateDistanceMeasure(distanceMeasure: String): Boolean = { | ||
distanceMeasure match { | ||
case DistanceSuite.EUCLIDEAN => true |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use two labels in one statement if the result is the same; might be clearer. Match is probably overkill anyway
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just wanted to be consistent with the similar implementation which is three lines above. Doing the same thing in two different ways a few lines of code after might be very confusing IMHO.
} | ||
|
||
@Since("2.3.0") | ||
object DistanceSuite { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this an API (and why so named)? this can all be internal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
About the name, if you have any better suggestion, I'd be happy to change it. Maybe DistanceMeasure
?
This in not internal because it contains the definition of the two constants which might be used by the users to set the right distance measure.
/** | ||
* Returns the index of the closest center to the given point, as well as the squared distance. | ||
*/ | ||
def findClosest( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would something like this vary with the distance function? finding a closest thing is the same for all definitions of close-ness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even though you are right in theory, if you look at the implementation for the euclidean distance, in the current code there is an optimization which doesn't use the real distance measure for performance reason. Thus, dropping this method and introducing a more generic one would cause a performance regression for the euclidean distance, which is something I'd definitely avoid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like this should have a default implementation then that does the obvious thing
@yanboliang may you please take a look at this when you have time? Thanks. |
@@ -246,14 +271,16 @@ class KMeans private ( | |||
|
|||
val initStartTime = System.nanoTime() | |||
|
|||
val distanceSuite = DistanceMeasure.decodeFromString(this.distanceMeasure) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this called "suite"?
/** | ||
* Returns the index of the closest center to the given point, as well as the squared distance. | ||
*/ | ||
def findClosest( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like this should have a default implementation then that does the obvious thing
/** | ||
* Returns whether a center converged or not, given the epsilon parameter. | ||
*/ | ||
def isCenterConverged( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likewise this always seems to be "distance < epsilon"; does it ever vary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not for the EuclideanDistance, since in that case for performance reason is: distance^2 < epsilon^2
.
Anyway, as you suggested for the findClosest
, I refactored the base class in order to have a default implementation based on the distance
method I introduced. And then, for the Euclidean distance I am overriding all the methods to the more efficient implementations. This looks the best and cleanest approach to me since it will allow to add more distance measures by implementing only the distance
method, as the current CosineDistance
implementation.
Thank you for your comments. Please, when you have time take a look at the new structure and let me know if it looks good to you now.
Thanks.
kindly remind to @srowen and @yanboliang if you can take a look at it when you have time, thanks. |
I'm kind of neutral given the complexity of adding this, but maybe it's the least complexity you can get away with. @hhbyyh was adding something related: https://issues.apache.org/jira/browse/SPARK-22195 |
thanks for your replt @srowen. I saw it. My feeling is that so far there is no distance metric definition on |
Test build #3945 has finished for PR 19340 at commit
|
kindly pinging @yanboliang |
@mgaido91 I actually needed something like that recently and I stumbled upon your PR (and JIRA, that I cannot update unfortunately). |
hi @Kevin-Ferret , thanks for looking at this. Yes, you are right, I have not changed the method for updating the centroids. The current methods seems to me the most widely adopted for cosine similarity too. Indeed, the same approach is used in RapidMiner (https://docs.rapidminer.com/latest/studio/operators/modeling/segmentation/k_means.html) and also in this paper (https://s3.amazonaws.com/academia.edu.documents/32952068/pg049_Similarity_Measures_for_Text_Document_Clustering.pdf?AWSAccessKeyId=AKIAIWOWYYGZ2Y53UL3A&Expires=1513706450&Signature=MFPcahadw35IpP2o0v%2F51xW7KOM%3D&response-content-disposition=inline%3B%20filename%3DSimilarity_Measures_for_Text_Document_Cl.pdf). |
@srowen this has been stuck a while now. Nobody so far was able to provide a "less complex" proposal. I tried to ping all the people I was aware of who might have helped. Do you have any suggestion how to go on? Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally @jkbradley could look at this, or @MLnick, as they are closer to this part, but it's looking good to me.
|
||
} | ||
|
||
@Since("2.3.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the "2.3.0" would likely have to change. I don't know if this would get in for 2.3.0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, any idea which version should I target here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd default to 2.4.0
object DistanceMeasure { | ||
|
||
@Since("2.3.0") | ||
val EUCLIDEAN = "euclidean" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we'd use an enum for this but I don't think Scala's enums are encouraged, and probably not worth involving Java enums.
@@ -149,4 +173,38 @@ object KMeansModel extends Loader[KMeansModel] { | |||
new KMeansModel(localCentroids.sortBy(_.id).map(_.point)) | |||
} | |||
} | |||
object SaveLoadV2_0 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: blank line before?
} | ||
|
||
/** | ||
* Returns the K-means cost of a given point against the given cluster centers. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: might make this @return
in order to get it to render in docs as the return documentation
def pointCost( | ||
centers: TraversableOnce[VectorWithNorm], | ||
point: VectorWithNorm): Double = | ||
findClosest(centers, point)._2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another nit, might put braces around these one-line functions just for clarity.
oldCenter: VectorWithNorm, | ||
newCenter: VectorWithNorm, | ||
epsilon: Double): Boolean = { | ||
EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter) <= epsilon * epsilon |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to override default isCenterConverged
here? Seems to me it is equal to the default one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is not, here we compare with epsilon * epsilon
and we use the squared distance, in order to avoid the computation of the square, which is an expensive operation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add sqrt to both sides?
Math.sqrt(EuclideanDistanceMeasure.fastSquaredDistance(newCenter, oldCenter)) <= Math.sqrt(epsilon * epsilon) = epsilon
The left one is just the override distance
, isn't?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using sqrt
would introduce a performance regression. This is the reason why I can't use only a function to differentiate the to distance measures, because the implementation for Euclidean distance is highly optimized and this is an optimization: avoiding sqrt can be a great performance improvement since it is an expensive operation.
private[spark] abstract class DistanceMeasure extends Serializable { | ||
|
||
/** | ||
* @return the index of the closest center to the given point, as well as the squared distance. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not always returns squared distance for now.
Based on some discussion I can find quickly, I am not sure if we can just support cosine distance by replacing distance function: |
@srowen thank you for pointing out my style issue. Addressed, thanks. |
The link to Matlab doc explicitly describes how it computes centroid clusters differently for the different, supported distance measures. For cosine distance, the centroids are computed with normalized points, before computing the mean of the points. In this part, seems to me Matlab's approach is more comprehensive than RapidMiners which only takes the mean of points without normalization. I quickly looked at Spark's KMeans implementation, looks like we now also compute the centroids as the mean of the points without normalization. I'm not sure if this can be an issue in practice usage of KMeans and affect its results or correctness. If we don't want to update centroids differently for different distance measures. I think we should at least clarify it in documents to warn users. |
Test build #4064 has finished for PR 19340 at commit
|
/** | ||
* A vector with its norm for fast distance computation. | ||
* | ||
* @see [[org.apache.spark.mllib.clustering.KMeans#fastSquaredDistance]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to fail the doc build for some reason. You can just remove it.
Jenkins, retest this please |
@srowen sorry, I don't know why but it seems that I cannot start new jenkins jobs for this PR... May you white-list it or trigger a new test please? Thanks. |
I think it may not be responding now for whatever reason. I use https://spark-prs.appspot.com/ to view and trigger tests |
Thanks, I didn't know its existence. |
Test build #4065 has finished for PR 19340 at commit
|
Merged to master |
@mgaido91 @srowen I have the same concern as @Kevin-Ferret and @viirya In matlab's doc for KMeans, it says "One minus the cosine of the included angle between points (treated as vectors). Each centroid is the mean of the points in that cluster, after normalizing those points to unit Euclidean length." I think RapidMiners's implementation of KMeans with cosine similarity is wrong, if it just assign new center with the arithmetic mean. Some reference: Scikit-Learn's example: Clustering text documents using k-means https://stats.stackexchange.com/questions/299013/cosine-distance-as-similarity-measure-in-kmeans https://www.quora.com/How-can-I-use-cosine-similarity-in-clustering-For-example-K-means-clustering |
@Kevin-Ferret pointed out that both the input and the centers should be normalized to unit Euclidean length. Citing you @zhengruifeng ,
Therefore ensuring convergence means that the input dataset should contain unit length vectors, but this should be done by the user. I think we can add a comment in the documentation or adding a check and a WARN, but this has performance impact. |
I think you could reasonably define it either way; depends on how much you think the cluster center is always defined as the mean (in "k-means") regardless of distance function, or not. However I think I'm more sympathetic now to defining the center as the point that minimizes intra-cluster distance, which isn't quite the same thing. In that case yes you must normalize the inputs in order for Euclidean distance and cosine distance to match up. Yeah you could tell the user that she can basically choose this behavior or not by normalizing or not. I think I'd now believe that's more potential for surprise than a useful choice. So yeah I'd also support going back and normalizing the inputs in all cases here when cosine distance is used. |
The updating of centers should be viewed as the M-step in EM algorithm, in which some objective is optimized. Since cosine similarity do not take vector-norm into account:
If we want to optimize intra-cluster cosine similarity (like Matlab), then arithmetic mean of normized points should be a better solution than arithmetic mean of original points. Suppose two 2D points (x=0,y=1) and (x=100,y=0):
Since
|
@zhengruifeng I agree with you, but then we can also normalize the center points here, since the user can access them and would therefore expect them to be unit length norm vectors. WDYT? |
@zhengruifeng yes I understand why the solutions aren't the same, though it depends on whether you think that's what k-means is supposed to do or not. We're not actually maximizing an expectation here, but, this is just semantics and I agree with you. |
@mgaido91 what do you think about the right follow-up here? as in your comment just above? |
@mgaido91 agree that it is better to normalize centers |
@srowen honestly I don't think that we should change current implementation. Rapidminer, ELKI and nltk work like this. Matlab instead works differently and does what suggested by @Kevin-Ferret and @zhengruifeng. Anyway, it looks like a majority (@viirya, @Kevin-Ferret, @zhengruifeng ) think that the other solution is better. So I think that if we change it, we should do basically the change suggested by @zhengruifeng and the normalization of the centers, otherwise we would come out with an hybrid and unclear solution. I can submit a follow up PR with this second solution and maybe we can continue the discussion there. What do you think? |
## What changes were proposed in this pull request? In apache#19340 some comments considered needed to use spherical KMeans when cosine distance measure is specified, as Matlab does; instead of the implementation based on the behavior of other tools/libraries like Rapidminer, nltk and ELKI, ie. the centroids are computed as the mean of all the points in the clusters. The PR introduce the approach used in spherical KMeans. This behavior has the nice feature to minimize the within-cluster cosine distance. ## How was this patch tested? existing/improved UTs Author: Marco Gaido <marcogaido91@gmail.com> Closes apache#20518 from mgaido91/SPARK-22119_followup.
## What changes were proposed in this pull request? In apache#19340 some comments considered needed to use spherical KMeans when cosine distance measure is specified, as Matlab does; instead of the implementation based on the behavior of other tools/libraries like Rapidminer, nltk and ELKI, ie. the centroids are computed as the mean of all the points in the clusters. The PR introduce the approach used in spherical KMeans. This behavior has the nice feature to minimize the within-cluster cosine distance. ## How was this patch tested? existing/improved UTs Author: Marco Gaido <marcogaido91@gmail.com> Closes apache#20518 from mgaido91/SPARK-22119_followup.
What changes were proposed in this pull request?
Currently, KMeans assumes the only possible distance measure to be used is the Euclidean. This PR aims to add the cosine distance support to the KMeans algorithm.
How was this patch tested?
existing and added UTs.