Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARKNLP-884 Enabling getVectors method #13957

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/sparknlp/annotator/embeddings/doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,9 @@ def pretrained(name="doc2vec_gigaword_300", lang="en", remote_loc=None):
from sparknlp.pretrained import ResourceDownloader
return ResourceDownloader.downloadModel(Doc2VecModel, name, lang, remote_loc)

def getVectors(self):
"""
Returns the vector representation of the words as a dataframe
with two fields, word and vector.
"""
return self._call_java("getVectors")
6 changes: 6 additions & 0 deletions python/sparknlp/annotator/embeddings/word2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,9 @@ def pretrained(name="word2vec_gigaword_300", lang="en", remote_loc=None):
from sparknlp.pretrained import ResourceDownloader
return ResourceDownloader.downloadModel(Word2VecModel, name, lang, remote_loc)

def getVectors(self):
"""
Returns the vector representation of the words as a dataframe
with two fields, word and vector.
"""
return self._call_java("getVectors")
27 changes: 24 additions & 3 deletions src/main/scala/com/johnsnowlabs/nlp/embeddings/Doc2VecModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ import com.johnsnowlabs.nlp._
import com.johnsnowlabs.storage.HasStorageRef
import org.apache.spark.ml.param.{IntParam, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.types.{ArrayType, FloatType, StringType, StructField, StructType}

/** Word2Vec model that creates vector representations of words in a text corpus.
*
Expand Down Expand Up @@ -166,6 +167,21 @@ class Doc2VecModel(override val uid: String)
/** @group setParam */
def setWordVectors(value: Map[String, Array[Float]]): this.type = set(wordVectors, value)

private var sparkSession: Option[SparkSession] = None

def getVectors: DataFrame = {
val vectors: Map[String, Array[Float]] = $$(wordVectors)
val rows = vectors.toSeq.map { case (key, values) => Row(key, values) }
val schema = StructType(
StructField("word", StringType, nullable = false) ::
StructField("vector", ArrayType(FloatType), nullable = false) :: Nil)
if (sparkSession.isEmpty) {
throw new UnsupportedOperationException(
"Vector representation empty. Please run Doc2VecModel in some pipeline before accessing vector vocabulary.")
}
sparkSession.get.createDataFrame(sparkSession.get.sparkContext.parallelize(rows), schema)
}

setDefault(inputCols -> Array(TOKEN), outputCol -> "doc2vec", vectorSize -> 100)

private def calculateSentenceEmbeddings(matrix: Seq[Array[Float]]): Array[Float] = {
Expand All @@ -180,6 +196,11 @@ class Doc2VecModel(override val uid: String)
res
}

override def beforeAnnotate(dataset: Dataset[_]): Dataset[_] = {
sparkSession = Some(dataset.sparkSession)
dataset
}

/** takes a document and annotations and produces new annotations of this annotator's annotation
* type
*
Expand All @@ -204,8 +225,8 @@ class Doc2VecModel(override val uid: String)
.filter(_.nonEmpty)

val oovVector = Array.fill($(vectorSize))(0.0f)
val vectors = tokens.map { tokne =>
$$(wordVectors).getOrElse(tokne, oovVector)
val vectors = tokens.map { token =>
$$(wordVectors).getOrElse(token, oovVector)
}

val sentEmbeddings = calculateSentenceEmbeddings(vectors)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ import com.johnsnowlabs.nlp._
import com.johnsnowlabs.storage.HasStorageRef
import org.apache.spark.ml.param.{IntParam, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.{ArrayType, FloatType, StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}

/** Word2Vec model that creates vector representations of words in a text corpus.
*
Expand Down Expand Up @@ -167,8 +168,28 @@ class Word2VecModel(override val uid: String)
/** @group setParam */
def setWordVectors(value: Map[String, Array[Float]]): this.type = set(wordVectors, value)

private var sparkSession: Option[SparkSession] = None

def getVectors: DataFrame = {
val vectors: Map[String, Array[Float]] = $$(wordVectors)
val rows = vectors.toSeq.map { case (key, values) => Row(key, values) }
val schema = StructType(
StructField("word", StringType, nullable = false) ::
StructField("vector", ArrayType(FloatType), nullable = false) :: Nil)
if (sparkSession.isEmpty) {
throw new UnsupportedOperationException(
"Vector representation empty. Please run Word2VecModel in some pipeline before accessing vector vocabulary.")
}
sparkSession.get.createDataFrame(sparkSession.get.sparkContext.parallelize(rows), schema)
}

setDefault(inputCols -> Array(TOKEN), outputCol -> "word2vec", vectorSize -> 100)

override def beforeAnnotate(dataset: Dataset[_]): Dataset[_] = {
sparkSession = Some(dataset.sparkSession)
dataset
}

/** takes a document and annotations and produces new annotations of this annotator's annotation
* type
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.johnsnowlabs.nlp.embeddings

import com.johnsnowlabs.nlp.annotator._
import com.johnsnowlabs.nlp.annotators.SparkSessionTest
import com.johnsnowlabs.nlp.base._
import com.johnsnowlabs.nlp.training.CoNLL
import com.johnsnowlabs.nlp.util.io.ResourceHelper
Expand All @@ -27,7 +28,7 @@ import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, Multiclas
import org.apache.spark.sql.functions.{explode, when}
import org.scalatest.flatspec.AnyFlatSpec

class Doc2VecTestSpec extends AnyFlatSpec {
class Doc2VecTestSpec extends AnyFlatSpec with SparkSessionTest {

"Doc2VecApproach" should "train, save, and load back the saved model" taggedAs FastTest in {

Expand All @@ -43,18 +44,6 @@ class Doc2VecTestSpec extends AnyFlatSpec {
" ",
" ").toDF("text")

val document = new DocumentAssembler()
.setInputCol("text")
.setOutputCol("document")

val setence = new SentenceDetector()
.setInputCols("document")
.setOutputCol("sentence")

val tokenizer = new Tokenizer()
.setInputCols(Array("sentence"))
.setOutputCol("token")

val stops = new StopWordsCleaner()
.setInputCols("token")
.setOutputCol("cleanedToken")
Expand All @@ -67,7 +56,7 @@ class Doc2VecTestSpec extends AnyFlatSpec {
.setStorageRef("my_awesome_doc2vec")
.setEnableCaching(true)

val pipeline = new Pipeline().setStages(Array(document, setence, tokenizer, stops, doc2Vec))
val pipeline = new Pipeline().setStages(Array(documentAssembler, sentenceDetector, tokenizerWithSentence, stops, doc2Vec))

val pipelineModel = pipeline.fit(ddd)
val pipelineDF = pipelineModel.transform(ddd)
Expand All @@ -87,7 +76,7 @@ class Doc2VecTestSpec extends AnyFlatSpec {
.setOutputCol("sentence_embeddings")

val loadedPipeline =
new Pipeline().setStages(Array(document, setence, tokenizer, loadedDoc2Vec))
new Pipeline().setStages(Array(documentAssembler, sentenceDetector, tokenizerWithSentence, loadedDoc2Vec))

loadedPipeline.fit(ddd).transform(ddd).select("sentence_embeddings").show()

Expand All @@ -105,10 +94,6 @@ class Doc2VecTestSpec extends AnyFlatSpec {
"carbon emissions have come down without impinging on our growth .\\u2009.\\u2009.",
"the ").toDF("text")

val document = new DocumentAssembler()
.setInputCol("text")
.setOutputCol("document")

val setence = new SentenceDetector()
.setInputCols("document")
.setOutputCol("sentence")
Expand All @@ -135,7 +120,7 @@ class Doc2VecTestSpec extends AnyFlatSpec {

val pipeline = new Pipeline().setStages(
Array(
document,
documentAssembler,
setence,
tokenizerDocument,
tokenizerSentence,
Expand Down Expand Up @@ -332,4 +317,38 @@ class Doc2VecTestSpec extends AnyFlatSpec {
println("Area under ROC = " + auROC)

}

it should "get word vectors as spark dataframe" taggedAs SlowTest in {

import ResourceHelper.spark.implicits._

val testDataset = Seq(
"Rare Hendrix song draft sells for almost $17,000. This is my second sentenece! The third one here!")
.toDF("text")

val doc2Vec = Doc2VecModel
.pretrained()
.setInputCols("token")
.setOutputCol("embeddings")

val pipeline =
new Pipeline().setStages(Array(documentAssembler, tokenizer, doc2Vec))

val result = pipeline.fit(testDataset).transform(testDataset)
result.show()

doc2Vec.getVectors.show()
}

it should "raise an error when trying to retrieve empty word vectors" taggedAs SlowTest in {
val word2Vec = Doc2VecModel
.pretrained()
.setInputCols("token")
.setOutputCol("embeddings")

intercept[UnsupportedOperationException] {
word2Vec.getVectors
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -201,4 +201,37 @@ class Word2VecTestSpec extends AnyFlatSpec with SparkSessionTest {

}

it should "get word vectors as spark dataframe" taggedAs SlowTest in {

import ResourceHelper.spark.implicits._

val testDataset = Seq(
"Rare Hendrix song draft sells for almost $17,000. This is my second sentenece! The third one here!")
.toDF("text")

val word2Vec = Word2VecModel
.pretrained()
.setInputCols("token")
.setOutputCol("embeddings")

val pipeline =
new Pipeline().setStages(Array(documentAssembler, tokenizer, word2Vec))

val result = pipeline.fit(testDataset).transform(testDataset)
result.show()

word2Vec.getVectors.show()
}

it should "raise an error when trying to retrieve empty word vectors" taggedAs SlowTest in {
val word2Vec = Word2VecModel
.pretrained()
.setInputCols("token")
.setOutputCol("embeddings")

intercept[UnsupportedOperationException] {
word2Vec.getVectors
}
}

}