Skip to content

Commit

Permalink
Merge pull request #14024 from JohnSnowLabs/feature/SPARNLP-929-930-9…
Browse files Browse the repository at this point in the history
…31-Add-ONNX-support-Roberta

Adding ONNX support for RobertaClassification
  • Loading branch information
maziyarpanahi authored Oct 26, 2023
2 parents 9baff1d + 7f66c16 commit d609470
Show file tree
Hide file tree
Showing 15 changed files with 9,476 additions and 128 deletions.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder}
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.LoadExternalModel.notSupportedEngineError
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.{ActivationFunction, Annotation}
Expand Down Expand Up @@ -342,7 +341,7 @@ private[johnsnowlabs] class AlbertClassification(
tensors.clearSession(outs)
tensors.clearTensors()

(endLogits, startLogits)
(startLogits, endLogits)
}

private def computeLogitsWithOnnx(
Expand Down
177 changes: 123 additions & 54 deletions src/main/scala/com/johnsnowlabs/ml/ai/RoBertaClassification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

package com.johnsnowlabs.ml.ai

import ai.onnxruntime.OnnxTensor
import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.TensorFlow
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.BpeTokenizer
import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder}
Expand All @@ -41,7 +43,8 @@ import scala.collection.JavaConverters._
* TF v2 signatures in Spark NLP
*/
private[johnsnowlabs] class RoBertaClassification(
val tensorflowWrapper: TensorflowWrapper,
val tensorflowWrapper: Option[TensorflowWrapper],
val onnxWrapper: Option[OnnxWrapper],
val sentenceStartTokenId: Int,
val sentenceEndTokenId: Int,
val sentencePadTokenId: Int,
Expand All @@ -56,6 +59,10 @@ private[johnsnowlabs] class RoBertaClassification(

val _tfRoBertaSignatures: Map[String, String] =
signatures.getOrElse(ModelSignatureManager.apply())
val detectedEngine: String =
if (tensorflowWrapper.isDefined) TensorFlow.name
else if (onnxWrapper.isDefined) ONNX.name
else TensorFlow.name

protected val sigmoidThreshold: Float = threshold

Expand Down Expand Up @@ -129,51 +136,13 @@ private[johnsnowlabs] class RoBertaClassification(
}

def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val batchLength = batch.length
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max

val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)

// [nb of encoded sentences , maxSentenceLength]
val shape = Array(batch.length.toLong, maxSentenceLength)

batch.zipWithIndex
.foreach { case (sentence, idx) =>
val offset = idx * maxSentenceLength
tokenBuffers.offset(offset).write(sentence)
maskBuffers
.offset(offset)
.write(sentence.map(x => if (x == sentencePadTokenId) 0 else 1))
}

val runner = tensorflowWrapper
.getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false)
.runner

val tokenTensors = tensors.createIntBufferTensor(shape, tokenBuffers)
val maskTensors = tensors.createIntBufferTensor(shape, maskBuffers)

runner
.feed(
_tfRoBertaSignatures
.getOrElse(ModelSignatureConstants.InputIds.key, "missing_input_id_key"),
tokenTensors)
.feed(
_tfRoBertaSignatures
.getOrElse(ModelSignatureConstants.AttentionMask.key, "missing_input_mask_key"),
maskTensors)
.fetch(_tfRoBertaSignatures
.getOrElse(ModelSignatureConstants.LogitsOutput.key, "missing_logits_key"))

val outs = runner.run().asScala
val rawScores = TensorResources.extractFloats(outs.head)

outs.foreach(_.close())
tensors.clearSession(outs)
tensors.clearTensors()
val rawScores = detectedEngine match {
case ONNX.name => getRowScoresWithOnnx(batch)
case _ => getRawScoresWithTF(batch, maxSentenceLength)
}

val dim = rawScores.length / (batchLength * maxSentenceLength)
val batchScores: Array[Array[Array[Float]]] = rawScores
Expand All @@ -186,10 +155,9 @@ private[johnsnowlabs] class RoBertaClassification(
batchScores
}

