Skip to content

Commit

Permalink
SPARKNLP-962: UAE Embeddings
Browse files Browse the repository at this point in the history
- Added default values
- Serialization tests
  • Loading branch information
DevinTDHa committed Mar 8, 2024
1 parent 84d4269 commit fa41371
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 13 deletions.
1 change: 1 addition & 0 deletions python/sparknlp/annotator/embeddings/uae_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def __init__(self, classname="com.johnsnowlabs.nlp.embeddings.UAEEmbeddings", ja
batchSize=8,
maxSentenceLength=512,
caseSensitive=False,
poolingStrategy="cls"
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,12 @@ class UAEEmbeddings(override val uid: String)
this
}

setDefault(dimension -> 1024, batchSize -> 8, maxSentenceLength -> 512, caseSensitive -> false)
setDefault(
dimension -> 1024,
batchSize -> 8,
maxSentenceLength -> 512,
caseSensitive -> false,
poolingStrategy -> "cls")

def tokenize(sentences: Seq[Annotation]): Seq[WordpieceTokenizedSentence] = {
val basicTokenizer = new BasicTokenizer($(caseSensitive))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.johnsnowlabs.nlp.embeddings

import com.johnsnowlabs.nlp.{DocumentAssembler, EmbeddingsFinisher}
import com.johnsnowlabs.nlp.base.LightPipeline
import com.johnsnowlabs.nlp.{Annotation, DocumentAssembler, EmbeddingsFinisher}
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.tags.SlowTest
import org.apache.spark.ml.Pipeline
Expand All @@ -18,7 +19,8 @@ class UAEEmbeddingsTestSpec extends AnyFlatSpec {
.setInputCols("document")
.setOutputCol("embeddings")

val data = Seq("hello world", "hello moon").toDF("text")
val rawData: Seq[String] = Seq("hello world", "hello moon")
val data = rawData.toDF("text")
val embeddingsFinisher = new EmbeddingsFinisher()
.setInputCols("embeddings")
.setOutputCols("embeddings")
Expand All @@ -30,7 +32,9 @@ class UAEEmbeddingsTestSpec extends AnyFlatSpec {
* @param expected
* The expected embeddings
*/
private def assertEmbeddings(expected: Array[Array[Float]]): Unit = {
private def assertEmbeddings(
expected: Array[Array[Float]],
pipeline: Pipeline = pipeline): Unit = {
val result = pipeline.fit(data).transform(data)

result.selectExpr("explode(embeddings)").show(5, 80)
Expand All @@ -47,7 +51,6 @@ class UAEEmbeddingsTestSpec extends AnyFlatSpec {
}

it should "work with default (cls) pooling" taggedAs SlowTest in {
model.setPoolingStrategy("cls")
val expected: Array[Array[Float]] = Array(
Array(0.50387836f, 0.5861595f, 0.35129607f, -0.76046336f, -0.32446113f, -0.11767582f,
0.49193293f, 0.58396333f, 0.8440052f, 0.3409165f, 0.02228897f, 0.3270517f, -0.3040624f,
Expand All @@ -58,16 +61,17 @@ class UAEEmbeddingsTestSpec extends AnyFlatSpec {
assertEmbeddings(expected)
}

val expected_cls_avg = Array(
Array(0.42190665f, 0.48439154f, 0.37961221f, -0.88345671f, -0.39864743f, -0.10434269f,
0.47246569f, 0.57266355f, 0.90948695f, 0.34240869f, -0.05249403f, 0.20690459f, -0.2502915f,
-0.075280815f, -0.72355306f, 0.37840521f),
Array(0.61534011f, 0.86877286f, 0.30440071f, -0.11193186f, -0.64877027f, 0.03778841f,
0.19575913f, 0.77637982f, 1.0544734f, 0.02276843f, 0.40709749f, 0.48178568f, -0.45722729f,
-0.25922f, -0.75728685f, -0.2886759f))

it should "work with cls_avg pooling" taggedAs SlowTest in {
model.setPoolingStrategy("cls_avg")
val expected = Array(
Array(0.42190665f, 0.48439154f, 0.37961221f, -0.88345671f, -0.39864743f, -0.10434269f,
0.47246569f, 0.57266355f, 0.90948695f, 0.34240869f, -0.05249403f, 0.20690459f,
-0.2502915f, -0.075280815f, -0.72355306f, 0.37840521f),
Array(0.61534011f, 0.86877286f, 0.30440071f, -0.11193186f, -0.64877027f, 0.03778841f,
0.19575913f, 0.77637982f, 1.0544734f, 0.02276843f, 0.40709749f, 0.48178568f, -0.45722729f,
-0.25922f, -0.75728685f, -0.2886759f))
assertEmbeddings(expected)
assertEmbeddings(expected_cls_avg)
}

it should "work with last pooling" taggedAs SlowTest in {
Expand Down Expand Up @@ -116,4 +120,38 @@ class UAEEmbeddingsTestSpec extends AnyFlatSpec {
assertEmbeddings(expected)
}

it should "be compatible with LightPipeline" taggedAs SlowTest in {
model.setPoolingStrategy("cls_avg")
val pipelineModel = pipeline.fit(data)
val lightPipeline = new LightPipeline(pipelineModel)
val result = lightPipeline.fullAnnotate(rawData.toArray)

val extractedEmbeddings: Array[Array[Float]] =
result.map(_("embeddings").head.asInstanceOf[Annotation].embeddings)
extractedEmbeddings
.zip(expected_cls_avg)
.foreach { case (embeddings, expected) =>
embeddings.take(16).zip(expected).foreach { case (e, exp) =>
assert(e === exp, "Embedding value not within tolerance")
}
}
}

it should "be serializable" taggedAs SlowTest in {
model.setPoolingStrategy("cls_avg")
val pipelineModel = pipeline.fit(data)
pipelineModel
.stages(1)
.asInstanceOf[UAEEmbeddings]
.write
.overwrite()
.save("./tmp_uae_model")

val loadedModel = UAEEmbeddings.load("./tmp_uae_model")
val newPipeline: Pipeline =
new Pipeline().setStages(Array(document, loadedModel, embeddingsFinisher))

assertEmbeddings(expected_cls_avg, newPipeline)
}

}

0 comments on commit fa41371

Please sign in to comment.