diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 1ee4ea7b796e2..1cb2b55a7fc8e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -183,10 +183,19 @@ class StreamingKMeans( this } - /** Initialize random centers, requiring only the number of dimensions. */ - def setRandomCenters(d: Int): this.type = { - val initialCenters = (0 until k).map(_ => Vectors.dense(Array.fill(d)(nextGaussian()))).toArray - val clusterCounts = Array.fill(this.k)(0).map(_.toLong) + /** Initialize random centers, requiring only the number of dimensions. + * + * @param dim Number of dimensions + * @param seed Random seed + * */ + def setRandomCenters(dim: Int, seed: Long = Utils.random.nextLong): this.type = { + + val random = Utils.random + random.setSeed(seed) + + val initialCenters = (0 until k) + .map(_ => Vectors.dense(Array.fill(dim)(random.nextGaussian()))).toArray + val clusterCounts = new Array[Long](this.k) this.model = new StreamingKMeansModel(initialCenters, clusterCounts) this }