Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ONNX Support to ALBERT Token and Sequence Classification and Question Answering annotators #13956

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 155 additions & 79 deletions src/main/scala/com/johnsnowlabs/ml/ai/AlbertClassification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@

package com.johnsnowlabs.ml.ai

import ai.onnxruntime.OnnxTensor
import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder}
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.LoadExternalModel.notSupportedEngineError
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.{ActivationFunction, Annotation}
import org.tensorflow.ndarray.buffer.IntDataBuffer
Expand All @@ -37,7 +41,8 @@ import scala.collection.JavaConverters._
* TF v2 signatures in Spark NLP
*/
private[johnsnowlabs] class AlbertClassification(
val tensorflowWrapper: TensorflowWrapper,
val tensorflowWrapper: Option[TensorflowWrapper],
val onnxWrapper: Option[OnnxWrapper],
val spp: SentencePieceWrapper,
configProtoBytes: Option[Array[Byte]] = None,
tags: Map[String, Int],
Expand All @@ -48,6 +53,10 @@ private[johnsnowlabs] class AlbertClassification(

val _tfAlbertSignatures: Map[String, String] =
signatures.getOrElse(ModelSignatureManager.apply())
val detectedEngine: String =
if (tensorflowWrapper.isDefined) TensorFlow.name
else if (onnxWrapper.isDefined) ONNX.name
else TensorFlow.name

// keys representing the input and output tensors of the ALBERT model
protected val sentencePadTokenId: Int = spp.getSppModel.pieceToId("[pad]")
Expand Down Expand Up @@ -95,59 +104,13 @@ private[johnsnowlabs] class AlbertClassification(
}

def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val batchLength = batch.length
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max

val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val segmentBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)

// [nb of encoded sentences , maxSentenceLength]
val shape = Array(batch.length.toLong, maxSentenceLength)

batch.zipWithIndex
.foreach { case (sentence, idx) =>
val offset = idx * maxSentenceLength
tokenBuffers.offset(offset).write(sentence)
maskBuffers
.offset(offset)
.write(sentence.map(x => if (x == sentencePadTokenId) 0 else 1))
segmentBuffers.offset(offset).write(Array.fill(maxSentenceLength)(0))
}

val runner = tensorflowWrapper
.getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false)
.runner

val tokenTensors = tensors.createIntBufferTensor(shape, tokenBuffers)
val maskTensors = tensors.createIntBufferTensor(shape, maskBuffers)
val segmentTensors = tensors.createIntBufferTensor(shape, segmentBuffers)

runner
.feed(
_tfAlbertSignatures.getOrElse(
ModelSignatureConstants.InputIds.key,
"missing_input_id_key"),
tokenTensors)
.feed(
_tfAlbertSignatures
.getOrElse(ModelSignatureConstants.AttentionMask.key, "missing_input_mask_key"),
maskTensors)
.feed(
_tfAlbertSignatures
.getOrElse(ModelSignatureConstants.TokenTypeIds.key, "missing_segment_ids_key"),
segmentTensors)
.fetch(_tfAlbertSignatures
.getOrElse(ModelSignatureConstants.LogitsOutput.key, "missing_logits_key"))

val outs = runner.run().asScala
val rawScores = TensorResources.extractFloats(outs.head)

outs.foreach(_.close())
tensors.clearSession(outs)
tensors.clearTensors()
val rawScores = detectedEngine match {
case ONNX.name => getRowScoresWithOnnx(batch, maxSentenceLength, sequence = true)
case _ => getRawScoresWithTF(batch, maxSentenceLength)
}

val dim = rawScores.length / (batchLength * maxSentenceLength)
val batchScores: Array[Array[Array[Float]]] = rawScores
Expand All @@ -161,17 +124,39 @@ private[johnsnowlabs] class AlbertClassification(
}

def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = {
val batchLength = batch.length
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max

val rawScores = detectedEngine match {
case ONNX.name => getRowScoresWithOnnx(batch, maxSentenceLength, sequence = true)
case _ => getRawScoresWithTF(batch, maxSentenceLength)
}

val dim = rawScores.length / batchLength
val batchScores: Array[Array[Float]] =
rawScores
.grouped(dim)
.map(scores =>
activation match {
case ActivationFunction.softmax => calculateSoftmax(scores)
case ActivationFunction.sigmoid => calculateSigmoid(scores)
case _ => calculateSoftmax(scores)
})
.toArray

batchScores
}

private def getRawScoresWithTF(batch: Seq[Array[Int]], maxSentenceLength: Int): Array[Float] = {
val tensors = new TensorResources()

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val batchLength = batch.length

val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val segmentBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)

// [nb of encoded sentences , maxSentenceLength]
val shape = Array(batch.length.toLong, maxSentenceLength)
val shape = Array(batchLength.toLong, maxSentenceLength)

batch.zipWithIndex
.foreach { case (sentence, idx) =>
Expand All @@ -183,7 +168,7 @@ private[johnsnowlabs] class AlbertClassification(
segmentBuffers.offset(offset).write(Array.fill(maxSentenceLength)(0))
}

val runner = tensorflowWrapper
val runner = tensorflowWrapper.get
.getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false)
.runner

Expand Down Expand Up @@ -215,19 +200,51 @@ private[johnsnowlabs] class AlbertClassification(
tensors.clearSession(outs)
tensors.clearTensors()

val dim = rawScores.length / batchLength
val batchScores: Array[Array[Float]] =
rawScores
.grouped(dim)
.map(scores =>
activation match {
case ActivationFunction.softmax => calculateSoftmax(scores)
case ActivationFunction.sigmoid => calculateSigmoid(scores)
case _ => calculateSoftmax(scores)
})
.toArray
rawScores
}

