Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ONNX support for RobertaClassification #14024

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder}
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.LoadExternalModel.notSupportedEngineError
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.{ActivationFunction, Annotation}
Expand Down Expand Up @@ -342,7 +341,7 @@ private[johnsnowlabs] class AlbertClassification(
tensors.clearSession(outs)
tensors.clearTensors()

(endLogits, startLogits)
(startLogits, endLogits)
}

private def computeLogitsWithOnnx(
Expand Down
177 changes: 123 additions & 54 deletions src/main/scala/com/johnsnowlabs/ml/ai/RoBertaClassification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

package com.johnsnowlabs.ml.ai

import ai.onnxruntime.OnnxTensor
import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.TensorFlow
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.BpeTokenizer
import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder}
Expand All @@ -41,7 +43,8 @@ import scala.collection.JavaConverters._
* TF v2 signatures in Spark NLP
*/
private[johnsnowlabs] class RoBertaClassification(
val tensorflowWrapper: TensorflowWrapper,
val tensorflowWrapper: Option[TensorflowWrapper],
val onnxWrapper: Option[OnnxWrapper],
val sentenceStartTokenId: Int,
val sentenceEndTokenId: Int,
val sentencePadTokenId: Int,
Expand All @@ -56,6 +59,10 @@ private[johnsnowlabs] class RoBertaClassification(

val _tfRoBertaSignatures: Map[String, String] =
signatures.getOrElse(ModelSignatureManager.apply())
val detectedEngine: String =
if (tensorflowWrapper.isDefined) TensorFlow.name
else if (onnxWrapper.isDefined) ONNX.name
else TensorFlow.name

protected val sigmoidThreshold: Float = threshold

Expand Down Expand Up @@ -129,51 +136,13 @@ private[johnsnowlabs] class RoBertaClassification(
}

def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val batchLength = batch.length
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max

val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)

// [nb of encoded sentences , maxSentenceLength]
val shape = Array(batch.length.toLong, maxSentenceLength)

batch.zipWithIndex
.foreach { case (sentence, idx) =>
val offset = idx * maxSentenceLength
tokenBuffers.offset(offset).write(sentence)
maskBuffers
.offset(offset)
.write(sentence.map(x => if (x == sentencePadTokenId) 0 else 1))
}

val runner = tensorflowWrapper
.getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false)
.runner

val tokenTensors = tensors.createIntBufferTensor(shape, tokenBuffers)
val maskTensors = tensors.createIntBufferTensor(shape, maskBuffers)

runner
.feed(
_tfRoBertaSignatures
.getOrElse(ModelSignatureConstants.InputIds.key, "missing_input_id_key"),
tokenTensors)
.feed(
_tfRoBertaSignatures
.getOrElse(ModelSignatureConstants.AttentionMask.key, "missing_input_mask_key"),
maskTensors)
.fetch(_tfRoBertaSignatures
.getOrElse(ModelSignatureConstants.LogitsOutput.key, "missing_logits_key"))

val outs = runner.run().asScala
val rawScores = TensorResources.extractFloats(outs.head)

outs.foreach(_.close())
tensors.clearSession(outs)
tensors.clearTensors()
val rawScores = detectedEngine match {
case ONNX.name => getRowScoresWithOnnx(batch)
case _ => getRawScoresWithTF(batch, maxSentenceLength)
}

val dim = rawScores.length / (batchLength * maxSentenceLength)
val batchScores: Array[Array[Array[Float]]] = rawScores
Expand All @@ -186,10 +155,9 @@ private[johnsnowlabs] class RoBertaClassification(
batchScores
}

