From 04d6d284c6806e28c20c66fec24df7bc3e3a27b3 Mon Sep 17 00:00:00 2001 From: ahmedlone127 Date: Sun, 17 Dec 2023 22:55:12 +0500 Subject: [PATCH] adding onnx support to DeberatForXXX annotators --- .../ml/ai/DeBertaClassification.scala | 184 ++++++++++++------ .../dl/DeBertaForQuestionAnswering.scala | 80 +++++--- .../dl/DeBertaForSequenceClassification.scala | 73 +++++-- .../dl/DeBertaForTokenClassification.scala | 90 ++++++--- 4 files changed, 294 insertions(+), 133 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/DeBertaClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/DeBertaClassification.scala index 5022105f47d588..32e2397e2c74e2 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/DeBertaClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/DeBertaClassification.scala @@ -16,12 +16,16 @@ package com.johnsnowlabs.ml.ai +import ai.onnxruntime.OnnxTensor +import com.johnsnowlabs.ml.onnx.{OnnxSession, OnnxWrapper} import com.johnsnowlabs.ml.tensorflow.sentencepiece.{SentencePieceWrapper, SentencepieceEncoder} import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} +import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.{ActivationFunction, Annotation} -import org.tensorflow.ndarray.buffer.IntDataBuffer +import org.tensorflow.ndarray.buffer +import org.tensorflow.ndarray.buffer.{IntDataBuffer, LongDataBuffer} import scala.collection.JavaConverters._ @@ -37,7 +41,8 @@ import scala.collection.JavaConverters._ * TF v2 signatures in Spark NLP */ private[johnsnowlabs] class DeBertaClassification( - val tensorflowWrapper: TensorflowWrapper, + val tensorflowWrapper: Option[TensorflowWrapper], + val onnxWrapper: Option[OnnxWrapper], val spp: SentencePieceWrapper, configProtoBytes: Option[Array[Byte]] = None, tags: Map[String, Int], @@ -48,6 +53,11 @@ private[johnsnowlabs] class DeBertaClassification( val _tfDeBertaSignatures: Map[String, String] = signatures.getOrElse(ModelSignatureManager.apply()) + val detectedEngine: String = + if (tensorflowWrapper.isDefined) TensorFlow.name + else if (onnxWrapper.isDefined) ONNX.name + else TensorFlow.name + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions // keys representing the input and output tensors of the DeBERTa model protected val sentencePadTokenId: Int = spp.getSppModel.pieceToId("[PAD]") @@ -95,59 +105,13 @@ private[johnsnowlabs] class DeBertaClassification( } def tag(batch: Seq[Array[Int]]): Seq[Array[Array[Float]]] = { - val tensors = new TensorResources() - val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max val batchLength = batch.length - val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength) - val maskBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength) - val segmentBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength) - - // [nb of encoded sentences , maxSentenceLength] - val shape = Array(batch.length.toLong, maxSentenceLength) - - batch.zipWithIndex - .foreach { case (sentence, idx) => - val offset = idx * maxSentenceLength - tokenBuffers.offset(offset).write(sentence) - maskBuffers - .offset(offset) - .write(sentence.map(x => if (x == sentencePadTokenId) 0 else 1)) - segmentBuffers.offset(offset).write(Array.fill(maxSentenceLength)(0)) - } - - val runner = tensorflowWrapper - .getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false) - .runner - - val tokenTensors = tensors.createIntBufferTensor(shape, tokenBuffers) - val maskTensors = tensors.createIntBufferTensor(shape, maskBuffers) - val segmentTensors = tensors.createIntBufferTensor(shape, segmentBuffers) - - runner - .feed( - _tfDeBertaSignatures.getOrElse( - ModelSignatureConstants.InputIds.key, - "missing_input_id_key"), - tokenTensors) - .feed( - _tfDeBertaSignatures - .getOrElse(ModelSignatureConstants.AttentionMask.key, "missing_input_mask_key"), - maskTensors) - .feed( - _tfDeBertaSignatures - .getOrElse(ModelSignatureConstants.TokenTypeIds.key, "missing_segment_ids_key"), - segmentTensors) - .fetch(_tfDeBertaSignatures - .getOrElse(ModelSignatureConstants.LogitsOutput.key, "missing_logits_key")) - - val outs = runner.run().asScala - val rawScores = TensorResources.extractFloats(outs.head) - - outs.foreach(_.close()) - tensors.clearSession(outs) - tensors.clearTensors() + val rawScores = detectedEngine match { + case ONNX.name => getRowScoresWithOnnx(batch) + case _ => getRawScoresWithTF(batch) + } val dim = rawScores.length / (batchLength * maxSentenceLength) val batchScores: Array[Array[Array[Float]]] = rawScores @@ -160,7 +124,7 @@ private[johnsnowlabs] class DeBertaClassification( batchScores } - def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = { + private def getRawScoresWithTF(batch: Seq[Array[Int]]): Array[Float] = { val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max @@ -183,7 +147,7 @@ private[johnsnowlabs] class DeBertaClassification( segmentBuffers.offset(offset).write(Array.fill(maxSentenceLength)(0)) } - val runner = tensorflowWrapper + val runner = tensorflowWrapper.get .getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false) .runner @@ -215,6 +179,51 @@ private[johnsnowlabs] class DeBertaClassification( tensors.clearSession(outs) tensors.clearTensors() + rawScores + } + + + private def getRowScoresWithOnnx(batch: Seq[Array[Int]]): Array[Float] = { + + // [nb of encoded sentences , maxSentenceLength] + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) + + val tokenTensors = + OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) + val maskTensors = + OnnxTensor.createTensor( + env, + batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray) + + val inputs = + Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava + + try { + val results = runner.run(inputs) + try { + val embeddings = results + .get("logits") + .get() + .asInstanceOf[OnnxTensor] + .getFloatBuffer + .array() + tokenTensors.close() + maskTensors.close() + + embeddings + } finally if (results != null) results.close() + } + } + + def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = { + + val batchLength = batch.length + + val rawScores = detectedEngine match { + case ONNX.name => getRowScoresWithOnnx(batch) + case _ => getRawScoresWithTF(batch) + } + val dim = rawScores.length / batchLength val batchScores: Array[Array[Float]] = rawScores @@ -237,6 +246,25 @@ private[johnsnowlabs] class DeBertaClassification( activation: String): Array[Array[Float]] = ??? def tagSpan(batch: Seq[Array[Int]]): (Array[Array[Float]], Array[Array[Float]]) = { + val batchLength = batch.length + val (startLogits, endLogits) = detectedEngine match { + case ONNX.name => computeLogitsWithOnnx(batch) + case _ => computeLogitsWithTF(batch) + } + + val endDim = endLogits.length / batchLength + val endScores: Array[Array[Float]] = + endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray + + val startDim = startLogits.length / batchLength + val startScores: Array[Array[Float]] = + startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray + + (startScores, endScores) + } + + + private def computeLogitsWithTF(batch: Seq[Array[Int]]): (Array[Float], Array[Float])={ val tensors = new TensorResources() val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max @@ -257,7 +285,7 @@ private[johnsnowlabs] class DeBertaClassification( .write(sentence.map(x => if (x == sentencePadTokenId) 0 else 1)) } - val runner = tensorflowWrapper + val runner = tensorflowWrapper.get .getTFSessionWithSignature(configProtoBytes = configProtoBytes, initAllTables = false) .runner @@ -286,15 +314,47 @@ private[johnsnowlabs] class DeBertaClassification( tensors.clearSession(outs) tensors.clearTensors() - val endDim = endLogits.length / batchLength - val endScores: Array[Array[Float]] = - endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray + (startLogits, endLogits) + } - val startDim = startLogits.length / batchLength - val startScores: Array[Array[Float]] = - startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray - (startScores, endScores) + private def computeLogitsWithOnnx(batch: Seq[Array[Int]]): (Array[Float], Array[Float]) = { + // [nb of encoded sentences] + val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions) + + val tokenTensors = + OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray) + val maskTensors = + OnnxTensor.createTensor( + env, + batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray) + + val inputs = + Map("input_ids" -> tokenTensors, "attention_mask" -> maskTensors).asJava + + try { + val output = runner.run(inputs) + try { + val startLogits = output + .get("start_logits") + .get() + .asInstanceOf[OnnxTensor] + .getFloatBuffer + .array() + + val endLogits = output + .get("end_logits") + .get() + .asInstanceOf[OnnxTensor] + .getFloatBuffer + .array() + + tokenTensors.close() + maskTensors.close() + + (startLogits, endLogits) + } finally if (output != null) output.close() + } } def findIndexedToken( diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala index 600b85da999a6d..a9a12a21aa7506 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala @@ -17,18 +17,11 @@ package com.johnsnowlabs.nlp.annotators.classifier.dl import com.johnsnowlabs.ml.ai.{DeBertaClassification, MergeTokenStrategy} +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} import com.johnsnowlabs.ml.tensorflow._ -import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ - ReadSentencePieceModel, - SentencePieceWrapper, - WriteSentencePieceModel -} -import com.johnsnowlabs.ml.util.LoadExternalModel.{ - loadSentencePieceAsset, - modelSanityCheck, - notSupportedEngineError -} -import com.johnsnowlabs.ml.util.TensorFlow +import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ReadSentencePieceModel, SentencePieceWrapper, WriteSentencePieceModel} +import com.johnsnowlabs.ml.util.LoadExternalModel.{loadSentencePieceAsset, modelSanityCheck, notSupportedEngineError} +import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast @@ -116,6 +109,7 @@ class DeBertaForQuestionAnswering(override val uid: String) extends AnnotatorModel[DeBertaForQuestionAnswering] with HasBatchedAnnotate[DeBertaForQuestionAnswering] with WriteTensorflowModel + with WriteOnnxModel with WriteSentencePieceModel with HasCaseSensitiveProperties with HasEngine { @@ -196,13 +190,15 @@ class DeBertaForQuestionAnswering(override val uid: String) /** @group setParam */ def setModelIfNotSet( spark: SparkSession, - tensorflowWrapper: TensorflowWrapper, + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper], spp: SentencePieceWrapper): DeBertaForQuestionAnswering = { if (_model.isEmpty) { _model = Some( spark.sparkContext.broadcast( new DeBertaClassification( tensorflowWrapper, + onnxWrapper, spp, configProtoBytes = getConfigProtoBytes, tags = Map.empty[String, Int], @@ -253,13 +249,26 @@ class DeBertaForQuestionAnswering(override val uid: String) override def onWrite(path: String, spark: SparkSession): Unit = { super.onWrite(path, spark) - writeTensorflowModelV2( - path, - spark, - getModelIfNotSet.tensorflowWrapper, - "_deberta_classification", - DeBertaForQuestionAnswering.tfFile, - configProtoBytes = getConfigProtoBytes) + val suffix = "_deberta_classification" + + getEngine match { + case TensorFlow.name => + writeTensorflowModelV2( + path, + spark, + getModelIfNotSet.tensorflowWrapper.get, + suffix, + DeBertaForQuestionAnswering.tfFile, + configProtoBytes = getConfigProtoBytes) + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + suffix, + DeBertaForQuestionAnswering.onnxFile) + } + writeSentencePieceModel( path, spark, @@ -292,21 +301,38 @@ trait ReadablePretrainedDeBertaForQAModel trait ReadDeBertaForQuestionAnsweringDLModel extends ReadTensorflowModel + with ReadOnnxModel with ReadSentencePieceModel { this: ParamsAndFeaturesReadable[DeBertaForQuestionAnswering] => override val tfFile: String = "deberta_classification_tensorflow" + override val onnxFile: String = "camembert_classification_onnx" override val sppFile: String = "deberta_spp" def readModel( instance: DeBertaForQuestionAnswering, path: String, spark: SparkSession): Unit = { - - val tf = - readTensorflowModel(path, spark, "_deberta_classification_tf", initAllTables = false) val spp = readSentencePieceModel(path, spark, "_deberta_spp", sppFile) - instance.setModelIfNotSet(spark, tf, spp) + + instance.getEngine match { + case TensorFlow.name => + val tfWrapper = + readTensorflowModel(path, spark, "_deberta_classification_tf", initAllTables = false) + instance.setModelIfNotSet(spark, Some(tfWrapper), None, spp) + case ONNX.name => + val onnxWrapper = + readOnnxModel( + path, + spark, + "_deberta_classification_onnx", + zipped = true, + useBundle = false, + None) + instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) + case _ => + throw new Exception(notSupportedEngineError) + } } addReader(readModel) @@ -324,7 +350,7 @@ trait ReadDeBertaForQuestionAnsweringDLModel detectedEngine match { case TensorFlow.name => - val (wrapper, signatures) = + val (tfWrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) val _signatures = signatures match { @@ -337,7 +363,11 @@ trait ReadDeBertaForQuestionAnsweringDLModel */ annotatorModel .setSignatures(_signatures) - .setModelIfNotSet(spark, wrapper, spModel) + .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) + case ONNX.name => + val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + annotatorModel + .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) case _ => throw new Exception(notSupportedEngineError) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala index 0f025ebca7c367..328bc5447edeca 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala @@ -17,6 +17,7 @@ package com.johnsnowlabs.nlp.annotators.classifier.dl import com.johnsnowlabs.ml.ai.DeBertaClassification +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} import com.johnsnowlabs.ml.tensorflow._ import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ ReadSentencePieceModel, @@ -29,7 +30,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.TensorFlow +import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -122,7 +123,7 @@ import org.apache.spark.sql.SparkSession */ class DeBertaForSequenceClassification(override val uid: String) extends AnnotatorModel[DeBertaForSequenceClassification] - with HasBatchedAnnotate[DeBertaForSequenceClassification] + with HasBatchedAnnotate[DeBertaForSequenceClassification] with WriteOnnxModel with WriteTensorflowModel with WriteSentencePieceModel with HasCaseSensitiveProperties @@ -238,13 +239,15 @@ class DeBertaForSequenceClassification(override val uid: String) /** @group setParam */ def setModelIfNotSet( spark: SparkSession, - tensorflowWrapper: TensorflowWrapper, + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper], spp: SentencePieceWrapper): DeBertaForSequenceClassification = { if (_model.isEmpty) { _model = Some( spark.sparkContext.broadcast( new DeBertaClassification( tensorflowWrapper, + onnxWrapper, spp, configProtoBytes = getConfigProtoBytes, tags = $$(labels), @@ -305,13 +308,26 @@ class DeBertaForSequenceClassification(override val uid: String) override def onWrite(path: String, spark: SparkSession): Unit = { super.onWrite(path, spark) - writeTensorflowModelV2( - path, - spark, - getModelIfNotSet.tensorflowWrapper, - "_deberta_classification", - DeBertaForSequenceClassification.tfFile, - configProtoBytes = getConfigProtoBytes) + val suffix = "_deberta_classification" + + getEngine match { + case TensorFlow.name => + writeTensorflowModelV2( + path, + spark, + getModelIfNotSet.tensorflowWrapper.get, + suffix, + DeBertaForSequenceClassification.tfFile, + configProtoBytes = getConfigProtoBytes) + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + suffix, + DeBertaForSequenceClassification.onnxFile) + } + writeSentencePieceModel( path, spark, @@ -342,21 +358,40 @@ trait ReadablePretrainedDeBertaForSequenceModel super.pretrained(name, lang, remoteLoc) } -trait ReadDeBertaForSequenceDLModel extends ReadTensorflowModel with ReadSentencePieceModel { +trait ReadDeBertaForSequenceDLModel + extends ReadTensorflowModel + with ReadOnnxModel + with ReadSentencePieceModel { this: ParamsAndFeaturesReadable[DeBertaForSequenceClassification] => override val tfFile: String = "deberta_classification_tensorflow" + override val onnxFile: String = "deberta_classification_onnx" override val sppFile: String = "deberta_spp" def readModel( instance: DeBertaForSequenceClassification, path: String, spark: SparkSession): Unit = { - - val tf = - readTensorflowModel(path, spark, "_deberta_classification_tf", initAllTables = false) val spp = readSentencePieceModel(path, spark, "_deberta_spp", sppFile) - instance.setModelIfNotSet(spark, tf, spp) + + instance.getEngine match { + case TensorFlow.name => + val tfWrapper = + readTensorflowModel(path, spark, "_deberta_classification_tf", initAllTables = false) + instance.setModelIfNotSet(spark, Some(tfWrapper), None, spp) + case ONNX.name => + val onnxWrapper = + readOnnxModel( + path, + spark, + "_deberta_classification_onnx", + zipped = true, + useBundle = false, + None) + instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) + case _ => + throw new Exception(notSupportedEngineError) + } } addReader(readModel) @@ -375,7 +410,7 @@ trait ReadDeBertaForSequenceDLModel extends ReadTensorflowModel with ReadSentenc detectedEngine match { case TensorFlow.name => - val (wrapper, signatures) = + val (tfWrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) val _signatures = signatures match { @@ -388,8 +423,12 @@ trait ReadDeBertaForSequenceDLModel extends ReadTensorflowModel with ReadSentenc */ annotatorModel .setSignatures(_signatures) - .setModelIfNotSet(spark, wrapper, spModel) + .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) + case ONNX.name => + val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + annotatorModel + .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) case _ => throw new Exception(notSupportedEngineError) } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala index 81b3fdff7def4b..43f10690e104ea 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala @@ -17,19 +17,11 @@ package com.johnsnowlabs.nlp.annotators.classifier.dl import com.johnsnowlabs.ml.ai.DeBertaClassification +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} import com.johnsnowlabs.ml.tensorflow._ -import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ - ReadSentencePieceModel, - SentencePieceWrapper, - WriteSentencePieceModel -} -import com.johnsnowlabs.ml.util.LoadExternalModel.{ - loadSentencePieceAsset, - loadTextAsset, - modelSanityCheck, - notSupportedEngineError -} -import com.johnsnowlabs.ml.util.TensorFlow +import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ReadSentencePieceModel, SentencePieceWrapper, WriteSentencePieceModel} +import com.johnsnowlabs.ml.util.LoadExternalModel.{loadSentencePieceAsset, loadTextAsset, modelSanityCheck, notSupportedEngineError} +import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -124,6 +116,7 @@ class DeBertaForTokenClassification(override val uid: String) extends AnnotatorModel[DeBertaForTokenClassification] with HasBatchedAnnotate[DeBertaForTokenClassification] with WriteTensorflowModel + with WriteOnnxModel with WriteSentencePieceModel with HasCaseSensitiveProperties with HasEngine { @@ -218,13 +211,15 @@ class DeBertaForTokenClassification(override val uid: String) /** @group setParam */ def setModelIfNotSet( spark: SparkSession, - tensorflowWrapper: TensorflowWrapper, + tensorflowWrapper: Option[TensorflowWrapper], + onnxWrapper: Option[OnnxWrapper], spp: SentencePieceWrapper): DeBertaForTokenClassification = { if (_model.isEmpty) { _model = Some( spark.sparkContext.broadcast( new DeBertaClassification( tensorflowWrapper, + onnxWrapper, spp, configProtoBytes = getConfigProtoBytes, tags = $$(labels), @@ -277,13 +272,26 @@ class DeBertaForTokenClassification(override val uid: String) override def onWrite(path: String, spark: SparkSession): Unit = { super.onWrite(path, spark) - writeTensorflowModelV2( - path, - spark, - getModelIfNotSet.tensorflowWrapper, - "_deberta_classification", - DeBertaForTokenClassification.tfFile, - configProtoBytes = getConfigProtoBytes) + val suffix = "_deberta_classification" + + getEngine match { + case TensorFlow.name => + writeTensorflowModelV2( + path, + spark, + getModelIfNotSet.tensorflowWrapper.get, + suffix, + DeBertaForTokenClassification.tfFile, + configProtoBytes = getConfigProtoBytes) + case ONNX.name => + writeOnnxModel( + path, + spark, + getModelIfNotSet.onnxWrapper.get, + suffix, + DeBertaForTokenClassification.onnxFile) + } + writeSentencePieceModel( path, spark, @@ -313,20 +321,40 @@ trait ReadablePretrainedDeBertaForTokenModel remoteLoc: String): DeBertaForTokenClassification = super.pretrained(name, lang, remoteLoc) } -trait ReadDeBertaForTokenDLModel extends ReadTensorflowModel with ReadSentencePieceModel { +trait ReadDeBertaForTokenDLModel + extends ReadTensorflowModel + with ReadOnnxModel + with ReadSentencePieceModel { this: ParamsAndFeaturesReadable[DeBertaForTokenClassification] => override val tfFile: String = "deberta_classification_tensorflow" + override val onnxFile: String = "deberta_classification_onnx" override val sppFile: String = "deberta_spp" def readModel( - instance: DeBertaForTokenClassification, - path: String, - spark: SparkSession): Unit = { - - val tf = readTensorflowModel(path, spark, "_deberta_classification_tf", initAllTables = false) + instance: DeBertaForTokenClassification, + path: String, + spark: SparkSession): Unit = { val spp = readSentencePieceModel(path, spark, "_deberta_spp", sppFile) - instance.setModelIfNotSet(spark, tf, spp) + + instance.getEngine match { + case TensorFlow.name => + val tfWrapper = + readTensorflowModel(path, spark, "_deberta_classification_tf", initAllTables = false) + instance.setModelIfNotSet(spark, Some(tfWrapper), None, spp) + case ONNX.name => + val onnxWrapper = + readOnnxModel( + path, + spark, + "_deberta_classification_onnx", + zipped = true, + useBundle = false, + None) + instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) + case _ => + throw new Exception(notSupportedEngineError) + } } addReader(readModel) @@ -344,7 +372,7 @@ trait ReadDeBertaForTokenDLModel extends ReadTensorflowModel with ReadSentencePi detectedEngine match { case TensorFlow.name => - val (wrapper, signatures) = + val (tfWrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) val _signatures = signatures match { @@ -357,7 +385,11 @@ trait ReadDeBertaForTokenDLModel extends ReadTensorflowModel with ReadSentencePi */ annotatorModel .setSignatures(_signatures) - .setModelIfNotSet(spark, wrapper, spModel) + .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) + case ONNX.name => + val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) + annotatorModel + .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) case _ => throw new Exception(notSupportedEngineError)