def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = {
private def getRawScoresWithTF(batch: Seq[Array[Int]], maxSentenceLength: Int): Array[Float] = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val batchLength = batch.length

val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
Expand All @@ -207,7 +175,7 @@ private[johnsnowlabs] class RoBertaClassification(
.write(sentence.map(x => if (x == sentencePadTokenId) 0 else 1))
}

val session = tensorflowWrapper.getTFSessionWithSignature(
val session = tensorflowWrapper.get.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
savedSignatures = signatures,
initAllTables = false)
Expand Down Expand Up @@ -235,6 +203,50 @@ private[johnsnowlabs] class RoBertaClassification(
tensors.clearSession(outs)
tensors.clearTensors()

rawScores
}

private def getRowScoresWithOnnx(batch: Seq[Array[Int]]): Array[Float] = {

// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val inputs =
Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava

try {
val results = runner.run(inputs)
try {
val embeddings = results
.get("logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
tokenTensors.close()
maskTensors.close()

embeddings
} finally if (results != null) results.close()
}
}

def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = {
val batchLength = batch.length
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max

val rawScores = detectedEngine match {
case ONNX.name => getRowScoresWithOnnx(batch)
case _ => getRawScoresWithTF(batch, maxSentenceLength)
}

val dim = rawScores.length / batchLength
val batchScores: Array[Array[Float]] =
rawScores
Expand Down Expand Up @@ -284,15 +296,14 @@ private[johnsnowlabs] class RoBertaClassification(
.toArray)
}

val session = tensorflowWrapper.getTFSessionWithSignature(
val session = tensorflowWrapper.get.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
savedSignatures = signatures,
initAllTables = false)
val runner = session.runner

val tokenTensors = tensors.createIntBufferTensor(shape, tokenBuffers)
val maskTensors = tensors.createIntBufferTensor(shape, maskBuffers)
val segmentTensors = tensors.createIntBufferTensor(shape, segmentBuffers)

runner
.feed(
Expand Down Expand Up @@ -321,10 +332,29 @@ private[johnsnowlabs] class RoBertaClassification(
}

def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = {
val tensors = new TensorResources()

val batchLength = batch.length
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val (startLogits, endLogits) = detectedEngine match {
case ONNX.name => computeLogitsWithOnnx(batch)
case _ => computeLogitsWithTF(batch, maxSentenceLength)
}

val endDim = endLogits.length / batchLength
val endScores: Array[Array[Float]] =
endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray

val startDim = startLogits.length / batchLength
val startScores: Array[Array[Float]] =
startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray

(startScores, endScores)
}

private def computeLogitsWithTF(
batch: Seq[Array[Int]],
maxSentenceLength: Int): (Array[Float], Array[Float]) = {
val batchLength = batch.length
val tensors = new TensorResources()

val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
Expand All @@ -341,7 +371,7 @@ private[johnsnowlabs] class RoBertaClassification(
.write(sentence.map(x => if (x == sentencePadTokenId) 0 else 1))
}

val session = tensorflowWrapper.getTFSessionWithSignature(
val session = tensorflowWrapper.get.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
savedSignatures = signatures,
initAllTables = false)
Expand Down Expand Up @@ -371,7 +401,7 @@ private[johnsnowlabs] class RoBertaClassification(
outs.foreach(_.close())
tensors.clearSession(outs)
tensors.clearTensors()

val endDim = endLogits.length / batchLength
val endScores: Array[Array[Float]] = endLogits.grouped(endDim).toArray

Expand All @@ -381,6 +411,45 @@ private[johnsnowlabs] class RoBertaClassification(
(startScores, endScores)
}

private def computeLogitsWithOnnx(batch: Seq[Array[Int]]): (Array[Float], Array[Float]) = {
// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val inputs =
Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava

try {
val output = runner.run(inputs)
try {
val startLogits = output
.get("start_logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()

val endLogits = output
.get("end_logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()

tokenTensors.close()
maskTensors.close()

(startLogits.slice(1, startLogits.length), endLogits.slice(1, endLogits.length))
} finally if (output != null) output.close()
}
}

def findIndexedToken(
tokenizedSentences: Seq[TokenizedSentence],
sentence: (WordpieceTokenizedSentence, Int),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

package com.johnsnowlabs.ml.ai

import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow.TensorflowWrapper
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}

private[johnsnowlabs] class ZeroShotNerClassification(
override val tensorflowWrapper: TensorflowWrapper,
override val tensorflowWrapper: Option[TensorflowWrapper],
override val onnxWrapper: Option[OnnxWrapper],
override val sentenceStartTokenId: Int,
override val sentenceEndTokenId: Int,
override val sentencePadTokenId: Int,
Expand All @@ -32,6 +34,7 @@ private[johnsnowlabs] class ZeroShotNerClassification(
vocabulary: Map[String, Int])
extends RoBertaClassification(
tensorflowWrapper,
onnxWrapper,
sentenceStartTokenId,
sentenceEndTokenId,
sentencePadTokenId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{
}
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.embeddings.BertEmbeddings
import com.johnsnowlabs.nlp.serialization.MapFeature
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.param.{IntArrayParam, IntParam}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.johnsnowlabs.nlp.annotators.classifier.dl

import com.johnsnowlabs.ml.ai.{MergeTokenStrategy, RoBertaClassification}
import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow._
import com.johnsnowlabs.ml.util.LoadExternalModel.{
loadTextAsset,
Expand Down Expand Up @@ -220,12 +221,14 @@ class LongformerForQuestionAnswering(override val uid: String)
/** @group setParam */
def setModelIfNotSet(
spark: SparkSession,
tensorflowWrapper: TensorflowWrapper): LongformerForQuestionAnswering = {
tensorflowWrapper: Option[TensorflowWrapper],
onnxWrapper: Option[OnnxWrapper]): LongformerForQuestionAnswering = {
if (_model.isEmpty) {
_model = Some(
spark.sparkContext.broadcast(
new RoBertaClassification(
tensorflowWrapper,
onnxWrapper,
sentenceStartTokenId,
sentenceEndTokenId,
padTokenId,
Expand Down Expand Up @@ -282,7 +285,7 @@ class LongformerForQuestionAnswering(override val uid: String)
writeTensorflowModelV2(
path,
spark,
getModelIfNotSet.tensorflowWrapper,
getModelIfNotSet.tensorflowWrapper.get,
"_longformer_classification",
LongformerForQuestionAnswering.tfFile,
configProtoBytes = getConfigProtoBytes)
Expand Down Expand Up @@ -321,9 +324,9 @@ trait ReadLongformerForQuestionAnsweringDLModel extends ReadTensorflowModel {
path: String,
spark: SparkSession): Unit = {

val tf =
val tfWrapper =
readTensorflowModel(path, spark, "_longformer_classification_tf", initAllTables = false)
instance.setModelIfNotSet(spark, tf)
instance.setModelIfNotSet(spark, Some(tfWrapper), None)
}

addReader(readModel)
Expand All @@ -350,7 +353,7 @@ trait ReadLongformerForQuestionAnsweringDLModel extends ReadTensorflowModel {

detectedEngine match {
case TensorFlow.name =>
val (wrapper, signatures) =
val (tfWrapper, signatures) =
TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true)

val _signatures = signatures match {
Expand All @@ -363,7 +366,7 @@ trait ReadLongformerForQuestionAnsweringDLModel extends ReadTensorflowModel {
*/
annotatorModel
.setSignatures(_signatures)
.setModelIfNotSet(spark, wrapper)
.setModelIfNotSet(spark, Some(tfWrapper), None)

case _ =>
throw new Exception(notSupportedEngineError)
Expand Down
Loading

0 comments on commit d609470

Please sign in to comment.