Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ADT for ModelEngine #13862

Merged
merged 1 commit into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import com.johnsnowlabs.ml.ai.util.PrepareEmbeddings
import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.{ModelEngine, ModelArch}
import com.johnsnowlabs.ml.util.{ModelArch, ModelEngine, ONNX, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}

Expand Down Expand Up @@ -63,9 +63,9 @@ private[johnsnowlabs] class Bert(

val _tfBertSignatures: Map[String, String] = signatures.getOrElse(ModelSignatureManager.apply())
val detectedEngine: String =
if (tensorflowWrapper.isDefined) ModelEngine.tensorflow
else if (onnxWrapper.isDefined) ModelEngine.onnx
else ModelEngine.tensorflow
if (tensorflowWrapper.isDefined) TensorFlow.name
else if (onnxWrapper.isDefined) ONNX.name
else TensorFlow.name

private def sessionWarmup(): Unit = {
val dummyInput =
Expand All @@ -88,7 +88,7 @@ private[johnsnowlabs] class Bert(

val embeddings = detectedEngine match {

case ModelEngine.onnx =>
case ONNX.name =>
// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()

Expand Down Expand Up @@ -191,7 +191,7 @@ private[johnsnowlabs] class Bert(
val batchLength = batch.length

val embeddings = detectedEngine match {
case ModelEngine.onnx =>
case ONNX.name =>
// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()

Expand Down
10 changes: 5 additions & 5 deletions src/main/scala/com/johnsnowlabs/ml/ai/DeBerta.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import com.johnsnowlabs.ml.onnx.OnnxWrapper
import com.johnsnowlabs.ml.tensorflow.sentencepiece._
import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager}
import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper}
import com.johnsnowlabs.ml.util.ModelEngine
import com.johnsnowlabs.ml.util.{ONNX, TensorFlow}
import com.johnsnowlabs.nlp.annotators.common._

import scala.collection.JavaConverters._
Expand All @@ -49,9 +49,9 @@ class DeBerta(
signatures.getOrElse(ModelSignatureManager.apply())

val detectedEngine: String =
if (tensorflowWrapper.isDefined) ModelEngine.tensorflow
else if (onnxWrapper.isDefined) ModelEngine.onnx
else ModelEngine.tensorflow
if (tensorflowWrapper.isDefined) TensorFlow.name
else if (onnxWrapper.isDefined) ONNX.name
else TensorFlow.name

// keys representing the input and output tensors of the DeBERTa model
private val SentenceStartTokenId = spp.getSppModel.pieceToId("[CLS]")
Expand All @@ -66,7 +66,7 @@ class DeBerta(

val embeddings = detectedEngine match {

case ModelEngine.onnx =>
case ONNX.name =>
// [nb of encoded sentences , maxSentenceLength]
val (runner, env) = onnxWrapper.get.getSession()

Expand Down
14 changes: 7 additions & 7 deletions src/main/scala/com/johnsnowlabs/ml/util/LoadExternalModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,19 @@ object LoadExternalModel {
}

def isTensorFlowModel(modelPath: String): Boolean = {
val tfSavedModel = new File(modelPath, ModelEngine.tensorflowModelName)
val tfSavedModel = new File(modelPath, TensorFlow.modelName)
tfSavedModel.exists()

}

def isOnnxModel(modelPath: String, isEncoderDecoder: Boolean = false): Boolean = {

if (isEncoderDecoder) {
val onnxEncoderModel = new File(modelPath, ModelEngine.onnxEncoderModel)
val onnxDecoderModel = new File(modelPath, ModelEngine.onnxDecoderModel)
val onnxEncoderModel = new File(modelPath, ONNX.encoderModel)
val onnxDecoderModel = new File(modelPath, ONNX.decoderModel)
onnxEncoderModel.exists() && onnxDecoderModel.exists()
} else {
val onnxModel = new File(modelPath, ModelEngine.onnxModelName)
val onnxModel = new File(modelPath, ONNX.modelName)
onnxModel.exists()
}

Expand All @@ -94,12 +94,12 @@ object LoadExternalModel {
val onnxModelExist = isOnnxModel(modelPath, isEncoderDecoder)

if (tfSavedModelExist) {
ModelEngine.tensorflow
TensorFlow.name
} else if (onnxModelExist) {
ModelEngine.onnx
ONNX.name
} else {
require(tfSavedModelExist || onnxModelExist, notSupportedEngineError)
ModelEngine.unk
Unknown.name
}

}
Expand Down
31 changes: 21 additions & 10 deletions src/main/scala/com/johnsnowlabs/ml/util/ModelEngine.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2017-2022 John Snow Labs
* Copyright 2017-2023 John Snow Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,13 +16,24 @@

package com.johnsnowlabs.ml.util

object ModelEngine {
val tensorflow = "tensorflow"
val tensorflowModelName = "saved_model.pb"
val onnx = "onnx"
val onnxModelName = "model.onnx"
val onnxEncoderModel = "encoder_model.onnx"
val onnxDecoderModel = "decoder_model.onnx"
val onnxDecoderWithPastModel = "decoder_with_past_model.onnx"
val unk = "unk"
sealed trait ModelEngine

final case object TensorFlow extends ModelEngine {
val name = "tensorflow"
val modelName = "saved_model.pb"
}
final case object PyTorch extends ModelEngine {
val name = "pytorch"
}

final case object ONNX extends ModelEngine {
val name = "onnx"
val modelName = "model.onnx"
val encoderModel = "encoder_model.onnx"
val decoderModel = "decoder_model.onnx"
val decoderWithPastModel = "decoder_with_past_model.onnx"
}

final case object Unknown extends ModelEngine {
val name = "unk"
}
4 changes: 2 additions & 2 deletions src/main/scala/com/johnsnowlabs/nlp/HasEngine.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.johnsnowlabs.nlp

import com.johnsnowlabs.ml.util.ModelEngine
import com.johnsnowlabs.ml.util.TensorFlow
import org.apache.spark.ml.param.Param

trait HasEngine extends ParamsAndFeaturesWritable {
Expand All @@ -27,7 +27,7 @@ trait HasEngine extends ParamsAndFeaturesWritable {
*/
val engine = new Param[String](this, "engine", "Deep Learning engine used for this model")

setDefault(engine, ModelEngine.tensorflow)
setDefault(engine, TensorFlow.name)

/** @group getParam */
def getEngine: String = $(engine)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.util.ModelEngine
import com.johnsnowlabs.ml.util.TensorFlow
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.audio.feature_extractor.Preprocessor
import org.apache.spark.ml.util.Identifiable
Expand Down Expand Up @@ -213,7 +213,7 @@ trait ReadHubertForAudioDLModel extends ReadTensorflowModel {
annotatorModel.set(annotatorModel.engine, detectedEngine)

detectedEngine match {
case ModelEngine.tensorflow =>
case TensorFlow.name =>
val (wrapper, signatures) =
TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.util.ModelEngine
import com.johnsnowlabs.ml.util.TensorFlow
import com.johnsnowlabs.nlp.AnnotatorType.{AUDIO, DOCUMENT}
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.audio.feature_extractor.Preprocessor
Expand Down Expand Up @@ -340,7 +340,7 @@ trait ReadWav2Vec2ForAudioDLModel extends ReadTensorflowModel {
annotatorModel.set(annotatorModel.engine, detectedEngine)

detectedEngine match {
case ModelEngine.tensorflow =>
case TensorFlow.name =>
val (wrapper, signatures) =
TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.util.ModelEngine
import com.johnsnowlabs.ml.util.TensorFlow
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.serialization.MapFeature
import org.apache.spark.broadcast.Broadcast
Expand Down Expand Up @@ -317,7 +317,7 @@ trait ReadAlbertForQuestionAnsweringDLModel
annotatorModel.set(annotatorModel.engine, detectedEngine)

detectedEngine match {
case ModelEngine.tensorflow =>
case TensorFlow.name =>
val (wrapper, signatures) =
TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.util.ModelEngine
import com.johnsnowlabs.ml.util.TensorFlow
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.serialization.MapFeature
Expand Down Expand Up @@ -372,7 +372,7 @@ trait ReadAlbertForSequenceDLModel extends ReadTensorflowModel with ReadSentence
annotatorModel.set(annotatorModel.engine, detectedEngine)

detectedEngine match {
case ModelEngine.tensorflow =>
case TensorFlow.name =>
val (wrapper, signatures) =
TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.util.ModelEngine
import com.johnsnowlabs.ml.util.{ModelEngine, TensorFlow}
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.serialization.MapFeature
Expand Down Expand Up @@ -343,7 +343,7 @@ trait ReadAlbertForTokenDLModel extends ReadTensorflowModel with ReadSentencePie
annotatorModel.set(annotatorModel.engine, detectedEngine)

detectedEngine match {
case ModelEngine.tensorflow =>
case TensorFlow.name =>
val (wrapper, signatures) =
TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.util.ModelEngine
import com.johnsnowlabs.ml.util.{ModelEngine, TensorFlow}
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.serialization.MapFeature
import org.apache.spark.broadcast.Broadcast
Expand Down Expand Up @@ -325,7 +325,7 @@ trait ReadBertForQuestionAnsweringDLModel extends ReadTensorflowModel {
annotatorModel.set(annotatorModel.engine, detectedEngine)

detectedEngine match {
case ModelEngine.tensorflow =>
case TensorFlow.name =>
val (wrapper, signatures) =
TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.util.ModelEngine
import com.johnsnowlabs.ml.util.{ModelEngine, TensorFlow}
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.serialization.MapFeature
Expand Down Expand Up @@ -383,7 +383,7 @@ trait ReadBertForSequenceDLModel extends ReadTensorflowModel {
annotatorModel.set(annotatorModel.engine, detectedEngine)

detectedEngine match {
case ModelEngine.tensorflow =>
case TensorFlow.name =>
val (wrapper, signatures) =
TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.util.ModelEngine
import com.johnsnowlabs.ml.util.TensorFlow
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.serialization.MapFeature
Expand Down Expand Up @@ -346,7 +346,7 @@ trait ReadBertForTokenDLModel extends ReadTensorflowModel {
annotatorModel.set(annotatorModel.engine, detectedEngine)

detectedEngine match {
case ModelEngine.tensorflow =>
case TensorFlow.name =>
val (wrapper, signatures) =
TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,19 @@ package com.johnsnowlabs.nlp.annotators.classifier.dl
import com.johnsnowlabs.ml.ai.BertClassification
import com.johnsnowlabs.ml.tensorflow._
import com.johnsnowlabs.ml.util.LoadExternalModel.{
loadSentencePieceAsset,
loadTextAsset,
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.util.ModelEngine
import com.johnsnowlabs.ml.util.TensorFlow
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.serialization.MapFeature
import com.johnsnowlabs.nlp.util.io.{ExternalResource, ReadAs, ResourceHelper}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.param.{BooleanParam, IntArrayParam, IntParam}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.SparkSession

import java.io.File

/** BertForZeroShotClassification using a `ModelForSequenceClassification` trained on NLI (natural
* language inference) tasks. Equivalent of `BertForSequenceClassification` models, but these
* models don't require a hardcoded number of potential classes, they can be chosen at runtime.
Expand Down Expand Up @@ -421,7 +417,7 @@ trait ReadBertForZeroShotDLModel extends ReadTensorflowModel {
annotatorModel.set(annotatorModel.engine, detectedEngine)

detectedEngine match {
case ModelEngine.tensorflow =>
case TensorFlow.name =>
val (wrapper, signatures) =
TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.util.ModelEngine
import com.johnsnowlabs.ml.util.TensorFlow
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.serialization.MapFeature
import org.apache.spark.broadcast.Broadcast
Expand Down Expand Up @@ -323,7 +323,7 @@ trait ReadCamemBertForQADLModel extends ReadTensorflowModel with ReadSentencePie
annotatorModel.set(annotatorModel.engine, detectedEngine)

detectedEngine match {
case ModelEngine.tensorflow =>
case TensorFlow.name =>
val (wrapper, signatures) =
TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{
modelSanityCheck,
notSupportedEngineError
}
import com.johnsnowlabs.ml.util.ModelEngine
import com.johnsnowlabs.ml.util.TensorFlow
import com.johnsnowlabs.nlp._
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.serialization.MapFeature
Expand Down Expand Up @@ -378,7 +378,7 @@ trait ReadCamemBertForSequenceDLModel extends ReadTensorflowModel with ReadSente
annotatorModel.set(annotatorModel.engine, detectedEngine)

detectedEngine match {
case ModelEngine.tensorflow =>
case TensorFlow.name =>
val (wrapper, signatures) =
TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true)

Expand Down
Loading