Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding onnx support to DeberatForXXX annotators #14096

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 122 additions & 62 deletions src/main/scala/com/johnsnowlabs/ml/ai/DeBertaClassification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@

package com.johnsnowlabs.ml.ai

import ai.onnxruntime.OnnxTensor
import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper}
import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder}
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.{ActivationFunction, Annotation}
import org.tensorflow.ndarray.buffer.IntDataBuffer
import org.tensorflow.ndarray.buffer
import org.tensorflow.ndarray.buffer.{IntDataBuffer, LongDataBuffer}

import scala.collection.JavaConverters._

Expand All @@ -37,7 +41,8 @@ import scala.collection.JavaConverters._
* TF v2 signatures in Spark NLP
*/
private[johnsnowlabs] class DeBertaClassification(
val tensorflowWrapper: TensorflowWrapper,
val tensorflowWrapper: Option[TensorflowWrapper],
val onnxWrapper: Option[OnnxWrapper],
val spp: SentencePieceWrapper,
configProtoBytes: Option[Array[Byte]] = None,
tags: Map[String, Int],
Expand All @@ -48,6 +53,11 @@ private[johnsnowlabs] class DeBertaClassification(

val _tfDeBertaSignatures: Map[String, String] =
signatures.getOrElse(ModelSignatureManager.apply())
val detectedEngine: String =
if (tensorflowWrapper.isDefined) TensorFlow.name
else if (onnxWrapper.isDefined) ONNX.name
else TensorFlow.name
private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions

// keys representing the input and output tensors of the DeBERTa model
protected val sentencePadTokenId: Int = spp.getSppModel.pieceToId("[PAD]")
Expand Down Expand Up @@ -95,59 +105,13 @@ private[johnsnowlabs] class DeBertaClassification(
}

def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val batchLength = batch.length

val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val segmentBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)

// [nb of encoded sentences , maxSentenceLength]
val shape = Array(batch.length.toLong, maxSentenceLength)

batch.zipWithIndex
.foreach { case (sentence, idx) =>
val offset = idx * maxSentenceLength
tokenBuffers.offset(offset).write(sentence)
maskBuffers
.offset(offset)
.write(sentence.map(x => if (x == sentencePadTokenId) 0 else 1))
segmentBuffers.offset(offset).write(Array.fill(maxSentenceLength)(0))
}

val runner = tensorflowWrapper
.getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false)
.runner

val tokenTensors = tensors.createIntBufferTensor(shape, tokenBuffers)
val maskTensors = tensors.createIntBufferTensor(shape, maskBuffers)
val segmentTensors = tensors.createIntBufferTensor(shape, segmentBuffers)

runner
.feed(
_tfDeBertaSignatures.getOrElse(
ModelSignatureConstants.InputIds.key,
"missing_input_id_key"),
tokenTensors)
.feed(
_tfDeBertaSignatures
.getOrElse(ModelSignatureConstants.AttentionMask.key, "missing_input_mask_key"),
maskTensors)
.feed(
_tfDeBertaSignatures
.getOrElse(ModelSignatureConstants.TokenTypeIds.key, "missing_segment_ids_key"),
segmentTensors)
.fetch(_tfDeBertaSignatures
.getOrElse(ModelSignatureConstants.LogitsOutput.key, "missing_logits_key"))

val outs = runner.run().asScala
val rawScores = TensorResources.extractFloats(outs.head)

outs.foreach(_.close())
tensors.clearSession(outs)
tensors.clearTensors()
val rawScores = detectedEngine match {
case ONNX.name => getRowScoresWithOnnx(batch)
case _ => getRawScoresWithTF(batch)
}

val dim = rawScores.length / (batchLength * maxSentenceLength)
val batchScores: Array[Array[Array[Float]]] = rawScores
Expand All @@ -160,7 +124,7 @@ private[johnsnowlabs] class DeBertaClassification(
batchScores
}

def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = {
private def getRawScoresWithTF(batch: Seq[Array[Int]]): Array[Float] = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
Expand All @@ -183,7 +147,7 @@ private[johnsnowlabs] class DeBertaClassification(
segmentBuffers.offset(offset).write(Array.fill(maxSentenceLength)(0))
}

val runner = tensorflowWrapper
val runner = tensorflowWrapper.get
.getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false)
.runner

Expand Down Expand Up @@ -215,6 +179,51 @@ private[johnsnowlabs] class DeBertaClassification(
tensors.clearSession(outs)
tensors.clearTensors()

rawScores
}


private def getRowScoresWithOnnx(batch: Seq[Array[Int]]): Array[Float] = {

// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions)

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val inputs =
Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava

try {
val results = runner.run(inputs)
try {
val embeddings = results
.get("logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
tokenTensors.close()
maskTensors.close()

embeddings
} finally if (results != null) results.close()
}
}

def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = {

val batchLength = batch.length

val rawScores = detectedEngine match {
case ONNX.name => getRowScoresWithOnnx(batch)
case _ => getRawScoresWithTF(batch)
}

val dim = rawScores.length / batchLength
val batchScores: Array[Array[Float]] =
rawScores
Expand All @@ -237,6 +246,25 @@ private[johnsnowlabs] class DeBertaClassification(
activation: String): Array[Array[Float]] = ???

def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = {
val batchLength = batch.length
val (startLogits, endLogits) = detectedEngine match {
case ONNX.name => computeLogitsWithOnnx(batch)
case _ => computeLogitsWithTF(batch)
}

val endDim = endLogits.length / batchLength
val endScores: Array[Array[Float]] =
endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray

val startDim = startLogits.length / batchLength
val startScores: Array[Array[Float]] =
startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray

(startScores, endScores)
}


private def computeLogitsWithTF(batch: Seq[Array[Int]]): (Array[Float], Array[Float])={
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
Expand All @@ -257,7 +285,7 @@ private[johnsnowlabs] class DeBertaClassification(
.write(sentence.map(x => if (x == sentencePadTokenId) 0 else 1))
}

val runner = tensorflowWrapper
val runner = tensorflowWrapper.get
.getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false)
.runner

Expand Down Expand Up @@ -286,15 +314,47 @@ private[johnsnowlabs] class DeBertaClassification(
tensors.clearSession(outs)
tensors.clearTensors()

val endDim = endLogits.length / batchLength
val endScores: Array[Array[Float]] =
endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray
(startLogits, endLogits)
}

val startDim = startLogits.length / batchLength
val startScores: Array[Array[Float]] =
startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray

(startScores, endScores)
private def computeLogitsWithOnnx(batch: Seq[Array[Int]]): (Array[Float], Array[Float]) = {
// [nb of encoded sentences]
val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions)

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val inputs =
Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava

try {
val output = runner.run(inputs)
try {
val startLogits = output
.get("start_logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()

val endLogits = output
.get("end_logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()

tokenTensors.close()
maskTensors.close()

(startLogits, endLogits)
} finally if (output != null) output.close()
}
}

def findIndexedToken(
Expand Down
Loading
Loading