Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARKNLP-994] Solves ZeroShotNerClassification Issue #14186

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ val ner_converter = new NerConverter()
.setInputCols(Array("sentence", "token", "zero_shot_ner"))
.setOutputCol("ner_chunk")

val pipeline = new .setStages(Array(
val pipeline = new Pipeline().setStages(Array(
documentAssembler,
sentenceDetector,
tokenizer,
zero_shot_ner,
ner_converter))

val data = Seq(Array("Hellen works in London, Paris and Berlin. My name is Clara, I live in New York and Hellen lives in Paris.",
"John is a man who works in London, London and London.")toDS().toDF("text")
val data = Seq("Hellen works in London, Paris and Berlin. My name is Clara, I live in New York and Hellen lives in Paris.",
"John is a man who works in London, London and London.").toDS.toDF("text")

val result = pipeline.fit(data).transform(data)
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,11 +352,11 @@ private[johnsnowlabs] class RoBertaClassification(

val endDim = endLogits.length / batchLength
val endScores: Array[Array[Float]] =
endLogits.grouped(endDim).toArray
endLogits.grouped(endDim).map(scores => calculateSoftmax(scores)).toArray

val startDim = startLogits.length / batchLength
val startScores: Array[Array[Float]] =
startLogits.grouped(startDim).toArray
startLogits.grouped(startDim).map(scores => calculateSoftmax(scores)).toArray

(startScores, endScores)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ private[johnsnowlabs] class ZeroShotNerClassification(

val allTokenPieces =
wordPieceTokenizedQuestion.head.tokens ++ wordPieceTokenizedContext.flatMap(x => x.tokens)
val decodedAnswer = allTokenPieces.slice(startIndex._2 - 2, endIndex._2 - 1)
val decodedAnswer = allTokenPieces.slice(startIndex._2 - 3, endIndex._2 - 2)
// Check if the answer span starts at the CLS symbol 0 - if so return empty string
val content =
if (startIndex._2 > 0)
Expand Down Expand Up @@ -141,4 +141,5 @@ private[johnsnowlabs] class ZeroShotNerClassification(
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
package com.johnsnowlabs.nlp.annotators.ner.dl

import com.johnsnowlabs.ml.ai.{RoBertaClassification, ZeroShotNerClassification}
import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel}
import com.johnsnowlabs.ml.tensorflow.{ReadTensorflowModel, TensorflowWrapper}
import com.johnsnowlabs.ml.util.LoadExternalModel.notSupportedEngineError
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, NAMED_ENTITY, TOKEN}
import com.johnsnowlabs.nlp.annotator.RoBertaForQuestionAnswering
import com.johnsnowlabs.nlp.pretrained.ResourceDownloader
Expand Down Expand Up @@ -448,19 +450,34 @@ trait ReadablePretrainedZeroShotNer
}
}

trait ReadZeroShotNerDLModel extends ReadTensorflowModel {
trait ReadZeroShotNerDLModel extends ReadTensorflowModel with ReadOnnxModel {
this: ParamsAndFeaturesReadable[ZeroShotNerModel] =>

override val tfFile: String = "roberta_classification_tensorflow"
override val onnxFile: String = "roberta_classification_onnx"

def readTensorflow(instance: ZeroShotNerModel, path: String, spark: SparkSession): Unit = {

val tfWrapper =
readTensorflowModel(path, spark, "_roberta_classification_tf", initAllTables = false)
instance.setModelIfNotSet(spark, Some(tfWrapper), None)
def readModel(instance: ZeroShotNerModel, path: String, spark: SparkSession): Unit = {
instance.getEngine match {
case TensorFlow.name => {
val tfWrapper = readTensorflowModel(path, spark, "_roberta_classification_tf", initAllTables = false)
instance.setModelIfNotSet(spark, Some(tfWrapper), None)
}
case ONNX.name => {
val onnxWrapper = readOnnxModel(
path,
spark,
"_roberta_classification_onnx",
zipped = true,
useBundle = false,
None)
instance.setModelIfNotSet(spark, None, Some(onnxWrapper))
}
case _ =>
throw new Exception(notSupportedEngineError)
}
}

addReader(readTensorflow)
addReader(readModel)
}
object ZeroShotNerModel extends ReadablePretrainedZeroShotNer with ReadZeroShotNerDLModel {

Expand All @@ -487,7 +504,12 @@ object ZeroShotNerModel extends ReadablePretrainedZeroShotNer with ReadZeroShotN
newModel.setSignatures(
model.signatures.get.getOrElse(throw new RuntimeException("Signatures not set")))

newModel.setModelIfNotSet(spark, model.getModelIfNotSet.tensorflowWrapper, None)
model.getEngine match {
case TensorFlow.name =>
newModel.setModelIfNotSet(spark, model.getModelIfNotSet.tensorflowWrapper, None)
case ONNX.name =>
newModel.setModelIfNotSet(spark, None, model.getModelIfNotSet.onnxWrapper)
}

model
.extractParamMap()
Expand Down
Loading