Skip to content

Commit

Permalink
[SPARK-22119][ML] Add cosine distance to KMeans
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Currently, KMeans assumes the only possible distance measure to be used is the Euclidean. This PR aims to add the cosine distance support to the KMeans algorithm.

## How was this patch tested?

existing and added UTs.

Author: Marco Gaido <marcogaido91@gmail.com>
Author: Marco Gaido <mgaido@hortonworks.com>

Closes #19340 from mgaido91/SPARK-22119.
  • Loading branch information
mgaido91 authored and srowen committed Jan 21, 2018
1 parent 121dc96 commit 4f43d27
Show file tree
Hide file tree
Showing 7 changed files with 315 additions and 66 deletions.
22 changes: 19 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
import org.apache.spark.mllib.clustering.{DistanceMeasure, KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -71,6 +71,15 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
@Since("1.5.0")
def getInitMode: String = $(initMode)

@Since("2.4.0")
final val distanceMeasure = new Param[String](this, "distanceMeasure", "The distance measure. " +
"Supported options: 'euclidean' and 'cosine'.",
(value: String) => MLlibKMeans.validateDistanceMeasure(value))

/** @group expertGetParam */
@Since("2.4.0")
def getDistanceMeasure: String = $(distanceMeasure)

/**
* Param for the number of steps for the k-means|| initialization mode. This is an advanced
* setting -- the default of 2 is almost always enough. Must be &gt; 0. Default: 2.
Expand Down Expand Up @@ -260,7 +269,8 @@ class KMeans @Since("1.5.0") (
maxIter -> 20,
initMode -> MLlibKMeans.K_MEANS_PARALLEL,
initSteps -> 2,
tol -> 1e-4)
tol -> 1e-4,
distanceMeasure -> DistanceMeasure.EUCLIDEAN)

@Since("1.5.0")
override def copy(extra: ParamMap): KMeans = defaultCopy(extra)
Expand All @@ -284,6 +294,10 @@ class KMeans @Since("1.5.0") (
@Since("1.5.0")
def setInitMode(value: String): this.type = set(initMode, value)

/** @group expertSetParam */
@Since("2.4.0")
def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value)

/** @group expertSetParam */
@Since("1.5.0")
def setInitSteps(value: Int): this.type = set(initSteps, value)
Expand Down Expand Up @@ -314,14 +328,16 @@ class KMeans @Since("1.5.0") (
}

val instr = Instrumentation.create(this, instances)
instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol)
instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, distanceMeasure,
maxIter, seed, tol)
val algo = new MLlibKMeans()
.setK($(k))
.setInitializationMode($(initMode))
.setInitializationSteps($(initSteps))
.setMaxIterations($(maxIter))
.setSeed($(seed))
.setEpsilon($(tol))
.setDistanceMeasure($(distanceMeasure))
val parentModel = algo.run(instances, Option(instr))
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
val summary = new KMeansSummary(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ private object BisectingKMeans extends Serializable {
val newClusterChildren = children.filter(newClusterCenters.contains(_))
if (newClusterChildren.nonEmpty) {
val selected = newClusterChildren.minBy { child =>
KMeans.fastSquaredDistance(newClusterCenters(child), v)
EuclideanDistanceMeasure.fastSquaredDistance(newClusterCenters(child), v)
}
(selected, v)
} else {
Expand Down Expand Up @@ -387,7 +387,7 @@ private object BisectingKMeans extends Serializable {
val rightIndex = rightChildIndex(rawIndex)
val indexes = Seq(leftIndex, rightIndex).filter(clusters.contains(_))
val height = math.sqrt(indexes.map { childIndex =>
KMeans.fastSquaredDistance(center, clusters(childIndex).center)
EuclideanDistanceMeasure.fastSquaredDistance(center, clusters(childIndex).center)
}.max)
val children = indexes.map(buildSubTree(_)).toArray
new ClusteringTreeNode(index, size, center, cost, height, children)
Expand Down Expand Up @@ -457,7 +457,7 @@ private[clustering] class ClusteringTreeNode private[clustering] (
this :: Nil
} else {
val selected = children.minBy { child =>
KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm)
EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm)
}
selected :: selected.predictPath(pointWithNorm)
}
Expand All @@ -475,7 +475,8 @@ private[clustering] class ClusteringTreeNode private[clustering] (
* Predicts the cluster index and the cost of the input point.
*/
private def predict(pointWithNorm: VectorWithNorm): (Int, Double) = {
predict(pointWithNorm, KMeans.fastSquaredDistance(centerWithNorm, pointWithNorm))
predict(pointWithNorm,
EuclideanDistanceMeasure.fastSquaredDistance(centerWithNorm, pointWithNorm))
}

/**
Expand All @@ -490,7 +491,7 @@ private[clustering] class ClusteringTreeNode private[clustering] (
(index, cost)
} else {
val (selectedChild, minCost) = children.map { child =>
(child, KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm))
(child, EuclideanDistanceMeasure.fastSquaredDistance(child.centerWithNorm, pointWithNorm))
}.minBy(_._2)
selectedChild.predict(pointWithNorm, minCost)
}
Expand Down
Loading

0 comments on commit 4f43d27

Please sign in to comment.