Skip to content

Commit

Permalink
[SPARKNLP-994] Adding changes required to load ONNX ZeroShotNerClassi…
Browse files Browse the repository at this point in the history
…fication and fix predictions issue (#14186)
  • Loading branch information
danilojsl authored Mar 3, 2024
1 parent 8877454 commit 901c884
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ val ner_converter = new NerConverter()
.setInputCols(Array("sentence", "token", "zero_shot_ner"))
.setOutputCol("ner_chunk")

val pipeline = new .setStages(Array(
val pipeline = new Pipeline().setStages(Array(
documentAssembler,
sentenceDetector,
tokenizer,
zero_shot_ner,
ner_converter))

val data = Seq(Array("Hellen works in London, Paris and Berlin. My name is Clara, I live in New York and Hellen lives in Paris.",
"John is a man who works in London, London and London.")toDS().toDF("text")
val data = Seq("Hellen works in London, Paris and Berlin. My name is Clara, I live in New York and Hellen lives in Paris.",
"John is a man who works in London, London and London.").toDS.toDF("text")

val result = pipeline.fit(data).transform(data)
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,11 +352,11 @@ private[johnsnowlabs] class RoBertaClassification(

val endDim = endLogits.length / batchLength
val endScores: Array[Array[Float]] =
endLogits.grouped(endDim).toArray
endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray

val startDim = startLogits.length / batchLength
val startScores: Array[Array[Float]] =
startLogits.grouped(startDim).toArray
startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray

(startScores, endScores)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ private[johnsnowlabs] class ZeroShotNerClassification(

val allTokenPieces =
wordPieceTokenizedQuestion.head.tokens ++ wordPieceTokenizedContext.flatMap(x => x.tokens)
val decodedAnswer = allTokenPieces.slice(startIndex._2 - 2, endIndex._2 - 1)
val decodedAnswer = allTokenPieces.slice(startIndex._2 - 3, endIndex._2 - 2)
// Check if the answer span starts at the CLS symbol 0 - if so return empty string
val content =
if (startIndex._2 > 0)
Expand Down Expand Up @@ -141,4 +141,5 @@ private[johnsnowlabs] class ZeroShotNerClassification(
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
package com.johnsnowlabs.nlp.annotators.ner.dl

import com.johnsnowlabs.ml.ai.{RoBertaClassification, ZeroShotNerClassification}
import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel}
import com.johnsnowlabs.ml.tensorflow.{ReadTensorflowModel, TensorflowWrapper}
import com.johnsnowlabs.ml.util.LoadExternalModel.notSupportedEngineError
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, NAMED_ENTITY, TOKEN}
import com.johnsnowlabs.nlp.annotator.RoBertaForQuestionAnswering
import com.johnsnowlabs.nlp.pretrained.ResourceDownloader
Expand Down Expand Up @@ -448,19 +450,34 @@ trait ReadablePretrainedZeroShotNer
}
}

trait ReadZeroShotNerDLModel extends ReadTensorflowModel {
trait ReadZeroShotNerDLModel extends ReadTensorflowModel with ReadOnnxModel {
this: ParamsAndFeaturesReadable[ZeroShotNerModel] =>

override val tfFile: String = "roberta_classification_tensorflow"
override val onnxFile: String = "roberta_classification_onnx"

def readTensorflow(instance: ZeroShotNerModel, path: String, spark: SparkSession): Unit = {

val tfWrapper =
readTensorflowModel(path, spark, "_roberta_classification_tf", initAllTables = false)
instance.setModelIfNotSet(spark, Some(tfWrapper), None)
def readModel(instance: ZeroShotNerModel, path: String, spark: SparkSession): Unit = {
instance.getEngine match {
case TensorFlow.name => {
val tfWrapper = readTensorflowModel(path, spark, "_roberta_classification_tf", initAllTables = false)
instance.setModelIfNotSet(spark, Some(tfWrapper), None)
}
case ONNX.name => {
val onnxWrapper = readOnnxModel(
path,
spark,
"_roberta_classification_onnx",
zipped = true,
useBundle = false,
None)
instance.setModelIfNotSet(spark, None, Some(onnxWrapper))
}
case _ =>
throw new Exception(notSupportedEngineError)
}
}

addReader(readTensorflow)
addReader(readModel)
}
object ZeroShotNerModel extends ReadablePretrainedZeroShotNer with ReadZeroShotNerDLModel {

Expand All @@ -487,7 +504,12 @@ object ZeroShotNerModel extends ReadablePretrainedZeroShotNer with ReadZeroShotN
newModel.setSignatures(
model.signatures.get.getOrElse(throw new RuntimeException("Signatures not set")))

newModel.setModelIfNotSet(spark, model.getModelIfNotSet.tensorflowWrapper, None)
model.getEngine match {
case TensorFlow.name =>
newModel.setModelIfNotSet(spark, model.getModelIfNotSet.tensorflowWrapper, None)
case ONNX.name =>
newModel.setModelIfNotSet(spark, None, model.getModelIfNotSet.onnxWrapper)
}

model
.extractParamMap()
Expand Down

0 comments on commit 901c884

Please sign in to comment.