Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparknlp 888 Add ONNX support to MPNet embeddings #13955

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 68 additions & 5 deletions src/main/scala/com/johnsnowlabs/ml/ai/MPNet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@

package com.johnsnowlabs.ml.ai

import ai.onnxruntime.OnnxTensor
import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}

import scala.collection.JavaConverters._

/** MPNET Sentence embeddings model
*
* @param tensorflow
* @param tensorflowWrapper
* tensorflow wrapper
* @param configProtoBytes
* config proto bytes
Expand All @@ -37,7 +40,8 @@ import scala.collection.JavaConverters._
* signatures
*/
private[johnsnowlabs] class MPNet(
val tensorflow: TensorflowWrapper,
val tensorflowWrapper: Option[TensorflowWrapper],
val onnxWrapper: Option[OnnxWrapper],
configProtoBytes: Option[Array[Byte]] = None,
sentenceStartTokenId: Int,
sentenceEndTokenId: Int,
Expand All @@ -47,8 +51,11 @@ private[johnsnowlabs] class MPNet(
private val _tfInstructorSignatures: Map[String, String] =
signatures.getOrElse(ModelSignatureManager.apply())
private val paddingTokenId = 1
private val bosTokenId = 0
private val eosTokenId = 2

val detectedEngine: String =
if (tensorflowWrapper.isDefined) TensorFlow.name
else if (onnxWrapper.isDefined) ONNX.name
else TensorFlow.name

/** Get sentence embeddings for a batch of sentences
* @param batch
Expand All @@ -57,6 +64,22 @@ private[johnsnowlabs] class MPNet(
* sentence embeddings
*/
private def getSentenceEmbedding(batch: Seq[Array[Int]]): Array[Array[Float]] = {
val embeddings = detectedEngine match {
case ONNX.name =>
getSentenceEmbeddingFromOnnx(batch)
case _ =>
getSentenceEmbeddingFromTF(batch)
}
embeddings
}

/** Get sentence embeddings for a batch of sentences
* @param batch
* batch of sentences
* @return
* sentence embeddings
*/
private def getSentenceEmbeddingFromTF(batch: Seq[Array[Int]]): Array[Array[Float]] = {
// get max sentence length
val sequencesLength = batch.map(x => x.length).toArray
val maxSentenceLength = sequencesLength.max
Expand Down Expand Up @@ -92,7 +115,7 @@ private[johnsnowlabs] class MPNet(
tensorEncoder.createIntBufferTensor(shape, encoderAttentionMaskBuffers)

// run model
val runner = tensorflow
val runner = tensorflowWrapper.get
.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
initAllTables = false,
Expand Down Expand Up @@ -131,6 +154,46 @@ private[johnsnowlabs] class MPNet(
sentenceEmbeddingsFloatsArray
}

private def getSentenceEmbeddingFromOnnx(batch: Seq[Array[Int]]): Array[Array[Float]] = {
val batchLength = batch.length
val maxSentenceLength = batch.map(pieceIds => pieceIds.length).max

val (runner, env) = onnxWrapper.get.getSession()
val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val segmentTensors =
OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray)

val inputs =
Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava

// TODO: A try without a catch or finally is equivalent to putting its body in a block; no exceptions are handled.
try {
val results = runner.run(inputs)
try {
val embeddings = results
.get("last_hidden_state")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
tokenTensors.close()
maskTensors.close()
segmentTensors.close()

val dim = embeddings.length / batchLength
// group embeddings
val sentenceEmbeddingsFloatsArray = embeddings.grouped(dim).toArray
sentenceEmbeddingsFloatsArray
} finally if (results != null) results.close()
}
}

/** Predict sentence embeddings for a batch of sentences
* @param sentences
* sentences
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
package com.johnsnowlabs.nlp.embeddings

import com.johnsnowlabs.ml.ai.MPNet
import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel}
import com.johnsnowlabs.ml.tensorflow._
import com.johnsnowlabs.ml.util.LoadExternalModel.{
loadTextAsset,
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.util.TensorFlow
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder}
Expand Down Expand Up @@ -145,6 +146,7 @@ class MPNetEmbeddings(override val uid: String)
extends AnnotatorModel[MPNetEmbeddings]
with HasBatchedAnnotate[MPNetEmbeddings]
with WriteTensorflowModel
with WriteOnnxModel
with HasEmbeddingsProperties
with HasStorageRef
with HasCaseSensitiveProperties
Expand Down Expand Up @@ -229,12 +231,14 @@ class MPNetEmbeddings(override val uid: String)
/** @group setParam */
def setModelIfNotSet(
spark: SparkSession,
tensorflowWrapper: TensorflowWrapper): MPNetEmbeddings = {
tensorflowWrapper: Option[TensorflowWrapper],
onnxWrapper: Option[OnnxWrapper]): MPNetEmbeddings = {
if (_model.isEmpty) {
_model = Some(
spark.sparkContext.broadcast(
new MPNet(
tensorflowWrapper,
onnxWrapper,
configProtoBytes = getConfigProtoBytes,
sentenceStartTokenId = sentenceStartTokenId,
sentenceEndTokenId = sentenceEndTokenId,
Expand Down Expand Up @@ -336,14 +340,29 @@ class MPNetEmbeddings(override val uid: String)

override def onWrite(path: String, spark: SparkSession): Unit = {
super.onWrite(path, spark)
writeTensorflowModelV2(
path,
spark,
getModelIfNotSet.tensorflow,
"_mpnet",
MPNetEmbeddings.tfFile,
configProtoBytes = getConfigProtoBytes,
savedSignatures = getSignatures)
val suffix = "_mpnet"

getEngine match {
case TensorFlow.name =>
writeTensorflowModelV2(
path,
spark,
getModelIfNotSet.tensorflowWrapper.get,
suffix,
MPNetEmbeddings.tfFile,
configProtoBytes = getConfigProtoBytes,
savedSignatures = getSignatures)
case ONNX.name =>
writeOnnxModel(
path,
spark,
getModelIfNotSet.onnxWrapper.get,
suffix,
MPNetEmbeddings.onnxFile)

case _ =>
throw new Exception(notSupportedEngineError)
}
}

/** @group getParam */
Expand All @@ -366,7 +385,7 @@ class MPNetEmbeddings(override val uid: String)
trait ReadablePretrainedMPNetModel
extends ParamsAndFeaturesReadable[MPNetEmbeddings]
with HasPretrained[MPNetEmbeddings] {
override val defaultModelName: Some[String] = Some("mpnet_small")
override val defaultModelName: Some[String] = Some("all_mpnet_base_v2")

/** Java compliant-overrides */
override def pretrained(): MPNetEmbeddings = super.pretrained()
Expand All @@ -380,19 +399,26 @@ trait ReadablePretrainedMPNetModel
super.pretrained(name, lang, remoteLoc)
}

trait ReadMPNetDLModel extends ReadTensorflowModel {
trait ReadMPNetDLModel extends ReadTensorflowModel with ReadOnnxModel {
this: ParamsAndFeaturesReadable[MPNetEmbeddings] =>

override val tfFile: String = "mpnet_tensorflow"
override val onnxFile: String = "mpnet_onnx"
def readModel(instance: MPNetEmbeddings, path: String, spark: SparkSession): Unit = {

val tf = readTensorflowModel(
path,
spark,
"_mpnet_tf",
savedSignatures = instance.getSignatures,
initAllTables = false)
instance.setModelIfNotSet(spark, tf)
instance.getEngine match {
case TensorFlow.name =>
val tfWrapper = readTensorflowModel(path, spark, "_mpnet_tf", initAllTables = false)
instance.setModelIfNotSet(spark, Some(tfWrapper), None)

case ONNX.name =>
val onnxWrapper =
readOnnxModel(path, spark, "_mpnet_onnx", zipped = true, useBundle = false, None)
instance.setModelIfNotSet(spark, None, Some(onnxWrapper))

case _ =>
throw new Exception(notSupportedEngineError)
}
}

addReader(readModel)
Expand Down Expand Up @@ -424,7 +450,12 @@ trait ReadMPNetDLModel extends ReadTensorflowModel {
*/
annotatorModel
.setSignatures(_signatures)
.setModelIfNotSet(spark, wrapper)
.setModelIfNotSet(spark, Some(wrapper), None)

case ONNX.name =>
val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true)
annotatorModel
.setModelIfNotSet(spark, None, Some(onnxWrapper))

case _ =>
throw new Exception(notSupportedEngineError)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.scalatest.flatspec.AnyFlatSpec

class MPNetEmbeddingsTestSpec extends AnyFlatSpec {

"E5 Embeddings" should "correctly embed multiple sentences" taggedAs SlowTest in {
"Mpnet Embeddings" should "correctly embed multiple sentences" taggedAs SlowTest in {

import ResourceHelper.spark.implicits._

Expand All @@ -38,12 +38,13 @@ class MPNetEmbeddingsTestSpec extends AnyFlatSpec {
val embeddings = MPNetEmbeddings
.pretrained()
.setInputCols(Array("document"))
.setOutputCol("e5")
.setOutputCol("mpnet")

val pipeline = new Pipeline().setStages(Array(document, embeddings))

val pipelineDF = pipeline.fit(ddd).transform(ddd)
pipelineDF.select("e5.embeddings").show(truncate = false)
pipelineDF.select("mpnet.embeddings").show(truncate = false)

}

}