forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Liquan Pei
committed
Aug 1, 2014
1 parent
c475540
commit 8d6befe
Showing
2 changed files
with
393 additions
and
0 deletions.
There are no files selected for viewing
353 changes: 353 additions & 0 deletions
353
mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,353 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* Add a comment to this line | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.feature | ||
|
||
import scala.util._ | ||
import scala.collection.mutable.ArrayBuffer | ||
import scala.collection.mutable.HashMap | ||
import scala.collection.mutable | ||
|
||
import com.github.fommil.netlib.BLAS.{getInstance => blas} | ||
|
||
import org.apache.spark._ | ||
import org.apache.spark.rdd._ | ||
import org.apache.spark.SparkContext._ | ||
import org.apache.spark.mllib.linalg.Vector | ||
import org.apache.spark.HashPartitioner | ||
|
||
private case class VocabWord( | ||
var word: String, | ||
var cn: Int, | ||
var point: Array[Int], | ||
var code: Array[Int], | ||
var codeLen:Int | ||
) | ||
|
||
class Word2Vec( | ||
val size: Int, | ||
val startingAlpha: Double, | ||
val window: Int, | ||
val minCount: Int) | ||
extends Serializable with Logging { | ||
|
||
private val EXP_TABLE_SIZE = 1000 | ||
private val MAX_EXP = 6 | ||
private val MAX_CODE_LENGTH = 40 | ||
private val MAX_SENTENCE_LENGTH = 1000 | ||
private val layer1Size = size | ||
|
||
private var trainWordsCount = 0 | ||
private var vocabSize = 0 | ||
private var vocab: Array[VocabWord] = null | ||
private var vocabHash = mutable.HashMap.empty[String, Int] | ||
private var alpha = startingAlpha | ||
|
||
private def learnVocab(dataset: RDD[String]) { | ||
vocab = dataset.flatMap(line => line.split(" ")) | ||
.map(w => (w, 1)) | ||
.reduceByKey(_ + _) | ||
.map(x => VocabWord(x._1, x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0)) | ||
.filter(_.cn >= minCount) | ||
.collect() | ||
.sortWith((a, b)=> a.cn > b.cn) | ||
|
||
vocabSize = vocab.length | ||
var a = 0 | ||
while (a < vocabSize) { | ||
vocabHash += vocab(a).word -> a | ||
trainWordsCount += vocab(a).cn | ||
a += 1 | ||
} | ||
logInfo("trainWordsCount = " + trainWordsCount) | ||
} | ||
|
||
private def createExpTable(): Array[Double] = { | ||
val expTable = new Array[Double](EXP_TABLE_SIZE) | ||
var i = 0 | ||
while (i < EXP_TABLE_SIZE) { | ||
val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0) * MAX_EXP) | ||
expTable(i) = tmp / (tmp + 1) | ||
i += 1 | ||
} | ||
expTable | ||
} | ||
|
||
private def createBinaryTree() { | ||
val count = new Array[Long](vocabSize * 2 + 1) | ||
val binary = new Array[Int](vocabSize * 2 + 1) | ||
val parentNode = new Array[Int](vocabSize * 2 + 1) | ||
val code = new Array[Int](MAX_CODE_LENGTH) | ||
val point = new Array[Int](MAX_CODE_LENGTH) | ||
var a = 0 | ||
while (a < vocabSize) { | ||
count(a) = vocab(a).cn | ||
a += 1 | ||
} | ||
while (a < 2 * vocabSize) { | ||
count(a) = 1e9.toInt | ||
a += 1 | ||
} | ||
var pos1 = vocabSize - 1 | ||
var pos2 = vocabSize | ||
|
||
var min1i = 0 | ||
var min2i = 0 | ||
|
||
a = 0 | ||
while (a < vocabSize - 1) { | ||
if (pos1 >= 0) { | ||
if (count(pos1) < count(pos2)) { | ||
min1i = pos1 | ||
pos1 -= 1 | ||
} else { | ||
min1i = pos2 | ||
pos2 += 1 | ||
} | ||
} else { | ||
min1i = pos2 | ||
pos2 += 1 | ||
} | ||
if (pos1 >= 0) { | ||
if (count(pos1) < count(pos2)) { | ||
min2i = pos1 | ||
pos1 -= 1 | ||
} else { | ||
min2i = pos2 | ||
pos2 += 1 | ||
} | ||
} else { | ||
min2i = pos2 | ||
pos2 += 1 | ||
} | ||
count(vocabSize + a) = count(min1i) + count(min2i) | ||
parentNode(min1i) = vocabSize + a | ||
parentNode(min2i) = vocabSize + a | ||
binary(min2i) = 1 | ||
a += 1 | ||
} | ||
// Now assign binary code to each vocabulary word | ||
var i = 0 | ||
a = 0 | ||
while (a < vocabSize) { | ||
var b = a | ||
i = 0 | ||
while (b != vocabSize * 2 - 2) { | ||
code(i) = binary(b) | ||
point(i) = b | ||
i += 1 | ||
b = parentNode(b) | ||
} | ||
vocab(a).codeLen = i | ||
vocab(a).point(0) = vocabSize - 2 | ||
b = 0 | ||
while (b < i) { | ||
vocab(a).code(i - b - 1) = code(b) | ||
vocab(a).point(i - b) = point(b) - vocabSize | ||
b += 1 | ||
} | ||
a += 1 | ||
} | ||
} | ||
|
||
/** | ||
* Computes the vector representation of each word in | ||
* vocabulary | ||
* @param dataset an RDD of strings | ||
*/ | ||
|
||
def fit(dataset:RDD[String]): Word2VecModel = { | ||
|
||
learnVocab(dataset) | ||
|
||
createBinaryTree() | ||
|
||
val sc = dataset.context | ||
|
||
val expTable = sc.broadcast(createExpTable()) | ||
val V = sc.broadcast(vocab) | ||
val VHash = sc.broadcast(vocabHash) | ||
|
||
val sentences = dataset.flatMap(line => line.split(" ")).mapPartitions { | ||
iter => { new Iterator[Array[Int]] { | ||
def hasNext = iter.hasNext | ||
def next = { | ||
var sentence = new ArrayBuffer[Int] | ||
var sentenceLength = 0 | ||
while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) { | ||
val word = VHash.value.get(iter.next) | ||
word match { | ||
case Some(w) => { | ||
sentence += w | ||
sentenceLength += 1 | ||
} | ||
case None => | ||
} | ||
} | ||
sentence.toArray | ||
} | ||
} | ||
} | ||
} | ||
|
||
val newSentences = sentences.repartition(1).cache() | ||
val temp = Array.fill[Double](vocabSize * layer1Size)((Random.nextDouble - 0.5) / layer1Size) | ||
val (aggSyn0, _, _, _) = | ||
// TODO: broadcast temp instead of serializing it directly or initialize the model in each executor | ||
newSentences.aggregate((temp.clone(), new Array[Double](vocabSize * layer1Size), 0, 0))( | ||
seqOp = (c, v) => (c, v) match { case ((syn0, syn1, lastWordCount, wordCount), sentence) => | ||
var lwc = lastWordCount | ||
var wc = wordCount | ||
if (wordCount - lastWordCount > 10000) { | ||
lwc = wordCount | ||
alpha = startingAlpha * (1 - wordCount.toDouble / (trainWordsCount + 1)) | ||
if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 | ||
logInfo("wordCount = " + wordCount + ", alpha = " + alpha) | ||
} | ||
wc += sentence.size | ||
var pos = 0 | ||
while (pos < sentence.size) { | ||
val word = sentence(pos) | ||
// TODO: fix random seed | ||
val b = Random.nextInt(window) | ||
// Train Skip-gram | ||
var a = b | ||
while (a < window * 2 + 1 - b) { | ||
if (a != window) { | ||
val c = pos - window + a | ||
if (c >= 0 && c < sentence.size) { | ||
val lastWord = sentence(c) | ||
val l1 = lastWord * layer1Size | ||
val neu1e = new Array[Double](layer1Size) | ||
//HS | ||
var d = 0 | ||
while (d < vocab(word).codeLen) { | ||
val l2 = vocab(word).point(d) * layer1Size | ||
// Propagate hidden -> output | ||
var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1) | ||
if (f > -MAX_EXP && f < MAX_EXP) { | ||
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt | ||
f = expTable.value(ind) | ||
val g = (1 - vocab(word).code(d) - f) * alpha | ||
blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1) | ||
blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1) | ||
} | ||
d += 1 | ||
} | ||
blas.daxpy(layer1Size, 1.0, neu1e, 0, 1, syn0, l1, 1) | ||
} | ||
} | ||
a += 1 | ||
} | ||
pos += 1 | ||
} | ||
(syn0, syn1, lwc, wc) | ||
}, | ||
combOp = (c1, c2) => (c1, c2) match { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => | ||
val n = syn0_1.length | ||
blas.daxpy(n, 1.0, syn0_2, 1, syn0_1, 1) | ||
blas.daxpy(n, 1.0, syn1_2, 1, syn1_1, 1) | ||
(syn0_1, syn0_2, lwc_1 + lwc_2, wc_1 + wc_2) | ||
}) | ||
|
||
val wordMap = new Array[(String, Array[Double])](vocabSize) | ||
var i = 0 | ||
while (i < vocabSize) { | ||
val word = vocab(i).word | ||
val vector = new Array[Double](layer1Size) | ||
Array.copy(aggSyn0, i * layer1Size, vector, 0, layer1Size) | ||
wordMap(i) = (word, vector) | ||
i += 1 | ||
} | ||
val modelRDD = sc.parallelize(wordMap,100).partitionBy(new HashPartitioner(100)) | ||
new Word2VecModel(modelRDD) | ||
} | ||
} | ||
|
||
class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializable { | ||
|
||
val model = _model | ||
|
||
private def distance(v1: Array[Double], v2: Array[Double]): Double = { | ||
require(v1.length == v2.length, "Vectors should have the same length") | ||
val n = v1.length | ||
val norm1 = blas.dnrm2(n, v1, 1) | ||
val norm2 = blas.dnrm2(n, v2, 1) | ||
if (norm1 == 0 || norm2 == 0) return 0.0 | ||
blas.ddot(n, v1, 1, v2,1) / norm1 / norm2 | ||
} | ||
|
||
def transform(word: String): Array[Double] = { | ||
val result = model.lookup(word) | ||
if (result.isEmpty) Array[Double]() | ||
else result(0) | ||
} | ||
|
||
def transform(dataset: RDD[String]): RDD[Array[Double]] = { | ||
dataset.map(word => transform(word)) | ||
} | ||
|
||
def findSynonyms(word: String, num: Int): Array[(String, Double)] = { | ||
val vector = transform(word) | ||
if (vector.isEmpty) Array[(String, Double)]() | ||
else findSynonyms(vector,num) | ||
} | ||
|
||
def findSynonyms(vector: Array[Double], num: Int): Array[(String, Double)] = { | ||
require(num > 0, "Number of similar words should > 0") | ||
val topK = model.map( | ||
{case(w, vec) => (distance(vector, vec), w)}) | ||
.sortByKey(ascending = false) | ||
.take(num + 1) | ||
.map({case (dist, w) => (w, dist)}).drop(1) | ||
|
||
topK | ||
} | ||
} | ||
|
||
object Word2Vec extends Serializable with Logging { | ||
def train( | ||
input: RDD[String], | ||
size: Int, | ||
startingAlpha: Double, | ||
window: Int, | ||
minCount: Int): Word2VecModel = { | ||
new Word2Vec(size,startingAlpha, window, minCount).fit(input) | ||
} | ||
|
||
def main(args: Array[String]) { | ||
if (args.length < 6) { | ||
println("Usage: word2vec input size startingAlpha window minCount num") | ||
sys.exit(1) | ||
} | ||
val conf = new SparkConf() | ||
.setAppName("word2vec") | ||
|
||
val sc = new SparkContext(conf) | ||
val input = sc.textFile(args(0)) | ||
val size = args(1).toInt | ||
val startingAlpha = args(2).toDouble | ||
val window = args(3).toInt | ||
val minCount = args(4).toInt | ||
val num = args(5).toInt | ||
val model = train(input, size, startingAlpha, window, minCount) | ||
val vec = model.findSynonyms("china", num) | ||
for((w, dist) <- vec) logInfo(w.toString + " " + dist.toString) | ||
sc.stop() | ||
} | ||
} |
40 changes: 40 additions & 0 deletions
40
mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* Add a comment to this line | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.feature | ||
|
||
import org.scalatest.FunSuite | ||
import org.apache.spark.SparkContext._ | ||
import org.apache.spark.mllib.util.LocalSparkContext | ||
|
||
class Word2VecSuite extends FunSuite with LocalSparkContext { | ||
test("word2vec") { | ||
val num = 2 | ||
val localModel = Seq( | ||
("china" , Array(0.50, 0.50, 0.50, 0.50)), | ||
("japan" , Array(0.40, 0.50, 0.50, 0.50)), | ||
("taiwan", Array(0.60, 0.50, 0.50, 0.50)), | ||
("korea" , Array(0.45, 0.60, 0.60, 0.60)) | ||
) | ||
val model = new Word2VecModel(sc.parallelize(localModel, 2)) | ||
val synons = model.findSynonyms("china", num) | ||
assert(synons.length == num) | ||
assert(synons(0)._1 == "taiwan") | ||
assert(synons(1)._1 == "japan") | ||
} | ||
} |