From 0c8371e4ef4ea7decfe0d3bf71cf5e10cb677418 Mon Sep 17 00:00:00 2001 From: David Cecchini Date: Tue, 12 Dec 2023 14:52:04 -0300 Subject: [PATCH 1/2] Added BGE Embeddings --- .../annotator/embeddings/bge_embeddings.py | 192 +++++++ python/sparknlp/internal/__init__.py | 3 + .../scala/com/johnsnowlabs/ml/ai/BGE.scala | 247 +++++++++ .../com/johnsnowlabs/ml/util/LinAlg.scala | 8 +- .../nlp/embeddings/BGEEmbeddings.scala | 482 ++++++++++++++++++ .../embeddings/BGEEmbeddingsTestSpec.scala | 116 +++++ 6 files changed, 1047 insertions(+), 1 deletion(-) create mode 100644 python/sparknlp/annotator/embeddings/bge_embeddings.py create mode 100644 src/main/scala/com/johnsnowlabs/ml/ai/BGE.scala create mode 100644 src/main/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddings.scala create mode 100644 src/test/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddingsTestSpec.scala diff --git a/python/sparknlp/annotator/embeddings/bge_embeddings.py b/python/sparknlp/annotator/embeddings/bge_embeddings.py new file mode 100644 index 00000000000000..38b39c3132e274 --- /dev/null +++ b/python/sparknlp/annotator/embeddings/bge_embeddings.py @@ -0,0 +1,192 @@ +# Copyright 2017-2022 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains classes for BGEEmbeddings.""" + +from sparknlp.common import * + + +class BGEEmbeddings(AnnotatorModel, + HasEmbeddingsProperties, + HasCaseSensitiveProperties, + HasStorageRef, + HasBatchedAnnotate, + HasMaxSentenceLengthLimit): + """Sentence embeddings using BGE. + + BGE, or BAAI General Embeddings, a model that can map any text to a low-dimensional dense + vector which can be used for tasks like retrieval, classification, clustering, or semantic search. + + Pretrained models can be loaded with `pretrained` of the companion object: + + >>> embeddings = BGEEmbeddings.pretrained() \\ + ... .setInputCols(["document"]) \\ + ... .setOutputCol("e5_embeddings") + + + The default model is ``"bge_base"``, if no name is provided. + + For available pretrained models please see the + `Models Hub `__. + + + ====================== ====================== + Input Annotation types Output Annotation type + ====================== ====================== + ``DOCUMENT`` ``SENTENCE_EMBEDDINGS`` + ====================== ====================== + + Parameters + ---------- + batchSize + Size of every batch , by default 8 + dimension + Number of embedding dimensions, by default 768 + caseSensitive + Whether to ignore case in tokens for embeddings matching, by default False + maxSentenceLength + Max sentence length to process, by default 512 + configProtoBytes + ConfigProto from tensorflow, serialized into byte array. + + References + ---------- + `C-Pack: Packaged Resources To Advance General Chinese Embedding `__ + `BGE Github Repository `__ + + **Paper abstract** + + *We introduce C-Pack, a package of resources that significantly advance the field of general + Chinese embeddings. C-Pack includes three critical resources. + 1) C-MTEB is a comprehensive benchmark for Chinese text embeddings covering 6 tasks and 35 datasets. + 2) C-MTP is a massive text embedding dataset curated from labeled and unlabeled Chinese corpora + for training embedding models. + 3) C-TEM is a family of embedding models covering multiple sizes. + Our models outperform all prior Chinese text embeddings on C-MTEB by up to +10% upon the + time of the release. We also integrate and optimize the entire suite of training methods for + C-TEM. Along with our resources on general Chinese embedding, we release our data and models for + English text embeddings. The English models achieve stateof-the-art performance on the MTEB + benchmark; meanwhile, our released English data is 2 times larger than the Chinese data. All + these resources are made publicly available at https://github.com/FlagOpen/FlagEmbedding.* + + Examples + -------- + >>> import sparknlp + >>> from sparknlp.base import * + >>> from sparknlp.annotator import * + >>> from pyspark.ml import Pipeline + >>> documentAssembler = DocumentAssembler() \\ + ... .setInputCol("text") \\ + ... .setOutputCol("document") + >>> embeddings = BGEEmbeddings.pretrained() \\ + ... .setInputCols(["document"]) \\ + ... .setOutputCol("bge_embeddings") + >>> embeddingsFinisher = EmbeddingsFinisher() \\ + ... .setInputCols(["bge_embeddings"]) \\ + ... .setOutputCols("finished_embeddings") \\ + ... .setOutputAsVector(True) + >>> pipeline = Pipeline().setStages([ + ... documentAssembler, + ... embeddings, + ... embeddingsFinisher + ... ]) + >>> data = spark.createDataFrame([["query: how much protein should a female eat", + ... "passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day." + \ + ... "But, as you can see from this chart, you'll need to increase that if you're expecting or training for a" + \ + ... "marathon. Check out the chart below to see how much protein you should be eating each day.", + ... ]]).toDF("text") + >>> result = pipeline.fit(data).transform(data) + >>> result.selectExpr("explode(finished_embeddings) as result").show(5, 80) + +--------------------------------------------------------------------------------+ + | result| + +--------------------------------------------------------------------------------+ + |[[8.0190285E-4, -0.005974853, -0.072875895, 0.007944068, 0.026059335, -0.0080...| + |[[0.050514214, 0.010061974, -0.04340176, -0.020937217, 0.05170225, 0.01157857...| + +--------------------------------------------------------------------------------+ + """ + + name = "BGEEmbeddings" + + inputAnnotatorTypes = [AnnotatorType.DOCUMENT] + + outputAnnotatorType = AnnotatorType.SENTENCE_EMBEDDINGS + configProtoBytes = Param(Params._dummy(), + "configProtoBytes", + "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()", + TypeConverters.toListInt) + + + def setConfigProtoBytes(self, b): + """Sets configProto from tensorflow, serialized into byte array. + + Parameters + ---------- + b : List[int] + ConfigProto from tensorflow, serialized into byte array + """ + return self._set(configProtoBytes=b) + + @keyword_only + def __init__(self, classname="com.johnsnowlabs.nlp.embeddings.BGEEmbeddings", java_model=None): + super(BGEEmbeddings, self).__init__( + classname=classname, + java_model=java_model + ) + self._setDefault( + dimension=768, + batchSize=8, + maxSentenceLength=512, + caseSensitive=False, + ) + + @staticmethod + def loadSavedModel(folder, spark_session): + """Loads a locally saved model. + + Parameters + ---------- + folder : str + Folder of the saved model + spark_session : pyspark.sql.SparkSession + The current SparkSession + + Returns + ------- + BGEEmbeddings + The restored model + """ + from sparknlp.internal import _BGELoader + jModel = _BGELoader(folder, spark_session._jsparkSession)._java_obj + return BGEEmbeddings(java_model=jModel) + + @staticmethod + def pretrained(name="bge_base", lang="en", remote_loc=None): + """Downloads and loads a pretrained model. + + Parameters + ---------- + name : str, optional + Name of the pretrained model, by default "bge_base" + lang : str, optional + Language of the pretrained model, by default "en" + remote_loc : str, optional + Optional remote address of the resource, by default None. Will use + Spark NLPs repositories otherwise. + + Returns + ------- + BGEEmbeddings + The restored model + """ + from sparknlp.pretrained import ResourceDownloader + return ResourceDownloader.downloadModel(BGEEmbeddings, name, lang, remote_loc) diff --git a/python/sparknlp/internal/__init__.py b/python/sparknlp/internal/__init__.py index 80e3749e323875..f49a5e4768deab 100644 --- a/python/sparknlp/internal/__init__.py +++ b/python/sparknlp/internal/__init__.py @@ -147,6 +147,9 @@ class _E5Loader(ExtendedJavaWrapper): def __init__(self, path, jspark): super(_E5Loader, self).__init__("com.johnsnowlabs.nlp.embeddings.E5Embeddings.loadSavedModel", path, jspark) +class _BGELoader(ExtendedJavaWrapper): + def __init__(self, path, jspark): + super(_BGELoader, self).__init__("com.johnsnowlabs.nlp.embeddings.BGEEmbeddings.loadSavedModel", path, jspark) class _GPT2Loader(ExtendedJavaWrapper): def __init__(self, path, jspark): diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/BGE.scala b/src/main/scala/com/johnsnowlabs/ml/ai/BGE.scala new file mode 100644 index 00000000000000..fb421b1fd58ebf --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/ml/ai/BGE.scala @@ -0,0 +1,247 @@ +/* + * Copyright 2017 - 2023 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.ml.ai + +import ai.onnxruntime.{OnnxTensor, TensorInfo} +import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper} +import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} +import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} +import com.johnsnowlabs.ml.util.{LinAlg, ONNX, TensorFlow} +import com.johnsnowlabs.nlp.annotators.common._ +import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} + +import scala.collection.JavaConverters._ + +/** BGE Sentence embeddings model + * @param tensorflowWrapper + * tensorflow wrapper + * @param configProtoBytes + * config proto bytes + * @param sentenceStartTokenId + * sentence start token id + * @param sentenceEndTokenId + * sentence end token id + * @param signatures + * signatures + */ +private[johnsnowlabs] class BGE( + val tensorflowWrapper: Option[TensorflowWrapper], + val onnxWrapper: Option[OnnxWrapper], + configProtoBytes: Option[Array[Byte]] = None, + sentenceStartTokenId: Int, + sentenceEndTokenId: Int, + signatures: Option[Map[String, String]] = None) + extends Serializable { + + private val _tfInstructorSignatures: Map[String, String] = + signatures.getOrElse(ModelSignatureManager.apply()) + private val paddingTokenId = 0 + + val detectedEngine: String = + if (tensorflowWrapper.isDefined) TensorFlow.name + else if (onnxWrapper.isDefined) ONNX.name + else TensorFlow.name + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions + + /** Get sentence embeddings for a batch of sentences + * @param batch + * batch of sentences + * @return + * sentence embeddings + */ + private def getSentenceEmbedding(batch: Seq[Array[Int]]): Array[Array[Float]] = { + val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max + val paddedBatch = batch.map(arr => padArrayWithZeros(arr, maxSentenceLength)) + val embeddings = detectedEngine match { + case ONNX.name => + getSentenceEmbeddingFromOnnx(paddedBatch, maxSentenceLength) + case _ => + getSentenceEmbeddingFromTF(paddedBatch, maxSentenceLength) + } + embeddings + } + + private def padArrayWithZeros(arr: Array[Int], maxLength: Int): Array[Int] = { + if (arr.length >= maxLength) { + arr + } else { + arr ++ Array.fill(maxLength - arr.length)(0) + } + } + + private def getSentenceEmbeddingFromTF( + batch: Seq[Array[Int]], + maxSentenceLength: Int): Array[Array[Float]] = { + val batchLength = batch.length + + // encode batch + val tensorEncoder = new TensorResources() + val inputDim = batch.length * maxSentenceLength + + // create buffers + val encoderInputBuffers = tensorEncoder.createIntBuffer(inputDim) + val encoderAttentionMaskBuffers = tensorEncoder.createIntBuffer(inputDim) + + val shape = Array(batch.length.toLong, maxSentenceLength) + + batch.zipWithIndex.foreach { case (tokenIds, idx) => + val offset = idx * maxSentenceLength + val diff = maxSentenceLength - tokenIds.length + + // pad with 0 + val s = tokenIds.take(maxSentenceLength) ++ Array.fill[Int](diff)(this.paddingTokenId) + encoderInputBuffers.offset(offset).write(s) + + // create attention mask + val mask = s.map(x => if (x != this.paddingTokenId) 1 else 0) + encoderAttentionMaskBuffers.offset(offset).write(mask) + + } + + // create tensors + val encoderInputTensors = tensorEncoder.createIntBufferTensor(shape, encoderInputBuffers) + val encoderAttentionMaskTensors = + tensorEncoder.createIntBufferTensor(shape, encoderAttentionMaskBuffers) + + // run model + val runner = tensorflowWrapper.get + .getTFSessionWithSignature( + configProtoBytes = configProtoBytes, + initAllTables = false, + savedSignatures = signatures) + .runner + + runner + .feed( + _tfInstructorSignatures.getOrElse( + ModelSignatureConstants.EncoderInputIds.key, + "missing_encoder_input_ids"), + encoderInputTensors) + .feed( + _tfInstructorSignatures.getOrElse( + ModelSignatureConstants.EncoderAttentionMask.key, + "missing_encoder_attention_mask"), + encoderAttentionMaskTensors) + .fetch(_tfInstructorSignatures + .getOrElse(ModelSignatureConstants.LastHiddenState.key, "missing_last_hidden_state")) + + // get embeddings + val sentenceEmbeddings = runner.run().asScala + val sentenceEmbeddingsFloats = TensorResources.extractFloats(sentenceEmbeddings.head) + val dim = sentenceEmbeddingsFloats.length / batchLength + + // group embeddings + val sentenceEmbeddingsFloatsArray = sentenceEmbeddingsFloats.grouped(dim).toArray + + // close buffers + sentenceEmbeddings.foreach(_.close()) + encoderInputTensors.close() + encoderAttentionMaskTensors.close() + tensorEncoder.clearTensors() + tensorEncoder.clearSession(sentenceEmbeddings) + + sentenceEmbeddingsFloatsArray + } + + private def getSentenceEmbeddingFromOnnx( + batch: Seq[Array[Int]], + maxSentenceLength: Int): Array[Array[Float]] = { + + val inputIds = batch.map(x => x.map(x => x.toLong)).toArray + val attentionMask = batch.map(sentence => sentence.map(x => if (x < 0L) 0L else 1L)).toArray + + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) + + val tokenTensors = OnnxTensor.createTensor(env, inputIds) + val maskTensors = OnnxTensor.createTensor(env, attentionMask) + val segmentTensors = + OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray) + val inputs = + Map( + "input_ids" -> tokenTensors, + "attention_mask" -> maskTensors, + "token_type_ids" -> segmentTensors).asJava + + // TODO: A try without a catch or finally is equivalent to putting its body in a block; no exceptions are handled. + try { + val results = runner.run(inputs) + val lastHiddenState = results.get("last_hidden_state").get() + val info = lastHiddenState.getInfo.asInstanceOf[TensorInfo] + val shape = info.getShape + try { + val embeddings = lastHiddenState + .asInstanceOf[OnnxTensor] + .getFloatBuffer + .array() + tokenTensors.close() + maskTensors.close() + segmentTensors.close() + + val dim = shape.last.toInt + // Perfom CLS pooling (the first element of each sequence) + val clsPooling = embeddings.grouped(dim).map(_.head).toArray + val normalizedSentenceEmbeddings = LinAlg.lpNormalizeArray(clsPooling, 2) + + Array(normalizedSentenceEmbeddings) + } finally if (results != null) results.close() + } + } + + /** Predict sentence embeddings for a batch of sentences + * @param sentences + * sentences + * @param tokenizedSentences + * tokenized sentences + * @param batchSize + * batch size + * @param maxSentenceLength + * max sentence length + * @return + */ + def predict( + sentences: Seq[Annotation], + tokenizedSentences: Seq[WordpieceTokenizedSentence], + batchSize: Int, + maxSentenceLength: Int): Seq[Annotation] = { + + tokenizedSentences + .zip(sentences) + .zipWithIndex + .grouped(batchSize) + .toArray + .flatMap { batch => + val tokensBatch = batch.map(x => x._1._1.tokens) + val tokens = tokensBatch.map(x => + Array(sentenceStartTokenId) ++ x + .map(y => y.pieceId) + .take(maxSentenceLength - 2) ++ Array(sentenceEndTokenId)) + + val sentenceEmbeddings = getSentenceEmbedding(tokens) + + batch.zip(sentenceEmbeddings).map { case (sentence, vectors) => + Annotation( + annotatorType = AnnotatorType.SENTENCE_EMBEDDINGS, + begin = sentence._1._2.begin, + end = sentence._1._2.end, + result = sentence._1._2.result, + metadata = sentence._1._2.metadata, + embeddings = vectors) + } + } + } + +} diff --git a/src/main/scala/com/johnsnowlabs/ml/util/LinAlg.scala b/src/main/scala/com/johnsnowlabs/ml/util/LinAlg.scala index 266bc6a69a46aa..cf23c78a83427a 100644 --- a/src/main/scala/com/johnsnowlabs/ml/util/LinAlg.scala +++ b/src/main/scala/com/johnsnowlabs/ml/util/LinAlg.scala @@ -1,7 +1,7 @@ package com.johnsnowlabs.ml.util import breeze.linalg.{DenseMatrix, tile} -import scala.math.sqrt +import scala.math.{sqrt, pow} object LinAlg { @@ -130,4 +130,10 @@ object LinAlg { array.map(value => if (l2Norm != 0.0f) value / l2Norm else 0.0f) } + def lpNormalizeArray(array: Array[Float], p: Int = 2): Array[Float] = { + val lpNorm: Float = pow(array.map(x => pow(x, p)).sum, 1.0 / p).toFloat + // Normalize each element in the array + array.map(value => if (lpNorm != 0.0f) value / lpNorm else 0.0f) + } + } diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddings.scala new file mode 100644 index 00000000000000..bfc76b7efa396c --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddings.scala @@ -0,0 +1,482 @@ +/* + * Copyright 2017-2022 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.embeddings + +import com.johnsnowlabs.ml.ai.BGE +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} +import com.johnsnowlabs.ml.tensorflow._ +import com.johnsnowlabs.ml.util.LoadExternalModel.{ + loadTextAsset, + modelSanityCheck, + notSupportedEngineError +} +import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} +import com.johnsnowlabs.nlp._ +import com.johnsnowlabs.nlp.annotators.common._ +import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder} +import com.johnsnowlabs.nlp.serialization.MapFeature +import com.johnsnowlabs.storage.HasStorageRef +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.slf4j.{Logger, LoggerFactory} + +/** Sentence embeddings using BGE. + * + * BGE, or BAAI General Embeddings, a model that can map any text to a low-dimensional dense + * vector which can be used for tasks like retrieval, classification, clustering, or semantic search. + * + * Pretrained models can be loaded with `pretrained` of the companion object: + * {{{ + * val embeddings = BGEEmbeddings.pretrained() + * .setInputCols("document") + * .setOutputCol("embeddings") + * }}} + * The default model is `"bge_base"`, if no name is provided. + * + * For available pretrained models please see the + * [[https://sparknlp.org/models?q=BGE Models Hub]]. + * + * For extended examples of usage, see + * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddingsTestSpec.scala BGEEmbeddingsTestSpec]]. + * + * '''Sources''' : + * + * [[https://arxiv.org/pdf/2309.07597 C-Pack: Packaged Resources To Advance General Chinese Embedding]] + * + * [[https://github.com/FlagOpen/FlagEmbedding BGE Github Repository]] + * + * ''' Paper abstract ''' + * + * ''We introduce C-Pack, a package of resources that significantly advance the field of general + * Chinese embeddings. C-Pack includes three critical resources. + * 1) C-MTEB is a comprehensive benchmark for Chinese text embeddings covering 6 tasks and 35 datasets. + * 2) C-MTP is a massive text embedding dataset curated from labeled and unlabeled Chinese corpora + * for training embedding models. + * 3) C-TEM is a family of embedding models covering multiple sizes. + * Our models outperform all prior Chinese text embeddings on C-MTEB by up to +10% upon the + * time of the release. We also integrate and optimize the entire suite of training methods for + * C-TEM. Along with our resources on general Chinese embedding, we release our data and models for + * English text embeddings. The English models achieve stateof-the-art performance on the MTEB + * benchmark; meanwhile, our released English data is 2 times larger than the Chinese data. All + * these resources are made publicly available at https://github.com/FlagOpen/FlagEmbedding.'' + * + * ==Example== + * {{{ + * import spark.implicits._ + * import com.johnsnowlabs.nlp.base.DocumentAssembler + * import com.johnsnowlabs.nlp.annotators.Tokenizer + * import com.johnsnowlabs.nlp.embeddings.BGEEmbeddings + * import com.johnsnowlabs.nlp.EmbeddingsFinisher + * import org.apache.spark.ml.Pipeline + * + * val documentAssembler = new DocumentAssembler() + * .setInputCol("text") + * .setOutputCol("document") + * + * val embeddings = BGEEmbeddings.pretrained("e5_small", "en") + * .setInputCols("document") + * .setOutputCol("bge_embeddings") + * + * val embeddingsFinisher = new EmbeddingsFinisher() + * .setInputCols("bge_embeddings") + * .setOutputCols("finished_embeddings") + * .setOutputAsVector(true) + * + * val pipeline = new Pipeline().setStages(Array( + * documentAssembler, + * embeddings, + * embeddingsFinisher + * )) + * + * val data = Seq("query: how much protein should a female eat", + * "passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day." + + * But, as you can see from this chart, you'll need to increase that if you're expecting or training for a" + + * marathon. Check out the chart below to see how much protein you should be eating each day." + * + * ).toDF("text") + * val result = pipeline.fit(data).transform(data) + * + * result.selectExpr("explode(finished_embeddings) as result").show(1, 80) + * +--------------------------------------------------------------------------------+ + * | result| + * +--------------------------------------------------------------------------------+ + * |[[8.0190285E-4, -0.005974853, -0.072875895, 0.007944068, 0.026059335, -0.0080...| + * [[0.050514214, 0.010061974, -0.04340176, -0.020937217, 0.05170225, 0.01157857...| + * +--------------------------------------------------------------------------------+ + * }}} + * + * @see + * [[https://sparknlp.org/docs/en/annotators Annotators Main Page]] for a list of transformer + * based embeddings + * @param uid + * required uid for storing annotator to disk + * @groupname anno Annotator types + * @groupdesc anno + * Required input and expected output annotator types + * @groupname Ungrouped Members + * @groupname param Parameters + * @groupname setParam Parameter setters + * @groupname getParam Parameter getters + * @groupname Ungrouped Members + * @groupprio param 1 + * @groupprio anno 2 + * @groupprio Ungrouped 3 + * @groupprio setParam 4 + * @groupprio getParam 5 + * @groupdesc param + * A list of (hyper-)parameter keys this annotator can take. Users can set and get the + * parameter values through setters and getters, respectively. + */ +class BGEEmbeddings(override val uid: String) + extends AnnotatorModel[E5Embeddings] + with HasBatchedAnnotate[E5Embeddings] + with WriteTensorflowModel + with WriteOnnxModel + with HasEmbeddingsProperties + with HasStorageRef + with HasCaseSensitiveProperties + with HasEngine { + + /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator + * type + */ + override val inputAnnotatorTypes: Array[String] = + Array(AnnotatorType.DOCUMENT) + override val outputAnnotatorType: AnnotatorType = AnnotatorType.SENTENCE_EMBEDDINGS + + /** ConfigProto from tensorflow, serialized into byte array. Get with + * `config_proto.SerializeToString()` + * + * @group param + */ + val configProtoBytes = new IntArrayParam( + this, + "configProtoBytes", + "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()") + + /** Max sentence length to process (Default: `128`) + * + * @group param + */ + val maxSentenceLength = + new IntParam(this, "maxSentenceLength", "Max sentence length to process") + + def sentenceStartTokenId: Int = { + $$(vocabulary)("[CLS]") + } + + /** @group setParam */ + def sentenceEndTokenId: Int = { + $$(vocabulary)("[SEP]") + } + + /** Vocabulary used to encode the words to ids with WordPieceEncoder + * + * @group param + */ + val vocabulary: MapFeature[String, Int] = new MapFeature(this, "vocabulary").setProtected() + + /** @group setParam */ + def setVocabulary(value: Map[String, Int]): this.type = set(vocabulary, value) + + /** It contains TF model signatures for the laded saved model + * + * @group param + */ + val signatures = + new MapFeature[String, String](model = this, name = "signatures").setProtected() + private var _model: Option[Broadcast[E5]] = None + + def this() = this(Identifiable.randomUID("BGE_EMBEDDINGS")) + + /** @group setParam */ + def setConfigProtoBytes(bytes: Array[Int]): E5Embeddings.this.type = + set(this.configProtoBytes, bytes) + + /** @group setParam */ + def setMaxSentenceLength(value: Int): this.type = { + require( + value <= 512, + "BGE models do not support sequences longer than 512 because of trainable positional embeddings.") + require(value >= 1, "The maxSentenceLength must be at least 1") + set(maxSentenceLength, value) + this + } + + /** @group getParam */ + def getMaxSentenceLength: Int = $(maxSentenceLength) + + /** @group setParam */ + def setSignatures(value: Map[String, String]): this.type = { + if (get(signatures).isEmpty) + set(signatures, value) + this + } + + /** @group setParam */ + def setModelIfNotSet( + spark: SparkSession, + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper]): E5Embeddings = { + if (_model.isEmpty) { + _model = Some( + spark.sparkContext.broadcast( + new E5( + tensorflowWrapper, + onnxWrapper, + configProtoBytes = getConfigProtoBytes, + sentenceStartTokenId = sentenceStartTokenId, + sentenceEndTokenId = sentenceEndTokenId, + signatures = getSignatures))) + } + + this + } + + /** Set Embeddings dimensions for the BERT model Only possible to set this when the first time + * is saved dimension is not changeable, it comes from BERT config file + * + * @group setParam + */ + override def setDimension(value: Int): this.type = { + if (get(dimension).isEmpty) + set(this.dimension, value) + this + } + + /** Whether to lowercase tokens or not + * + * @group setParam + */ + override def setCaseSensitive(value: Boolean): this.type = { + if (get(caseSensitive).isEmpty) + set(this.caseSensitive, value) + this + } + + setDefault(dimension -> 768, batchSize -> 8, maxSentenceLength -> 512, caseSensitive -> false) + + def tokenize(sentences: Seq[Annotation]): Seq[WordpieceTokenizedSentence] = { + val basicTokenizer = new BasicTokenizer($(caseSensitive)) + val encoder = new WordpieceEncoder($$(vocabulary)) + sentences.map { s => + val sent = Sentence( + content = s.result, + start = s.begin, + end = s.end, + metadata = Some(s.metadata), + index = s.begin) + val tokens = basicTokenizer.tokenize(sent) + val wordpieceTokens = tokens.flatMap(token => encoder.encode(token)) + WordpieceTokenizedSentence(wordpieceTokens) + } + } + + /** takes a document and annotations and produces new annotations of this annotator's annotation + * type + * + * @param batchedAnnotations + * Annotations that correspond to inputAnnotationCols generated by previous annotators if any + * @return + * any number of annotations processed for every input annotation. Not necessary one to one + * relationship + */ + override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = { + + val allAnnotations = batchedAnnotations + .filter(_.nonEmpty) + .zipWithIndex + .flatMap { case (annotations, i) => + annotations.filter(_.result.nonEmpty).map(x => (x, i)) + } + + // Tokenize sentences + val tokenizedSentences = tokenize(allAnnotations.map(_._1)) + val processedAnnotations = if (allAnnotations.nonEmpty) { + this.getModelIfNotSet.predict( + sentences = allAnnotations.map(_._1), + tokenizedSentences = tokenizedSentences, + batchSize = $(batchSize), + maxSentenceLength = $(maxSentenceLength)) + } else { + Seq() + } + + // Group resulting annotations by rows. If there are not sentences in a given row, return empty sequence + batchedAnnotations.indices.map(rowIndex => { + val rowAnnotations = processedAnnotations + // zip each annotation with its corresponding row index + .zip(allAnnotations) + // select the sentences belonging to the current row + .filter(_._2._2 == rowIndex) + // leave the annotation only + .map(_._1) + + if (rowAnnotations.nonEmpty) + rowAnnotations + else + Seq.empty[Annotation] + }) + + } + + /** @group getParam */ + def getModelIfNotSet: BGE = _model.get.value + + override def onWrite(path: String, spark: SparkSession): Unit = { + super.onWrite(path, spark) + val suffix = "_bge" + + getEngine match { + case TensorFlow.name => + writeTensorflowModelV2( + path, + spark, + getModelIfNotSet.tensorflowWrapper.get, + suffix, + BGEEmbeddings.tfFile, + configProtoBytes = getConfigProtoBytes, + savedSignatures = getSignatures) + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + suffix, + BGEEmbeddings.onnxFile) + + case _ => + throw new Exception(notSupportedEngineError) + } + } + + /** @group getParam */ + def getConfigProtoBytes: Option[Array[Byte]] = get(this.configProtoBytes).map(_.map(_.toByte)) + + /** @group getParam */ + def getSignatures: Option[Map[String, String]] = get(this.signatures) + + override protected def afterAnnotate(dataset: DataFrame): DataFrame = { + dataset.withColumn( + getOutputCol, + wrapSentenceEmbeddingsMetadata( + dataset.col(getOutputCol), + $(dimension), + Some($(storageRef)))) + } + +} + +trait ReadablePretrainedBGEModel + extends ParamsAndFeaturesReadable[BGEEmbeddings] + with HasPretrained[BGEEmbeddings] { + override val defaultModelName: Some[String] = Some("bge_base") + + /** Java compliant-overrides */ + override def pretrained(): BGEEmbeddings = super.pretrained() + + override def pretrained(name: String): BGEEmbeddings = super.pretrained(name) + + override def pretrained(name: String, lang: String): BGEEmbeddings = + super.pretrained(name, lang) + + override def pretrained(name: String, lang: String, remoteLoc: String): BGEEmbeddings = + super.pretrained(name, lang, remoteLoc) +} + +trait ReadBGEDLModel extends ReadTensorflowModel with ReadOnnxModel { + this: ParamsAndFeaturesReadable[BGEEmbeddings] => + + override val tfFile: String = "bge_tensorflow" + override val onnxFile: String = "bge_onnx" + + def readModel(instance: BGEEmbeddings, path: String, spark: SparkSession): Unit = { + + instance.getEngine match { + case TensorFlow.name => + val tfWrapper = readTensorflowModel(path, spark, "_bge_tf", initAllTables = false) + instance.setModelIfNotSet(spark, Some(tfWrapper), None) + + case ONNX.name => + val onnxWrapper = + readOnnxModel(path, spark, "_bge_onnx", zipped = true, useBundle = false, None) + instance.setModelIfNotSet(spark, None, Some(onnxWrapper)) + + case _ => + throw new Exception(notSupportedEngineError) + } + + } + + addReader(readModel) + + def loadSavedModel(modelPath: String, spark: SparkSession): BGEEmbeddings = { + + val (localModelPath, detectedEngine) = modelSanityCheck(modelPath) + + val vocabs = loadTextAsset(localModelPath, "vocab.txt").zipWithIndex.toMap + + /*Universal parameters for all engines*/ + val annotatorModel = new BGEEmbeddings() + .setVocabulary(vocabs) + + annotatorModel.set(annotatorModel.engine, detectedEngine) + + detectedEngine match { + case TensorFlow.name => + val (wrapper, signatures) = + TensorflowWrapper.read( + localModelPath, + zipped = false, + useBundle = true, + tags = Array("serve"), + initAllTables = false) + + val _signatures = signatures match { + case Some(s) => s + case None => throw new Exception("Cannot load signature definitions from model!") + } + + /** the order of setSignatures is important if we use getSignatures inside + * setModelIfNotSet + */ + annotatorModel + .setSignatures(_signatures) + .setModelIfNotSet(spark, Some(wrapper), None) + + case ONNX.name => + val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + annotatorModel + .setModelIfNotSet(spark, None, Some(onnxWrapper)) + + case _ => + throw new Exception(notSupportedEngineError) + } + + annotatorModel + } +} + +/** This is the companion object of [[BGEEmbeddings]]. Please refer to that class for the + * documentation. + */ +object BGEEmbeddings extends ReadablePretrainedBGEModel with ReadBGEDLModel { + private[BGEEmbeddings] val logger: Logger = + LoggerFactory.getLogger("BGEEmbeddings") +} diff --git a/src/test/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddingsTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddingsTestSpec.scala new file mode 100644 index 00000000000000..77e92795b8bacd --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddingsTestSpec.scala @@ -0,0 +1,116 @@ +/* + * Copyright 2017-2022 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.embeddings + +import com.johnsnowlabs.nlp.annotators.sentence_detector_dl.SentenceDetectorDLModel +import com.johnsnowlabs.nlp.base.DocumentAssembler +import com.johnsnowlabs.nlp.util.io.ResourceHelper +import com.johnsnowlabs.tags.SlowTest +import org.apache.spark.ml.Pipeline +import org.apache.spark.sql.functions.{col, size} +import org.scalatest.flatspec.AnyFlatSpec + +class BGEEmbeddingsTestSpec extends AnyFlatSpec { + + "BGE Embeddings" should "correctly embed multiple sentences" taggedAs SlowTest in { + + import ResourceHelper.spark.implicits._ + + val ddd = Seq( + "query: how much protein should a female eat", + "query: summit define", + "passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 " + + "grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or" + + " training for a marathon. Check out the chart below to see how much protein you should be eating each day.", + "passage: Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of" + + " a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more" + + " governments.") + .toDF("text") + + val document = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + + val embeddings = BGEEmbeddings + .pretrained() + .setInputCols(Array("document")) + .setOutputCol("bge") + + val pipeline = new Pipeline().setStages(Array(document, embeddings)) + + val pipelineDF = pipeline.fit(ddd).transform(ddd) + pipelineDF.select("bge.embeddings").show(truncate = false) + + } + + it should "have embeddings of the same size" taggedAs SlowTest in { + import ResourceHelper.spark.implicits._ + val testDf = Seq( + "I like apples", + "I like bananas \\n and other things \\n like icream \\n and cats", + "I like rockets") + .toDF("text") + + val document = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + + val embeddings = BGEEmbeddings + .pretrained() + .setInputCols(Array("document")) + .setOutputCol("bge") + + val pipeline = new Pipeline().setStages(Array(document, embeddings)) + + val pipelineDF = pipeline.fit(testDf).transform(testDf) + + val embeddingsDF = pipelineDF.withColumn("embeddings", col("bge.embeddings").getItem(0)) + + val sizesArray: Array[Int] = embeddingsDF + .select(size(col("embeddings")).as("size")) + .collect() + .map(row => row.getAs[Int]("size")) + + assert(sizesArray.forall(_ == sizesArray.head)) + } + + it should "work with sentences" taggedAs SlowTest in { + import ResourceHelper.spark.implicits._ + val testData = "I really enjoy my job. This is amazing" + val testDf = Seq(testData).toDF("text") + + val document = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("document") + + val sentenceDetectorDL = SentenceDetectorDLModel + .pretrained("sentence_detector_dl", "en") + .setInputCols(Array("document")) + .setOutputCol("sentences") + + val embeddings = BGEEmbeddings + .pretrained() + .setInputCols(Array("sentences")) + .setOutputCol("bge") + + val pipeline = new Pipeline().setStages(Array(document, sentenceDetectorDL, embeddings)) + + val pipelineDF = pipeline.fit(testDf).transform(testDf) + pipelineDF.select("bge.embeddings").show(false) + } + +} From 1750cadcd105a7b4bd9c814d3440bbab4eb49375 Mon Sep 17 00:00:00 2001 From: David Cecchini Date: Tue, 12 Dec 2023 15:07:49 -0300 Subject: [PATCH 2/2] Fixed class names --- .../annotator/embeddings/bge_embeddings.py | 2 +- .../nlp/embeddings/BGEEmbeddings.scala | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/sparknlp/annotator/embeddings/bge_embeddings.py b/python/sparknlp/annotator/embeddings/bge_embeddings.py index 38b39c3132e274..0c0428141a3ec7 100644 --- a/python/sparknlp/annotator/embeddings/bge_embeddings.py +++ b/python/sparknlp/annotator/embeddings/bge_embeddings.py @@ -31,7 +31,7 @@ class BGEEmbeddings(AnnotatorModel, >>> embeddings = BGEEmbeddings.pretrained() \\ ... .setInputCols(["document"]) \\ - ... .setOutputCol("e5_embeddings") + ... .setOutputCol("bge_embeddings") The default model is ``"bge_base"``, if no name is provided. diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddings.scala index bfc76b7efa396c..8fb3bba9d3ba40 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BGEEmbeddings.scala @@ -89,7 +89,7 @@ import org.slf4j.{Logger, LoggerFactory} * .setInputCol("text") * .setOutputCol("document") * - * val embeddings = BGEEmbeddings.pretrained("e5_small", "en") + * val embeddings = BGEEmbeddings.pretrained("bge_base", "en") * .setInputCols("document") * .setOutputCol("bge_embeddings") * @@ -144,8 +144,8 @@ import org.slf4j.{Logger, LoggerFactory} * parameter values through setters and getters, respectively. */ class BGEEmbeddings(override val uid: String) - extends AnnotatorModel[E5Embeddings] - with HasBatchedAnnotate[E5Embeddings] + extends AnnotatorModel[BGEEmbeddings] + with HasBatchedAnnotate[BGEEmbeddings] with WriteTensorflowModel with WriteOnnxModel with HasEmbeddingsProperties @@ -201,12 +201,12 @@ class BGEEmbeddings(override val uid: String) */ val signatures = new MapFeature[String, String](model = this, name = "signatures").setProtected() - private var _model: Option[Broadcast[E5]] = None + private var _model: Option[Broadcast[BGE]] = None def this() = this(Identifiable.randomUID("BGE_EMBEDDINGS")) /** @group setParam */ - def setConfigProtoBytes(bytes: Array[Int]): E5Embeddings.this.type = + def setConfigProtoBytes(bytes: Array[Int]): BGEEmbeddings.this.type = set(this.configProtoBytes, bytes) /** @group setParam */ @@ -233,11 +233,11 @@ class BGEEmbeddings(override val uid: String) def setModelIfNotSet( spark: SparkSession, tensorflowWrapper: Option[TensorflowWrapper], - onnxWrapper: Option[OnnxWrapper]): E5Embeddings = { + onnxWrapper: Option[OnnxWrapper]): BGEEmbeddings = { if (_model.isEmpty) { _model = Some( spark.sparkContext.broadcast( - new E5( + new BGE( tensorflowWrapper, onnxWrapper, configProtoBytes = getConfigProtoBytes,