From f28c536584504a5b359e1acc68c317a7f9ebb480 Mon Sep 17 00:00:00 2001 From: ahmedlone127 Date: Tue, 16 Jul 2024 18:49:55 +0500 Subject: [PATCH 1/2] implementing SnowFlake --- .../sparknlp/annotator/embeddings/__init__.py | 1 + .../embeddings/snowflake_embeddings.py | 202 +++++++ python/sparknlp/internal/__init__.py | 7 + .../embeddings/snowflake_embeddings_test.py | 50 ++ .../embeddings/uae_embeddings_test.py | 3 +- .../com/johnsnowlabs/ml/ai/SnowFlake.scala | 299 ++++++++++ .../com/johnsnowlabs/nlp/annotator.scala | 5 + .../nlp/embeddings/SnowFlakeEmbeddings.scala | 528 ++++++++++++++++++ .../nlp/pretrained/ResourceDownloader.scala | 3 +- .../SnowFlakeEmbeddingsTestSpec.scala | 152 +++++ 10 files changed, 1247 insertions(+), 3 deletions(-) create mode 100644 python/sparknlp/annotator/embeddings/snowflake_embeddings.py create mode 100644 python/test/annotator/embeddings/snowflake_embeddings_test.py create mode 100644 src/main/scala/com/johnsnowlabs/ml/ai/SnowFlake.scala create mode 100644 src/main/scala/com/johnsnowlabs/nlp/embeddings/SnowFlakeEmbeddings.scala create mode 100644 src/test/scala/com/johnsnowlabs/nlp/embeddings/SnowFlakeEmbeddingsTestSpec.scala diff --git a/python/sparknlp/annotator/embeddings/__init__.py b/python/sparknlp/annotator/embeddings/__init__.py index f07049bfcc3caa..dcd5d59693df6a 100644 --- a/python/sparknlp/annotator/embeddings/__init__.py +++ b/python/sparknlp/annotator/embeddings/__init__.py @@ -37,3 +37,4 @@ from sparknlp.annotator.embeddings.xlnet_embeddings import * from sparknlp.annotator.embeddings.bge_embeddings import * from sparknlp.annotator.embeddings.uae_embeddings import * +from sparknlp.annotator.embeddings.snowflake_embeddings import * \ No newline at end of file diff --git a/python/sparknlp/annotator/embeddings/snowflake_embeddings.py b/python/sparknlp/annotator/embeddings/snowflake_embeddings.py new file mode 100644 index 00000000000000..6e9a3d33e85aa0 --- /dev/null +++ b/python/sparknlp/annotator/embeddings/snowflake_embeddings.py @@ -0,0 +1,202 @@ +# Copyright 2017-2022 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains classes for SnowFlakeEmbeddings.""" + +from sparknlp.common import * + + +class SnowFlakeEmbeddings(AnnotatorModel, + HasEmbeddingsProperties, + HasCaseSensitiveProperties, + HasStorageRef, + HasBatchedAnnotate, + HasMaxSentenceLengthLimit): + """Sentence embeddings using SnowFlake. + + snowflake-arctic-embed is a suite of text embedding models that focuses on creating + high-quality retrieval models optimized for performance. + + Pretrained models can be loaded with :meth:`.pretrained` of the companion + object: + + >>> embeddings = SnowFlakeEmbeddings.pretrained() \\ + ... .setInputCols(["document"]) \\ + ... .setOutputCol("SnowFlake_embeddings") + + + The default model is ``"snowflake_artic_m"``, if no name is provided. + + For available pretrained models please see the + `Models Hub `__. + + + ====================== ====================== + Input Annotation types Output Annotation type + ====================== ====================== + ``DOCUMENT`` ``SENTENCE_EMBEDDINGS`` + ====================== ====================== + + Parameters + ---------- + batchSize + Size of every batch , by default 8 + dimension + Number of embedding dimensions, by default 768 + caseSensitive + Whether to ignore case in tokens for embeddings matching, by default False + maxSentenceLength + Max sentence length to process, by default 512 + configProtoBytes + ConfigProto from tensorflow, serialized into byte array. + + References + ---------- + + `Arctic-Embed: Scalable, Efficient, and Accurate Text Embedding Models `__ + `Snowflake Arctic-Embed Models `__ + + **Paper abstract** + + *The models are trained by leveraging existing open-source text representation models, such + as bert-base-uncased, and are trained in a multi-stage pipeline to optimize their retrieval + performance. First, the models are trained with large batches of query-document pairs where + negatives are derived in-batch—pretraining leverages about 400m samples of a mix of public + datasets and proprietary web search data. Following pretraining models are further optimized + with long training on a smaller dataset (about 1m samples) of triplets of query, positive + document, and negative document derived from hard harmful mining. Mining of the negatives and + data curation is crucial to retrieval accuracy. A detailed technical report will be available + shortly. * + + Examples + -------- + >>> import sparknlp + >>> from sparknlp.base import * + >>> from sparknlp.annotator import * + >>> from pyspark.ml import Pipeline + >>> documentAssembler = DocumentAssembler() \\ + ... .setInputCol("text") \\ + ... .setOutputCol("document") + >>> embeddings = SnowFlakeEmbeddings.pretrained() \\ + ... .setInputCols(["document"]) \\ + ... .setOutputCol("embeddings") + >>> embeddingsFinisher = EmbeddingsFinisher() \\ + ... .setInputCols("embeddings") \\ + ... .setOutputCols("finished_embeddings") \\ + ... .setOutputAsVector(True) + >>> pipeline = Pipeline().setStages([ + ... documentAssembler, + ... embeddings, + ... embeddingsFinisher + ... ]) + >>> data = spark.createDataFrame([["hello world", "hello moon"]]).toDF("text") + >>> result = pipeline.fit(data).transform(data) + >>> result.selectExpr("explode(finished_embeddings) as result").show(5, 80) + +--------------------------------------------------------------------------------+ + | result| + +--------------------------------------------------------------------------------+ + |[0.50387806, 0.5861606, 0.35129607, -0.76046336, -0.32446072, -0.117674336, 0...| + |[0.6660665, 0.961762, 0.24854276, -0.1018044, -0.6569202, 0.027635604, 0.1915...| + +--------------------------------------------------------------------------------+ + """ + + name = "SnowFlakeEmbeddings" + + inputAnnotatorTypes = [AnnotatorType.DOCUMENT] + + outputAnnotatorType = AnnotatorType.SENTENCE_EMBEDDINGS + poolingStrategy = Param(Params._dummy(), + "poolingStrategy", + "Pooling strategy to use for sentence embeddings", + TypeConverters.toString) + + def setPoolingStrategy(self, value): + """Pooling strategy to use for sentence embeddings. + + Available pooling strategies for sentence embeddings are: + - `"cls"`: leading `[CLS]` token + - `"cls_avg"`: leading `[CLS]` token + mean of all other tokens + - `"last"`: embeddings of the last token in the sequence + - `"avg"`: mean of all tokens + - `"max"`: max of all embedding features of the entire token sequence + - `"int"`: An integer number, which represents the index of the token to use as the + embedding + + Parameters + ---------- + value : str + Pooling strategy to use for sentence embeddings + """ + + valid_strategies = {"cls", "cls_avg", "last", "avg", "max"} + if value in valid_strategies or value.isdigit(): + return self._set(poolingStrategy=value) + else: + raise ValueError(f"Invalid pooling strategy: {value}. " + f"Valid strategies are: {', '.join(self.valid_strategies)} or an integer.") + + @keyword_only + def __init__(self, classname="com.johnsnowlabs.nlp.embeddings.SnowFlakeEmbeddings", java_model=None): + super(SnowFlakeEmbeddings, self).__init__( + classname=classname, + java_model=java_model + ) + self._setDefault( + dimension=1024, + batchSize=8, + maxSentenceLength=512, + caseSensitive=False, + poolingStrategy="cls" + ) + + @staticmethod + def loadSavedModel(folder, spark_session): + """Loads a locally saved model. + + Parameters + ---------- + folder : str + Folder of the saved model + spark_session : pyspark.sql.SparkSession + The current SparkSession + + Returns + ------- + SnowFlakeEmbeddings + The restored model + """ + from sparknlp.internal import _SnowFlakeEmbeddingsLoader + jModel = _SnowFlakeEmbeddingsLoader(folder, spark_session._jsparkSession)._java_obj + return SnowFlakeEmbeddings(java_model=jModel) + + @staticmethod + def pretrained(name="snowflake_artic_m", lang="en", remote_loc=None): + """Downloads and loads a pretrained model. + + Parameters + ---------- + name : str, optional + Name of the pretrained model, by default "snowflake_artic_m" + lang : str, optional + Language of the pretrained model, by default "en" + remote_loc : str, optional + Optional remote address of the resource, by default None. Will use + Spark NLPs repositories otherwise. + + Returns + ------- + SnowFlakeEmbeddings + The restored model + """ + from sparknlp.pretrained import ResourceDownloader + return ResourceDownloader.downloadModel(SnowFlakeEmbeddings, name, lang, remote_loc) diff --git a/python/sparknlp/internal/__init__.py b/python/sparknlp/internal/__init__.py index deeff9c5189f52..e68730326d2df3 100644 --- a/python/sparknlp/internal/__init__.py +++ b/python/sparknlp/internal/__init__.py @@ -910,3 +910,10 @@ def __init__(self, path, jspark): super(_UAEEmbeddingsLoader, self).__init__( "com.johnsnowlabs.nlp.embeddings.UAEEmbeddings.loadSavedModel", path, jspark ) + + +class _SnowFlakeEmbeddingsLoader(ExtendedJavaWrapper): + def __init__(self, path, jspark): + super(_SnowFlakeEmbeddingsLoader, self).__init__( + "com.johnsnowlabs.nlp.embeddings.SnowFlakeEmbeddings.loadSavedModel", path, jspark + ) diff --git a/python/test/annotator/embeddings/snowflake_embeddings_test.py b/python/test/annotator/embeddings/snowflake_embeddings_test.py new file mode 100644 index 00000000000000..7b6271a7579646 --- /dev/null +++ b/python/test/annotator/embeddings/snowflake_embeddings_test.py @@ -0,0 +1,50 @@ +# Copyright 2017-2022 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import pytest +import os + +from sparknlp.annotator import * +from sparknlp.base import * +from test.util import SparkContextForTest + + +@pytest.mark.slow +class SnowFlakeEmbeddingsTestSpec(unittest.TestCase): + def setUp(self): + self.spark = SparkContextForTest.spark + self.tested_annotator = SnowFlakeEmbeddings \ + .loadSavedModel("1", + SparkContextForTest.spark) \ + .setInputCols(["documents"]) \ + .setOutputCol("embeddings") \ + .setPoolingStrategy("cls_avg") + + def test_run(self): + data = SparkContextForTest.spark.read.option("header", "true") \ + .csv(path="file:///" + os.getcwd() + "/../src/test/resources/embeddings/sentence_embeddings.csv") + + document_assembler = DocumentAssembler() \ + .setInputCol("text") \ + .setOutputCol("documents") + + embeddings_finisher = EmbeddingsFinisher().setInputCols("embeddings").setOutputCols("embeddings") + + snowflake = self.tested_annotator + + pipeline = Pipeline().setStages([document_assembler, snowflake, embeddings_finisher]) + results = pipeline.fit(data).transform(data) + + results.selectExpr("explode(embeddings) as result").show(truncate=False) diff --git a/python/test/annotator/embeddings/uae_embeddings_test.py b/python/test/annotator/embeddings/uae_embeddings_test.py index d36083dc7a0883..814341df5b434f 100644 --- a/python/test/annotator/embeddings/uae_embeddings_test.py +++ b/python/test/annotator/embeddings/uae_embeddings_test.py @@ -25,8 +25,7 @@ class UAEEmbeddingsTestSpec(unittest.TestCase): def setUp(self): self.spark = SparkContextForTest.spark self.tested_annotator = UAEEmbeddings \ - .loadSavedModel("/home/ducha/Workspace/JSL/spark-nlp-dev-things/hf_exports/UAE/exported_onnx", - SparkContextForTest.spark) \ + .pretrained() \ .setInputCols(["documents"]) \ .setOutputCol("embeddings") \ .setPoolingStrategy("cls_avg") diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/SnowFlake.scala b/src/main/scala/com/johnsnowlabs/ml/ai/SnowFlake.scala new file mode 100644 index 00000000000000..971c5a1fc79378 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/ml/ai/SnowFlake.scala @@ -0,0 +1,299 @@ +/* + * Copyright 2017 - 2023 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.ml.ai + +import ai.onnxruntime.{OnnxTensor, TensorInfo} +import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper} +import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} +import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} +import com.johnsnowlabs.ml.util.{LinAlg, ONNX, TensorFlow} +import com.johnsnowlabs.nlp.annotators.common._ +import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} + +import scala.collection.JavaConverters._ +import scala.util.Try + +/** SnowFlake Sentence embeddings model + * @param tensorflowWrapper + * tensorflow wrapper + * @param configProtoBytes + * config proto bytes + * @param sentenceStartTokenId + * sentence start token id + * @param sentenceEndTokenId + * sentence end token id + * @param signatures + * signatures + */ +private[johnsnowlabs] class SnowFlake( + val tensorflowWrapper: Option[TensorflowWrapper], + val onnxWrapper: Option[OnnxWrapper], + configProtoBytes: Option[Array[Byte]] = None, + sentenceStartTokenId: Int, + sentenceEndTokenId: Int, + signatures: Option[Map[String, String]] = None) { + + private val _tfInstructorSignatures: Map[String, String] = + signatures.getOrElse(ModelSignatureManager.apply()) + private val paddingTokenId = 0 + + val detectedEngine: String = + if (tensorflowWrapper.isDefined) TensorFlow.name + else if (onnxWrapper.isDefined) ONNX.name + else TensorFlow.name + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions + + /** Get sentence embeddings for a batch of sentences + * + * @param batch + * batch of sentences + * @return + * sentence embeddings + */ + private def getSentenceEmbedding( + batch: Seq[Array[Int]], + poolingStrategy: String): Array[Array[Float]] = { + val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max + val paddedBatch = batch.map(arr => padArrayWithZeros(arr, maxSentenceLength)) + val sentenceEmbeddings: Array[Array[Float]] = detectedEngine match { + case ONNX.name => + getSentenceEmbeddingFromOnnx(paddedBatch, maxSentenceLength, poolingStrategy) + case _ => // TF Case + getSentenceEmbeddingFromTF(paddedBatch, maxSentenceLength, poolingStrategy) + } + + sentenceEmbeddings + } + + /** Pools word embeddings to sentence embeddings given a strategy. + * + * @param embeddings + * A 3D array of Floats representing the embeddings. The dimensions are [batch_size, + * sequence_length, embedding_dim]. + * @param attentionMask + * A 2D array of Longs representing the attention mask. The dimensions are [batch_size, + * sequence_length]. + * @param poolingStrategy + * A String representing the pooling strategy to be applied. The following strategies are + * supported: + * + * - `"cls"`: leading `[CLS]` token + * - `"cls_avg"`: leading `[CLS]` token + mean of all other tokens + * - `"last"`: embeddings of the last token in the sequence + * - `"avg"`: mean of all tokens + * - `"max"`: max of all embedding values for the token sequence + * - `"int"`: An integer number, which represents the index of the token to use as the + * embedding + * @return + * A 2D array of Floats representing the pooled embeddings. The dimensions are [batch_size, + * embedding_dim]. + */ + private def pool( + embeddings: Array[Array[Array[Float]]], + attentionMask: Array[Array[Long]], + poolingStrategy: String): Array[Array[Float]] = { + poolingStrategy match { + case "cls" => LinAlg.clsPooling(embeddings, attentionMask) + case "cls_avg" => LinAlg.clsAvgPooling(embeddings, attentionMask) + case "last" => LinAlg.lastPooling(embeddings, attentionMask) + case "avg" => + val shape: Array[Long] = + Array(embeddings.length, embeddings.head.length, embeddings.head.head.length) + val avgPooled = LinAlg.avgPooling(embeddings.flatten.flatten, attentionMask, shape) + avgPooled.t.toArray.grouped(avgPooled.cols).toArray + case "max" => LinAlg.maxPooling(embeddings, attentionMask) + case index if Try(index.toInt).isSuccess => LinAlg.tokenPooling(embeddings, index.toInt) + case _ => + throw new IllegalArgumentException(s"Pooling strategy $poolingStrategy not supported.") + } + } + + private def padArrayWithZeros(arr: Array[Int], maxLength: Int): Array[Int] = { + if (arr.length >= maxLength) { + arr + } else { + arr ++ Array.fill(maxLength - arr.length)(0) + } + } + + private def getSentenceEmbeddingFromTF( + batch: Seq[Array[Int]], + maxSentenceLength: Int, + poolingStrategy: String): Array[Array[Float]] = { + val batchLength = batch.length + + // encode batch + val tensorEncoder = new TensorResources() + val inputDim = batch.length * maxSentenceLength + + // create buffers + val encoderInputBuffers = tensorEncoder.createIntBuffer(inputDim) + val encoderAttentionMaskBuffers = tensorEncoder.createIntBuffer(inputDim) + + val shape = Array(batch.length.toLong, maxSentenceLength) + + batch.zipWithIndex.foreach { case (tokenIds, idx) => + val offset = idx * maxSentenceLength + val diff = maxSentenceLength - tokenIds.length + + // pad with 0 + val s = tokenIds.take(maxSentenceLength) ++ Array.fill[Int](diff)(this.paddingTokenId) + encoderInputBuffers.offset(offset).write(s) + + // create attention mask + val mask = s.map(x => if (x != this.paddingTokenId) 1 else 0) + encoderAttentionMaskBuffers.offset(offset).write(mask) + } + + // create tensors + val encoderInputTensors = tensorEncoder.createIntBufferTensor(shape, encoderInputBuffers) + val encoderAttentionMaskTensors = + tensorEncoder.createIntBufferTensor(shape, encoderAttentionMaskBuffers) + + // run model + val runner = tensorflowWrapper.get + .getTFSessionWithSignature( + configProtoBytes = configProtoBytes, + initAllTables = false, + savedSignatures = signatures) + .runner + + runner + .feed( + _tfInstructorSignatures.getOrElse( + ModelSignatureConstants.EncoderInputIds.key, + "missing_encoder_input_ids"), + encoderInputTensors) + .feed( + _tfInstructorSignatures.getOrElse( + ModelSignatureConstants.EncoderAttentionMask.key, + "missing_encoder_attention_mask"), + encoderAttentionMaskTensors) + .fetch(_tfInstructorSignatures + .getOrElse(ModelSignatureConstants.LastHiddenState.key, "missing_last_hidden_state")) + + // get embeddings + val sentenceEmbeddings = runner.run().asScala + val sentenceEmbeddingsFloats = TensorResources.extractFloats(sentenceEmbeddings.head) + val embeddingDim = sentenceEmbeddingsFloats.length / maxSentenceLength / batchLength + + // group embeddings + val sentenceEmbeddingsFloatsArray = + sentenceEmbeddingsFloats.grouped(embeddingDim).toArray.grouped(maxSentenceLength).toArray + + val attentionMask: Array[Array[Long]] = + TensorResources.extractLongs(encoderAttentionMaskTensors).grouped(maxSentenceLength).toArray + + // close buffers + sentenceEmbeddings.foreach(_.close()) + encoderInputTensors.close() + encoderAttentionMaskTensors.close() + tensorEncoder.clearTensors() + tensorEncoder.clearSession(sentenceEmbeddings) + + pool(sentenceEmbeddingsFloatsArray, attentionMask, poolingStrategy) + } + + private def getSentenceEmbeddingFromOnnx( + batch: Seq[Array[Int]], + maxSentenceLength: Int, + poolingStrategy: String): Array[Array[Float]] = { + + val inputIds = batch.map(x => x.map(x => x.toLong)).toArray + val attentionMask = batch.map(sentence => sentence.map(x => if (x < 0L) 0L else 1L)).toArray + + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) + + val tokenTensors = OnnxTensor.createTensor(env, inputIds) + val maskTensors = OnnxTensor.createTensor(env, attentionMask) + val segmentTensors = + OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray) + val inputs = + Map( + "input_ids" -> tokenTensors, + "attention_mask" -> maskTensors, + "token_type_ids" -> segmentTensors).asJava + + val embeddings = + try { + val results = runner.run(inputs) + val lastHiddenState = results.get("last_hidden_state").get() + val info = lastHiddenState.getInfo.asInstanceOf[TensorInfo] + val shape = info.getShape.map(_.toInt) + val Array(_, sequenceLength, embeddingDim) = shape + try { + val flattenEmbeddings = lastHiddenState + .asInstanceOf[OnnxTensor] + .getFloatBuffer + .array() + tokenTensors.close() + maskTensors.close() + segmentTensors.close() + + flattenEmbeddings.grouped(embeddingDim).toArray.grouped(sequenceLength).toArray + } finally if (results != null) results.close() + } + + pool(embeddings, attentionMask, poolingStrategy) + } + + /** Predict sentence embeddings for a batch of sentences + * + * @param sentences + * sentences + * @param tokenizedSentences + * tokenized sentences + * @param batchSize + * batch size + * @param maxSentenceLength + * max sentence length + * @return + */ + def predict( + sentences: Seq[Annotation], + tokenizedSentences: Seq[WordpieceTokenizedSentence], + batchSize: Int, + maxSentenceLength: Int, + poolingStrategy: String): Seq[Annotation] = { + + tokenizedSentences + .zip(sentences) + .zipWithIndex + .grouped(batchSize) + .toArray + .flatMap { batch => + val tokensBatch = batch.map(x => x._1._1.tokens) + val tokens = tokensBatch.map(x => + Array(sentenceStartTokenId) ++ x + .map(y => y.pieceId) + .take(maxSentenceLength - 2) ++ Array(sentenceEndTokenId)) + + val sentenceEmbeddings = getSentenceEmbedding(tokens, poolingStrategy) + + batch.zip(sentenceEmbeddings).map { case (sentence, vectors) => + Annotation( + annotatorType = AnnotatorType.SENTENCE_EMBEDDINGS, + begin = sentence._1._2.begin, + end = sentence._1._2.end, + result = sentence._1._2.result, + metadata = sentence._1._2.metadata, + embeddings = vectors) + } + } + } + +} diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotator.scala b/src/main/scala/com/johnsnowlabs/nlp/annotator.scala index 60655bda2809ec..9239f43831eeaa 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotator.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotator.scala @@ -804,4 +804,9 @@ package object annotator { type UAEEmbeddings = com.johnsnowlabs.nlp.embeddings.UAEEmbeddings object UAEEmbeddings extends ReadablePretrainedUAEModel with ReadUAEDLModel + + type SnowFlakeEmbeddings = + com.johnsnowlabs.nlp.embeddings.SnowFlakeEmbeddings + + object SnowFlakeEmbeddings extends ReadablePretrainedSnowFlakeModel with ReadSnowFlakeDLModel } diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/SnowFlakeEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/SnowFlakeEmbeddings.scala new file mode 100644 index 00000000000000..ad62a5a0a2faa7 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/SnowFlakeEmbeddings.scala @@ -0,0 +1,528 @@ +/* + * Copyright 2017-2022 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.embeddings + +import com.johnsnowlabs.ml.ai.SnowFlake +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} +import com.johnsnowlabs.ml.tensorflow._ +import com.johnsnowlabs.ml.util.LoadExternalModel.{ + loadTextAsset, + modelSanityCheck, + notSupportedEngineError +} +import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} +import com.johnsnowlabs.nlp._ +import com.johnsnowlabs.nlp.annotators.common._ +import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder} +import com.johnsnowlabs.nlp.serialization.MapFeature +import com.johnsnowlabs.storage.HasStorageRef +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.slf4j.{Logger, LoggerFactory} + +import scala.util.Try + +/** Sentence embeddings using SnowFlake. + * + * snowflake-arctic-embed is a suite of text embedding models that focuses on creating + * high-quality retrieval models optimized for performance. + * + * Pretrained models can be loaded with `pretrained` of the companion object: + * {{{ + * val embeddings = SnowFlakeEmbeddings.pretrained() + * .setInputCols("document") + * .setOutputCol("snowflake_embeddings") + * }}} + * The default model is `"snowflake_artic_m"`, if no name is provided. + * + * For available pretrained models please see the + * [[https://sparknlp.org/models?q=snowflake Models Hub]]. + * + * For extended examples of usage, see + * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/embeddings/SnowFlakeEmbeddingsTestSpec.scala SnowFlakeEmbeddingsTestSpec]]. + * + * '''Sources''' : + * + * [[https://arxiv.org/abs/2405.05374 Arctic-Embed: Scalable, Efficient, and Accurate Text Embedding Models]] + * + * [[https://github.com/Snowflake-Labs/arctic-embed Snowflake Arctic-Embed Models]] + * + * ''' Paper abstract ''' + * + * ''The models are trained by leveraging existing open-source text representation models, such + * as bert-base-uncased, and are trained in a multi-stage pipeline to optimize their retrieval + * performance. First, the models are trained with large batches of query-document pairs where + * negatives are derived in-batch—pretraining leverages about 400m samples of a mix of public + * datasets and proprietary web search data. Following pretraining models are further optimized + * with long training on a smaller dataset (about 1m samples) of triplets of query, positive + * document, and negative document derived from hard harmful mining. Mining of the negatives and + * data curation is crucial to retrieval accuracy. A detailed technical report will be available + * shortly. '' + * + * ==Example== + * {{{ + * import spark.implicits._ + * import com.johnsnowlabs.nlp.base.DocumentAssembler + * import com.johnsnowlabs.nlp.annotators.Tokenizer + * import com.johnsnowlabs.nlp.embeddings.SnowFlakeEmbeddings + * import com.johnsnowlabs.nlp.EmbeddingsFinisher + * import org.apache.spark.ml.Pipeline + * + * val documentAssembler = new DocumentAssembler() + * .setInputCol("text") + * .setOutputCol("document") + * + * val embeddings = SnowFlakeEmbeddings.pretrained() + * .setInputCols("document") + * .setOutputCol("snowflake_embeddings") + * + * val embeddingsFinisher = new EmbeddingsFinisher() + * .setInputCols("snowflake_embeddings") + * .setOutputCols("finished_embeddings") + * .setOutputAsVector(true) + * + * val pipeline = new Pipeline().setStages(Array( + * documentAssembler, + * embeddings, + * embeddingsFinisher + * )) + * + * val data = Seq("hello world", "hello moon").toDF("text") + * val result = pipeline.fit(data).transform(data) + * + * result.selectExpr("explode(finished_embeddings) as result").show(5, 80) + * --------------------+ + * finished_embeddings| + * --------------------+ + * [[-0.45763275, 0....| + * [[-0.43076283, 0....| + * --------------------+ + * }}} + * + * @see + * [[https://sparknlp.org/docs/en/annotators Annotators Main Page]] for a list of transformer + * based embeddings + * @param uid + * required uid for storing annotator to disk + * @groupname anno Annotator types + * @groupdesc anno + * Required input and expected output annotator types + * @groupname Ungrouped Members + * @groupname param Parameters + * @groupname setParam Parameter setters + * @groupname getParam Parameter getters + * @groupname Ungrouped Members + * @groupprio param 1 + * @groupprio anno 2 + * @groupprio Ungrouped 3 + * @groupprio setParam 4 + * @groupprio getParam 5 + * @groupdesc param + * A list of (hyper-)parameter keys this annotator can take. Users can set and get the + * parameter values through setters and getters, respectively. + */ +class SnowFlakeEmbeddings(override val uid: String) + extends AnnotatorModel[SnowFlakeEmbeddings] + with HasBatchedAnnotate[SnowFlakeEmbeddings] + with WriteTensorflowModel + with WriteOnnxModel + with HasEmbeddingsProperties + with HasStorageRef + with HasCaseSensitiveProperties + with HasEngine { + + /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator + * type + */ + override val inputAnnotatorTypes: Array[String] = + Array(AnnotatorType.DOCUMENT) + override val outputAnnotatorType: AnnotatorType = AnnotatorType.SENTENCE_EMBEDDINGS + + /** ConfigProto from tensorflow, serialized into byte array. Get with + * `config_proto.SerializeToString()` + * + * @group param + */ + val configProtoBytes = new IntArrayParam( + this, + "configProtoBytes", + "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()") + + /** Max sentence length to process (Default: `128`) + * + * @group param + */ + val maxSentenceLength = + new IntParam(this, "maxSentenceLength", "Max sentence length to process") + + def sentenceStartTokenId: Int = { + $$(vocabulary)("[CLS]") + } + + /** @group setParam */ + def sentenceEndTokenId: Int = { + $$(vocabulary)("[SEP]") + } + + /** Vocabulary used to encode the words to ids with WordPieceEncoder + * + * @group param + */ + val vocabulary: MapFeature[String, Int] = new MapFeature(this, "vocabulary").setProtected() + + /** @group setParam */ + def setVocabulary(value: Map[String, Int]): this.type = set(vocabulary, value) + + /** It contains TF model signatures for the laded saved model + * + * @group param + */ + val signatures = + new MapFeature[String, String](model = this, name = "signatures").setProtected() + private var _model: Option[Broadcast[SnowFlake]] = None + + def this() = this(Identifiable.randomUID("SnowFlake_EMBEDDINGS")) + + /** @group setParam */ + def setConfigProtoBytes(bytes: Array[Int]): SnowFlakeEmbeddings.this.type = + set(this.configProtoBytes, bytes) + + /** @group setParam */ + def setMaxSentenceLength(value: Int): this.type = { + require( + value <= 512, + "SnowFlake models do not support sequences longer than 512 because of trainable positional embeddings.") + require(value >= 1, "The maxSentenceLength must be at least 1") + set(maxSentenceLength, value) + this + } + + /** @group getParam */ + def getMaxSentenceLength: Int = $(maxSentenceLength) + + /** @group setParam */ + def setSignatures(value: Map[String, String]): this.type = { + if (get(signatures).isEmpty) + set(signatures, value) + this + } + + /** Pooling strategy to use for sentence embeddings. + * + * Available pooling strategies for sentence embeddings are: + * + * - `"cls"`: leading `[CLS]` token + * - `"cls_avg"`: leading `[CLS]` token + mean of all other tokens + * - `"last"`: embeddings of the last token in the sequence + * - `"avg"`: mean of all tokens + * - `"max"`: max of all embedding values for the token sequence + * - `"all"`: return all token embeddings + * - `"int"`: An integer number, which represents the index of the token to use as the + * embedding + * + * @group param + */ + val poolingStrategy = + new Param[String](this, "poolingStrategy", "Pooling strategy to use for sentence embeddings") + + def getPoolingStrategy: String = $(poolingStrategy) + + /** Pooling strategy to use for sentence embeddings. + * + * Available pooling strategies for sentence embeddings are: + * + * - `"cls"`: leading `[CLS]` token + * - `"cls_avg"`: leading `[CLS]` token + mean of all other tokens + * - `"last"`: embeddings of the last token in the sequence + * - `"avg"`: mean of all tokens + * - `"max"`: max of all embedding features of the entire token sequence + * - `"int"`: An integer number, which represents the index of the token to use as the + * embedding + * + * @group setParam + */ + def setPoolingStrategy(value: String): this.type = { + val validStrategies = Set("cls", "cls_avg", "last", "avg", "max") + + if (validStrategies.contains(value) || Try(value.toInt).isSuccess) { + set(poolingStrategy, value) + } else { + throw new IllegalArgumentException( + s"Invalid pooling strategy: $value. " + + s"Valid strategies are: ${validStrategies.mkString(", ")} or an integer.") + } + } + + /** @group setParam */ + def setModelIfNotSet( + spark: SparkSession, + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper]): SnowFlakeEmbeddings = { + if (_model.isEmpty) { + _model = Some( + spark.sparkContext.broadcast( + new SnowFlake( + tensorflowWrapper, + onnxWrapper, + configProtoBytes = getConfigProtoBytes, + sentenceStartTokenId = sentenceStartTokenId, + sentenceEndTokenId = sentenceEndTokenId, + signatures = getSignatures))) + } + + this + } + + /** Set Embeddings dimensions for the BERT model Only possible to set this when the first time + * is saved dimension is not changeable, it comes from BERT config file + * + * @group setParam + */ + override def setDimension(value: Int): this.type = { + if (get(dimension).isEmpty) + set(this.dimension, value) + this + } + + /** Whether to lowercase tokens or not + * + * @group setParam + */ + override def setCaseSensitive(value: Boolean): this.type = { + if (get(caseSensitive).isEmpty) + set(this.caseSensitive, value) + this + } + + setDefault( + dimension -> 1024, + batchSize -> 8, + maxSentenceLength -> 512, + caseSensitive -> false, + poolingStrategy -> "cls") + + def tokenize(sentences: Seq[Annotation]): Seq[WordpieceTokenizedSentence] = { + val basicTokenizer = new BasicTokenizer($(caseSensitive)) + val encoder = new WordpieceEncoder($$(vocabulary)) + sentences.map { s => + val sent = Sentence( + content = s.result, + start = s.begin, + end = s.end, + metadata = Some(s.metadata), + index = s.begin) + val tokens = basicTokenizer.tokenize(sent) + val wordpieceTokens = tokens.flatMap(token => encoder.encode(token)) + WordpieceTokenizedSentence(wordpieceTokens) + } + } + + /** takes a document and annotations and produces new annotations of this annotator's annotation + * type + * + * @param batchedAnnotations + * Annotations that correspond to inputAnnotationCols generated by previous annotators if any + * @return + * any number of annotations processed for every input annotation. Not necessary one to one + * relationship + */ + override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = { + + val allAnnotations = batchedAnnotations + .filter(_.nonEmpty) + .zipWithIndex + .flatMap { case (annotations, i) => + annotations.filter(_.result.nonEmpty).map(x => (x, i)) + } + + // Tokenize sentences + val tokenizedSentences = tokenize(allAnnotations.map(_._1)) + val processedAnnotations = if (allAnnotations.nonEmpty) { + this.getModelIfNotSet.predict( + sentences = allAnnotations.map(_._1), + tokenizedSentences = tokenizedSentences, + batchSize = $(batchSize), + maxSentenceLength = $(maxSentenceLength), + poolingStrategy = getPoolingStrategy) + } else { + Seq() + } + + // Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence + batchedAnnotations.indices.map(rowIndex => { + val rowAnnotations = processedAnnotations + // zip each annotation with its corresponding row index + .zip(allAnnotations) + // select the sentences belonging to the current row + .filter(_._2._2 == rowIndex) + // leave the annotation only + .map(_._1) + + if (rowAnnotations.nonEmpty) + rowAnnotations + else + Seq.empty[Annotation] + }) + + } + + /** @group getParam */ + def getModelIfNotSet: SnowFlake = _model.get.value + + override def onWrite(path: String, spark: SparkSession): Unit = { + super.onWrite(path, spark) + val suffix = "_SnowFlake" + + getEngine match { + case TensorFlow.name => + writeTensorflowModelV2( + path, + spark, + getModelIfNotSet.tensorflowWrapper.get, + suffix, + SnowFlakeEmbeddings.tfFile, + configProtoBytes = getConfigProtoBytes, + savedSignatures = getSignatures) + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + suffix, + SnowFlakeEmbeddings.onnxFile) + + case _ => + throw new Exception(notSupportedEngineError) + } + } + + /** @group getParam */ + def getConfigProtoBytes: Option[Array[Byte]] = get(this.configProtoBytes).map(_.map(_.toByte)) + + /** @group getParam */ + def getSignatures: Option[Map[String, String]] = get(this.signatures) + + override protected def afterAnnotate(dataset: DataFrame): DataFrame = { + dataset.withColumn( + getOutputCol, + wrapSentenceEmbeddingsMetadata( + dataset.col(getOutputCol), + $(dimension), + Some($(storageRef)))) + } + +} + +trait ReadablePretrainedSnowFlakeModel + extends ParamsAndFeaturesReadable[SnowFlakeEmbeddings] + with HasPretrained[SnowFlakeEmbeddings] { + override val defaultModelName: Some[String] = Some("snowflake_artic_m") + + /** Java compliant-overrides */ + override def pretrained(): SnowFlakeEmbeddings = super.pretrained() + + override def pretrained(name: String): SnowFlakeEmbeddings = super.pretrained(name) + + override def pretrained(name: String, lang: String): SnowFlakeEmbeddings = + super.pretrained(name, lang) + + override def pretrained(name: String, lang: String, remoteLoc: String): SnowFlakeEmbeddings = + super.pretrained(name, lang, remoteLoc) +} + +trait ReadSnowFlakeDLModel extends ReadTensorflowModel with ReadOnnxModel { + this: ParamsAndFeaturesReadable[SnowFlakeEmbeddings] => + + override val tfFile: String = "SnowFlake_tensorflow" + override val onnxFile: String = "SnowFlake_onnx" + + def readModel(instance: SnowFlakeEmbeddings, path: String, spark: SparkSession): Unit = { + + instance.getEngine match { + case TensorFlow.name => + val tfWrapper = readTensorflowModel(path, spark, "_SnowFlake_tf") + instance.setModelIfNotSet(spark, Some(tfWrapper), None) + + case ONNX.name => + val onnxWrapper = + readOnnxModel(path, spark, "_SnowFlake_onnx", zipped = true, useBundle = false, None) + instance.setModelIfNotSet(spark, None, Some(onnxWrapper)) + + case _ => + throw new Exception(notSupportedEngineError) + } + + } + + addReader(readModel) + + def loadSavedModel(modelPath: String, spark: SparkSession): SnowFlakeEmbeddings = { + + val (localModelPath, detectedEngine) = modelSanityCheck(modelPath) + + val vocabs = loadTextAsset(localModelPath, "vocab.txt").zipWithIndex.toMap + + /*Universal parameters for all engines*/ + val annotatorModel = new SnowFlakeEmbeddings() + .setVocabulary(vocabs) + + annotatorModel.set(annotatorModel.engine, detectedEngine) + + detectedEngine match { + case TensorFlow.name => + val (wrapper, signatures) = + TensorflowWrapper.read( + localModelPath, + zipped = false, + useBundle = true, + tags = Array("serve")) + + val _signatures = signatures match { + case Some(s) => s + case None => throw new Exception("Cannot load signature definitions from model!") + } + + /** the order of setSignatures is important if we use getSignatures inside + * setModelIfNotSet + */ + annotatorModel + .setSignatures(_signatures) + .setModelIfNotSet(spark, Some(wrapper), None) + + case ONNX.name => + val onnxWrapper = + OnnxWrapper.read(spark, localModelPath, zipped = false, useBundle = true) + annotatorModel + .setModelIfNotSet(spark, None, Some(onnxWrapper)) + + case _ => + throw new Exception(notSupportedEngineError) + } + + annotatorModel + } +} + +/** This is the companion object of [[SnowFlakeEmbeddings]]. Please refer to that class for the + * documentation. + */ +object SnowFlakeEmbeddings extends ReadablePretrainedSnowFlakeModel with ReadSnowFlakeDLModel { + private[SnowFlakeEmbeddings] val logger: Logger = + LoggerFactory.getLogger("SnowFlakeEmbeddings") +} diff --git a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala index e8f797e56e3238..f1c05f6d966786 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala @@ -690,7 +690,8 @@ object PythonResourceDownloader { "MPNetForQuestionAnswering" -> MPNetForQuestionAnswering, "LLAMA2Transformer" -> LLAMA2Transformer, "M2M100Transformer" -> M2M100Transformer, - "UAEEmbeddings" -> UAEEmbeddings) + "UAEEmbeddings" -> UAEEmbeddings, + "SnowFlakeEmbeddings" -> SnowFlakeEmbeddings) // List pairs of types such as the one with key type can load a pretrained model from the value type val typeMapper: Map[String, String] = Map("ZeroShotNerModel" -> "RoBertaForQuestionAnswering") diff --git a/src/test/scala/com/johnsnowlabs/nlp/embeddings/SnowFlakeEmbeddingsTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/embeddings/SnowFlakeEmbeddingsTestSpec.scala new file mode 100644 index 00000000000000..21145670e70c8d --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/embeddings/SnowFlakeEmbeddingsTestSpec.scala @@ -0,0 +1,152 @@ +/* + * Copyright 2017-2022 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.embeddings + +import com.johnsnowlabs.nlp.annotators.sentence_detector_dl.SentenceDetectorDLModel +import com.johnsnowlabs.nlp.base.DocumentAssembler +import com.johnsnowlabs.nlp.util.io.ResourceHelper +import com.johnsnowlabs.tags.SlowTest +import org.apache.spark.ml.Pipeline +import org.apache.spark.sql.functions.{col, size} +import org.scalatest.flatspec.AnyFlatSpec + +class SnowFlakeEmbeddingsTestSpec extends AnyFlatSpec { + + "SnowFlake Embeddings" should "correctly embed multiple sentences" taggedAs SlowTest in { + + import ResourceHelper.spark.implicits._ + + val ddd = Seq("i love to cook food and eat it") + .toDF("text") + + val document = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + + val embeddings = SnowFlakeEmbeddings + .loadSavedModel("1", ResourceHelper.spark) + .setInputCols(Array("document")) + .setOutputCol("snowflake") + + val pipeline = new Pipeline().setStages(Array(document, embeddings)) + + val pipelineDF = pipeline.fit(ddd).transform(ddd) + pipelineDF.select("snowflake.embeddings").show(truncate = false) + + } + + it should "have embeddings of the same size" taggedAs SlowTest in { + import ResourceHelper.spark.implicits._ + val testDf = Seq( + "I like apples", + "I like bananas \\n and other things \\n like icream \\n and cats", + "I like rockets") + .toDF("text") + + val document = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + + val embeddings = SnowFlakeEmbeddings + .pretrained() + .setInputCols(Array("document")) + .setOutputCol("snowflake") + + val pipeline = new Pipeline().setStages(Array(document, embeddings)) + + val pipelineDF = pipeline.fit(testDf).transform(testDf) + + val embeddingsDF = pipelineDF.withColumn("embeddings", col("snowflake.embeddings").getItem(0)) + + val sizesArray: Array[Int] = embeddingsDF + .select(size(col("embeddings")).as("size")) + .collect() + .map(row => row.getAs[Int]("size")) + + assert(sizesArray.forall(_ == sizesArray.head)) + } + + it should "work with sentences" taggedAs SlowTest in { + import ResourceHelper.spark.implicits._ + val testData = "I really enjoy my job. This is amazing" + val testDf = Seq(testData).toDF("text") + + val document = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + + val sentenceDetectorDL = SentenceDetectorDLModel + .pretrained("sentence_detector_dl", "en") + .setInputCols(Array("document")) + .setOutputCol("sentences") + + val embeddings = SnowFlakeEmbeddings + .pretrained() + .setInputCols(Array("sentences")) + .setOutputCol("snowflake") + + val pipeline = new Pipeline().setStages(Array(document, sentenceDetectorDL, embeddings)) + + val pipelineDF = pipeline.fit(testDf).transform(testDf) + pipelineDF.select("snowflake.embeddings").show(false) + } + + it should "not return empty embeddings" taggedAs SlowTest in { + import ResourceHelper.spark.implicits._ + val interests = Seq( + "I like music", + "I like movies", + "I like books", + "I like sports", + "I like travel", + "I like food", + "I like games", + "I like art", + "I like nature", + "I like science", + "I like technology", + "I like history", + "I like fashion", + "I like cars", + "I like animals", + "I like gardening") + val testDf = interests.toDF("text") + + val document = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + + val embeddings = SnowFlakeEmbeddings + .pretrained() + .setInputCols(Array("document")) + .setOutputCol("snowflake") + + val pipeline = new Pipeline().setStages(Array(document, embeddings)) + + val pipelineDF = pipeline.fit(testDf).transform(testDf) + + val embeddingsDF = pipelineDF.withColumn("embeddings", col("snowflake.embeddings").getItem(0)) + + val sizesArray: Array[Int] = embeddingsDF + .select(size(col("embeddings")).as("size")) + .collect() + .map(row => row.getAs[Int]("size")) + + assert(sizesArray.forall(_ > 0)) + } + +} From 43203d2080182b57ecab9bb5063c73af95ce7ebf Mon Sep 17 00:00:00 2001 From: ahmedlone127 Date: Tue, 16 Jul 2024 19:59:02 +0500 Subject: [PATCH 2/2] typo fix --- .../nlp/embeddings/SnowFlakeEmbeddingsTestSpec.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/scala/com/johnsnowlabs/nlp/embeddings/SnowFlakeEmbeddingsTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/embeddings/SnowFlakeEmbeddingsTestSpec.scala index 21145670e70c8d..da2d249bab2a9c 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/embeddings/SnowFlakeEmbeddingsTestSpec.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/embeddings/SnowFlakeEmbeddingsTestSpec.scala @@ -30,7 +30,7 @@ class SnowFlakeEmbeddingsTestSpec extends AnyFlatSpec { import ResourceHelper.spark.implicits._ - val ddd = Seq("i love to cook food and eat it") + val ddd = Seq("This is an example sentence", "Each sentence is converted") .toDF("text") val document = new DocumentAssembler() @@ -38,7 +38,7 @@ class SnowFlakeEmbeddingsTestSpec extends AnyFlatSpec { .setOutputCol("document") val embeddings = SnowFlakeEmbeddings - .loadSavedModel("1", ResourceHelper.spark) + .pretrained() .setInputCols(Array("document")) .setOutputCol("snowflake")