Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22119][ML] Add cosine distance to KMeans #19340

Closed
wants to merge 16 commits into from
22 changes: 19 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -71,6 +71,15 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
@Since("1.5.0")
def getInitMode: String = $(initMode)

@Since("2.4.0")
final val distanceMeasure = new Param[String](this, "distanceMeasure", "The distance measure. " +
"Supported options: 'euclidean' and 'cosine'.",
(value: String) => MLlibKMeans.validateDistanceMeasure(value))

/** @group expertGetParam */
@Since("2.4.0")
def getDistanceMeasure: String = $(distanceMeasure)

/**
* Param for the number of steps for the k-means|| initialization mode. This is an advanced
* setting -- the default of 2 is almost always enough. Must be > 0. Default: 2.
Expand Down Expand Up @@ -260,7 +269,8 @@ class KMeans @Since("1.5.0") (
maxIter -> 20,
initMode -> MLlibKMeans.K_MEANS_PARALLEL,
initSteps -> 2,
tol -> 1e-4)
tol -> 1e-4,
distanceMeasure -> DistanceMeasure.EUCLIDEAN)

@Since("1.5.0")
override def copy(extra: ParamMap): KMeans = defaultCopy(extra)
Expand All @@ -284,6 +294,10 @@ class KMeans @Since("1.5.0") (
@Since("1.5.0")
def setInitMode(value: String): this.type = set(initMode, value)

/** @group expertSetParam */
@Since("2.4.0")
def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value)

/** @group expertSetParam */
@Since("1.5.0")
def setInitSteps(value: Int): this.type = set(initSteps, value)
Expand Down Expand Up @@ -314,14 +328,16 @@ class KMeans @Since("1.5.0") (
}

val instr = Instrumentation.create(this, instances)
instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol)
instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure,
maxIter, seed, tol)
val algo = new MLlibKMeans()
.setK($(k))
.setInitializationMode($(initMode))
.setInitializationSteps($(initSteps))
.setMaxIterations($(maxIter))
.setSeed($(seed))
.setEpsilon($(tol))
.setDistanceMeasure($(distanceMeasure))
val parentModel = algo.run(instances, Option(instr))
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
val summary = new KMeansSummary(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ private object BisectingKMeans extends Serializable {
val newClusterChildren = children.filter(newClusterCenters.contains(_))
if (newClusterChildren.nonEmpty) {
val selected = newClusterChildren.minBy { child =>
KMeans.fastSquaredDistance(newClusterCenters(child), v)
EuclideanDistanceMeasure.fastSquaredDistance(newClusterCenters(child), v)
}
(selected, v)
} else {
Expand Down Expand Up @@ -379,7 +379,7 @@ private object BisectingKMeans extends Serializable {
val rightIndex = rightChildIndex(rawIndex)
val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_))
val height = math.sqrt(indexes.map { childIndex =>
KMeans.fastSquaredDistance(center, clusters(childIndex).center)
EuclideanDistanceMeasure.fastSquaredDistance(center, clusters(childIndex).center)
}.max)
val children = indexes.map(buildSubTree(_)).toArray
new ClusteringTreeNode(index, size, center, cost, height, children)
Expand Down Expand Up @@ -449,7 +449,7 @@ private[clustering] class ClusteringTreeNode private[clustering] (
this :: Nil
} else {
val selected = children.minBy { child =>
KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm)
EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm)
}
selected :: selected.predictPath(pointWithNorm)
}
Expand All @@ -467,7 +467,8 @@ private[clustering] class ClusteringTreeNode private[clustering] (
* Predicts the cluster index and the cost of the input point.
*/
private def predict(pointWithNorm: VectorWithNorm): (Int, Double) = {
predict(pointWithNorm, KMeans.fastSquaredDistance(centerWithNorm, pointWithNorm))
predict(pointWithNorm,
EuclideanDistanceMeasure.fastSquaredDistance(centerWithNorm, pointWithNorm))
}

/**
Expand All @@ -482,7 +483,7 @@ private[clustering] class ClusteringTreeNode private[clustering] (
(index, cost)
} else {
val (selectedChild, minCost) = children.map { child =>
(child, KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm))
(child, EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm))
}.minBy(_._2)
selectedChild.predict(pointWithNorm, minCost)
}
Expand Down
Loading