Skip to content

Commit

Permalink
adding onnx support to DeberatForXXX annotators (#14096)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmedlone127 authored Dec 27, 2023
1 parent 05dda07 commit 6a3623f
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 133 deletions.
184 changes: 122 additions & 62 deletions src/main/scala/com/johnsnowlabs/ml/ai/DeBertaClassification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@

package com.johnsnowlabs.ml.ai

import ai.onnxruntime.OnnxTensor
import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper}
import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder}
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.{ActivationFunction, Annotation}
import org.tensorflow.ndarray.buffer.IntDataBuffer
import org.tensorflow.ndarray.buffer
import org.tensorflow.ndarray.buffer.{IntDataBuffer, LongDataBuffer}

import scala.collection.JavaConverters._

Expand All @@ -37,7 +41,8 @@ import scala.collection.JavaConverters._
* TF v2 signatures in Spark NLP
*/
private[johnsnowlabs] class DeBertaClassification(
val tensorflowWrapper: TensorflowWrapper,
val tensorflowWrapper: Option[TensorflowWrapper],
val onnxWrapper: Option[OnnxWrapper],
val spp: SentencePieceWrapper,
configProtoBytes: Option[Array[Byte]] = None,
tags: Map[String, Int],
Expand All @@ -48,6 +53,11 @@ private[johnsnowlabs] class DeBertaClassification(

val _tfDeBertaSignatures: Map[String, String] =
signatures.getOrElse(ModelSignatureManager.apply())
val detectedEngine: String =
if (tensorflowWrapper.isDefined) TensorFlow.name
else if (onnxWrapper.isDefined) ONNX.name
else TensorFlow.name
private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions

// keys representing the input and output tensors of the DeBERTa model
protected val sentencePadTokenId: Int = spp.getSppModel.pieceToId("[PAD]")
Expand Down Expand Up @@ -95,59 +105,13 @@ private[johnsnowlabs] class DeBertaClassification(
}

def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val batchLength = batch.length

val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val segmentBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)

// [nb of encoded sentences , maxSentenceLength]
val shape = Array(batch.length.toLong, maxSentenceLength)

batch.zipWithIndex
.foreach { case (sentence, idx) =>
val offset = idx * maxSentenceLength
tokenBuffers.offset(offset).write(sentence)
maskBuffers
.offset(offset)
.write(sentence.map(x => if (x == sentencePadTokenId) 0 else 1))
segmentBuffers.offset(offset).write(Array.fill(maxSentenceLength)(0))
}

val runner = tensorflowWrapper
.getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false)
.runner

val tokenTensors = tensors.createIntBufferTensor(shape, tokenBuffers)
val maskTensors = tensors.createIntBufferTensor(shape, maskBuffers)
val segmentTensors = tensors.createIntBufferTensor(shape, segmentBuffers)

runner
.feed(
_tfDeBertaSignatures.getOrElse(
ModelSignatureConstants.InputIds.key,
"missing_input_id_key"),
tokenTensors)
.feed(
_tfDeBertaSignatures
.getOrElse(ModelSignatureConstants.AttentionMask.key, "missing_input_mask_key"),
maskTensors)
.feed(
_tfDeBertaSignatures
.getOrElse(ModelSignatureConstants.TokenTypeIds.key, "missing_segment_ids_key"),
segmentTensors)
.fetch(_tfDeBertaSignatures
.getOrElse(ModelSignatureConstants.LogitsOutput.key, "missing_logits_key"))

val outs = runner.run().asScala
val rawScores = TensorResources.extractFloats(outs.head)

outs.foreach(_.close())
tensors.clearSession(outs)
tensors.clearTensors()
val rawScores = detectedEngine match {
case ONNX.name => getRowScoresWithOnnx(batch)
case _ => getRawScoresWithTF(batch)
}

val dim = rawScores.length / (batchLength * maxSentenceLength)
val batchScores: Array[Array[Array[Float]]] = rawScores
Expand All @@ -160,7 +124,7 @@ private[johnsnowlabs] class DeBertaClassification(
batchScores
}

def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = {
private def getRawScoresWithTF(batch: Seq[Array[Int]]): Array[Float] = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
Expand All @@ -183,7 +147,7 @@ private[johnsnowlabs] class DeBertaClassification(
segmentBuffers.offset(offset).write(Array.fill(maxSentenceLength)(0))
}

val runner = tensorflowWrapper
val runner = tensorflowWrapper.get
.getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false)
.runner

Expand Down Expand Up @@ -215,6 +179,51 @@ private[johnsnowlabs] class DeBertaClassification(
tensors.clearSession(outs)
tensors.clearTensors()

rawScores
}


private def getRowScoresWithOnnx(batch: Seq[Array[Int]]): Array[Float] = {

// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions)

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val inputs =
Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava

try {
val results = runner.run(inputs)
try {
val embeddings = results
.get("logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
tokenTensors.close()
maskTensors.close()

embeddings
} finally if (results != null) results.close()
}
}

def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = {

val batchLength = batch.length

val rawScores = detectedEngine match {
case ONNX.name => getRowScoresWithOnnx(batch)
case _ => getRawScoresWithTF(batch)
}

val dim = rawScores.length / batchLength
val batchScores: Array[Array[Float]] =
rawScores
Expand All @@ -237,6 +246,25 @@ private[johnsnowlabs] class DeBertaClassification(
activation: String): Array[Array[Float]] = ???

def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = {
val batchLength = batch.length
val (startLogits, endLogits) = detectedEngine match {
case ONNX.name => computeLogitsWithOnnx(batch)
case _ => computeLogitsWithTF(batch)
}

val endDim = endLogits.length / batchLength
val endScores: Array[Array[Float]] =
endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray

val startDim = startLogits.length / batchLength
val startScores: Array[Array[Float]] =
startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray

(startScores, endScores)
}


private def computeLogitsWithTF(batch: Seq[Array[Int]]): (Array[Float], Array[Float])={
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
Expand All @@ -257,7 +285,7 @@ private[johnsnowlabs] class DeBertaClassification(
.write(sentence.map(x => if (x == sentencePadTokenId) 0 else 1))
}

val runner = tensorflowWrapper
val runner = tensorflowWrapper.get
.getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false)
.runner

Expand Down Expand Up @@ -286,15 +314,47 @@ private[johnsnowlabs] class DeBertaClassification(
tensors.clearSession(outs)
tensors.clearTensors()

val endDim = endLogits.length / batchLength
val endScores: Array[Array[Float]] =
endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray
(startLogits, endLogits)
}

val startDim = startLogits.length / batchLength
val startScores: Array[Array[Float]] =
startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray

(startScores, endScores)
private def computeLogitsWithOnnx(batch: Seq[Array[Int]]): (Array[Float], Array[Float]) = {
// [nb of encoded sentences]
val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions)

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val inputs =
Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava

try {
val output = runner.run(inputs)
try {
val startLogits = output
.get("start_logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()

val endLogits = output
.get("end_logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()

tokenTensors.close()
maskTensors.close()

(startLogits, endLogits)
} finally if (output != null) output.close()
}
}

def findIndexedToken(
Expand Down
Loading

0 comments on commit 6a3623f

Please sign in to comment.