diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index c3375ed44fd99..fc1444705364a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -347,19 +347,20 @@ class Word2Vec extends Serializable with Logging { } val syn0Local = model._1 val syn1Local = model._2 - val synOut = mutable.ListBuffer.empty[(Int, Array[Float])] - var index = 0 - while(index < vocabSize) { - if (syn0Modify(index) != 0) { - synOut += ((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))) + // Only output modified vectors. + Iterator.tabulate(vocabSize) { index => + if (syn0Modify(index) > 0) { + Some((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))) + } else { + None } - if (syn1Modify(index) != 0) { - synOut += ((index + vocabSize, - syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))) + }.flatten ++ Iterator.tabulate(vocabSize) { index => + if (syn1Modify(index) > 0) { + Some((index + vocabSize, syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))) + } else { + None } - index += 1 - } - synOut.toIterator + }.flatten } val synAgg = partial.reduceByKey { case (v1, v2) => blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)