Skip to content

Commit

Permalink
adding import notebook + changing default model + adding onnx support (
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmedlone127 authored Feb 6, 2024
1 parent db55524 commit 37c4df2
Show file tree
Hide file tree
Showing 6 changed files with 2,653 additions and 26 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class BertForZeroShotClassification(AnnotatorModel,
... .setInputCols(["token", "document"]) \\
... .setOutputCol("label")
The default model is ``"bert_base_cased_zero_shot_classifier_xnli"``, if no name is
The default model is ``"bert_zero_shot_classifier_mnli"``, if no name is
provided.
For available pretrained models please see the `Models Hub
Expand Down Expand Up @@ -189,14 +189,14 @@ def loadSavedModel(folder, spark_session):
return BertForZeroShotClassification(java_model=jModel)

@staticmethod
def pretrained(name="bert_base_cased_zero_shot_classifier_xnli", lang="en", remote_loc=None):
def pretrained(name="bert_zero_shot_classifier_mnli", lang="xx", remote_loc=None):
"""Downloads and loads a pretrained model.
Parameters
----------
name : str, optional
Name of the pretrained model, by default
"bert_base_cased_zero_shot_classifier_xnli"
"bert_zero_shot_classifier_mnli"
lang : str, optional
Language of the pretrained model, by default "en"
remote_loc : str, optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def setUp(self):
.toDF("text")

self.tested_annotator = BertForZeroShotClassification \
.pretrained("bert_base_cased_zero_shot_classifier_xnli") \
.pretrained() \
.setInputCols(["document", "token"]) \
.setOutputCol("class") \
.setCandidateLabels(["urgent", "mobile", "travel", "movie", "music", "sport", "weather", "technology"])
Expand Down
78 changes: 72 additions & 6 deletions src/main/scala/com/johnsnowlabs/ml/ai/BertClassification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ private[johnsnowlabs] class BertClassification(
embeddings
} finally if (results != null) results.close()
}

}

def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = {
Expand Down Expand Up @@ -284,14 +285,62 @@ private[johnsnowlabs] class BertClassification(
batchScores
}

def tagZeroShotSequence(
def computeZeroShotLogitsWithONNX(
batch: Seq[Array[Int]],
entailmentId: Int,
contradictionId: Int,
activation: String): Array[Array[Float]] = {
val tensors = new TensorResources()
maxSentenceLength: Int): Array[Float] = {

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions)

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val segmentTensors =
OnnxTensor.createTensor(
env,
batch
.map(sentence =>
sentence.indices
.map(i =>
if (i < sentence.indexOf(sentenceEndTokenId)) 0L
else if (i == sentence.indexOf(sentenceEndTokenId)) 1L
else 1L)
.toArray)
.toArray)

val inputs =
Map(
"input_ids" -> tokenTensors,
"attention_mask" -> maskTensors,
"token_type_ids" -> segmentTensors).asJava

try {
val results = runner.run(inputs)
try {
val embeddings = results
.get("logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
tokenTensors.close()
maskTensors.close()
segmentTensors.close()

embeddings
} finally if (results != null) results.close()
}

}

def computeZeroShotLogitsWithTF(
batch: Seq[Array[Int]],
maxSentenceLength: Int): Array[Float] = {

val tensors = new TensorResources()
val batchLength = batch.length

val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
Expand Down Expand Up @@ -350,6 +399,23 @@ private[johnsnowlabs] class BertClassification(
tensors.clearSession(outs)
tensors.clearTensors()

rawScores
}

def tagZeroShotSequence(
batch: Seq[Array[Int]],
entailmentId: Int,
contradictionId: Int,
activation: String): Array[Array[Float]] = {

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val batchLength = batch.length

val rawScores = detectedEngine match {
case ONNX.name => computeZeroShotLogitsWithONNX(batch, maxSentenceLength)
case _ => computeZeroShotLogitsWithTF(batch, maxSentenceLength)
}

val dim = rawScores.length / batchLength
rawScores
.grouped(dim)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
package com.johnsnowlabs.nlp.annotators.classifier.dl

import com.johnsnowlabs.ml.ai.BertClassification
import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel}
import com.johnsnowlabs.ml.tensorflow._
import com.johnsnowlabs.ml.util.LoadExternalModel.{
loadTextAsset,
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.util.TensorFlow
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.serialization.MapFeature
Expand All @@ -50,7 +50,7 @@ import org.apache.spark.sql.SparkSession
* .setInputCols("token", "document")
* .setOutputCol("label")
* }}}
* The default model is `"bert_base_cased_zero_shot_classifier_xnli"`, if no name is provided.
* The default model is `"bert_zero_shot_classifier_mnli"`, if no name is provided.
*
* For available pretrained models please see the
* [[https://sparknlp.org/models?task=Text+Classification Models Hub]].
Expand Down Expand Up @@ -124,6 +124,7 @@ class BertForZeroShotClassification(override val uid: String)
extends AnnotatorModel[BertForZeroShotClassification]
with HasBatchedAnnotate[BertForZeroShotClassification]
with WriteTensorflowModel
with WriteOnnxModel
with HasCaseSensitiveProperties
with HasClassifierActivationProperties
with HasEngine
Expand Down Expand Up @@ -338,21 +339,34 @@ class BertForZeroShotClassification(override val uid: String)

override def onWrite(path: String, spark: SparkSession): Unit = {
super.onWrite(path, spark)
writeTensorflowModelV2(
path,
spark,
getModelIfNotSet.tensorflowWrapper.get,
"_bert_classification",
BertForZeroShotClassification.tfFile,
configProtoBytes = getConfigProtoBytes)
val suffix = "_bert_classification"

getEngine match {
case TensorFlow.name =>
writeTensorflowModelV2(
path,
spark,
getModelIfNotSet.tensorflowWrapper.get,
suffix,
BertForZeroShotClassification.tfFile,
configProtoBytes = getConfigProtoBytes)
case ONNX.name =>
writeOnnxModel(
path,
spark,
getModelIfNotSet.onnxWrapper.get,
suffix,
BertForZeroShotClassification.onnxFile)
}
}

}

trait ReadablePretrainedBertForZeroShotModel
extends ParamsAndFeaturesReadable[BertForZeroShotClassification]
with HasPretrained[BertForZeroShotClassification] {
override val defaultModelName: Some[String] = Some("bert_base_cased_zero_shot_classifier_xnli")
override val defaultModelName: Some[String] = Some("bert_zero_shot_classifier_mnli")
override val defaultLang: String = "xx"

/** Java compliant-overrides */
override def pretrained(): BertForZeroShotClassification = super.pretrained()
Expand All @@ -368,19 +382,29 @@ trait ReadablePretrainedBertForZeroShotModel
remoteLoc: String): BertForZeroShotClassification = super.pretrained(name, lang, remoteLoc)
}

trait ReadBertForZeroShotDLModel extends ReadTensorflowModel {
trait ReadBertForZeroShotDLModel extends ReadTensorflowModel with ReadOnnxModel {
this: ParamsAndFeaturesReadable[BertForZeroShotClassification] =>

override val tfFile: String = "bert_classification_tensorflow"
override val onnxFile: String = "bert_classification_onnx"

def readModel(
instance: BertForZeroShotClassification,
path: String,
spark: SparkSession): Unit = {

val tensorFlow =
readTensorflowModel(path, spark, "_bert_classification_tf", initAllTables = false)
instance.setModelIfNotSet(spark, Some(tensorFlow), None)
instance.getEngine match {
case TensorFlow.name =>
val tensorFlow =
readTensorflowModel(path, spark, "_bert_classification_tf", initAllTables = false)
instance.setModelIfNotSet(spark, Some(tensorFlow), None)
case ONNX.name =>
val onnxWrapper =
readOnnxModel(path, spark, "_bert_classification_onnx")
instance.setModelIfNotSet(spark, None, Some(onnxWrapper))
case _ =>
throw new Exception(notSupportedEngineError)
}
}

addReader(readModel)
Expand Down Expand Up @@ -437,6 +461,11 @@ trait ReadBertForZeroShotDLModel extends ReadTensorflowModel {
.setSignatures(_signatures)
.setModelIfNotSet(spark, Some(wrapper), None)

case ONNX.name =>
val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true)
annotatorModel
.setModelIfNotSet(spark, None, Some(onnxWrapper))

case _ =>
throw new Exception(notSupportedEngineError)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class BertForZeroShotClassificationTestSpec extends AnyFlatSpec {
val candidateLabels =
Array("urgent", "mobile", "travel", "movie", "music", "sport", "weather", "technology")

"BertForSBertForZeroShotClassification" should "correctly load custom model with extracted signatures" taggedAs SlowTest in {
"BertForZeroShotClassification" should "correctly load custom model with extracted signatures" taggedAs SlowTest in {

val ddd = Seq(
"I have a problem with my iphone that needs to be resolved asap!!",
Expand Down

0 comments on commit 37c4df2

Please sign in to comment.