diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala b/src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala index 5718f8b29de950..4dec8c80cabb4a 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/Bert.scala @@ -21,7 +21,7 @@ import com.johnsnowlabs.ml.ai.util.PrepareEmbeddings import com.johnsnowlabs.ml.onnx.OnnxWrapper import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} -import com.johnsnowlabs.ml.util.{ModelEngine, ModelArch} +import com.johnsnowlabs.ml.util.{ModelArch, ModelEngine, ONNX, TensorFlow} import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.{Annotation, AnnotatorType} @@ -63,9 +63,9 @@ private[johnsnowlabs] class Bert( val _tfBertSignatures: Map[String, String] = signatures.getOrElse(ModelSignatureManager.apply()) val detectedEngine: String = - if (tensorflowWrapper.isDefined) ModelEngine.tensorflow - else if (onnxWrapper.isDefined) ModelEngine.onnx - else ModelEngine.tensorflow + if (tensorflowWrapper.isDefined) TensorFlow.name + else if (onnxWrapper.isDefined) ONNX.name + else TensorFlow.name private def sessionWarmup(): Unit = { val dummyInput = @@ -88,7 +88,7 @@ private[johnsnowlabs] class Bert( val embeddings = detectedEngine match { - case ModelEngine.onnx => + case ONNX.name => // [nb of encoded sentences , maxSentenceLength] val (runner, env) = onnxWrapper.get.getSession() @@ -191,7 +191,7 @@ private[johnsnowlabs] class Bert( val batchLength = batch.length val embeddings = detectedEngine match { - case ModelEngine.onnx => + case ONNX.name => // [nb of encoded sentences , maxSentenceLength] val (runner, env) = onnxWrapper.get.getSession() diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/DeBerta.scala b/src/main/scala/com/johnsnowlabs/ml/ai/DeBerta.scala index 20450d9df3777a..bbf4ac83b1862b 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/DeBerta.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/DeBerta.scala @@ -22,7 +22,7 @@ import com.johnsnowlabs.ml.onnx.OnnxWrapper import com.johnsnowlabs.ml.tensorflow.sentencepiece._ import com.johnsnowlabs.ml.tensorflow.sign.{ModelSignatureConstants, ModelSignatureManager} import com.johnsnowlabs.ml.tensorflow.{TensorResources, TensorflowWrapper} -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.{ONNX, TensorFlow} import com.johnsnowlabs.nlp.annotators.common._ import scala.collection.JavaConverters._ @@ -49,9 +49,9 @@ class DeBerta( signatures.getOrElse(ModelSignatureManager.apply()) val detectedEngine: String = - if (tensorflowWrapper.isDefined) ModelEngine.tensorflow - else if (onnxWrapper.isDefined) ModelEngine.onnx - else ModelEngine.tensorflow + if (tensorflowWrapper.isDefined) TensorFlow.name + else if (onnxWrapper.isDefined) ONNX.name + else TensorFlow.name // keys representing the input and output tensors of the DeBERTa model private val SentenceStartTokenId = spp.getSppModel.pieceToId("[CLS]") @@ -66,7 +66,7 @@ class DeBerta( val embeddings = detectedEngine match { - case ModelEngine.onnx => + case ONNX.name => // [nb of encoded sentences , maxSentenceLength] val (runner, env) = onnxWrapper.get.getSession() diff --git a/src/main/scala/com/johnsnowlabs/ml/util/LoadExternalModel.scala b/src/main/scala/com/johnsnowlabs/ml/util/LoadExternalModel.scala index 17cc7e8505b5a2..58aff6825f0408 100644 --- a/src/main/scala/com/johnsnowlabs/ml/util/LoadExternalModel.scala +++ b/src/main/scala/com/johnsnowlabs/ml/util/LoadExternalModel.scala @@ -56,7 +56,7 @@ object LoadExternalModel { } def isTensorFlowModel(modelPath: String): Boolean = { - val tfSavedModel = new File(modelPath, ModelEngine.tensorflowModelName) + val tfSavedModel = new File(modelPath, TensorFlow.modelName) tfSavedModel.exists() } @@ -64,11 +64,11 @@ object LoadExternalModel { def isOnnxModel(modelPath: String, isEncoderDecoder: Boolean = false): Boolean = { if (isEncoderDecoder) { - val onnxEncoderModel = new File(modelPath, ModelEngine.onnxEncoderModel) - val onnxDecoderModel = new File(modelPath, ModelEngine.onnxDecoderModel) + val onnxEncoderModel = new File(modelPath, ONNX.encoderModel) + val onnxDecoderModel = new File(modelPath, ONNX.decoderModel) onnxEncoderModel.exists() && onnxDecoderModel.exists() } else { - val onnxModel = new File(modelPath, ModelEngine.onnxModelName) + val onnxModel = new File(modelPath, ONNX.modelName) onnxModel.exists() } @@ -94,12 +94,12 @@ object LoadExternalModel { val onnxModelExist = isOnnxModel(modelPath, isEncoderDecoder) if (tfSavedModelExist) { - ModelEngine.tensorflow + TensorFlow.name } else if (onnxModelExist) { - ModelEngine.onnx + ONNX.name } else { require(tfSavedModelExist || onnxModelExist, notSupportedEngineError) - ModelEngine.unk + Unknown.name } } diff --git a/src/main/scala/com/johnsnowlabs/ml/util/ModelEngine.scala b/src/main/scala/com/johnsnowlabs/ml/util/ModelEngine.scala index 9e8b93e9991219..061a42e7caa930 100644 --- a/src/main/scala/com/johnsnowlabs/ml/util/ModelEngine.scala +++ b/src/main/scala/com/johnsnowlabs/ml/util/ModelEngine.scala @@ -1,5 +1,5 @@ /* - * Copyright 2017-2022 John Snow Labs + * Copyright 2017-2023 John Snow Labs * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,13 +16,24 @@ package com.johnsnowlabs.ml.util -object ModelEngine { - val tensorflow = "tensorflow" - val tensorflowModelName = "saved_model.pb" - val onnx = "onnx" - val onnxModelName = "model.onnx" - val onnxEncoderModel = "encoder_model.onnx" - val onnxDecoderModel = "decoder_model.onnx" - val onnxDecoderWithPastModel = "decoder_with_past_model.onnx" - val unk = "unk" +sealed trait ModelEngine + +final case object TensorFlow extends ModelEngine { + val name = "tensorflow" + val modelName = "saved_model.pb" +} +final case object PyTorch extends ModelEngine { + val name = "pytorch" +} + +final case object ONNX extends ModelEngine { + val name = "onnx" + val modelName = "model.onnx" + val encoderModel = "encoder_model.onnx" + val decoderModel = "decoder_model.onnx" + val decoderWithPastModel = "decoder_with_past_model.onnx" +} + +final case object Unknown extends ModelEngine { + val name = "unk" } diff --git a/src/main/scala/com/johnsnowlabs/nlp/HasEngine.scala b/src/main/scala/com/johnsnowlabs/nlp/HasEngine.scala index 541d50b34afee5..39870b3073ce12 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/HasEngine.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/HasEngine.scala @@ -16,7 +16,7 @@ package com.johnsnowlabs.nlp -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import org.apache.spark.ml.param.Param trait HasEngine extends ParamsAndFeaturesWritable { @@ -27,7 +27,7 @@ trait HasEngine extends ParamsAndFeaturesWritable { */ val engine = new Param[String](this, "engine", "Deep Learning engine used for this model") - setDefault(engine, ModelEngine.tensorflow) + setDefault(engine, TensorFlow.name) /** @group getParam */ def getEngine: String = $(engine) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/HubertForCTC.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/HubertForCTC.scala index 989b3ee0634d30..520ffd0bda69ab 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/HubertForCTC.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/HubertForCTC.scala @@ -22,7 +22,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.audio.feature_extractor.Preprocessor import org.apache.spark.ml.util.Identifiable @@ -213,7 +213,7 @@ trait ReadHubertForAudioDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/Wav2Vec2ForCTC.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/Wav2Vec2ForCTC.scala index 4e51a3812f1a25..5927de36c587e9 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/Wav2Vec2ForCTC.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/Wav2Vec2ForCTC.scala @@ -27,7 +27,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp.AnnotatorType.{AUDIO, DOCUMENT} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.audio.feature_extractor.Preprocessor @@ -340,7 +340,7 @@ trait ReadWav2Vec2ForAudioDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala index 1d2026e7f1bd3a..217fbc6ca25947 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala @@ -28,7 +28,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast @@ -317,7 +317,7 @@ trait ReadAlbertForQuestionAnsweringDLModel annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala index 8e110c8460ec5a..f0d61bcaade650 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -372,7 +372,7 @@ trait ReadAlbertForSequenceDLModel extends ReadTensorflowModel with ReadSentence annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala index 4abbb18a6307f1..89e61223d63097 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.{ModelEngine, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -343,7 +343,7 @@ trait ReadAlbertForTokenDLModel extends ReadTensorflowModel with ReadSentencePie annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala index e8d17348c1b968..d48b40dcb65c08 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.{ModelEngine, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast @@ -325,7 +325,7 @@ trait ReadBertForQuestionAnsweringDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala index d873915c1e412e..ff0bb3aeb4676a 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala @@ -24,7 +24,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.{ModelEngine, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -383,7 +383,7 @@ trait ReadBertForSequenceDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala index c9062cd3d99b83..0c287de7d2cd64 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -346,7 +346,7 @@ trait ReadBertForTokenDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala index b0149ab660d9c8..6c6ddc35140d1a 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala @@ -19,23 +19,19 @@ package com.johnsnowlabs.nlp.annotators.classifier.dl import com.johnsnowlabs.ml.ai.BertClassification import com.johnsnowlabs.ml.tensorflow._ import com.johnsnowlabs.ml.util.LoadExternalModel.{ - loadSentencePieceAsset, loadTextAsset, modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature -import com.johnsnowlabs.nlp.util.io.{ExternalResource, ReadAs, ResourceHelper} import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.param.{BooleanParam, IntArrayParam, IntParam} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.SparkSession -import java.io.File - /** BertForZeroShotClassification using a `ModelForSequenceClassification` trained on NLI (natural * language inference) tasks. Equivalent of `BertForSequenceClassification` models, but these * models don't require a hardcoded number of potential classes, they can be chosen at runtime. @@ -421,7 +417,7 @@ trait ReadBertForZeroShotDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala index 784003488a1a83..e55e6adf4b6cb5 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala @@ -28,7 +28,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast @@ -323,7 +323,7 @@ trait ReadCamemBertForQADLModel extends ReadTensorflowModel with ReadSentencePie annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForSequenceClassification.scala index d96d8e59318e1e..9519af01f8a7ac 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForSequenceClassification.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -378,7 +378,7 @@ trait ReadCamemBertForSequenceDLModel extends ReadTensorflowModel with ReadSente annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassification.scala index 7b440341739223..275cd4bba61238 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassification.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -345,7 +345,7 @@ trait ReadCamemBertForTokenDLModel extends ReadTensorflowModel with ReadSentence annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala index 06e9c955d1f0f6..600b85da999a6d 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala @@ -28,7 +28,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast @@ -323,7 +323,7 @@ trait ReadDeBertaForQuestionAnsweringDLModel annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala index dae903e43e14df..0f025ebca7c367 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -374,7 +374,7 @@ trait ReadDeBertaForSequenceDLModel extends ReadTensorflowModel with ReadSentenc annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala index b09d9d5298bca9..81b3fdff7def4b 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -343,7 +343,7 @@ trait ReadDeBertaForTokenDLModel extends ReadTensorflowModel with ReadSentencePi annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForQuestionAnswering.scala index e950099b9e82fc..be3709d19b6279 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForQuestionAnswering.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast @@ -328,7 +328,7 @@ trait ReadDistilBertForQuestionAnsweringDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForSequenceClassification.scala index 4c2699cf848e28..aee25f66d01640 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForSequenceClassification.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -383,7 +383,7 @@ trait ReadDistilBertForSequenceDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForTokenClassification.scala index 53690d311104e2..20616a8303e7fc 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForTokenClassification.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -351,7 +351,7 @@ trait ReadDistilBertForTokenDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForZeroShotClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForZeroShotClassification.scala index 27f34509c867aa..b1afba431726d2 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForZeroShotClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForZeroShotClassification.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -423,7 +423,7 @@ trait ReadDistilBertForZeroShotDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForQuestionAnswering.scala index f9cdbeaf323127..453b8ac7e2cb17 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForQuestionAnswering.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast @@ -349,7 +349,7 @@ trait ReadLongformerForQuestionAnsweringDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForSequenceClassification.scala index e6c91330eaf371..6dd293f033515c 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForSequenceClassification.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -403,7 +403,7 @@ trait ReadLongformerForSequenceDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForTokenClassification.scala index 1957ccb20b00cb..176fea3d1e19f2 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForTokenClassification.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -371,7 +371,7 @@ trait ReadLongformerForTokenDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForQuestionAnswering.scala index 27212384881b1b..35bd006fc4a3e5 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForQuestionAnswering.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.{ModelEngine, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast @@ -347,7 +347,7 @@ trait ReadRoBertaForQuestionAnsweringDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForSequenceClassification.scala index f3c76c0b88f915..5e4b268af48f0d 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForSequenceClassification.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -398,7 +398,7 @@ trait ReadRoBertaForSequenceDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForTokenClassification.scala index 65c14a953c5b05..742306621bd376 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForTokenClassification.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -369,7 +369,7 @@ trait ReadRoBertaForTokenDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForZeroShotClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForZeroShotClassification.scala index ff24acd94a7894..60041627854e15 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForZeroShotClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForZeroShotClassification.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -441,7 +441,7 @@ trait ReadRoBertaForZeroShotDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/TapasForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/TapasForQuestionAnswering.scala index b9a1e253525054..22b3760cb8fa69 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/TapasForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/TapasForQuestionAnswering.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp.base.TableAssembler import com.johnsnowlabs.nlp.{Annotation, AnnotatorType, HasPretrained, ParamsAndFeaturesReadable} import org.apache.spark.broadcast.Broadcast @@ -265,7 +265,7 @@ trait ReadTapasForQuestionAnsweringDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForQuestionAnswering.scala index a42fef9c880aea..01920477d5a672 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForQuestionAnswering.scala @@ -28,7 +28,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast @@ -323,7 +323,7 @@ trait ReadXlmRoBertaForQuestionAnsweringDLModel annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForSequenceClassification.scala index eada6953d7bca1..add55d9270b8be 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForSequenceClassification.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -375,7 +375,7 @@ trait ReadXlmRoBertaForSequenceDLModel extends ReadTensorflowModel with ReadSent annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForTokenClassification.scala index 38a379d9ae529a..ded252b097d481 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForTokenClassification.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -345,7 +345,7 @@ trait ReadXlmRoBertaForTokenDLModel extends ReadTensorflowModel with ReadSentenc annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForSequenceClassification.scala index 593e9d51e37a6f..b9e786c4a869fb 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForSequenceClassification.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -371,7 +371,7 @@ trait ReadXlnetForSequenceDLModel extends ReadTensorflowModel with ReadSentenceP annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForTokenClassification.scala index 3f9e9f54df57e8..43b1e4dcd46103 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForTokenClassification.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -343,7 +343,7 @@ trait ReadXlnetForTokenDLModel extends ReadTensorflowModel with ReadSentencePiec annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/coref/SpanBertCorefModel.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/coref/SpanBertCorefModel.scala index 1b097a76813d03..bb48cbc9d7c575 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/coref/SpanBertCorefModel.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/coref/SpanBertCorefModel.scala @@ -26,7 +26,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder} @@ -449,7 +449,7 @@ trait ReadSpanBertCorefTensorflowModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ConvNextForImageClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ConvNextForImageClassification.scala index b87de63bae8e58..3af9710b12ff9b 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ConvNextForImageClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ConvNextForImageClassification.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.cv.feature_extractor.Preprocessor import org.apache.spark.broadcast.Broadcast @@ -353,7 +353,7 @@ trait ReadConvNextForImageDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/SwinForImageClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/SwinForImageClassification.scala index 4341fa23cd4bd6..344200c0c2e501 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/SwinForImageClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/SwinForImageClassification.scala @@ -22,7 +22,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.cv.feature_extractor.Preprocessor import org.apache.spark.ml.param.{BooleanParam, DoubleParam} @@ -334,7 +334,7 @@ trait ReadSwinForImageDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ViTForImageClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ViTForImageClassification.scala index e786739b6ac718..985fbc041251a5 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ViTForImageClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ViTForImageClassification.scala @@ -27,7 +27,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp.AnnotatorType.{CATEGORY, IMAGE} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.cv.feature_extractor.Preprocessor @@ -39,8 +39,6 @@ import org.apache.spark.sql.SparkSession import org.json4s._ import org.json4s.jackson.JsonMethods._ -import java.io.File - /** Vision Transformer (ViT) for image classification. * * ViT is a transformer based alternative to the convolutional neural networks usually used for @@ -384,7 +382,7 @@ trait ReadViTForImageDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/ld/dl/LanguageDetectorDL.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/ld/dl/LanguageDetectorDL.scala index 79a14a4b6098e5..5e05d06d6c2bec 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/ld/dl/LanguageDetectorDL.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/ld/dl/LanguageDetectorDL.scala @@ -22,7 +22,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -341,7 +341,7 @@ trait ReadLanguageDetectorDLTensorflowModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, _) = TensorflowWrapper.read( localModelPath, diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/BartTransformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/BartTransformer.scala index 66d97181a86e38..aa72c0e4738fbe 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/BartTransformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/BartTransformer.scala @@ -27,7 +27,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -596,7 +596,7 @@ trait ReadBartTransformerDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read( localModelPath, diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/GPT2Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/GPT2Transformer.scala index 246c6e8a6f10ac..29d76fcd0dea17 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/GPT2Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/GPT2Transformer.scala @@ -27,7 +27,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.{BpeTokenizer, Gpt2Tokenizer} @@ -544,7 +544,7 @@ trait ReadGPT2TransformerDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, _) = TensorflowWrapper.read( localModelPath, diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MarianTransformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MarianTransformer.scala index e15729d3f05a42..ce18cf3ad4f8bd 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MarianTransformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MarianTransformer.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature import org.apache.spark.broadcast.Broadcast @@ -458,7 +458,7 @@ trait ReadMarianMTDLModel extends ReadTensorflowModel with ReadSentencePieceMode annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read( localModelPath, zipped = false, diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/T5Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/T5Transformer.scala index ac70e675f3df97..edc691a236191c 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/T5Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/T5Transformer.scala @@ -32,7 +32,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -537,7 +537,7 @@ trait ReadT5TransformerDLModel extends ReadTensorflowModel with ReadSentencePiec val spModel = loadSentencePieceAsset(localModelPath, "spiece.model") detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read( localModelPath, zipped = false, diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala index 1f4c0e8a2923b3..c8da89256f2b4c 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala @@ -28,7 +28,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -396,7 +396,7 @@ trait ReadAlbertDLModel extends ReadTensorflowModel with ReadSentencePieceModel annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala index 80049bb6a28952..3717c09bf3b9bc 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala @@ -24,7 +24,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.{ModelEngine, ModelArch} +import com.johnsnowlabs.ml.util.{ModelArch, ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder} @@ -378,7 +378,7 @@ class BertEmbeddings(override val uid: String) val suffix = "_bert" getEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => writeTensorflowModelV2( path, spark, @@ -386,7 +386,7 @@ class BertEmbeddings(override val uid: String) suffix, BertEmbeddings.tfFile, configProtoBytes = getConfigProtoBytes) - case ModelEngine.onnx => + case ONNX.name => writeOnnxModel( path, spark, @@ -427,11 +427,11 @@ trait ReadBertDLModel extends ReadTensorflowModel with ReadOnnxModel { def readModel(instance: BertEmbeddings, path: String, spark: SparkSession): Unit = { instance.getEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val tfWrapper = readTensorflowModel(path, spark, "_bert_tf", initAllTables = false) instance.setModelIfNotSet(spark, Some(tfWrapper), None) - case ModelEngine.onnx => { + case ONNX.name => { val onnxWrapper = readOnnxModel(path, spark, "_bert_onnx", zipped = true, useBundle = false, None) instance.setModelIfNotSet(spark, None, Some(onnxWrapper)) @@ -456,7 +456,7 @@ trait ReadBertDLModel extends ReadTensorflowModel with ReadOnnxModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (tfWrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) @@ -472,7 +472,7 @@ trait ReadBertDLModel extends ReadTensorflowModel with ReadOnnxModel { .setSignatures(_signatures) .setModelIfNotSet(spark, Some(tfWrapper), None) - case ModelEngine.onnx => + case ONNX.name => val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala index 67fee6db982101..4dac9e9e02e52f 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala @@ -24,7 +24,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.{ModelEngine, ModelArch} +import com.johnsnowlabs.ml.util.{ModelArch, ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder} @@ -399,7 +399,7 @@ class BertSentenceEmbeddings(override val uid: String) super.onWrite(path, spark) getEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => writeTensorflowModelV2( path, spark, @@ -407,7 +407,7 @@ class BertSentenceEmbeddings(override val uid: String) "_bert_sentence", BertEmbeddings.tfFile, configProtoBytes = getConfigProtoBytes) - case ModelEngine.onnx => + case ONNX.name => writeOnnxModel( path, spark, @@ -449,12 +449,12 @@ trait ReadBertSentenceDLModel extends ReadTensorflowModel with ReadOnnxModel { def readModel(instance: BertSentenceEmbeddings, path: String, spark: SparkSession): Unit = { instance.getEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val tfWrapper = readTensorflowModel(path, spark, "_bert_sentence_tf", initAllTables = false) instance.setModelIfNotSet(spark, Some(tfWrapper), None) - case ModelEngine.onnx => { + case ONNX.name => { val onnxWrapper = readOnnxModel( path, @@ -485,7 +485,7 @@ trait ReadBertSentenceDLModel extends ReadTensorflowModel with ReadOnnxModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (tfWrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) @@ -501,7 +501,7 @@ trait ReadBertSentenceDLModel extends ReadTensorflowModel with ReadOnnxModel { .setSignatures(_signatures) .setModelIfNotSet(spark, Some(tfWrapper), None) - case ModelEngine.onnx => + case ONNX.name => val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper)) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddings.scala index 12c2b4d1edaef0..914d9b87b91449 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddings.scala @@ -12,7 +12,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -362,7 +362,7 @@ trait ReadCamemBertDLModel extends ReadTensorflowModel with ReadSentencePieceMod annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala index 99b094108c5fdc..3c2ea07792ba74 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala @@ -29,7 +29,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.{ModelEngine, ONNX, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -323,7 +323,7 @@ class DeBertaEmbeddings(override val uid: String) val suffix = "_deberta" getEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => writeTensorflowModelV2( path, spark, @@ -331,7 +331,7 @@ class DeBertaEmbeddings(override val uid: String) suffix, DeBertaEmbeddings.tfFile, configProtoBytes = getConfigProtoBytes) - case ModelEngine.onnx => + case ONNX.name => writeOnnxModel( path, spark, @@ -381,12 +381,12 @@ trait ReadDeBertaDLModel val spp = readSentencePieceModel(path, spark, "_deberta_spp", sppFile) instance.getEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val tfWrapper = readTensorflowModel(path, spark, "_deberta_tf", initAllTables = false) instance.setModelIfNotSet(spark, Some(tfWrapper), None, spp) - case ModelEngine.onnx => { + case ONNX.name => { val onnxWrapper = readOnnxModel(path, spark, "_deberta_onnx", zipped = true, useBundle = false, None) instance.setModelIfNotSet(spark, None, Some(onnxWrapper), spp) @@ -410,7 +410,7 @@ trait ReadDeBertaDLModel annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (tfWrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) @@ -426,7 +426,7 @@ trait ReadDeBertaDLModel .setSignatures(_signatures) .setModelIfNotSet(spark, Some(tfWrapper), None, spModel) - case ModelEngine.onnx => + case ONNX.name => val onnxWrapper = OnnxWrapper.read(localModelPath, zipped = false, useBundle = true) annotatorModel .setModelIfNotSet(spark, None, Some(onnxWrapper), spModel) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala index 8bcfbd578a2343..a2c08450829750 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.{ModelEngine, ModelArch} +import com.johnsnowlabs.ml.util.{ModelArch, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.annotators.tokenizer.wordpiece.{BasicTokenizer, WordpieceEncoder} @@ -426,7 +426,7 @@ trait ReadDistilBertDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/ElmoEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/ElmoEmbeddings.scala index 647061442198c1..7f12ffc4c89a4d 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/ElmoEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/ElmoEmbeddings.scala @@ -19,7 +19,7 @@ package com.johnsnowlabs.nlp.embeddings import com.johnsnowlabs.ml.ai.Elmo import com.johnsnowlabs.ml.tensorflow._ import com.johnsnowlabs.ml.util.LoadExternalModel.{modelSanityCheck, notSupportedEngineError} -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.storage.HasStorageRef @@ -363,7 +363,7 @@ trait ReadElmoDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, _) = TensorflowWrapper.read( localModelPath, zipped = false, diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/LongformerEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/LongformerEmbeddings.scala index 7984096e2ca163..ae5b102983a47f 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/LongformerEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/LongformerEmbeddings.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.{ModelEngine, ModelArch} +import com.johnsnowlabs.ml.util.{ModelArch, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.BpeTokenizer @@ -440,7 +440,7 @@ trait ReadLongformerDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala index dae1369d440a2e..c0f9efc2ddfff7 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.{ModelEngine, ModelArch} +import com.johnsnowlabs.ml.util.{ModelArch, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.BpeTokenizer @@ -453,7 +453,7 @@ trait ReadRobertaDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaSentenceEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaSentenceEmbeddings.scala index 9cff4ea74fbab5..83247df4e21e48 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaSentenceEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaSentenceEmbeddings.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.{ModelEngine, ModelArch} +import com.johnsnowlabs.ml.util.{ModelArch, ModelEngine, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.BpeTokenizer @@ -429,7 +429,7 @@ trait ReadRobertaSentenceDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/UniversalSentenceEncoder.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/UniversalSentenceEncoder.scala index 2cc0712a615e1d..89bce984f4c388 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/UniversalSentenceEncoder.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/UniversalSentenceEncoder.scala @@ -23,7 +23,7 @@ import com.johnsnowlabs.ml.tensorflow.{ WriteTensorflowModel } import com.johnsnowlabs.ml.util.LoadExternalModel.{modelSanityCheck, notSupportedEngineError} -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.TensorFlow import com.johnsnowlabs.nlp.AnnotatorType.{DOCUMENT, SENTENCE_EMBEDDINGS} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common.SentenceSplit @@ -349,7 +349,7 @@ trait ReadUSEDLModel extends ReadTensorflowModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val wrapper = TensorflowWrapper.readWithSP( localModelPath, diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala index 76cbc656235e4e..107da32535a946 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala @@ -28,7 +28,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.{ModelEngine, ModelArch} +import com.johnsnowlabs.ml.util.{ModelArch, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -394,7 +394,7 @@ trait ReadXlmRobertaDLModel extends ReadTensorflowModel with ReadSentencePieceMo annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaSentenceEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaSentenceEmbeddings.scala index f81836a86e1a75..07df2844768290 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaSentenceEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaSentenceEmbeddings.scala @@ -28,7 +28,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.{ModelEngine, ModelArch} +import com.johnsnowlabs.ml.util.{ModelArch, ModelEngine, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -383,7 +383,7 @@ trait ReadXlmRobertaSentenceDLModel extends ReadTensorflowModel with ReadSentenc annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlnetEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlnetEmbeddings.scala index a5bfac1d55159b..86d2c8b3e1cf97 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlnetEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlnetEmbeddings.scala @@ -28,7 +28,7 @@ import com.johnsnowlabs.ml.util.LoadExternalModel.{ modelSanityCheck, notSupportedEngineError } -import com.johnsnowlabs.ml.util.ModelEngine +import com.johnsnowlabs.ml.util.{ModelEngine, TensorFlow} import com.johnsnowlabs.nlp._ import com.johnsnowlabs.nlp.annotators.common._ import com.johnsnowlabs.nlp.serialization.MapFeature @@ -393,7 +393,7 @@ trait ReadXlnetDLModel extends ReadTensorflowModel with ReadSentencePieceModel { annotatorModel.set(annotatorModel.engine, detectedEngine) detectedEngine match { - case ModelEngine.tensorflow => + case TensorFlow.name => val (wrapper, signatures) = TensorflowWrapper.read(localModelPath, zipped = false, useBundle = true)