From 1a8fb4127b9433945e75beea16fc2d485a249219 Mon Sep 17 00:00:00 2001 From: Liquan Pei Date: Sun, 3 Aug 2014 16:24:35 -0700 Subject: [PATCH] use weighted sum in combOp --- .../org/apache/spark/mllib/feature/Word2Vec.scala | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 3ace0800fb9f8..66429f5af1a46 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -87,7 +87,7 @@ class Word2Vec( private var vocabHash = mutable.HashMap.empty[String, Int] private var alpha = startingAlpha - private def learnVocab(words:RDD[String]) { + private def learnVocab(words:RDD[String]){ vocab = words.map(w => (w, 1)) .reduceByKey(_ + _) .map(x => VocabWord( @@ -110,6 +110,10 @@ class Word2Vec( logInfo("trainWordsCount = " + trainWordsCount) } + private def learnVocabPerPartition(words:RDD[String]) { + + } + private def createExpTable(): Array[Double] = { val expTable = new Array[Double](EXP_TABLE_SIZE) var i = 0 @@ -303,8 +307,12 @@ class Word2Vec( combOp = (c1, c2) => (c1, c2) match { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => val n = syn0_1.length - blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1) - blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1) + val weight1 = 1.0 * wc_1 / (wc_1 + wc_2) + val weight2 = 1.0 * wc_2 / (wc_1 + wc_2) + blas.dscal(n, weight1, syn0_1, 1) + blas.dscal(n, weight1, syn1_1, 1) + blas.daxpy(n, weight2, syn0_2, 1, syn0_1, 1) + blas.daxpy(n, weight2, syn1_2, 1, syn1_1, 1) (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) }) syn0Global = aggSyn0