From f6b406fd6018c765c36445b18890f037c517ecb6 Mon Sep 17 00:00:00 2001 From: Danilo Burbano Date: Mon, 19 Feb 2024 16:36:04 -0500 Subject: [PATCH] [SPARKNLP-994] Adding changes required to load ONNX ZeroShotNerClassification and fix predictions issue --- .../2023-02-08-zero_shot_ner_roberta_en.md | 6 +-- .../ml/ai/RoBertaClassification.scala | 4 +- .../ml/ai/ZeroShotNerClassification.scala | 3 +- .../annotators/ner/dl/ZeroShotNerModel.scala | 40 ++++++++++++++----- 4 files changed, 38 insertions(+), 15 deletions(-) diff --git a/docs/_posts/maziyarpanahi/2023-02-08-zero_shot_ner_roberta_en.md b/docs/_posts/maziyarpanahi/2023-02-08-zero_shot_ner_roberta_en.md index 5a07d61ac6ad47..ce1e1995b037f8 100644 --- a/docs/_posts/maziyarpanahi/2023-02-08-zero_shot_ner_roberta_en.md +++ b/docs/_posts/maziyarpanahi/2023-02-08-zero_shot_ner_roberta_en.md @@ -101,15 +101,15 @@ val ner_converter = new NerConverter() .setInputCols(Array("sentence", "token", "zero_shot_ner")) .setOutputCol("ner_chunk") -val pipeline = new .setStages(Array( +val pipeline = new Pipeline().setStages(Array( documentAssembler, sentenceDetector, tokenizer, zero_shot_ner, ner_converter)) -val data = Seq(Array("Hellen works in London, Paris and Berlin. My name is Clara, I live in New York and Hellen lives in Paris.", - "John is a man who works in London, London and London.")toDS().toDF("text") +val data = Seq("Hellen works in London, Paris and Berlin. My name is Clara, I live in New York and Hellen lives in Paris.", + "John is a man who works in London, London and London.").toDS.toDF("text") val result = pipeline.fit(data).transform(data) ``` diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/RoBertaClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/RoBertaClassification.scala index 4296c8bcf5a542..fda62f52bee6e1 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/RoBertaClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/RoBertaClassification.scala @@ -352,11 +352,11 @@ private[johnsnowlabs] class RoBertaClassification( val endDim = endLogits.length / batchLength val endScores: Array[Array[Float]] = - endLogits.grouped(endDim).toArray + endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray val startDim = startLogits.length / batchLength val startScores: Array[Array[Float]] = - startLogits.grouped(startDim).toArray + startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray (startScores, endScores) } diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/ZeroShotNerClassification.scala b/src/main/scala/com/johnsnowlabs/ml/ai/ZeroShotNerClassification.scala index f1cc6cb7b9a8c1..638138223176d1 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/ZeroShotNerClassification.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/ZeroShotNerClassification.scala @@ -88,7 +88,7 @@ private[johnsnowlabs] class ZeroShotNerClassification( val allTokenPieces = wordPieceTokenizedQuestion.head.tokens ++ wordPieceTokenizedContext.flatMap(x => x.tokens) - val decodedAnswer = allTokenPieces.slice(startIndex._2 - 2, endIndex._2 - 1) + val decodedAnswer = allTokenPieces.slice(startIndex._2 - 3, endIndex._2 - 2) // Check if the answer span starts at the CLS symbol 0 - if so return empty string val content = if (startIndex._2 > 0) @@ -141,4 +141,5 @@ private[johnsnowlabs] class ZeroShotNerClassification( } } + } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/ZeroShotNerModel.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/ZeroShotNerModel.scala index 5c26a2615fbb37..51878181758069 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/ZeroShotNerModel.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/dl/ZeroShotNerModel.scala @@ -17,8 +17,10 @@ package com.johnsnowlabs.nlp.annotators.ner.dl import com.johnsnowlabs.ml.ai.{RoBertaClassification, ZeroShotNerClassification} -import com.johnsnowlabs.ml.onnx.OnnxWrapper +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel} import com.johnsnowlabs.ml.tensorflow.{ReadTensorflowModel, TensorflowWrapper} +import com.johnsnowlabs.ml.util.LoadExternalModel.notSupportedEngineError +import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, NAMED_ENTITY, TOKEN} import com.johnsnowlabs.nlp.annotator.RoBertaForQuestionAnswering import com.johnsnowlabs.nlp.pretrained.ResourceDownloader @@ -448,19 +450,34 @@ trait ReadablePretrainedZeroShotNer } } -trait ReadZeroShotNerDLModel extends ReadTensorflowModel { +trait ReadZeroShotNerDLModel extends ReadTensorflowModel with ReadOnnxModel { this: ParamsAndFeaturesReadable[ZeroShotNerModel] => override val tfFile: String = "roberta_classification_tensorflow" + override val onnxFile: String = "roberta_classification_onnx" - def readTensorflow(instance: ZeroShotNerModel, path: String, spark: SparkSession): Unit = { - - val tfWrapper = - readTensorflowModel(path, spark, "_roberta_classification_tf", initAllTables = false) - instance.setModelIfNotSet(spark, Some(tfWrapper), None) + def readModel(instance: ZeroShotNerModel, path: String, spark: SparkSession): Unit = { + instance.getEngine match { + case TensorFlow.name => { + val tfWrapper = readTensorflowModel(path, spark, "_roberta_classification_tf", initAllTables = false) + instance.setModelIfNotSet(spark, Some(tfWrapper), None) + } + case ONNX.name => { + val onnxWrapper = readOnnxModel( + path, + spark, + "_roberta_classification_onnx", + zipped = true, + useBundle = false, + None) + instance.setModelIfNotSet(spark, None, Some(onnxWrapper)) + } + case _ => + throw new Exception(notSupportedEngineError) + } } - addReader(readTensorflow) + addReader(readModel) } object ZeroShotNerModel extends ReadablePretrainedZeroShotNer with ReadZeroShotNerDLModel { @@ -487,7 +504,12 @@ object ZeroShotNerModel extends ReadablePretrainedZeroShotNer with ReadZeroShotN newModel.setSignatures( model.signatures.get.getOrElse(throw new RuntimeException("Signatures not set"))) - newModel.setModelIfNotSet(spark, model.getModelIfNotSet.tensorflowWrapper, None) + model.getEngine match { + case TensorFlow.name => + newModel.setModelIfNotSet(spark, model.getModelIfNotSet.tensorflowWrapper, None) + case ONNX.name => + newModel.setModelIfNotSet(spark, None, model.getModelIfNotSet.onnxWrapper) + } model .extractParamMap()