Skip to content

Commit

Permalink
[MLLIB] minor update to word2vec
Browse files Browse the repository at this point in the history
very minor update Ishiihara

Author: Xiangrui Meng <meng@databricks.com>

Closes #2043 from mengxr/minor-w2v and squashes the following commits:

be649fd [Xiangrui Meng] remove map because we only need append
eccefcc [Xiangrui Meng] minor updates to word2vec
  • Loading branch information
mengxr committed Aug 20, 2014
1 parent 8b9dc99 commit 1870dba
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,9 @@ import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.rdd._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
import org.apache.spark.util.collection.PrimitiveKeyOpenHashMap

/**
* Entry in vocabulary
Expand Down Expand Up @@ -285,9 +283,9 @@ class Word2Vec extends Serializable with Logging {

val newSentences = sentences.repartition(numPartitions).cache()
val initRandom = new XORShiftRandom(seed)
var syn0Global =
val syn0Global =
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
var syn1Global = new Array[Float](vocabSize * vectorSize)
val syn1Global = new Array[Float](vocabSize * vectorSize)
var alpha = startingAlpha
for (k <- 1 to numIterations) {
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
Expand Down Expand Up @@ -349,21 +347,21 @@ class Word2Vec extends Serializable with Logging {
}
val syn0Local = model._1
val syn1Local = model._2
val synOut = new PrimitiveKeyOpenHashMap[Int, Array[Float]](vocabSize * 2)
val synOut = mutable.ListBuffer.empty[(Int, Array[Float])]
var index = 0
while(index < vocabSize) {
if (syn0Modify(index) != 0) {
synOut.update(index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize))
synOut += ((index, syn0Local.slice(index * vectorSize, (index + 1) * vectorSize)))
}
if (syn1Modify(index) != 0) {
synOut.update(index + vocabSize,
syn1Local.slice(index * vectorSize, (index + 1) * vectorSize))
synOut += ((index + vocabSize,
syn1Local.slice(index * vectorSize, (index + 1) * vectorSize)))
}
index += 1
}
Iterator(synOut)
synOut.toIterator
}
val synAgg = partial.flatMap(x => x).reduceByKey { case (v1, v2) =>
val synAgg = partial.reduceByKey { case (v1, v2) =>
blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
v1
}.collect()
Expand Down

0 comments on commit 1870dba

Please sign in to comment.