batchScores
private def getRowScoresWithOnnx(
batch: Seq[Array[Int]],
maxSentenceLength: Int,
sequence: Boolean): Array[Float] = {

val output = if (sequence) "logits" else "last_hidden_state"

// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val segmentTensors =
OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray)

val inputs =
Map(
"input_ids" -> tokenTensors,
"attention_mask" -> maskTensors,
"token_type_ids" -> segmentTensors).asJava

try {
val results = runner.run(inputs)
try {
val embeddings = results
.get(output)
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
tokenTensors.close()
maskTensors.close()
segmentTensors.close()

embeddings
} finally if (results != null) results.close()
}
}

def tagZeroShotSequence(
Expand All @@ -237,10 +254,29 @@ private[johnsnowlabs] class AlbertClassification(
activation: String): Array[Array[Float]] = ???

def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = {
val tensors = new TensorResources()

val batchLength = batch.length
val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val (startLogits, endLogits) = detectedEngine match {
case ONNX.name => computeLogitsWithOnnx(batch, maxSentenceLength)
case _ => computeLogitsWithTF(batch, maxSentenceLength)
}

val endDim = endLogits.length / batchLength
val endScores: Array[Array[Float]] =
endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray

val startDim = startLogits.length / batchLength
val startScores: Array[Array[Float]] =
startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray

(startScores, endScores)
}

private def computeLogitsWithTF(
batch: Seq[Array[Int]],
maxSentenceLength: Int): (Array[Float], Array[Float]) = {
val batchLength = batch.length
val tensors = new TensorResources()

val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
Expand Down Expand Up @@ -271,7 +307,7 @@ private[johnsnowlabs] class AlbertClassification(
})
}

val runner = tensorflowWrapper
val runner = tensorflowWrapper.get
.getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false)
.runner

Expand Down Expand Up @@ -306,15 +342,55 @@ private[johnsnowlabs] class AlbertClassification(
tensors.clearSession(outs)
tensors.clearTensors()

val endDim = endLogits.length / batchLength
val endScores: Array[Array[Float]] =
endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray

val startDim = startLogits.length / batchLength
val startScores: Array[Array[Float]] =
startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray
(endLogits, startLogits)
}

(startScores, endScores)
private def computeLogitsWithOnnx(
batch: Seq[Array[Int]],
maxSentenceLength: Int): (Array[Float], Array[Float]) = {
// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val segmentTensors =
OnnxTensor.createTensor(env, batch.map(x => Array.fill(maxSentenceLength)(0L)).toArray)

val inputs =
Map(
"input_ids" -> tokenTensors,
"attention_mask" -> maskTensors,
"token_type_ids" -> segmentTensors).asJava

try {
val output = runner.run(inputs)
try {
val startLogits = output
.get("start_logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()

val endLogits = output
.get("end_logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()

tokenTensors.close()
maskTensors.close()
segmentTensors.close()

(startLogits.slice(1, startLogits.length), endLogits.slice(1, endLogits.length))
} finally if (output != null) output.close()
}
}

def findIndexedToken(
Expand Down
10 changes: 8 additions & 2 deletions src/main/scala/com/johnsnowlabs/ml/ai/XXXForClassification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.johnsnowlabs.ml.ai

import com.johnsnowlabs.ml.util.TensorFlow
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.{ActivationFunction, Annotation, AnnotatorType}

Expand Down Expand Up @@ -244,7 +245,8 @@ private[johnsnowlabs] trait XXXForClassification {
documents: Seq[Annotation],
maxSentenceLength: Int,
caseSensitive: Boolean,
mergeTokenStrategy: String = MergeTokenStrategy.vocab): Seq[Annotation] = {
mergeTokenStrategy: String = MergeTokenStrategy.vocab,
engine: String = TensorFlow.name): Seq[Annotation] = {

val questionAnnot = Seq(documents.head)
val contextAnnot = documents.drop(1)
Expand All @@ -264,9 +266,13 @@ private[johnsnowlabs] trait XXXForClassification {
val startIndex = startScores.zipWithIndex.maxBy(_._1)
val endIndex = endScores.zipWithIndex.maxBy(_._1)

val offsetStartIndex = if (engine == TensorFlow.name) 2 else 1
val offsetEndIndex = if (engine == TensorFlow.name) 1 else 0

val allTokenPieces =
wordPieceTokenizedQuestion.head.tokens ++ wordPieceTokenizedContext.flatMap(x => x.tokens)
val decodedAnswer = allTokenPieces.slice(startIndex._2 - 2, endIndex._2 - 1)
val decodedAnswer =
allTokenPieces.slice(startIndex._2 - offsetStartIndex, endIndex._2 - offsetEndIndex)
val content =
mergeTokenStrategy match {
case MergeTokenStrategy.vocab =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ private[johnsnowlabs] class ZeroShotNerClassification(
documents: Seq[Annotation],
maxSentenceLength: Int,
caseSensitive: Boolean,
mergeTokenStrategy: String): Seq[Annotation] = {
mergeTokenStrategy: String,
engine: String): Seq[Annotation] = {
val questionAnnot = Seq(documents.head)
val contextAnnot = documents.drop(1)

Expand Down
Loading