Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding import notebook + changing default model + adding onnx support #14158

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class BertForZeroShotClassification(AnnotatorModel,
... .setInputCols(["token", "document"]) \\
... .setOutputCol("label")

The default model is ``"bert_base_cased_zero_shot_classifier_xnli"``, if no name is
The default model is ``"bert_zero_shot_classifier_mnli"``, if no name is
provided.

For available pretrained models please see the `Models Hub
Expand Down Expand Up @@ -189,14 +189,14 @@ def loadSavedModel(folder, spark_session):
return BertForZeroShotClassification(java_model=jModel)

@staticmethod
def pretrained(name="bert_base_cased_zero_shot_classifier_xnli", lang="en", remote_loc=None):
def pretrained(name="bert_zero_shot_classifier_mnli", lang="xx", remote_loc=None):
"""Downloads and loads a pretrained model.

Parameters
----------
name : str, optional
Name of the pretrained model, by default
"bert_base_cased_zero_shot_classifier_xnli"
"bert_zero_shot_classifier_mnli"
lang : str, optional
Language of the pretrained model, by default "en"
remote_loc : str, optional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def setUp(self):
.toDF("text")

self.tested_annotator = BertForZeroShotClassification \
.pretrained("bert_base_cased_zero_shot_classifier_xnli") \
.pretrained() \
.setInputCols(["document", "token"]) \
.setOutputCol("class") \
.setCandidateLabels(["urgent", "mobile", "travel", "movie", "music", "sport", "weather", "technology"])
Expand Down
78 changes: 72 additions & 6 deletions src/main/scala/com/johnsnowlabs/ml/ai/BertClassification.scala
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ private[johnsnowlabs] class BertClassification(
embeddings
} finally if (results != null) results.close()
}

}

def tagSequence(batch: Seq[Array[Int]], activation: String): Array[Array[Float]] = {
Expand Down Expand Up @@ -284,14 +285,62 @@ private[johnsnowlabs] class BertClassification(
batchScores
}

def tagZeroShotSequence(
def computeZeroShotLogitsWithONNX(
batch: Seq[Array[Int]],
entailmentId: Int,
contradictionId: Int,
activation: String): Array[Array[Float]] = {
val tensors = new TensorResources()
maxSentenceLength: Int): Array[Float] = {

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val (runner, env) = onnxWrapper.get.getSession(onnxSessionOptions)

val tokenTensors =
OnnxTensor.createTensor(env, batch.map(x => x.map(x => x.toLong)).toArray)
val maskTensors =
OnnxTensor.createTensor(
env,
batch.map(sentence => sentence.map(x => if (x == 0L) 0L else 1L)).toArray)

val segmentTensors =
OnnxTensor.createTensor(
env,
batch
.map(sentence =>
sentence.indices
.map(i =>
if (i < sentence.indexOf(sentenceEndTokenId)) 0L
else if (i == sentence.indexOf(sentenceEndTokenId)) 1L
else 1L)
.toArray)
.toArray)

val inputs =
Map(
"input_ids" -> tokenTensors,
"attention_mask" -> maskTensors,
"token_type_ids" -> segmentTensors).asJava

try {
val results = runner.run(inputs)
try {
val embeddings = results
.get("logits")
.get()
.asInstanceOf[OnnxTensor]
.getFloatBuffer
.array()
tokenTensors.close()
maskTensors.close()
segmentTensors.close()

embeddings
} finally if (results != null) results.close()
}

}

def computeZeroShotLogitsWithTF(
batch: Seq[Array[Int]],
maxSentenceLength: Int): Array[Float] = {

val tensors = new TensorResources()
val batchLength = batch.length

val tokenBuffers: IntDataBuffer = tensors.createIntBuffer(batchLength * maxSentenceLength)
Expand Down Expand Up @@ -350,6 +399,23 @@ private[johnsnowlabs] class BertClassification(
tensors.clearSession(outs)
tensors.clearTensors()

rawScores
}

def tagZeroShotSequence(
batch: Seq[Array[Int]],
entailmentId: Int,
contradictionId: Int,
activation: String): Array[Array[Float]] = {

val maxSentenceLength = batch.map(encodedSentence => encodedSentence.length).max
val batchLength = batch.length

val rawScores = detectedEngine match {
case ONNX.name => computeZeroShotLogitsWithONNX(batch, maxSentenceLength)
case _ => computeZeroShotLogitsWithTF(batch, maxSentenceLength)
}

val dim = rawScores.length / batchLength
rawScores
.grouped(dim)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
package com.johnsnowlabs.nlp.annotators.classifier.dl

import com.johnsnowlabs.ml.ai.BertClassification
import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel}
import com.johnsnowlabs.ml.tensorflow._
import com.johnsnowlabs.ml.util.LoadExternalModel.{
loadTextAsset,
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.util.TensorFlow
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.serialization.MapFeature
Expand All @@ -50,7 +50,7 @@ import org.apache.spark.sql.SparkSession
* .setInputCols("token", "document")
* .setOutputCol("label")
* }}}
* The default model is `"bert_base_cased_zero_shot_classifier_xnli"`, if no name is provided.
* The default model is `"bert_zero_shot_classifier_mnli"`, if no name is provided.
*
* For available pretrained models please see the
* [[https://sparknlp.org/models?task=Text+Classification Models Hub]].
Expand Down Expand Up @@ -124,6 +124,7 @@ class BertForZeroShotClassification(override val uid: String)
extends AnnotatorModel[BertForZeroShotClassification]
with HasBatchedAnnotate[BertForZeroShotClassification]
with WriteTensorflowModel
with WriteOnnxModel
with HasCaseSensitiveProperties
with HasClassifierActivationProperties
with HasEngine
Expand Down Expand Up @@ -338,21 +339,34 @@ class BertForZeroShotClassification(override val uid: String)

override def onWrite(path: String, spark: SparkSession): Unit = {
super.onWrite(path, spark)
writeTensorflowModelV2(
path,
spark,
getModelIfNotSet.tensorflowWrapper.get,
"_bert_classification",
BertForZeroShotClassification.tfFile,
configProtoBytes = getConfigProtoBytes)
val suffix = "_bert_classification"

getEngine match {
case TensorFlow.name =>
writeTensorflowModelV2(
path,
spark,
getModelIfNotSet.tensorflowWrapper.get,
suffix,
BertForZeroShotClassification.tfFile,
configProtoBytes = getConfigProtoBytes)
case ONNX.name =>
writeOnnxModel(
path,
spark,
getModelIfNotSet.onnxWrapper.get,
suffix,
BertForZeroShotClassification.onnxFile)
}
}

}

trait ReadablePretrainedBertForZeroShotModel
extends ParamsAndFeaturesReadable[BertForZeroShotClassification]
with HasPretrained[BertForZeroShotClassification] {
override val defaultModelName: Some[String] = Some("bert_base_cased_zero_shot_classifier_xnli")
override val defaultModelName: Some[String] = Some("bert_zero_shot_classifier_mnli")
override val defaultLang: String = "xx"

/** Java compliant-overrides */
override def pretrained(): BertForZeroShotClassification = super.pretrained()
Expand All @@ -368,19 +382,29 @@ trait ReadablePretrainedBertForZeroShotModel
remoteLoc: String): BertForZeroShotClassification = super.pretrained(name, lang, remoteLoc)
}

trait ReadBertForZeroShotDLModel extends ReadTensorflowModel {
trait ReadBertForZeroShotDLModel extends ReadTensorflowModel with ReadOnnxModel {
this: ParamsAndFeaturesReadable[BertForZeroShotClassification] =>

override val tfFile: String = "bert_classification_tensorflow"
override val onnxFile: String = "bert_classification_onnx"

def readModel(
instance: BertForZeroShotClassification,
path: String,
spark: SparkSession): Unit = {

val tensorFlow =
readTensorflowModel(path, spark, "_bert_classification_tf", initAllTables = false)
instance.setModelIfNotSet(spark, Some(tensorFlow), None)
instance.getEngine match {
case TensorFlow.name =>
val tensorFlow =
readTensorflowModel(path, spark, "_bert_classification_tf", initAllTables = false)
instance.setModelIfNotSet(spark, Some(tensorFlow), None)
case ONNX.name =>
val onnxWrapper =
readOnnxModel(path, spark, "_bert_classification_onnx")
instance.setModelIfNotSet(spark, None, Some(onnxWrapper))
case _ =>
throw new Exception(notSupportedEngineError)
}
}

addReader(readModel)
Expand Down Expand Up @@ -437,6 +461,11 @@ trait ReadBertForZeroShotDLModel extends ReadTensorflowModel {
.setSignatures(_signatures)
.setModelIfNotSet(spark, Some(wrapper), None)

case ONNX.name =>
val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true)
annotatorModel
.setModelIfNotSet(spark, None, Some(onnxWrapper))

case _ =>
throw new Exception(notSupportedEngineError)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class BertForZeroShotClassificationTestSpec extends AnyFlatSpec {
val candidateLabels =
Array("urgent", "mobile", "travel", "movie", "music", "sport", "weather", "technology")

"BertForSBertForZeroShotClassification" should "correctly load custom model with extracted signatures" taggedAs SlowTest in {
"BertForZeroShotClassification" should "correctly load custom model with extracted signatures" taggedAs SlowTest in {

val ddd = Seq(
"I have a problem with my iphone that needs to be resolved asap!!",
Expand Down
Loading