From bf6d21edfa8a81b6e373357d702122caac0e757e Mon Sep 17 00:00:00 2001 From: Devin Ha <33089471+DevinTDHa@users.noreply.github.com> Date: Fri, 5 Apr 2024 17:37:07 +0200 Subject: [PATCH] SPARKNLP-962: UAEEmbeddings (#14199) * SPARKNLP-962: UAE Embeddings - added Scala side * SPARKNLP-962: UAE Embeddings - added Python Side * SPARKNLP-962: UAE Embeddings - Added default values - Serialization tests * Bugfix: Can't serialize models without onnx_data file - onnxModelPath is not set for models without an .onnx_data file, so it will be None - None.get will throw an error, this checks for it first * SPARKNLP-962: UAE Embeddings - Documentation * SPARKNLP-962: UAE Embeddings - make tests lazy --- docs/en/annotators.md | 1 + docs/en/transformer_entries/UAEEmbeddings.md | 157 +++++ .../HuggingFace_ONNX_in_Spark_NLP_E5.ipynb | 3 +- .../sparknlp/annotator/embeddings/__init__.py | 1 + .../annotator/embeddings/uae_embeddings.py | 211 +++++++ python/sparknlp/internal/__init__.py | 13 +- .../pretrained/resource_downloader.py | 5 +- .../embeddings/uae_embeddings_test.py | 48 ++ .../scala/com/johnsnowlabs/ml/ai/UAE.scala | 300 ++++++++++ .../ml/onnx/OnnxSerializeModel.scala | 9 +- .../com/johnsnowlabs/ml/util/LinAlg.scala | 121 +++- .../com/johnsnowlabs/nlp/annotator.scala | 2 + .../nlp/embeddings/UAEEmbeddings.scala | 536 ++++++++++++++++++ .../nlp/pretrained/ResourceDownloader.scala | 3 +- .../com/johnsnowlabs/ml/util/LinAlgTest.scala | 82 +++ .../embeddings/UAEEmbeddingsTestSpec.scala | 157 +++++ 16 files changed, 1635 insertions(+), 14 deletions(-) create mode 100644 docs/en/transformer_entries/UAEEmbeddings.md create mode 100644 python/sparknlp/annotator/embeddings/uae_embeddings.py create mode 100644 python/test/annotator/embeddings/uae_embeddings_test.py create mode 100644 src/main/scala/com/johnsnowlabs/ml/ai/UAE.scala create mode 100644 src/main/scala/com/johnsnowlabs/nlp/embeddings/UAEEmbeddings.scala create mode 100644 src/test/scala/com/johnsnowlabs/nlp/embeddings/UAEEmbeddingsTestSpec.scala diff --git a/docs/en/annotators.md b/docs/en/annotators.md index 135641dd0e8266..858a07d0a06336 100644 --- a/docs/en/annotators.md +++ b/docs/en/annotators.md @@ -161,6 +161,7 @@ Additionally, these transformers are available. {% include templates/anno_table_entry.md path="./transformers" name="SwinForImageClassification" summary="SwinImageClassification is an image classifier based on Swin."%} {% include templates/anno_table_entry.md path="./transformers" name="T5Transformer" summary="T5 reconsiders all NLP tasks into a unified text-to-text-format where the input and output are always text strings, in contrast to BERT-style models that can only output either a class label or a span of the input."%} {% include templates/anno_table_entry.md path="./transformers" name="TapasForQuestionAnswering" summary="TapasForQuestionAnswering is an implementation of TaPas - a BERT-based model specifically designed for answering questions about tabular data."%} +{% include templates/anno_table_entry.md path="./transformers" name="UAEEmbeddings" summary="Sentence embeddings using Universal AnglE Embedding (UAE)."%} {% include templates/anno_table_entry.md path="./transformers" name="UniversalSentenceEncoder" summary="The Universal Sentence Encoder encodes text into high dimensional vectors that can be used for text classification, semantic similarity, clustering and other natural language tasks."%} {% include templates/anno_table_entry.md path="./transformers" name="VisionEncoderDecoderForImageCaptioning" summary="VisionEncoderDecoder model that converts images into text captions."%} {% include templates/anno_table_entry.md path="./transformers" name="ViTForImageClassification" summary="Vision Transformer (ViT) for image classification."%} diff --git a/docs/en/transformer_entries/UAEEmbeddings.md b/docs/en/transformer_entries/UAEEmbeddings.md new file mode 100644 index 00000000000000..c1d5b6c9c517f5 --- /dev/null +++ b/docs/en/transformer_entries/UAEEmbeddings.md @@ -0,0 +1,157 @@ +{%- capture title -%} +UAEEmbeddings +{%- endcapture -%} + +{%- capture description -%} +Sentence embeddings using Universal AnglE Embedding (UAE). + +UAE is a novel angle-optimized text embedding model, designed to improve semantic textual +similarity tasks, which are crucial for Large Language Model (LLM) applications. By +introducing angle optimization in a complex space, AnglE effectively mitigates saturation of +the cosine similarity function. + +Pretrained models can be loaded with `pretrained` of the companion object: + +```scala +val embeddings = UAEEmbeddings.pretrained() + .setInputCols("document") + .setOutputCol("UAE_embeddings") +``` + +The default model is `"uae_large_v1"`, if no name is provided. + +For available pretrained models please see the +[Models Hub](https://sparknlp.org/models?q=UAE). + +For extended examples of usage, see +[UAEEmbeddingsTestSpec](https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/embeddings/UAEEmbeddingsTestSpec.scala). + +**Sources** : + +[AnglE-optimized Text Embeddings](https://arxiv.org/abs/2309.12871) + +[UAE Github Repository](https://github.com/baochi0212/uae-embedding) + +**Paper abstract** + +*High-quality text embedding is pivotal in improving semantic textual similarity (STS) tasks, +which are crucial components in Large Language Model (LLM) applications. However, a common +challenge existing text embedding models face is the problem of vanishing gradients, primarily +due to their reliance on the cosine function in the optimization objective, which has +saturation zones. To address this issue, this paper proposes a novel angle-optimized text +embedding model called AnglE. The core idea of AnglE is to introduce angle optimization in a +complex space. This novel approach effectively mitigates the adverse effects of the saturation +zone in the cosine function, which can impede gradient and hinder optimization processes. To +set up a comprehensive STS evaluation, we experimented on existing short-text STS datasets and +a newly collected long-text STS dataset from GitHub Issues. Furthermore, we examine +domain-specific STS scenarios with limited labeled data and explore how AnglE works with +LLM-annotated data. Extensive experiments were conducted on various tasks including short-text +STS, long-text STS, and domain-specific STS tasks. The results show that AnglE outperforms the +state-of-the-art (SOTA) STS models that ignore the cosine saturation zone. These findings +demonstrate the ability of AnglE to generate high-quality text embeddings and the usefulness +of angle optimization in STS.* +{%- endcapture -%} + +{%- capture input_anno -%} +DOCUMENT +{%- endcapture -%} + +{%- capture output_anno -%} +SENTENCE_EMBEDDINGS +{%- endcapture -%} + +{%- capture python_example -%} +import sparknlp +from sparknlp.base import * +from sparknlp.annotator import * +from pyspark.ml import Pipeline +documentAssembler = DocumentAssembler() \ + .setInputCol("text") \ + .setOutputCol("document") +embeddings = UAEEmbeddings.pretrained() \ + .setInputCols(["document"]) \ + .setOutputCol("embeddings") +embeddingsFinisher = EmbeddingsFinisher() \ + .setInputCols("embeddings") \ + .setOutputCols("finished_embeddings") \ + .setOutputAsVector(True) +pipeline = Pipeline().setStages([ + documentAssembler, + embeddings, + embeddingsFinisher +]) + +data = spark.createDataFrame([["hello world", "hello moon"]]).toDF("text") +result = pipeline.fit(data).transform(data) +result.selectExpr("explode(finished_embeddings) as result").show(5, 80) ++--------------------------------------------------------------------------------+ +| result| ++--------------------------------------------------------------------------------+ +|[0.50387806, 0.5861606, 0.35129607, -0.76046336, -0.32446072, -0.117674336, 0...| +|[0.6660665, 0.961762, 0.24854276, -0.1018044, -0.6569202, 0.027635604, 0.1915...| ++--------------------------------------------------------------------------------+ +{%- endcapture -%} + +{%- capture scala_example -%} +import spark.implicits._ +import com.johnsnowlabs.nlp.base.DocumentAssembler +import com.johnsnowlabs.nlp.annotators.Tokenizer +import com.johnsnowlabs.nlp.embeddings.UAEEmbeddings +import com.johnsnowlabs.nlp.EmbeddingsFinisher +import org.apache.spark.ml.Pipeline + +val documentAssembler = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + +val embeddings = UAEEmbeddings.pretrained() + .setInputCols("document") + .setOutputCol("UAE_embeddings") + +val embeddingsFinisher = new EmbeddingsFinisher() + .setInputCols("UAE_embeddings") + .setOutputCols("finished_embeddings") + .setOutputAsVector(true) + +val pipeline = new Pipeline().setStages(Array( + documentAssembler, + embeddings, + embeddingsFinisher +)) + +val data = Seq("hello world", "hello moon").toDF("text") +val result = pipeline.fit(data).transform(data) + +result.selectExpr("explode(finished_embeddings) as result").show(5, 80) ++--------------------------------------------------------------------------------+ +| result| ++--------------------------------------------------------------------------------+ +|[0.50387806, 0.5861606, 0.35129607, -0.76046336, -0.32446072, -0.117674336, 0...| +|[0.6660665, 0.961762, 0.24854276, -0.1018044, -0.6569202, 0.027635604, 0.1915...| ++--------------------------------------------------------------------------------+ + +{%- endcapture -%} + +{%- capture api_link -%} +[UAEEmbeddings](/api/com/johnsnowlabs/nlp/embeddings/UAEEmbeddings) +{%- endcapture -%} + +{%- capture python_api_link -%} +[UAEEmbeddings](/api/python/reference/autosummary/sparknlp/annotator/embeddings/uae_embeddings/index.html#sparknlp.annotator.embeddings.uae_embeddings.UAEEmbeddings) +{%- endcapture -%} + +{%- capture source_link -%} +[UAEEmbeddings](https://github.com/JohnSnowLabs/spark-nlp/tree/master/src/main/scala/com/johnsnowlabs/nlp/embeddings/UAEEmbeddings.scala) +{%- endcapture -%} + +{% include templates/anno_template.md +title=title +description=description +input_anno=input_anno +output_anno=output_anno +python_example=python_example +scala_example=scala_example +api_link=api_link +python_api_link=python_api_link +source_link=source_link +%} \ No newline at end of file diff --git a/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_E5.ipynb b/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_E5.ipynb index 8e2b4ccfabb266..e63c891e83a887 100644 --- a/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_E5.ipynb +++ b/examples/python/transformers/onnx/HuggingFace_ONNX_in_Spark_NLP_E5.ipynb @@ -389,8 +389,7 @@ "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" + "pygments_lexer": "ipython3" } }, "nbformat": 4, diff --git a/python/sparknlp/annotator/embeddings/__init__.py b/python/sparknlp/annotator/embeddings/__init__.py index 1ddf7952558df7..f07049bfcc3caa 100644 --- a/python/sparknlp/annotator/embeddings/__init__.py +++ b/python/sparknlp/annotator/embeddings/__init__.py @@ -36,3 +36,4 @@ from sparknlp.annotator.embeddings.xlm_roberta_sentence_embeddings import * from sparknlp.annotator.embeddings.xlnet_embeddings import * from sparknlp.annotator.embeddings.bge_embeddings import * +from sparknlp.annotator.embeddings.uae_embeddings import * diff --git a/python/sparknlp/annotator/embeddings/uae_embeddings.py b/python/sparknlp/annotator/embeddings/uae_embeddings.py new file mode 100644 index 00000000000000..1a25d63ded20c8 --- /dev/null +++ b/python/sparknlp/annotator/embeddings/uae_embeddings.py @@ -0,0 +1,211 @@ +# Copyright 2017-2022 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains classes for UAEEmbeddings.""" + +from sparknlp.common import * + + +class UAEEmbeddings(AnnotatorModel, + HasEmbeddingsProperties, + HasCaseSensitiveProperties, + HasStorageRef, + HasBatchedAnnotate, + HasMaxSentenceLengthLimit): + """Sentence embeddings using Universal AnglE Embedding (UAE). + + UAE is a novel angle-optimized text embedding model, designed to improve semantic textual + similarity tasks, which are crucial for Large Language Model (LLM) applications. By + introducing angle optimization in a complex space, AnglE effectively mitigates saturation of + the cosine similarity function. + + Pretrained models can be loaded with :meth:`.pretrained` of the companion + object: + + >>> embeddings = UAEEmbeddings.pretrained() \\ + ... .setInputCols(["document"]) \\ + ... .setOutputCol("UAE_embeddings") + + + The default model is ``"uae_large_v1"``, if no name is provided. + + For available pretrained models please see the + `Models Hub `__. + + + ====================== ====================== + Input Annotation types Output Annotation type + ====================== ====================== + ``DOCUMENT`` ``SENTENCE_EMBEDDINGS`` + ====================== ====================== + + Parameters + ---------- + batchSize + Size of every batch , by default 8 + dimension + Number of embedding dimensions, by default 768 + caseSensitive + Whether to ignore case in tokens for embeddings matching, by default False + maxSentenceLength + Max sentence length to process, by default 512 + configProtoBytes + ConfigProto from tensorflow, serialized into byte array. + + References + ---------- + + `AnglE-optimized Text Embeddings `__ + `UAE Github Repository `__ + + **Paper abstract** + + *High-quality text embedding is pivotal in improving semantic textual similarity (STS) tasks, + which are crucial components in Large Language Model (LLM) applications. However, a common + challenge existing text embedding models face is the problem of vanishing gradients, primarily + due to their reliance on the cosine function in the optimization objective, which has + saturation zones. To address this issue, this paper proposes a novel angle-optimized text + embedding model called AnglE. The core idea of AnglE is to introduce angle optimization in a + complex space. This novel approach effectively mitigates the adverse effects of the saturation + zone in the cosine function, which can impede gradient and hinder optimization processes. To + set up a comprehensive STS evaluation, we experimented on existing short-text STS datasets and + a newly collected long-text STS dataset from GitHub Issues. Furthermore, we examine + domain-specific STS scenarios with limited labeled data and explore how AnglE works with + LLM-annotated data. Extensive experiments were conducted on various tasks including short-text + STS, long-text STS, and domain-specific STS tasks. The results show that AnglE outperforms the + state-of-the-art (SOTA) STS models that ignore the cosine saturation zone. These findings + demonstrate the ability of AnglE to generate high-quality text embeddings and the usefulness + of angle optimization in STS.* + + Examples + -------- + >>> import sparknlp + >>> from sparknlp.base import * + >>> from sparknlp.annotator import * + >>> from pyspark.ml import Pipeline + >>> documentAssembler = DocumentAssembler() \\ + ... .setInputCol("text") \\ + ... .setOutputCol("document") + >>> embeddings = UAEEmbeddings.pretrained() \\ + ... .setInputCols(["document"]) \\ + ... .setOutputCol("embeddings") + >>> embeddingsFinisher = EmbeddingsFinisher() \\ + ... .setInputCols("embeddings") \\ + ... .setOutputCols("finished_embeddings") \\ + ... .setOutputAsVector(True) + >>> pipeline = Pipeline().setStages([ + ... documentAssembler, + ... embeddings, + ... embeddingsFinisher + ... ]) + >>> data = spark.createDataFrame([["hello world", "hello moon"]]).toDF("text") + >>> result = pipeline.fit(data).transform(data) + >>> result.selectExpr("explode(finished_embeddings) as result").show(5, 80) + +--------------------------------------------------------------------------------+ + | result| + +--------------------------------------------------------------------------------+ + |[0.50387806, 0.5861606, 0.35129607, -0.76046336, -0.32446072, -0.117674336, 0...| + |[0.6660665, 0.961762, 0.24854276, -0.1018044, -0.6569202, 0.027635604, 0.1915...| + +--------------------------------------------------------------------------------+ + """ + + name = "UAEEmbeddings" + + inputAnnotatorTypes = [AnnotatorType.DOCUMENT] + + outputAnnotatorType = AnnotatorType.SENTENCE_EMBEDDINGS + poolingStrategy = Param(Params._dummy(), + "poolingStrategy", + "Pooling strategy to use for sentence embeddings", + TypeConverters.toString) + + def setPoolingStrategy(self, value): + """Pooling strategy to use for sentence embeddings. + + Available pooling strategies for sentence embeddings are: + - `"cls"`: leading `[CLS]` token + - `"cls_avg"`: leading `[CLS]` token + mean of all other tokens + - `"last"`: embeddings of the last token in the sequence + - `"avg"`: mean of all tokens + - `"max"`: max of all embedding features of the entire token sequence + - `"int"`: An integer number, which represents the index of the token to use as the + embedding + + Parameters + ---------- + value : str + Pooling strategy to use for sentence embeddings + """ + + valid_strategies = {"cls", "cls_avg", "last", "avg", "max"} + if value in valid_strategies or value.isdigit(): + return self._set(poolingStrategy=value) + else: + raise ValueError(f"Invalid pooling strategy: {value}. " + f"Valid strategies are: {', '.join(self.valid_strategies)} or an integer.") + + @keyword_only + def __init__(self, classname="com.johnsnowlabs.nlp.embeddings.UAEEmbeddings", java_model=None): + super(UAEEmbeddings, self).__init__( + classname=classname, + java_model=java_model + ) + self._setDefault( + dimension=1024, + batchSize=8, + maxSentenceLength=512, + caseSensitive=False, + poolingStrategy="cls" + ) + + @staticmethod + def loadSavedModel(folder, spark_session): + """Loads a locally saved model. + + Parameters + ---------- + folder : str + Folder of the saved model + spark_session : pyspark.sql.SparkSession + The current SparkSession + + Returns + ------- + UAEEmbeddings + The restored model + """ + from sparknlp.internal import _UAEEmbeddingsLoader + jModel = _UAEEmbeddingsLoader(folder, spark_session._jsparkSession)._java_obj + return UAEEmbeddings(java_model=jModel) + + @staticmethod + def pretrained(name="uae_large_v1", lang="en", remote_loc=None): + """Downloads and loads a pretrained model. + + Parameters + ---------- + name : str, optional + Name of the pretrained model, by default "UAE_small" + lang : str, optional + Language of the pretrained model, by default "en" + remote_loc : str, optional + Optional remote address of the resource, by default None. Will use + Spark NLPs repositories otherwise. + + Returns + ------- + UAEEmbeddings + The restored model + """ + from sparknlp.pretrained import ResourceDownloader + return ResourceDownloader.downloadModel(UAEEmbeddings, name, lang, remote_loc) diff --git a/python/sparknlp/internal/__init__.py b/python/sparknlp/internal/__init__.py index c1aabeeb36aec0..54180480bdce63 100644 --- a/python/sparknlp/internal/__init__.py +++ b/python/sparknlp/internal/__init__.py @@ -158,11 +158,13 @@ def __init__(self, path, jspark): super(_GPT2Loader, self).__init__( "com.johnsnowlabs.nlp.annotators.seq2seq.GPT2Transformer.loadSavedModel", path, jspark) + class _LLAMA2Loader(ExtendedJavaWrapper): def __init__(self, path, jspark): super(_LLAMA2Loader, self).__init__( "com.johnsnowlabs.nlp.annotators.seq2seq.LLAMA2Transformer.loadSavedModel", path, jspark) + class _LongformerLoader(ExtendedJavaWrapper): def __init__(self, path, jspark): super(_LongformerLoader, self).__init__("com.johnsnowlabs.nlp.embeddings.LongformerEmbeddings.loadSavedModel", @@ -601,8 +603,8 @@ def __init__(self, path, jspark): super(_DeBertaForZeroShotClassification, self).__init__( "com.johnsnowlabs.nlp.annotators.classifier.dl.DeBertaForZeroShotClassification.loadSavedModel", path, jspark) - - + + class _MPNetForSequenceClassificationLoader(ExtendedJavaWrapper): def __init__(self, path, jspark): super(_MPNetForSequenceClassificationLoader, self).__init__( @@ -615,3 +617,10 @@ def __init__(self, path, jspark): super(_MPNetForQuestionAnsweringLoader, self).__init__( "com.johnsnowlabs.nlp.annotators.classifier.dl.MPNetForQuestionAnswering.loadSavedModel", path, jspark) + + +class _UAEEmbeddingsLoader(ExtendedJavaWrapper): + def __init__(self, path, jspark): + super(_UAEEmbeddingsLoader, self).__init__( + "com.johnsnowlabs.nlp.embeddings.UAEEmbeddings.loadSavedModel", path, + jspark) diff --git a/python/sparknlp/pretrained/resource_downloader.py b/python/sparknlp/pretrained/resource_downloader.py index 7755b9d0e5878c..00ffd0848a275b 100644 --- a/python/sparknlp/pretrained/resource_downloader.py +++ b/python/sparknlp/pretrained/resource_downloader.py @@ -58,7 +58,6 @@ class ResourceDownloader(object): """ - @staticmethod def downloadModel(reader, name, language, remote_loc=None, j_dwn='PythonResourceDownloader'): """Downloads and loads a model with the default downloader. Usually this method @@ -67,8 +66,8 @@ def downloadModel(reader, name, language, remote_loc=None, j_dwn='PythonResource Parameters ---------- - reader : str - Name of the class to read the model for + reader : obj + Class to read the model for name : str Name of the pretrained model language : str diff --git a/python/test/annotator/embeddings/uae_embeddings_test.py b/python/test/annotator/embeddings/uae_embeddings_test.py new file mode 100644 index 00000000000000..d36083dc7a0883 --- /dev/null +++ b/python/test/annotator/embeddings/uae_embeddings_test.py @@ -0,0 +1,48 @@ +# Copyright 2017-2022 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import pytest + +from sparknlp.annotator import * +from sparknlp.base import * +from test.util import SparkContextForTest + + +@pytest.mark.slow +class UAEEmbeddingsTestSpec(unittest.TestCase): + def setUp(self): + self.spark = SparkContextForTest.spark + self.tested_annotator = UAEEmbeddings \ + .loadSavedModel("/home/ducha/Workspace/JSL/spark-nlp-dev-things/hf_exports/UAE/exported_onnx", + SparkContextForTest.spark) \ + .setInputCols(["documents"]) \ + .setOutputCol("embeddings") \ + .setPoolingStrategy("cls_avg") + + def test_run(self): + data = self.spark.createDataFrame([["hello world"], ["hello moon"]]).toDF("text") + + document_assembler = DocumentAssembler() \ + .setInputCol("text") \ + .setOutputCol("documents") + + embeddings_finisher = EmbeddingsFinisher().setInputCols("embeddings").setOutputCols("embeddings") + + uae = self.tested_annotator + + pipeline = Pipeline().setStages([document_assembler, uae, embeddings_finisher]) + results = pipeline.fit(data).transform(data) + + results.selectExpr("explode(embeddings) as result").show(truncate=False) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/UAE.scala b/src/main/scala/com/johnsnowlabs/ml/ai/UAE.scala new file mode 100644 index 00000000000000..34400f17d835e1 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/ml/ai/UAE.scala @@ -0,0 +1,300 @@ +/* + * Copyright 2017 - 2023 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.ml.ai + +import ai.onnxruntime.{OnnxTensor, TensorInfo} +import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper} +import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} +import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} +import com.johnsnowlabs.ml.util.{LinAlg, ONNX, TensorFlow} +import com.johnsnowlabs.nlp.annotators.common._ +import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} + +import scala.collection.JavaConverters._ +import scala.util.Try + +/** UAE Sentence embeddings model + * @param tensorflowWrapper + * tensorflow wrapper + * @param configProtoBytes + * config proto bytes + * @param sentenceStartTokenId + * sentence start token id + * @param sentenceEndTokenId + * sentence end token id + * @param signatures + * signatures + */ +private[johnsnowlabs] class UAE( + val tensorflowWrapper: Option[TensorflowWrapper], + val onnxWrapper: Option[OnnxWrapper], + configProtoBytes: Option[Array[Byte]] = None, + sentenceStartTokenId: Int, + sentenceEndTokenId: Int, + signatures: Option[Map[String, String]] = None) { + + private val _tfInstructorSignatures: Map[String, String] = + signatures.getOrElse(ModelSignatureManager.apply()) + private val paddingTokenId = 0 + + val detectedEngine: String = + if (tensorflowWrapper.isDefined) TensorFlow.name + else if (onnxWrapper.isDefined) ONNX.name + else TensorFlow.name + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions + + /** Get sentence embeddings for a batch of sentences + * + * @param batch + * batch of sentences + * @return + * sentence embeddings + */ + private def getSentenceEmbedding( + batch: Seq[Array[Int]], + poolingStrategy: String): Array[Array[Float]] = { + val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max + val paddedBatch = batch.map(arr => padArrayWithZeros(arr, maxSentenceLength)) + val sentenceEmbeddings: Array[Array[Float]] = detectedEngine match { + case ONNX.name => + getSentenceEmbeddingFromOnnx(paddedBatch, maxSentenceLength, poolingStrategy) + case _ => // TF Case + getSentenceEmbeddingFromTF(paddedBatch, maxSentenceLength, poolingStrategy) + } + + sentenceEmbeddings + } + + /** Pools word embeddings to sentence embeddings given a strategy. + * + * @param embeddings + * A 3D array of Floats representing the embeddings. The dimensions are [batch_size, + * sequence_length, embedding_dim]. + * @param attentionMask + * A 2D array of Longs representing the attention mask. The dimensions are [batch_size, + * sequence_length]. + * @param poolingStrategy + * A String representing the pooling strategy to be applied. The following strategies are + * supported: + * + * - `"cls"`: leading `[CLS]` token + * - `"cls_avg"`: leading `[CLS]` token + mean of all other tokens + * - `"last"`: embeddings of the last token in the sequence + * - `"avg"`: mean of all tokens + * - `"max"`: max of all embedding values for the token sequence + * - `"int"`: An integer number, which represents the index of the token to use as the + * embedding + * @return + * A 2D array of Floats representing the pooled embeddings. The dimensions are [batch_size, + * embedding_dim]. + */ + private def pool( + embeddings: Array[Array[Array[Float]]], + attentionMask: Array[Array[Long]], + poolingStrategy: String): Array[Array[Float]] = { + poolingStrategy match { + case "cls" => LinAlg.clsPooling(embeddings, attentionMask) + case "cls_avg" => LinAlg.clsAvgPooling(embeddings, attentionMask) + case "last" => LinAlg.lastPooling(embeddings, attentionMask) + case "avg" => + val shape: Array[Long] = + Array(embeddings.length, embeddings.head.length, embeddings.head.head.length) + val avgPooled = LinAlg.avgPooling(embeddings.flatten.flatten, attentionMask, shape) + avgPooled.t.toArray.grouped(avgPooled.cols).toArray + case "max" => LinAlg.maxPooling(embeddings, attentionMask) + case index if Try(index.toInt).isSuccess => LinAlg.tokenPooling(embeddings, index.toInt) + case _ => + throw new IllegalArgumentException(s"Pooling strategy $poolingStrategy not supported.") + } + } + + private def padArrayWithZeros(arr: Array[Int], maxLength: Int): Array[Int] = { + if (arr.length >= maxLength) { + arr + } else { + arr ++ Array.fill(maxLength - arr.length)(0) + } + } + + private def getSentenceEmbeddingFromTF( + batch: Seq[Array[Int]], + maxSentenceLength: Int, + poolingStrategy: String): Array[Array[Float]] = { + val batchLength = batch.length + + // encode batch + val tensorEncoder = new TensorResources() + val inputDim = batch.length * maxSentenceLength + + // create buffers + val encoderInputBuffers = tensorEncoder.createIntBuffer(inputDim) + val encoderAttentionMaskBuffers = tensorEncoder.createIntBuffer(inputDim) + + val shape = Array(batch.length.toLong, maxSentenceLength) + + batch.zipWithIndex.foreach { case (tokenIds, idx) => + val offset = idx * maxSentenceLength + val diff = maxSentenceLength - tokenIds.length + + // pad with 0 + val s = tokenIds.take(maxSentenceLength) ++ Array.fill[Int](diff)(this.paddingTokenId) + encoderInputBuffers.offset(offset).write(s) + + // create attention mask + val mask = s.map(x => if (x != this.paddingTokenId) 1 else 0) + encoderAttentionMaskBuffers.offset(offset).write(mask) + } + + // create tensors + val encoderInputTensors = tensorEncoder.createIntBufferTensor(shape, encoderInputBuffers) + val encoderAttentionMaskTensors = + tensorEncoder.createIntBufferTensor(shape, encoderAttentionMaskBuffers) + + // run model + val runner = tensorflowWrapper.get + .getTFSessionWithSignature( + configProtoBytes = configProtoBytes, + initAllTables = false, + savedSignatures = signatures) + .runner + + runner + .feed( + _tfInstructorSignatures.getOrElse( + ModelSignatureConstants.EncoderInputIds.key, + "missing_encoder_input_ids"), + encoderInputTensors) + .feed( + _tfInstructorSignatures.getOrElse( + ModelSignatureConstants.EncoderAttentionMask.key, + "missing_encoder_attention_mask"), + encoderAttentionMaskTensors) + .fetch(_tfInstructorSignatures + .getOrElse(ModelSignatureConstants.LastHiddenState.key, "missing_last_hidden_state")) + + // get embeddings + val sentenceEmbeddings = runner.run().asScala + val sentenceEmbeddingsFloats = TensorResources.extractFloats(sentenceEmbeddings.head) + val embeddingDim = sentenceEmbeddingsFloats.length / maxSentenceLength / batchLength + + // group embeddings + val sentenceEmbeddingsFloatsArray = + sentenceEmbeddingsFloats.grouped(embeddingDim).toArray.grouped(maxSentenceLength).toArray + + val attentionMask: Array[Array[Long]] = + TensorResources.extractLongs(encoderAttentionMaskTensors).grouped(maxSentenceLength).toArray + + // close buffers + sentenceEmbeddings.foreach(_.close()) + encoderInputTensors.close() + encoderAttentionMaskTensors.close() + tensorEncoder.clearTensors() + tensorEncoder.clearSession(sentenceEmbeddings) + + pool(sentenceEmbeddingsFloatsArray, attentionMask, poolingStrategy) + } + + private def getSentenceEmbeddingFromOnnx( + batch: Seq[Array[Int]], + maxSentenceLength: Int, + poolingStrategy: String): Array[Array[Float]] = { + + val inputIds = batch.map(x => x.map(x => x.toLong)).toArray + val attentionMask = batch.map(sentence => sentence.map(x => if (x < 0L) 0L else 1L)).toArray + + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) + + val tokenTensors = OnnxTensor.createTensor(env, inputIds) + val maskTensors = OnnxTensor.createTensor(env, attentionMask) + val segmentTensors = + OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray) + val inputs = + Map( + "input_ids" -> tokenTensors, + "attention_mask" -> maskTensors, + "token_type_ids" -> segmentTensors).asJava + + // TODO: A try without a catch or finally is equivalent to putting its body in a block; no exceptions are handled. + val embeddings = + try { + val results = runner.run(inputs) + val lastHiddenState = results.get("last_hidden_state").get() + val info = lastHiddenState.getInfo.asInstanceOf[TensorInfo] + val shape = info.getShape.map(_.toInt) + val Array(_, sequenceLength, embeddingDim) = shape + try { + val flattenEmbeddings = lastHiddenState + .asInstanceOf[OnnxTensor] + .getFloatBuffer + .array() + tokenTensors.close() + maskTensors.close() + segmentTensors.close() + + flattenEmbeddings.grouped(embeddingDim).toArray.grouped(sequenceLength).toArray + } finally if (results != null) results.close() + } + + pool(embeddings, attentionMask, poolingStrategy) + } + + /** Predict sentence embeddings for a batch of sentences + * + * @param sentences + * sentences + * @param tokenizedSentences + * tokenized sentences + * @param batchSize + * batch size + * @param maxSentenceLength + * max sentence length + * @return + */ + def predict( + sentences: Seq[Annotation], + tokenizedSentences: Seq[WordpieceTokenizedSentence], + batchSize: Int, + maxSentenceLength: Int, + poolingStrategy: String): Seq[Annotation] = { + + tokenizedSentences + .zip(sentences) + .zipWithIndex + .grouped(batchSize) + .toArray + .flatMap { batch => + val tokensBatch = batch.map(x => x._1._1.tokens) + val tokens = tokensBatch.map(x => + Array(sentenceStartTokenId) ++ x + .map(y => y.pieceId) + .take(maxSentenceLength - 2) ++ Array(sentenceEndTokenId)) + + val sentenceEmbeddings = getSentenceEmbedding(tokens, poolingStrategy) + + batch.zip(sentenceEmbeddings).map { case (sentence, vectors) => + Annotation( + annotatorType = AnnotatorType.SENTENCE_EMBEDDINGS, + begin = sentence._1._2.begin, + end = sentence._1._2.end, + result = sentence._1._2.result, + metadata = sentence._1._2.metadata, + embeddings = vectors) + } + } + } + +} diff --git a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala index 2735626930016d..5c9156539d4cd0 100644 --- a/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala +++ b/src/main/scala/com/johnsnowlabs/ml/onnx/OnnxSerializeModel.scala @@ -54,10 +54,11 @@ trait WriteOnnxModel { fs.copyFromLocalFile(new Path(onnxFile), new Path(path)) // 4. check if there is a onnx_data file - - val onnxDataFile = new Path(onnxWrapper.onnxModelPath.get + dataFileSuffix) - if (fs.exists(onnxDataFile)) { - fs.copyFromLocalFile(onnxDataFile, new Path(path)) + if (onnxWrapper.onnxModelPath.isDefined) { + val onnxDataFile = new Path(onnxWrapper.onnxModelPath.get + dataFileSuffix) + if (fs.exists(onnxDataFile)) { + fs.copyFromLocalFile(onnxDataFile, new Path(path)) + } } } diff --git a/src/main/scala/com/johnsnowlabs/ml/util/LinAlg.scala b/src/main/scala/com/johnsnowlabs/ml/util/LinAlg.scala index 97a9c36a4861b4..a8f34e87464c3d 100644 --- a/src/main/scala/com/johnsnowlabs/ml/util/LinAlg.scala +++ b/src/main/scala/com/johnsnowlabs/ml/util/LinAlg.scala @@ -1,7 +1,8 @@ package com.johnsnowlabs.ml.util -import breeze.linalg.{DenseMatrix, norm, sum, tile, *} -import scala.math.{sqrt, pow} +import breeze.linalg.{*, DenseMatrix, DenseVector, max, norm, sum, tile} + +import scala.math.pow object LinAlg { @@ -277,4 +278,120 @@ object LinAlg { array.map(value => if (lpNorm != 0.0f) value / lpNorm else 0.0f) } + /** Creates pooled embeddings by selecting the token at the index position. + * + * @param embeddings + * Embeddings in shape (batchSize, sequenceLength, embeddingDim) + * @param indexes + * Array of Index Positions to select for each sequence in the batch + * @return + * A 2D array representing the pooled embeddings + */ + def tokenPooling( + embeddings: Array[Array[Array[Float]]], + indexes: Array[Int]): Array[Array[Float]] = { + val batchSize = embeddings.length + require(indexes.length == batchSize, "Indexes length should be equal to batch size") + + embeddings.zip(indexes).map { case (tokens: Array[Array[Float]], index: Int) => + tokens(index) + } + } + + /** Creates pooled embeddings by selecting the token at the index position. + * + * @param embeddings + * Embeddings in shape (batchSize, sequenceLength, embeddingDim) + * @param index + * Index Position to select for each sequence in the batch + * @return + * A 2D array representing the pooled embeddings + */ + def tokenPooling(embeddings: Array[Array[Array[Float]]], index: Int): Array[Array[Float]] = + tokenPooling(embeddings, Array.fill(embeddings.length)(index)) + + /** Creates pooled embeddings by taking the maximum of the embedding features along the + * sequence. + * + * @param embeddings + * Embeddings in shape (batchSize, sequenceLength, embeddingDim) + * @return + * A 2D array representing the pooled embeddings + */ + def maxPooling( + embeddings: Array[Array[Array[Float]]], + attentionMask: Array[Array[Long]]): Array[Array[Float]] = { + val embeddingsMatrix = embeddings.map(embedding => DenseMatrix(embedding: _*)) + + val maskedEmbeddings: Array[DenseMatrix[Float]] = + embeddingsMatrix.zip(attentionMask).map { + case (embedding: DenseMatrix[Float], mask: Array[Long]) => + val maskVector: DenseVector[Float] = new DenseVector(mask.map(_.toFloat)) + embedding(::, *) *:* maskVector + } + + maskedEmbeddings.map { seqEmbeddings: DenseMatrix[Float] => + max(seqEmbeddings(::, *)).t.toArray + } + } + + /** Creates pooled embeddings by using the CLS token as the representative embedding of the + * sequence. + * + * @param embeddings + * Embeddings in shape (batchSize, sequenceLength, embeddingDim) + * @param attentionMask + * Attention mask in shape (batchSize, sequenceLength) + * @return + * The pooled embeddings in shape (batchSize, embeddingDim) + */ + def clsPooling( + embeddings: Array[Array[Array[Float]]], + attentionMask: Array[Array[Long]]): Array[Array[Float]] = { + tokenPooling(embeddings, 0) // CLS embedding is at the front of each sequence + } + + /** Creates pooled embeddings by averaging the embeddings of the CLS token and the average + * embedding the sequence. + * + * @param embeddings + * Embeddings in shape (batchSize, sequenceLength, embeddingDim) + * @param attentionMask + * Attention mask in shape (batchSize, sequenceLength) + * @return + * The pooled embeddings in shape (batchSize, embeddingDim) + */ + def clsAvgPooling( + embeddings: Array[Array[Array[Float]]], + attentionMask: Array[Array[Long]]): Array[Array[Float]] = { + val clsEmbeddings = DenseMatrix(clsPooling(embeddings, attentionMask): _*) + val shape: Array[Long] = + Array(embeddings.length, embeddings.head.length, embeddings.head.head.length) + + val flatEmbeddings: Array[Float] = embeddings.flatten.flatten + val meanEmbeddings = avgPooling(flatEmbeddings, attentionMask, shape) + + val clsAvgEmbeddings = (clsEmbeddings +:+ meanEmbeddings) / 2.0f + clsAvgEmbeddings.t.toArray // Breeze uses column-major order + .grouped(meanEmbeddings.cols) + .toArray + } + + /** Creates pooled embeddings by taking the last token embedding of the sequence. Assumes right + * padding. + * + * @param embeddings + * Embeddings in shape (batchSize, sequenceLength, embeddingDim) + * @param attentionMask + * Attention mask in shape (batchSize, sequenceLength) + * @return + * The pooled embeddings in shape (batchSize, embeddingDim) + */ + def lastPooling( + embeddings: Array[Array[Array[Float]]], + attentionMask: Array[Array[Long]]): Array[Array[Float]] = { + val lastTokenIndexes = attentionMask.map(_.sum.toInt - 1) + + tokenPooling(embeddings, lastTokenIndexes) + } } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotator.scala b/src/main/scala/com/johnsnowlabs/nlp/annotator.scala index ac50e18b5f18cc..373d87342b4203 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotator.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotator.scala @@ -794,4 +794,6 @@ package object annotator { extends ReadablePretrainedM2M100TransformerModel with ReadM2M100TransformerDLModel + type UAEEmbeddings = com.johnsnowlabs.nlp.embeddings.UAEEmbeddings + object UAEEmbeddings extends ReadablePretrainedUAEModel with ReadUAEDLModel } diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/UAEEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/UAEEmbeddings.scala new file mode 100644 index 00000000000000..f82fc3f2e994c1 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/UAEEmbeddings.scala @@ -0,0 +1,536 @@ +/* + * Copyright 2017-2022 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.embeddings + +import com.johnsnowlabs.ml.ai.UAE +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} +import com.johnsnowlabs.ml.tensorflow._ +import com.johnsnowlabs.ml.util.LoadExternalModel.{ + loadTextAsset, + modelSanityCheck, + notSupportedEngineError +} +import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} +import com.johnsnowlabs.nlp._ +import com.johnsnowlabs.nlp.annotators.common._ +import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder} +import com.johnsnowlabs.nlp.serialization.MapFeature +import com.johnsnowlabs.storage.HasStorageRef +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.slf4j.{Logger, LoggerFactory} + +import scala.util.Try + +/** Sentence embeddings using Universal AnglE Embedding (UAE). + * + * UAE is a novel angle-optimized text embedding model, designed to improve semantic textual + * similarity tasks, which are crucial for Large Language Model (LLM) applications. By + * introducing angle optimization in a complex space, AnglE effectively mitigates saturation of + * the cosine similarity function. + * + * Pretrained models can be loaded with `pretrained` of the companion object: + * {{{ + * val embeddings = UAEEmbeddings.pretrained() + * .setInputCols("document") + * .setOutputCol("UAE_embeddings") + * }}} + * The default model is `"uae_large_v1"`, if no name is provided. + * + * For available pretrained models please see the + * [[https://sparknlp.org/models?q=UAE Models Hub]]. + * + * For extended examples of usage, see + * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/embeddings/UAEEmbeddingsTestSpec.scala UAEEmbeddingsTestSpec]]. + * + * '''Sources''' : + * + * [[https://arxiv.org/abs/2309.12871 AnglE-optimized Text Embeddings]] + * + * [[https://github.com/baochi0212/uae-embedding UAE Github Repository]] + * + * ''' Paper abstract ''' + * + * ''High-quality text embedding is pivotal in improving semantic textual similarity (STS) tasks, + * which are crucial components in Large Language Model (LLM) applications. However, a common + * challenge existing text embedding models face is the problem of vanishing gradients, primarily + * due to their reliance on the cosine function in the optimization objective, which has + * saturation zones. To address this issue, this paper proposes a novel angle-optimized text + * embedding model called AnglE. The core idea of AnglE is to introduce angle optimization in a + * complex space. This novel approach effectively mitigates the adverse effects of the saturation + * zone in the cosine function, which can impede gradient and hinder optimization processes. To + * set up a comprehensive STS evaluation, we experimented on existing short-text STS datasets and + * a newly collected long-text STS dataset from GitHub Issues. Furthermore, we examine + * domain-specific STS scenarios with limited labeled data and explore how AnglE works with + * LLM-annotated data. Extensive experiments were conducted on various tasks including short-text + * STS, long-text STS, and domain-specific STS tasks. The results show that AnglE outperforms the + * state-of-the-art (SOTA) STS models that ignore the cosine saturation zone. These findings + * demonstrate the ability of AnglE to generate high-quality text embeddings and the usefulness + * of angle optimization in STS. '' + * + * ==Example== + * {{{ + * import spark.implicits._ + * import com.johnsnowlabs.nlp.base.DocumentAssembler + * import com.johnsnowlabs.nlp.annotators.Tokenizer + * import com.johnsnowlabs.nlp.embeddings.UAEEmbeddings + * import com.johnsnowlabs.nlp.EmbeddingsFinisher + * import org.apache.spark.ml.Pipeline + * + * val documentAssembler = new DocumentAssembler() + * .setInputCol("text") + * .setOutputCol("document") + * + * val embeddings = UAEEmbeddings.pretrained() + * .setInputCols("document") + * .setOutputCol("UAE_embeddings") + * + * val embeddingsFinisher = new EmbeddingsFinisher() + * .setInputCols("UAE_embeddings") + * .setOutputCols("finished_embeddings") + * .setOutputAsVector(true) + * + * val pipeline = new Pipeline().setStages(Array( + * documentAssembler, + * embeddings, + * embeddingsFinisher + * )) + * + * val data = Seq("hello world", "hello moon").toDF("text") + * val result = pipeline.fit(data).transform(data) + * + * result.selectExpr("explode(finished_embeddings) as result").show(5, 80) + * +--------------------------------------------------------------------------------+ + * | result| + * +--------------------------------------------------------------------------------+ + * |[0.50387806, 0.5861606, 0.35129607, -0.76046336, -0.32446072, -0.117674336, 0...| + * |[0.6660665, 0.961762, 0.24854276, -0.1018044, -0.6569202, 0.027635604, 0.1915...| + * +--------------------------------------------------------------------------------+ + * }}} + * + * @see + * [[https://sparknlp.org/docs/en/annotators Annotators Main Page]] for a list of transformer + * based embeddings + * @param uid + * required uid for storing annotator to disk + * @groupname anno Annotator types + * @groupdesc anno + * Required input and expected output annotator types + * @groupname Ungrouped Members + * @groupname param Parameters + * @groupname setParam Parameter setters + * @groupname getParam Parameter getters + * @groupname Ungrouped Members + * @groupprio param 1 + * @groupprio anno 2 + * @groupprio Ungrouped 3 + * @groupprio setParam 4 + * @groupprio getParam 5 + * @groupdesc param + * A list of (hyper-)parameter keys this annotator can take. Users can set and get the + * parameter values through setters and getters, respectively. + */ +class UAEEmbeddings(override val uid: String) + extends AnnotatorModel[UAEEmbeddings] + with HasBatchedAnnotate[UAEEmbeddings] + with WriteTensorflowModel + with WriteOnnxModel + with HasEmbeddingsProperties + with HasStorageRef + with HasCaseSensitiveProperties + with HasEngine { + + /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator + * type + */ + override val inputAnnotatorTypes: Array[String] = + Array(AnnotatorType.DOCUMENT) + override val outputAnnotatorType: AnnotatorType = AnnotatorType.SENTENCE_EMBEDDINGS + + /** ConfigProto from tensorflow, serialized into byte array. Get with + * `config_proto.SerializeToString()` + * + * @group param + */ + val configProtoBytes = new IntArrayParam( + this, + "configProtoBytes", + "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()") + + /** Max sentence length to process (Default: `128`) + * + * @group param + */ + val maxSentenceLength = + new IntParam(this, "maxSentenceLength", "Max sentence length to process") + + def sentenceStartTokenId: Int = { + $$(vocabulary)("[CLS]") + } + + /** @group setParam */ + def sentenceEndTokenId: Int = { + $$(vocabulary)("[SEP]") + } + + /** Vocabulary used to encode the words to ids with WordPieceEncoder + * + * @group param + */ + val vocabulary: MapFeature[String, Int] = new MapFeature(this, "vocabulary").setProtected() + + /** @group setParam */ + def setVocabulary(value: Map[String, Int]): this.type = set(vocabulary, value) + + /** It contains TF model signatures for the laded saved model + * + * @group param + */ + val signatures = + new MapFeature[String, String](model = this, name = "signatures").setProtected() + private var _model: Option[Broadcast[UAE]] = None + + def this() = this(Identifiable.randomUID("UAE_EMBEDDINGS")) + + /** @group setParam */ + def setConfigProtoBytes(bytes: Array[Int]): UAEEmbeddings.this.type = + set(this.configProtoBytes, bytes) + + /** @group setParam */ + def setMaxSentenceLength(value: Int): this.type = { + require( + value <= 512, + "UAE models do not support sequences longer than 512 because of trainable positional embeddings.") + require(value >= 1, "The maxSentenceLength must be at least 1") + set(maxSentenceLength, value) + this + } + + /** @group getParam */ + def getMaxSentenceLength: Int = $(maxSentenceLength) + + /** @group setParam */ + def setSignatures(value: Map[String, String]): this.type = { + if (get(signatures).isEmpty) + set(signatures, value) + this + } + + /** Pooling strategy to use for sentence embeddings. + * + * Available pooling strategies for sentence embeddings are: + * + * - `"cls"`: leading `[CLS]` token + * - `"cls_avg"`: leading `[CLS]` token + mean of all other tokens + * - `"last"`: embeddings of the last token in the sequence + * - `"avg"`: mean of all tokens + * - `"max"`: max of all embedding values for the token sequence + * - `"all"`: return all token embeddings + * - `"int"`: An integer number, which represents the index of the token to use as the + * embedding + * + * @group param + */ + val poolingStrategy = + new Param[String](this, "poolingStrategy", "Pooling strategy to use for sentence embeddings") + + def getPoolingStrategy: String = $(poolingStrategy) + + /** Pooling strategy to use for sentence embeddings. + * + * Available pooling strategies for sentence embeddings are: + * + * - `"cls"`: leading `[CLS]` token + * - `"cls_avg"`: leading `[CLS]` token + mean of all other tokens + * - `"last"`: embeddings of the last token in the sequence + * - `"avg"`: mean of all tokens + * - `"max"`: max of all embedding features of the entire token sequence + * - `"int"`: An integer number, which represents the index of the token to use as the + * embedding + * + * @group setParam + */ + def setPoolingStrategy(value: String): this.type = { + val validStrategies = Set("cls", "cls_avg", "last", "avg", "max") + + if (validStrategies.contains(value) || Try(value.toInt).isSuccess) { + set(poolingStrategy, value) + } else { + throw new IllegalArgumentException( + s"Invalid pooling strategy: $value. " + + s"Valid strategies are: ${validStrategies.mkString(", ")} or an integer.") + } + } + + /** @group setParam */ + def setModelIfNotSet( + spark: SparkSession, + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper]): UAEEmbeddings = { + if (_model.isEmpty) { + _model = Some( + spark.sparkContext.broadcast( + new UAE( + tensorflowWrapper, + onnxWrapper, + configProtoBytes = getConfigProtoBytes, + sentenceStartTokenId = sentenceStartTokenId, + sentenceEndTokenId = sentenceEndTokenId, + signatures = getSignatures))) + } + + this + } + + /** Set Embeddings dimensions for the BERT model Only possible to set this when the first time + * is saved dimension is not changeable, it comes from BERT config file + * + * @group setParam + */ + override def setDimension(value: Int): this.type = { + if (get(dimension).isEmpty) + set(this.dimension, value) + this + } + + /** Whether to lowercase tokens or not + * + * @group setParam + */ + override def setCaseSensitive(value: Boolean): this.type = { + if (get(caseSensitive).isEmpty) + set(this.caseSensitive, value) + this + } + + setDefault( + dimension -> 1024, + batchSize -> 8, + maxSentenceLength -> 512, + caseSensitive -> false, + poolingStrategy -> "cls") + + def tokenize(sentences: Seq[Annotation]): Seq[WordpieceTokenizedSentence] = { + val basicTokenizer = new BasicTokenizer($(caseSensitive)) + val encoder = new WordpieceEncoder($$(vocabulary)) + sentences.map { s => + val sent = Sentence( + content = s.result, + start = s.begin, + end = s.end, + metadata = Some(s.metadata), + index = s.begin) + val tokens = basicTokenizer.tokenize(sent) + val wordpieceTokens = tokens.flatMap(token => encoder.encode(token)) + WordpieceTokenizedSentence(wordpieceTokens) + } + } + + /** takes a document and annotations and produces new annotations of this annotator's annotation + * type + * + * @param batchedAnnotations + * Annotations that correspond to inputAnnotationCols generated by previous annotators if any + * @return + * any number of annotations processed for every input annotation. Not necessary one to one + * relationship + */ + override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = { + + val allAnnotations = batchedAnnotations + .filter(_.nonEmpty) + .zipWithIndex + .flatMap { case (annotations, i) => + annotations.filter(_.result.nonEmpty).map(x => (x, i)) + } + + // Tokenize sentences + val tokenizedSentences = tokenize(allAnnotations.map(_._1)) + val processedAnnotations = if (allAnnotations.nonEmpty) { + this.getModelIfNotSet.predict( + sentences = allAnnotations.map(_._1), + tokenizedSentences = tokenizedSentences, + batchSize = $(batchSize), + maxSentenceLength = $(maxSentenceLength), + poolingStrategy = getPoolingStrategy) + } else { + Seq() + } + + // Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence + batchedAnnotations.indices.map(rowIndex => { + val rowAnnotations = processedAnnotations + // zip each annotation with its corresponding row index + .zip(allAnnotations) + // select the sentences belonging to the current row + .filter(_._2._2 == rowIndex) + // leave the annotation only + .map(_._1) + + if (rowAnnotations.nonEmpty) + rowAnnotations + else + Seq.empty[Annotation] + }) + + } + + /** @group getParam */ + def getModelIfNotSet: UAE = _model.get.value + + override def onWrite(path: String, spark: SparkSession): Unit = { + super.onWrite(path, spark) + val suffix = "_UAE" + + getEngine match { + case TensorFlow.name => + writeTensorflowModelV2( + path, + spark, + getModelIfNotSet.tensorflowWrapper.get, + suffix, + UAEEmbeddings.tfFile, + configProtoBytes = getConfigProtoBytes, + savedSignatures = getSignatures) + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + suffix, + UAEEmbeddings.onnxFile) + + case _ => + throw new Exception(notSupportedEngineError) + } + } + + /** @group getParam */ + def getConfigProtoBytes: Option[Array[Byte]] = get(this.configProtoBytes).map(_.map(_.toByte)) + + /** @group getParam */ + def getSignatures: Option[Map[String, String]] = get(this.signatures) + + override protected def afterAnnotate(dataset: DataFrame): DataFrame = { + dataset.withColumn( + getOutputCol, + wrapSentenceEmbeddingsMetadata( + dataset.col(getOutputCol), + $(dimension), + Some($(storageRef)))) + } + +} + +trait ReadablePretrainedUAEModel + extends ParamsAndFeaturesReadable[UAEEmbeddings] + with HasPretrained[UAEEmbeddings] { + override val defaultModelName: Some[String] = Some("uae_large_v1") + + /** Java compliant-overrides */ + override def pretrained(): UAEEmbeddings = super.pretrained() + + override def pretrained(name: String): UAEEmbeddings = super.pretrained(name) + + override def pretrained(name: String, lang: String): UAEEmbeddings = + super.pretrained(name, lang) + + override def pretrained(name: String, lang: String, remoteLoc: String): UAEEmbeddings = + super.pretrained(name, lang, remoteLoc) +} + +trait ReadUAEDLModel extends ReadTensorflowModel with ReadOnnxModel { + this: ParamsAndFeaturesReadable[UAEEmbeddings] => + + override val tfFile: String = "UAE_tensorflow" + override val onnxFile: String = "UAE_onnx" + + def readModel(instance: UAEEmbeddings, path: String, spark: SparkSession): Unit = { + + instance.getEngine match { + case TensorFlow.name => + val tfWrapper = readTensorflowModel(path, spark, "_UAE_tf") + instance.setModelIfNotSet(spark, Some(tfWrapper), None) + + case ONNX.name => + val onnxWrapper = + readOnnxModel(path, spark, "_UAE_onnx", zipped = true, useBundle = false, None) + instance.setModelIfNotSet(spark, None, Some(onnxWrapper)) + + case _ => + throw new Exception(notSupportedEngineError) + } + + } + + addReader(readModel) + + def loadSavedModel(modelPath: String, spark: SparkSession): UAEEmbeddings = { + + val (localModelPath, detectedEngine) = modelSanityCheck(modelPath) + + val vocabs = loadTextAsset(localModelPath, "vocab.txt").zipWithIndex.toMap + + /*Universal parameters for all engines*/ + val annotatorModel = new UAEEmbeddings() + .setVocabulary(vocabs) + + annotatorModel.set(annotatorModel.engine, detectedEngine) + + detectedEngine match { + case TensorFlow.name => + val (wrapper, signatures) = + TensorflowWrapper.read( + localModelPath, + zipped = false, + useBundle = true, + tags = Array("serve")) + + val _signatures = signatures match { + case Some(s) => s + case None => throw new Exception("Cannot load signature definitions from model!") + } + + /** the order of setSignatures is important if we use getSignatures inside + * setModelIfNotSet + */ + annotatorModel + .setSignatures(_signatures) + .setModelIfNotSet(spark, Some(wrapper), None) + + case ONNX.name => + val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + annotatorModel + .setModelIfNotSet(spark, None, Some(onnxWrapper)) + + case _ => + throw new Exception(notSupportedEngineError) + } + + annotatorModel + } +} + +/** This is the companion object of [[UAEEmbeddings]]. Please refer to that class for the + * documentation. + */ +object UAEEmbeddings extends ReadablePretrainedUAEModel with ReadUAEDLModel { + private[UAEEmbeddings] val logger: Logger = + LoggerFactory.getLogger("UAEEmbeddings") +} diff --git a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala index 2864975aebbb0c..e8f797e56e3238 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala @@ -689,7 +689,8 @@ object PythonResourceDownloader { "MPNetForSequenceClassification" -> MPNetForSequenceClassification, "MPNetForQuestionAnswering" -> MPNetForQuestionAnswering, "LLAMA2Transformer" -> LLAMA2Transformer, - "M2M100Transformer" -> M2M100Transformer) + "M2M100Transformer" -> M2M100Transformer, + "UAEEmbeddings" -> UAEEmbeddings) // List pairs of types such as the one with key type can load a pretrained model from the value type val typeMapper: Map[String, String] = Map("ZeroShotNerModel" -> "RoBertaForQuestionAnswering") diff --git a/src/test/scala/com/johnsnowlabs/ml/util/LinAlgTest.scala b/src/test/scala/com/johnsnowlabs/ml/util/LinAlgTest.scala index e5b975b2d82ffc..96169bb094f618 100644 --- a/src/test/scala/com/johnsnowlabs/ml/util/LinAlgTest.scala +++ b/src/test/scala/com/johnsnowlabs/ml/util/LinAlgTest.scala @@ -141,4 +141,86 @@ class LinAlgTest extends AnyFlatSpec with Matchers { assert(array === Array(Array(-1.0f, 0.0f), Array(3.0f, -4.0f))) } + "tokenPooling" should "correctly pool tokens for one index" in { + val tokens: Array[Float] = Array(1, 2, 3, 4, 1, 2, 3, 4) + val shape = Array(2, 4, 1) + val embeddings = tokens.grouped(shape(2)).toArray.grouped(shape(1)).toArray + val pooled = LinAlg.tokenPooling(embeddings, index = 0) + + assert(pooled.flatten sameElements Array(1.0f, 1.0f)) + } + + it should "correctly pool tokens for one index per sequence" in { + val tokens: Array[Float] = Array(1, 2, 3, 4, 1, 2, 3, 4) + val shape = Array(2, 4, 1) + val embeddings = tokens.grouped(shape(2)).toArray.grouped(shape(1)).toArray + val pooled = + LinAlg.tokenPooling(embeddings, indexes = Array(0, 3)) + + assert(pooled.flatten sameElements Array(1.0f, 4.0f)) + } + + "maxPooling" should "correctly pool tokens for one index" in { + val embeddings: Array[Array[Array[Float]]] = + Array( + Array(Array(1.0f, 1.0f), Array(2.0f, 2.0f)), + Array(Array(3.0f, 3.0f), Array(4.0f, 4.0f))) + val attentionMask = Array(Array(1L, 1L), Array(1L, 1L)) + val pooled: Array[Array[Float]] = LinAlg.maxPooling(embeddings, attentionMask) + + assert(pooled(0) sameElements Array(2.0f, 2.0f)) + assert(pooled(1) sameElements Array(4.0f, 4.0f)) + } + + it should "consider attention mask" in { + val embeddings: Array[Array[Array[Float]]] = + Array( + Array(Array(1.0f, 1.0f), Array(2.0f, 2.0f)), + Array(Array(3.0f, 3.0f), Array(4.0f, 4.0f))) + val attentionMask = Array(Array(1L, 0L), Array(1L, 0L)) + val pooled: Array[Array[Float]] = LinAlg.maxPooling(embeddings, attentionMask) + + assert(pooled(0) sameElements Array(1.0f, 1.0f)) + assert(pooled(1) sameElements Array(3.0f, 3.0f)) + } + + "clsAvgPooling" should "correctly pool tokens" in { + val embeddings: Array[Array[Array[Float]]] = + Array( + Array(Array(1.0f, 2.0f), Array(1.0f, 2.0f)), + Array(Array(3.0f, 4.0f), Array(3.0f, 4.0f))) + val attentionMask = Array(Array(1L, 1L), Array(1L, 1L)) + val pooled: Array[Array[Float]] = + LinAlg.clsAvgPooling(embeddings, attentionMask) + + assert(pooled(0) sameElements Array(1f, 2f)) + assert(pooled(1) sameElements Array(3f, 4f)) + } + + "lastPooling" should "correctly pool tokens" in { + val embeddings: Array[Array[Array[Float]]] = + Array( + Array(Array(1.0f, 1.0f), Array(2.0f, 2.0f)), + Array(Array(3.0f, 3.0f), Array(4.0f, 4.0f))) + val attentionMask = Array(Array(1L, 1L), Array(1L, 1L)) + val pooled: Array[Array[Float]] = + LinAlg.lastPooling(embeddings, attentionMask) + + assert(pooled(0) sameElements Array(2f, 2f)) + assert(pooled(1) sameElements Array(4f, 4f)) + } + + it should "correctly pool with padded sequences" in { + val embeddings: Array[Array[Array[Float]]] = + Array( + Array(Array(1.0f, 1.0f), Array(2.0f, 2.0f), Array(-1f, -1f)), + Array(Array(3.0f, 3.0f), Array(4.0f, 4.0f), Array(4.0f, 4.0f))) + val attentionMask = Array(Array(1L, 1L, 0L), Array(1L, 1L, 1L)) + val pooled: Array[Array[Float]] = + LinAlg.lastPooling(embeddings, attentionMask) + + assert(pooled(0) sameElements Array(2f, 2f)) + assert(pooled(1) sameElements Array(4f, 4f)) + } + } diff --git a/src/test/scala/com/johnsnowlabs/nlp/embeddings/UAEEmbeddingsTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/embeddings/UAEEmbeddingsTestSpec.scala new file mode 100644 index 00000000000000..16461bf860800b --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/embeddings/UAEEmbeddingsTestSpec.scala @@ -0,0 +1,157 @@ +package com.johnsnowlabs.nlp.embeddings + +import com.johnsnowlabs.nlp.base.LightPipeline +import com.johnsnowlabs.nlp.{Annotation, DocumentAssembler, EmbeddingsFinisher} +import com.johnsnowlabs.nlp.util.io.ResourceHelper +import com.johnsnowlabs.tags.SlowTest +import org.apache.spark.ml.Pipeline +import org.scalatest.flatspec.AnyFlatSpec +import com.johnsnowlabs.util.TestUtils.tolerantFloatEq + +class UAEEmbeddingsTestSpec extends AnyFlatSpec { + lazy val spark = ResourceHelper.spark + import spark.implicits._ + behavior of "UAEEmbeddings" + + lazy val document = new DocumentAssembler().setInputCol("text").setOutputCol("document") + lazy val model = UAEEmbeddings + .pretrained() + .setInputCols("document") + .setOutputCol("embeddings") + + lazy val rawData: Seq[String] = Seq("hello world", "hello moon") + lazy val data = rawData.toDF("text") + lazy val embeddingsFinisher = new EmbeddingsFinisher() + .setInputCols("embeddings") + .setOutputCols("embeddings") + + lazy val pipeline = new Pipeline().setStages(Array(document, model, embeddingsFinisher)) + + /** Asserts the first 16 values of the embeddings are within tolerance. + * + * @param expected + * The expected embeddings + */ + private def assertEmbeddings( + expected: Array[Array[Float]], + pipeline: Pipeline = pipeline): Unit = { + val result = pipeline.fit(data).transform(data) + + result.selectExpr("explode(embeddings)").show(5, 80) + + val extractedEmbeddings = + result.selectExpr("explode(embeddings)").as[Array[Float]].collect() + extractedEmbeddings + .zip(expected) + .foreach { case (embeddings, expected) => + embeddings.take(16).zip(expected).foreach { case (e, exp) => + assert(e === exp, "Embedding value not within tolerance") + } + } + } + + it should "work with default (cls) pooling" taggedAs SlowTest in { + val expected: Array[Array[Float]] = Array( + Array(0.50387836f, 0.5861595f, 0.35129607f, -0.76046336f, -0.32446113f, -0.11767582f, + 0.49193293f, 0.58396333f, 0.8440052f, 0.3409165f, 0.02228897f, 0.3270517f, -0.3040624f, + 0.0651551f, -0.7069445f, 0.39551276f), + Array(0.66606593f, 0.9617606f, 0.24854378f, -0.10180531f, -0.6569206f, 0.02763455f, + 0.19156311f, 0.7743124f, 1.0966388f, -0.03704539f, 0.43159822f, 0.48135376f, -0.47491387f, + -0.22510622f, -0.7761906f, -0.29289678f)) + assertEmbeddings(expected) + } + + val expected_cls_avg = Array( + Array(0.42190665f, 0.48439154f, 0.37961221f, -0.88345671f, -0.39864743f, -0.10434269f, + 0.47246569f, 0.57266355f, 0.90948695f, 0.34240869f, -0.05249403f, 0.20690459f, -0.2502915f, + -0.075280815f, -0.72355306f, 0.37840521f), + Array(0.61534011f, 0.86877286f, 0.30440071f, -0.11193186f, -0.64877027f, 0.03778841f, + 0.19575913f, 0.77637982f, 1.0544734f, 0.02276843f, 0.40709749f, 0.48178568f, -0.45722729f, + -0.25922f, -0.75728685f, -0.2886759f)) + + it should "work with cls_avg pooling" taggedAs SlowTest in { + model.setPoolingStrategy("cls_avg") + assertEmbeddings(expected_cls_avg) + } + + it should "work with last pooling" taggedAs SlowTest in { + model.setPoolingStrategy("last") + val expected = Array( + Array(0.32610807f, 0.40477207f, 0.5753994f, -1.0180508f, -0.15669955f, -0.26589864f, + 0.57111073f, 0.59625691f, 0.98112649f, 0.31161842f, -0.088124298f, -0.23382883f, + -0.10615025f, -0.4932569f, -0.92297047f, 0.64136416f), + Array(0.42494205f, 0.91639936f, 0.47431907f, -0.11696267f, -0.78128248f, -0.044441216f, + 0.34416255f, 0.91160774f, 1.0371225f, 0.28027025f, 0.49664021f, 0.60586137f, -0.52690864f, + -0.49278158f, -1.0315861f, -0.10492325f)) + assertEmbeddings(expected) + } + + it should "work with avg pooling" taggedAs SlowTest in { + model.setPoolingStrategy("avg") + val expected = Array( + Array(0.33993506f, 0.38262373f, 0.40792847f, -1.0064504f, -0.47283337f, -0.091009863f, + 0.45299777f, 0.5613634f, 0.97496814f, 0.34390116f, -0.12727717f, 0.086757362f, + -0.19652022f, -0.21571696f, -0.740161f, 0.36129794f), + Array(0.5646143f, 0.77578509f, 0.36025763f, -0.12205841f, -0.64061993f, 0.047942273f, + 0.19995515f, 0.77844721f, 1.0123079f, 0.08258225f, 0.38259676f, 0.48221761f, -0.43954074f, + -0.2933338f, -0.73838311f, -0.28445506f)) + assertEmbeddings(expected) + } + it should "work with max pooling" taggedAs SlowTest in { + model.setPoolingStrategy("max") + val expected = Array( + Array(0.50387824f, 0.58615935f, 0.5753994f, -0.76046306f, -0.15669955f, 0.070831679f, + 0.57988632f, 0.63754135f, 1.0989035f, 0.36707285f, 0.022289103f, 0.32705182f, + -0.094303429f, 0.065155327f, -0.59403443f, 0.64136416f), + Array(0.69840318f, 0.96176058f, 0.47431907f, -0.053866591f, -0.49888393f, 0.36105314f, + 0.34416255f, 0.91160774f, 1.192958f, 0.28027025f, 0.49664021f, 0.60586137f, -0.31200063f, + -0.21072304f, -0.46940672f, -0.10492325f)) + assertEmbeddings(expected) + } + it should "work with integer pooling" taggedAs SlowTest in { + model.setPoolingStrategy("2") + val expected = Array( + Array(0.13630758f, 0.26152137f, 0.13758762f, -1.2564588f, -0.8082003f, 0.070831679f, + 0.16906039f, 0.42769182f, 1.0989035f, 0.36707285f, -0.056193497f, 0.27165085f, + -0.094303429f, -0.38840955f, -0.73669398f, 0.21443801f), + Array(0.69840318f, 0.54228693f, 0.29342332f, -0.21559906f, -0.49888393f, 0.36105314f, + 0.14411977f, 0.52433759f, 0.72251248f, 0.039639104f, 0.37450147f, 0.59273022f, + -0.31200063f, -0.21072304f, -0.46940672f, -0.45673683f)) + assertEmbeddings(expected) + } + + it should "be compatible with LightPipeline" taggedAs SlowTest in { + model.setPoolingStrategy("cls_avg") + val pipelineModel = pipeline.fit(data) + val lightPipeline = new LightPipeline(pipelineModel) + val result = lightPipeline.fullAnnotate(rawData.toArray) + + val extractedEmbeddings: Array[Array[Float]] = + result.map(_("embeddings").head.asInstanceOf[Annotation].embeddings) + extractedEmbeddings + .zip(expected_cls_avg) + .foreach { case (embeddings, expected) => + embeddings.take(16).zip(expected).foreach { case (e, exp) => + assert(e === exp, "Embedding value not within tolerance") + } + } + } + + it should "be serializable" taggedAs SlowTest in { + model.setPoolingStrategy("cls_avg") + val pipelineModel = pipeline.fit(data) + pipelineModel + .stages(1) + .asInstanceOf[UAEEmbeddings] + .write + .overwrite() + .save("./tmp_uae_model") + + val loadedModel = UAEEmbeddings.load("./tmp_uae_model") + val newPipeline: Pipeline = + new Pipeline().setStages(Array(document, loadedModel, embeddingsFinisher)) + + assertEmbeddings(expected_cls_avg, newPipeline) + } + +}