def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = {
private def getRawScoresWithTF(batch: Seq[Array[Int]], maxSentenceLength: Int): Array[Float] = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val batchLength = batch.length

val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
Expand All @@ -207,7 +175,7 @@ private[johnsnowlabs] class RoBertaClassification(
.write(sentence.map(x => if (x == sentencePadTokenId) 0 else 1))
}

val session = tensorflowWrapper.getTFSessionWithSignature(
val session = tensorflowWrapper.get.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
savedSignatures = signatures,
initAllTables = false)
Expand Down Expand Up @@ -235,6 +203,50 @@ private[johnsnowlabs] class RoBertaClassification(
tensors.clearSession(outs)
tensors.clearTensors()

rawScores
}

private def getRowScoresWithOnnx(batch: Seq[Array[Int]]): Array[Float] = {

// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val inputs =
Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava

try {
val results = runner.run(inputs)
try {
val embeddings = results
.get("logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
tokenTensors.close()
maskTensors.close()

embeddings
} finally if (results != null) results.close()
}
}

def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = {
val batchLength = batch.length
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max

val rawScores = detectedEngine match {
case ONNX.name => getRowScoresWithOnnx(batch)
case _ => getRawScoresWithTF(batch, maxSentenceLength)
}

val dim = rawScores.length / batchLength
val batchScores: Array[Array[Float]] =
rawScores
Expand Down Expand Up @@ -284,15 +296,14 @@ private[johnsnowlabs] class RoBertaClassification(
.toArray)
}

val session = tensorflowWrapper.getTFSessionWithSignature(
val session = tensorflowWrapper.get.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
savedSignatures = signatures,
initAllTables = false)
val runner = session.runner

val tokenTensors = tensors.createIntBufferTensor(shape, tokenBuffers)
val maskTensors = tensors.createIntBufferTensor(shape, maskBuffers)
val segmentTensors = tensors.createIntBufferTensor(shape, segmentBuffers)

runner
.feed(
Expand Down Expand Up @@ -321,10 +332,29 @@ private[johnsnowlabs] class RoBertaClassification(
}

def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = {
val tensors = new TensorResources()

val batchLength = batch.length
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val (startLogits, endLogits) = detectedEngine match {
case ONNX.name => computeLogitsWithOnnx(batch)
case _ => computeLogitsWithTF(batch, maxSentenceLength)
}

val endDim = endLogits.length / batchLength
val endScores: Array[Array[Float]] =
endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray

val startDim = startLogits.length / batchLength
val startScores: Array[Array[Float]] =
startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray

(startScores, endScores)
}

private def computeLogitsWithTF(
batch: Seq[Array[Int]],
maxSentenceLength: Int): (Array[Float], Array[Float]) = {
val batchLength = batch.length
val tensors = new TensorResources()

val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
Expand All @@ -341,7 +371,7 @@ private[johnsnowlabs] class RoBertaClassification(
.write(sentence.map(x => if (x == sentencePadTokenId) 0 else 1))
}

val session = tensorflowWrapper.getTFSessionWithSignature(
val session = tensorflowWrapper.get.getTFSessionWithSignature(
configProtoBytes = configProtoBytes,
savedSignatures = signatures,
initAllTables = false)
Expand Down Expand Up @@ -371,7 +401,7 @@ private[johnsnowlabs] class RoBertaClassification(
outs.foreach(_.close())
tensors.clearSession(outs)
tensors.clearTensors()

val endDim = endLogits.length / batchLength
val endScores: Array[Array[Float]] = endLogits.grouped(endDim).toArray

Expand All @@ -381,6 +411,45 @@ private[johnsnowlabs] class RoBertaClassification(
(startScores, endScores)
}

private def computeLogitsWithOnnx(batch: Seq[Array[Int]]): (Array[Float], Array[Float]) = {
// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val inputs =
Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava

try {
val output = runner.run(inputs)
try {
val startLogits = output
.get("start_logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()

val endLogits = output
.get("end_logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()

tokenTensors.close()
maskTensors.close()

(startLogits.slice(1, startLogits.length), endLogits.slice(1, endLogits.length))
} finally if (output != null) output.close()
}
}

def findIndexedToken(
tokenizedSentences: Seq[TokenizedSentence],
sentence: (WordpieceTokenizedSentence, Int),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

package com.johnsnowlabs.ml.ai

import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow.TensorflowWrapper
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}

private[johnsnowlabs] class ZeroShotNerClassification(
override val tensorflowWrapper: TensorflowWrapper,
override val tensorflowWrapper: Option[TensorflowWrapper],
override val onnxWrapper: Option[OnnxWrapper],
override val sentenceStartTokenId: Int,
override val sentenceEndTokenId: Int,
override val sentencePadTokenId: Int,
Expand All @@ -32,6 +34,7 @@ private[johnsnowlabs] class ZeroShotNerClassification(
vocabulary: Map[String, Int])
extends RoBertaClassification(
tensorflowWrapper,
onnxWrapper,
sentenceStartTokenId,
sentenceEndTokenId,
sentencePadTokenId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{
}
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.embeddings.BertEmbeddings
import com.johnsnowlabs.nlp.serialization.MapFeature
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.param.{IntArrayParam, IntParam}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.johnsnowlabs.nlp.annotators.classifier.dl

import com.johnsnowlabs.ml.ai.{MergeTokenStrategy, RoBertaClassification}
import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow._
import com.johnsnowlabs.ml.util.LoadExternalModel.{
loadTextAsset,
Expand Down Expand Up @@ -220,12 +221,14 @@ class LongformerForQuestionAnswering(override val uid: String)
/** @group setParam */
def setModelIfNotSet(
spark: SparkSession,
tensorflowWrapper: TensorflowWrapper): LongformerForQuestionAnswering = {
tensorflowWrapper: Option[TensorflowWrapper],
onnxWrapper: Option[OnnxWrapper]): LongformerForQuestionAnswering = {
if (_model.isEmpty) {
_model = Some(
spark.sparkContext.broadcast(
new RoBertaClassification(
tensorflowWrapper,
onnxWrapper,
sentenceStartTokenId,
sentenceEndTokenId,
padTokenId,
Expand Down Expand Up @@ -282,7 +285,7 @@ class LongformerForQuestionAnswering(override val uid: String)
writeTensorflowModelV2(
path,
spark,
getModelIfNotSet.tensorflowWrapper,
getModelIfNotSet.tensorflowWrapper.get,
"_longformer_classification",
LongformerForQuestionAnswering.tfFile,
configProtoBytes = getConfigProtoBytes)
Expand Down Expand Up @@ -321,9 +324,9 @@ trait ReadLongformerForQuestionAnsweringDLModel extends ReadTensorflowModel {
path: String,
spark: SparkSession): Unit = {

val tf =
val tfWrapper =
readTensorflowModel(path, spark, "_longformer_classification_tf", initAllTables = false)
instance.setModelIfNotSet(spark, tf)
instance.setModelIfNotSet(spark, Some(tfWrapper), None)
}

addReader(readModel)
Expand All @@ -350,7 +353,7 @@ trait ReadLongformerForQuestionAnsweringDLModel extends ReadTensorflowModel {

detectedEngine match {
case TensorFlow.name =>
val (wrapper, signatures) =
val (tfWrapper, signatures) =
TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true)

val _signatures = signatures match {
Expand All @@ -363,7 +366,7 @@ trait ReadLongformerForQuestionAnsweringDLModel extends ReadTensorflowModel {
*/
annotatorModel
.setSignatures(_signatures)
.setModelIfNotSet(spark, wrapper)
.setModelIfNotSet(spark, Some(tfWrapper), None)

case _ =>
throw new Exception(notSupportedEngineError)
Expand Down
Loading
Loading