From e586d5647be4e574643d00f70fb0c3da194f5330 Mon Sep 17 00:00:00 2001 From: Devin Ha <33089471+DevinTDHa@users.noreply.github.com> Date: Thu, 25 May 2023 10:08:33 +0200 Subject: [PATCH] SPARKNLP-835: ProtectedParam and ProtectedFeature (#13797) * SPARKNLP-835: Finalize protected Features * SPARKNLP-835: Remove redundant checks for protected Features * SPARKNLP-835: Introduce ProtectedParam * SPARKNLP-835: Resolve encoding/decoding issue for HasProtectedParams * SPARKNLP-835: Make caseSensitive settable * SPARKNLP-835: Make maxSentenceLength, batchSize settable * SPARKNLP-835: Enable protected Params for Annotators --- .../johnsnowlabs/nlp/HasProtectedParams.scala | 62 +++++++++++++++++++ .../nlp/annotators/audio/Wav2Vec2ForCTC.scala | 3 +- .../dl/AlbertForQuestionAnswering.scala | 3 +- .../dl/AlbertForSequenceClassification.scala | 7 +-- .../dl/AlbertForTokenClassification.scala | 7 +-- .../dl/BertForQuestionAnswering.scala | 3 +- .../dl/BertForSequenceClassification.scala | 7 +-- .../dl/BertForTokenClassification.scala | 7 +-- .../dl/BertForZeroShotClassification.scala | 10 +-- .../dl/CamemBertForQuestionAnswering.scala | 3 +- .../CamemBertForSequenceClassification.scala | 7 +-- .../dl/CamemBertForTokenClassification.scala | 7 +-- .../dl/DeBertaForQuestionAnswering.scala | 3 +- .../dl/DeBertaForSequenceClassification.scala | 7 +-- .../dl/DeBertaForTokenClassification.scala | 7 +-- .../dl/DistilBertForQuestionAnswering.scala | 3 +- .../DistilBertForSequenceClassification.scala | 7 +-- .../dl/DistilBertForTokenClassification.scala | 7 +-- .../DistilBertForZeroShotClassification.scala | 10 +-- .../dl/LongformerForQuestionAnswering.scala | 3 +- .../LongformerForSequenceClassification.scala | 7 +-- .../dl/LongformerForTokenClassification.scala | 7 +-- .../dl/RoBertaForQuestionAnswering.scala | 3 +- .../dl/RoBertaForSequenceClassification.scala | 7 +-- .../dl/RoBertaForTokenClassification.scala | 7 +-- .../dl/RoBertaForZeroShotClassification.scala | 15 ++--- .../dl/XlmRoBertaForQuestionAnswering.scala | 3 +- .../XlmRoBertaForSequenceClassification.scala | 7 +-- .../dl/XlmRoBertaForTokenClassification.scala | 7 +-- .../dl/XlnetForSequenceClassification.scala | 7 +-- .../dl/XlnetForTokenClassification.scala | 7 +-- .../annotators/coref/SpanBertCorefModel.scala | 3 +- .../cv/ViTForImageClassification.scala | 3 +- .../annotators/ld/dl/LanguageDetectorDL.scala | 10 ++- .../annotators/seq2seq/BartTransformer.scala | 3 +- .../seq2seq/MarianTransformer.scala | 12 ++-- .../annotators/seq2seq/T5Transformer.scala | 3 +- .../nlp/embeddings/AlbertEmbeddings.scala | 7 +-- .../nlp/embeddings/BertEmbeddings.scala | 11 +--- .../embeddings/BertSentenceEmbeddings.scala | 24 +++---- .../nlp/embeddings/CamemBertEmbeddings.scala | 11 +--- .../nlp/embeddings/DeBertaEmbeddings.scala | 7 +-- .../nlp/embeddings/DistilBertEmbeddings.scala | 11 +--- .../nlp/embeddings/Doc2VecApproach.scala | 7 ++- .../nlp/embeddings/Doc2VecModel.scala | 6 +- .../nlp/embeddings/ElmoEmbeddings.scala | 8 +-- .../embeddings/HasEmbeddingsProperties.scala | 6 +- .../nlp/embeddings/LongformerEmbeddings.scala | 11 +--- .../nlp/embeddings/RoBertaEmbeddings.scala | 11 +--- .../RoBertaSentenceEmbeddings.scala | 11 +--- .../embeddings/UniversalSentenceEncoder.scala | 5 +- .../nlp/embeddings/Word2VecApproach.scala | 6 +- .../nlp/embeddings/Word2VecModel.scala | 6 +- .../nlp/embeddings/XlmRoBertaEmbeddings.scala | 11 +--- .../XlmRoBertaSentenceEmbeddings.scala | 11 +--- .../nlp/embeddings/XlnetEmbeddings.scala | 7 +-- .../nlp/serialization/Feature.scala | 20 +++--- .../nlp/HasFeaturesTestSpec.scala | 17 ++--- .../nlp/HasProtectedParamsTestSpec.scala | 42 +++++++++++++ .../ResourceDownloaderMetaSpec.scala | 9 +-- .../com/johnsnowlabs/util/TestUtils.scala | 9 +++ 61 files changed, 263 insertions(+), 295 deletions(-) create mode 100644 src/main/scala/com/johnsnowlabs/nlp/HasProtectedParams.scala create mode 100644 src/test/scala/com/johnsnowlabs/nlp/HasProtectedParamsTestSpec.scala diff --git a/src/main/scala/com/johnsnowlabs/nlp/HasProtectedParams.scala b/src/main/scala/com/johnsnowlabs/nlp/HasProtectedParams.scala new file mode 100644 index 00000000000000..43fe692548eece --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/HasProtectedParams.scala @@ -0,0 +1,62 @@ +package com.johnsnowlabs.nlp + +import org.apache.spark.ml.param.{Param, Params} + +/** Enables a class to protect a parameter, which means that it can only be set once. + * + * This trait will enable a implicit conversion from Param to ProtectedParam. In addition, the + * new set for ProtectedParam will then check, whether or not the value was already set. If so, + * then a warning will be output and the value will not be set again. + */ +trait HasProtectedParams { + this: Params => + implicit class ProtectedParam[T](baseParam: Param[T]) + extends Param[T](baseParam.parent, baseParam.name, baseParam.doc, baseParam.isValid) { + + var isProtected = false + + /** Sets this parameter to be protected, which means that it can only be set once. + * + * Default values do not count as a set value and can be overridden. + * + * @return + * This object + */ + def setProtected(): this.type = { + isProtected = true + this + } + + def toParam: Param[T] = this.asInstanceOf[Param[T]] + + // Overrides needed for individual Param implementation + override def jsonEncode(value: T): String = baseParam.jsonEncode(value) + override def jsonDecode(json: String): T = baseParam.jsonDecode(json) + } + + /** Sets the value for a protected Param. + * + * If the parameter was already set, it will not be set again. Default values do not count as a + * set value and can be overridden. + * + * @param param + * Protected parameter to set + * @param value + * Value for the parameter + * @tparam T + * Type of the parameter + * @return + * This object + */ + def set[T](param: ProtectedParam[T], value: T): this.type = { + if (param.isProtected && get(param).isDefined) + println( + s"Warning: The parameter ${param.name} is protected and can only be set once." + + " For a pretrained model, this was done during the initialization process." + + " If you are trying to train your own model, please check the documentation." + + " If this is intentional, set the parameter directly with set(annotator.param, value).") + else + set(param.toParam, value) + this + } +} diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/Wav2Vec2ForCTC.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/Wav2Vec2ForCTC.scala index e83889d7e3aa07..4e51a3812f1a25 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/Wav2Vec2ForCTC.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/Wav2Vec2ForCTC.scala @@ -189,8 +189,7 @@ class Wav2Vec2ForCTC(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala index 55efa9711d06eb..1d2026e7f1bd3a 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala @@ -184,8 +184,7 @@ class AlbertForQuestionAnswering(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala index d579bc89dc89e8..8e110c8460ec5a 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala @@ -227,8 +227,7 @@ class AlbertForSequenceClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -265,9 +264,7 @@ class AlbertForSequenceClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault( diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala index fa42694afa3dfa..4abbb18a6307f1 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala @@ -205,8 +205,7 @@ class AlbertForTokenClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -242,9 +241,7 @@ class AlbertForTokenClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault(batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> false) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala index 72c8f52d4bcf4d..e8d17348c1b968 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala @@ -198,8 +198,7 @@ class BertForQuestionAnswering(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala index 0ec2ca07f1a21c..d873915c1e412e 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala @@ -243,8 +243,7 @@ class BertForSequenceClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -282,9 +281,7 @@ class BertForSequenceClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault( diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala index 4b2f57d92fe9bb..c9062cd3d99b83 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala @@ -217,8 +217,7 @@ class BertForTokenClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -255,9 +254,7 @@ class BertForTokenClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault(batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala index 0f6d3ecc05acac..8a71a452336154 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala @@ -167,8 +167,7 @@ class BertForZeroShotClassification(override val uid: String) /** @group setParam */ def setVocabulary(value: Map[String, Int]): this.type = { - if (get(vocabulary).isEmpty) - set(vocabulary, value) + set(vocabulary, value) this } @@ -256,8 +255,7 @@ class BertForZeroShotClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -295,9 +293,7 @@ class BertForZeroShotClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault( diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala index 58e652f5ffa701..784003488a1a83 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala @@ -184,8 +184,7 @@ class CamemBertForQuestionAnswering(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForSequenceClassification.scala index c2067a1a0e5285..d96d8e59318e1e 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForSequenceClassification.scala @@ -227,8 +227,7 @@ class CamemBertForSequenceClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -265,9 +264,7 @@ class CamemBertForSequenceClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault( diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassification.scala index 2313b8d4b01c96..7b440341739223 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForTokenClassification.scala @@ -205,8 +205,7 @@ class CamemBertForTokenClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -242,9 +241,7 @@ class CamemBertForTokenClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault(batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala index 59bd7ba04219a6..06e9c955d1f0f6 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForQuestionAnswering.scala @@ -184,8 +184,7 @@ class DeBertaForQuestionAnswering(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala index ea39e687727c20..dae903e43e14df 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForSequenceClassification.scala @@ -226,8 +226,7 @@ class DeBertaForSequenceClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -264,9 +263,7 @@ class DeBertaForSequenceClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault( diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala index ba332344ce2e53..b09d9d5298bca9 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DeBertaForTokenClassification.scala @@ -206,8 +206,7 @@ class DeBertaForTokenClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -243,9 +242,7 @@ class DeBertaForTokenClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault(batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForQuestionAnswering.scala index d3a070da7cfb07..e950099b9e82fc 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForQuestionAnswering.scala @@ -197,8 +197,7 @@ class DistilBertForQuestionAnswering(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForSequenceClassification.scala index 58461288f18e49..4c2699cf848e28 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForSequenceClassification.scala @@ -239,8 +239,7 @@ class DistilBertForSequenceClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -278,9 +277,7 @@ class DistilBertForSequenceClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault( diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForTokenClassification.scala index c2de49344c35f8..53690d311104e2 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForTokenClassification.scala @@ -217,8 +217,7 @@ class DistilBertForTokenClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -255,9 +254,7 @@ class DistilBertForTokenClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault(batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForZeroShotClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForZeroShotClassification.scala index 9896667d9651bf..646ed8a33ecfa7 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForZeroShotClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/DistilBertForZeroShotClassification.scala @@ -164,8 +164,7 @@ class DistilBertForZeroShotClassification(override val uid: String) /** @group setParam */ def setVocabulary(value: Map[String, Int]): this.type = { - if (get(vocabulary).isEmpty) - set(vocabulary, value) + set(vocabulary, value) this } @@ -253,8 +252,7 @@ class DistilBertForZeroShotClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -291,9 +289,7 @@ class DistilBertForZeroShotClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault( diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForQuestionAnswering.scala index 7789aa9507cc53..f9cdbeaf323127 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForQuestionAnswering.scala @@ -208,8 +208,7 @@ class LongformerForQuestionAnswering(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForSequenceClassification.scala index ed4831808cf18d..e6c91330eaf371 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForSequenceClassification.scala @@ -250,8 +250,7 @@ class LongformerForSequenceClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -291,9 +290,7 @@ class LongformerForSequenceClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault( diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForTokenClassification.scala index 052be393347406..1957ccb20b00cb 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/LongformerForTokenClassification.scala @@ -228,8 +228,7 @@ class LongformerForTokenClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -268,9 +267,7 @@ class LongformerForTokenClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault(batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForQuestionAnswering.scala index 65b0527982cc00..27212384881b1b 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForQuestionAnswering.scala @@ -208,8 +208,7 @@ class RoBertaForQuestionAnswering(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForSequenceClassification.scala index 6f267146ec077a..f3c76c0b88f915 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForSequenceClassification.scala @@ -250,8 +250,7 @@ class RoBertaForSequenceClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -291,9 +290,7 @@ class RoBertaForSequenceClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault( diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForTokenClassification.scala index 40f2ff589386f5..65c14a953c5b05 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForTokenClassification.scala @@ -228,8 +228,7 @@ class RoBertaForTokenClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -268,9 +267,7 @@ class RoBertaForTokenClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault(batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForZeroShotClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForZeroShotClassification.scala index 21494a964a6708..e5fd985e073a74 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForZeroShotClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/RoBertaForZeroShotClassification.scala @@ -163,12 +163,11 @@ class RoBertaForZeroShotClassification(override val uid: String) * * @group param */ - val vocabulary: MapFeature[String, Int] = new MapFeature(this, "vocabulary") + val vocabulary: MapFeature[String, Int] = new MapFeature(this, "vocabulary").setProtected() /** @group setParam */ def setVocabulary(value: Map[String, Int]): this.type = { - if (get(vocabulary).isEmpty) - set(vocabulary, value) + set(vocabulary, value) this } @@ -260,12 +259,12 @@ class RoBertaForZeroShotClassification(override val uid: String) * * @group param */ - val signatures = new MapFeature[String, String](model = this, name = "signatures") + val signatures = + new MapFeature[String, String](model = this, name = "signatures").setProtected() /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -304,9 +303,7 @@ class RoBertaForZeroShotClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault( diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForQuestionAnswering.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForQuestionAnswering.scala index 98e9d5df10a14b..a42fef9c880aea 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForQuestionAnswering.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForQuestionAnswering.scala @@ -184,8 +184,7 @@ class XlmRoBertaForQuestionAnswering(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForSequenceClassification.scala index 50902d32e1c326..eada6953d7bca1 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForSequenceClassification.scala @@ -226,8 +226,7 @@ class XlmRoBertaForSequenceClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -264,9 +263,7 @@ class XlmRoBertaForSequenceClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault( diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForTokenClassification.scala index 6c498f4ffeae02..38a379d9ae529a 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlmRoBertaForTokenClassification.scala @@ -205,8 +205,7 @@ class XlmRoBertaForTokenClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -242,9 +241,7 @@ class XlmRoBertaForTokenClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault(batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForSequenceClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForSequenceClassification.scala index 545618721c71b1..593e9d51e37a6f 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForSequenceClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForSequenceClassification.scala @@ -227,8 +227,7 @@ class XlnetForSequenceClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -265,9 +264,7 @@ class XlnetForSequenceClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault( diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForTokenClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForTokenClassification.scala index 9ed313bd0d7a6d..3f9e9f54df57e8 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForTokenClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/XlnetForTokenClassification.scala @@ -205,8 +205,7 @@ class XlnetForTokenClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -242,9 +241,7 @@ class XlnetForTokenClassification(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault(batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/coref/SpanBertCorefModel.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/coref/SpanBertCorefModel.scala index 4f3ada09859a88..1b097a76813d03 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/coref/SpanBertCorefModel.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/coref/SpanBertCorefModel.scala @@ -206,8 +206,7 @@ class SpanBertCorefModel(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ViTForImageClassification.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ViTForImageClassification.scala index 4d8be970c0eee5..e786739b6ac718 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ViTForImageClassification.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/cv/ViTForImageClassification.scala @@ -214,8 +214,7 @@ class ViTForImageClassification(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/ld/dl/LanguageDetectorDL.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/ld/dl/LanguageDetectorDL.scala index 444dad7b4fec7a..79a14a4b6098e5 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/ld/dl/LanguageDetectorDL.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/ld/dl/LanguageDetectorDL.scala @@ -123,13 +123,13 @@ class LanguageDetectorDL(override val uid: String) * * @group param */ - val alphabet: MapFeature[String, Int] = new MapFeature(this, "alphabet") + val alphabet: MapFeature[String, Int] = new MapFeature(this, "alphabet").setProtected() /** Language used to map prediction to ISO 639-1 language codes * * @group param */ - val language: MapFeature[String, Int] = new MapFeature(this, "language") + val language: MapFeature[String, Int] = new MapFeature(this, "language").setProtected() /** The minimum threshold for the final result, otherwise it will be either `"unk"` or the value * set in `thresholdLabel` (Default: `0.1f`). Value is between 0.0 to 1.0. Try to set this @@ -179,15 +179,13 @@ class LanguageDetectorDL(override val uid: String) /** @group setParam */ def setLanguage(value: Map[String, Int]): this.type = { - if (get(language).isEmpty) - set(this.language, value) + set(this.language, value) this } /** @group setParam */ def setAlphabet(value: Map[String, Int]): this.type = { - if (get(language).isEmpty) - set(alphabet, value) + set(alphabet, value) this } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/BartTransformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/BartTransformer.scala index 751d2ed01f1d08..66d97181a86e38 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/BartTransformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/BartTransformer.scala @@ -401,8 +401,7 @@ class BartTransformer(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MarianTransformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MarianTransformer.scala index cbd7da632c5fff..e15729d3f05a42 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MarianTransformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/MarianTransformer.scala @@ -150,7 +150,8 @@ class MarianTransformer(override val uid: String) with HasBatchedAnnotate[MarianTransformer] with WriteTensorflowModel with WriteSentencePieceModel - with HasEngine { + with HasEngine + with HasProtectedParams { /** Annotator reference id. Used to identify elements in metadata or to refer to this annotator * type @@ -177,13 +178,11 @@ class MarianTransformer(override val uid: String) val vocabulary = new StringArrayParam( this, "vocabulary", - "Vocabulary used to encode and decode piece words generated by SentencePiece") + "Vocabulary used to encode and decode piece words generated by SentencePiece").setProtected() /** @group setParam */ def setVocabulary(value: Array[String]): this.type = { - if (get(vocabulary).isEmpty) - set(vocabulary, value) - this + set(vocabulary, value) } /** Controls the maximum length for encoder inputs (source language texts) (Default: `40`) @@ -286,8 +285,7 @@ class MarianTransformer(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/T5Transformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/T5Transformer.scala index 30242f72a791fa..ac70e675f3df97 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/T5Transformer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/T5Transformer.scala @@ -387,8 +387,7 @@ class T5Transformer(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala index 86666830941918..1f4c0e8a2923b3 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/AlbertEmbeddings.scala @@ -242,9 +242,7 @@ class AlbertEmbeddings(override val uid: String) /** @group setParam */ override def setDimension(value: Int): this.type = { - if (get(dimension).isEmpty) - set(this.dimension, value) - this + set(this.dimension, value) } /** It contains TF model signatures for the laded saved model @@ -256,8 +254,7 @@ class AlbertEmbeddings(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala index 61d13aeb63199f..9e9307a44fe0ea 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertEmbeddings.scala @@ -229,8 +229,7 @@ class BertEmbeddings(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -266,9 +265,7 @@ class BertEmbeddings(override val uid: String) * @group setParam */ override def setDimension(value: Int): this.type = { - if (get(dimension).isEmpty) - set(this.dimension, value) - this + set(this.dimension, value) } /** Whether to lowercase tokens or not @@ -276,9 +273,7 @@ class BertEmbeddings(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault(dimension -> 768, batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> false) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala index c3c533c267289f..ffce05ef5d155f 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala @@ -155,7 +155,8 @@ class BertSentenceEmbeddings(override val uid: String) with HasEmbeddingsProperties with HasStorageRef with HasCaseSensitiveProperties - with HasEngine { + with HasEngine + with HasProtectedParams { def this() = this(Identifiable.randomUID("BERT_SENTENCE_EMBEDDINGS")) @@ -190,15 +191,14 @@ class BertSentenceEmbeddings(override val uid: String) parent = this, name = "isLong", "Use Long type instead of Int type for inputs buffer - Some Bert models require Long instead of Int.") + .setProtected() /** set isLong * * @group setParam */ def setIsLong(value: Boolean): this.type = { - if (get(isLong).isEmpty) - set(this.isLong, value) - this + set(this.isLong, value) } /** get isLong @@ -223,10 +223,7 @@ class BertSentenceEmbeddings(override val uid: String) * @group setParam */ override def setDimension(value: Int): this.type = { - if (get(dimension).isEmpty) - set(this.dimension, value) - this - + set(this.dimension, value) } /** Whether to lowercase tokens or not @@ -234,9 +231,7 @@ class BertSentenceEmbeddings(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } /** Vocabulary used to encode the words to ids with WordPieceEncoder @@ -262,9 +257,7 @@ class BertSentenceEmbeddings(override val uid: String) value <= 512, "BERT models do not support sequences longer than 512 because of trainable positional embeddings") - if (get(maxSentenceLength).isEmpty) - set(maxSentenceLength, value) - this + set(maxSentenceLength, value) } /** ConfigProto from tensorflow, serialized into byte array. Get with @@ -296,8 +289,7 @@ class BertSentenceEmbeddings(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddings.scala index b1fd6c3461c7ff..12c2b4d1edaef0 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/CamemBertEmbeddings.scala @@ -192,8 +192,7 @@ class CamemBertEmbeddings(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -228,9 +227,7 @@ class CamemBertEmbeddings(override val uid: String) * @group setParam */ override def setDimension(value: Int): this.type = { - if (get(dimension).isEmpty) - set(this.dimension, value) - this + set(this.dimension, value) } /** Whether to lowercase tokens or not @@ -238,9 +235,7 @@ class CamemBertEmbeddings(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault(batchSize -> 8, dimension -> 768, maxSentenceLength -> 128, caseSensitive -> true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala index f40aabbf80bb3f..e502484f1f1521 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/DeBertaEmbeddings.scala @@ -223,9 +223,7 @@ class DeBertaEmbeddings(override val uid: String) /** @group setParam */ override def setDimension(value: Int): this.type = { - if (get(dimension).isEmpty) - set(this.dimension, value) - this + set(this.dimension, value) } /** It contains TF model signatures for the laded saved model @@ -237,8 +235,7 @@ class DeBertaEmbeddings(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala index 4bfb0a4027a87c..b6ec6216d2c1b1 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/DistilBertEmbeddings.scala @@ -234,8 +234,7 @@ class DistilBertEmbeddings(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -271,9 +270,7 @@ class DistilBertEmbeddings(override val uid: String) * @group setParam */ override def setDimension(value: Int): this.type = { - if (get(dimension).isEmpty) - set(this.dimension, value) - this + set(this.dimension, value) } /** Whether to lowercase tokens or not @@ -281,9 +278,7 @@ class DistilBertEmbeddings(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault(dimension -> 768, batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> false) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/Doc2VecApproach.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/Doc2VecApproach.scala index 7cdf870ece470b..5f66d62336775c 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/Doc2VecApproach.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/Doc2VecApproach.scala @@ -17,7 +17,7 @@ package com.johnsnowlabs.nlp.embeddings import com.johnsnowlabs.nlp.AnnotatorType.{SENTENCE_EMBEDDINGS, TOKEN} -import com.johnsnowlabs.nlp.{AnnotatorApproach, HasEnableCachingProperties} +import com.johnsnowlabs.nlp.{AnnotatorApproach, HasEnableCachingProperties, HasProtectedParams} import com.johnsnowlabs.storage.HasStorageRef import org.apache.spark.ml.PipelineModel import org.apache.spark.ml.param.{DoubleParam, IntParam} @@ -98,7 +98,8 @@ import org.apache.spark.sql.{Dataset, SparkSession} class Doc2VecApproach(override val uid: String) extends AnnotatorApproach[Doc2VecModel] with HasStorageRef - with HasEnableCachingProperties { + with HasEnableCachingProperties + with HasProtectedParams { def this() = this(Identifiable.randomUID("Doc2VecApproach")) @@ -123,12 +124,12 @@ class Doc2VecApproach(override val uid: String) */ val vectorSize = new IntParam(this, "vectorSize", "the dimension of codes after transforming from words (> 0)") + .setProtected() /** @group setParam */ def setVectorSize(value: Int): this.type = { require(value > 0, s"vector size must be positive but got $value") set(vectorSize, value) - this } /** @group getParam */ diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/Doc2VecModel.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/Doc2VecModel.scala index 57784dfe9884e9..6b2d6f86664a50 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/Doc2VecModel.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/Doc2VecModel.scala @@ -147,16 +147,14 @@ class Doc2VecModel(override val uid: String) this, "vectorSize", "the dimension of codes after transforming from words (> 0)", - ParamValidators.gt(0)) + ParamValidators.gt(0)).setProtected() /** @group getParam */ def getVectorSize: Int = $(vectorSize) /** @group setParam */ def setVectorSize(value: Int): this.type = { - if (get(vectorSize).isEmpty) - set(vectorSize, value) - this + set(vectorSize, value) } /** Dictionary of words with their vectors diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/ElmoEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/ElmoEmbeddings.scala index 5e1d2a1a323d6c..647061442198c1 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/ElmoEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/ElmoEmbeddings.scala @@ -223,9 +223,7 @@ class ElmoEmbeddings(override val uid: String) * @group setParam */ def setBatchSize(size: Int): this.type = { - if (get(batchSize).isEmpty) - set(batchSize, size) - this + set(batchSize, size) } /** Set Dimension of pooling layer. This is meta for the annotation and will not affect the @@ -234,9 +232,7 @@ class ElmoEmbeddings(override val uid: String) * @group setParam */ override def setDimension(value: Int): this.type = { - if (get(dimension).isEmpty) - set(this.dimension, value) - this + set(this.dimension, value) } diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/HasEmbeddingsProperties.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/HasEmbeddingsProperties.scala index 476230dfaf72d6..595e82b450f6d7 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/HasEmbeddingsProperties.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/HasEmbeddingsProperties.scala @@ -16,18 +16,18 @@ package com.johnsnowlabs.nlp.embeddings -import com.johnsnowlabs.nlp.AnnotatorType +import com.johnsnowlabs.nlp.{AnnotatorType, HasProtectedParams} import org.apache.spark.ml.param.{IntParam, Params} import org.apache.spark.sql.Column import org.apache.spark.sql.types.MetadataBuilder -trait HasEmbeddingsProperties extends Params { +trait HasEmbeddingsProperties extends Params with HasProtectedParams { /** Number of embedding dimensions (Default depends on model) * * @group param */ - val dimension = new IntParam(this, "dimension", "Number of embedding dimensions") + val dimension = new IntParam(this, "dimension", "Number of embedding dimensions").setProtected() /** @group setParam */ def setDimension(value: Int): this.type = set(this.dimension, value) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/LongformerEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/LongformerEmbeddings.scala index 345e7f8b9787ed..daeca631328740 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/LongformerEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/LongformerEmbeddings.scala @@ -235,8 +235,7 @@ class LongformerEmbeddings(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -273,9 +272,7 @@ class LongformerEmbeddings(override val uid: String) * @group setParam */ override def setDimension(value: Int): this.type = { - if (get(dimension).isEmpty) - set(this.dimension, value) - this + set(this.dimension, value) } /** Whether to lowercase tokens or not @@ -283,9 +280,7 @@ class LongformerEmbeddings(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault(dimension -> 768, batchSize -> 4, maxSentenceLength -> 1024, caseSensitive -> true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala index 4cfb0210278bd5..2a614003027248 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaEmbeddings.scala @@ -248,8 +248,7 @@ class RoBertaEmbeddings(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -286,9 +285,7 @@ class RoBertaEmbeddings(override val uid: String) * @group setParam */ override def setDimension(value: Int): this.type = { - if (get(dimension).isEmpty) - set(this.dimension, value) - this + set(this.dimension, value) } /** Whether to lowercase tokens or not @@ -296,9 +293,7 @@ class RoBertaEmbeddings(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault(dimension -> 768, batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaSentenceEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaSentenceEmbeddings.scala index ab55c5f0026c8e..0ef98f067c48b6 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaSentenceEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/RoBertaSentenceEmbeddings.scala @@ -245,8 +245,7 @@ class RoBertaSentenceEmbeddings(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -283,9 +282,7 @@ class RoBertaSentenceEmbeddings(override val uid: String) * @group setParam */ override def setDimension(value: Int): this.type = { - if (get(dimension).isEmpty) - set(this.dimension, value) - this + set(this.dimension, value) } /** Whether to lowercase tokens or not @@ -293,9 +290,7 @@ class RoBertaSentenceEmbeddings(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault(dimension -> 768, batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/UniversalSentenceEncoder.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/UniversalSentenceEncoder.scala index 2cd55a147d5be2..2cc0712a615e1d 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/UniversalSentenceEncoder.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/UniversalSentenceEncoder.scala @@ -190,15 +190,14 @@ class UniversalSentenceEncoder(override val uid: String) "loadSP", "Whether to load SentencePiece ops file which is required only by multi-lingual models. " + "This is not changeable after it's set with a pretrained model nor it is compatible with Windows.") + .setProtected() /** Whether to load SentencePiece ops file which is required only by multi-lingual models. * * @group setParam */ def setLoadSP(value: Boolean): this.type = { - if (get(loadSP).isEmpty) - set(this.loadSP, value) - this + set(this.loadSP, value) } /** Whether to load SentencePiece ops file which is required only by multi-lingual models. diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/Word2VecApproach.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/Word2VecApproach.scala index 5bcde15cc9ebf5..555816ec2a7d70 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/Word2VecApproach.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/Word2VecApproach.scala @@ -17,7 +17,7 @@ package com.johnsnowlabs.nlp.embeddings import com.johnsnowlabs.nlp.AnnotatorType.{TOKEN, WORD_EMBEDDINGS} -import com.johnsnowlabs.nlp.{AnnotatorApproach, HasEnableCachingProperties} +import com.johnsnowlabs.nlp.{AnnotatorApproach, HasEnableCachingProperties, HasProtectedParams} import com.johnsnowlabs.storage.HasStorageRef import org.apache.spark.ml.PipelineModel import org.apache.spark.ml.param.{DoubleParam, IntParam} @@ -98,7 +98,8 @@ import org.apache.spark.sql.{Dataset, SparkSession} class Word2VecApproach(override val uid: String) extends AnnotatorApproach[Word2VecModel] with HasStorageRef - with HasEnableCachingProperties { + with HasEnableCachingProperties + with HasProtectedParams { def this() = this(Identifiable.randomUID("Word2VecApproach")) @@ -123,6 +124,7 @@ class Word2VecApproach(override val uid: String) */ val vectorSize = new IntParam(this, "vectorSize", "the dimension of codes after transforming from words (> 0)") + .setProtected() /** @group setParam */ def setVectorSize(value: Int): this.type = { diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/Word2VecModel.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/Word2VecModel.scala index 33aac55349cbdd..5ddf760450df4d 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/Word2VecModel.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/Word2VecModel.scala @@ -148,16 +148,14 @@ class Word2VecModel(override val uid: String) this, "vectorSize", "the dimension of codes after transforming from words (> 0)", - ParamValidators.gt(0)) + ParamValidators.gt(0)).setProtected() /** @group getParam */ def getVectorSize: Int = $(vectorSize) /** @group setParam */ def setVectorSize(value: Int): this.type = { - if (get(vectorSize).isEmpty) - set(vectorSize, value) - this + set(vectorSize, value) } /** Dictionary of words with their vectors diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala index e429a9efc942ca..c969cdd310577b 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaEmbeddings.scala @@ -223,8 +223,7 @@ class XlmRoBertaEmbeddings(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -261,9 +260,7 @@ class XlmRoBertaEmbeddings(override val uid: String) * @group setParam */ override def setDimension(value: Int): this.type = { - if (get(dimension).isEmpty) - set(this.dimension, value) - this + set(this.dimension, value) } /** Whether to lowercase tokens or not @@ -271,9 +268,7 @@ class XlmRoBertaEmbeddings(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault(dimension -> 768, batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaSentenceEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaSentenceEmbeddings.scala index d30cfdb93e819c..19f8a502300417 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaSentenceEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlmRoBertaSentenceEmbeddings.scala @@ -220,8 +220,7 @@ class XlmRoBertaSentenceEmbeddings(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } @@ -258,9 +257,7 @@ class XlmRoBertaSentenceEmbeddings(override val uid: String) * @group setParam */ override def setDimension(value: Int): this.type = { - if (get(dimension).isEmpty) - set(this.dimension, value) - this + set(this.dimension, value) } /** Whether to lowercase tokens or not @@ -268,9 +265,7 @@ class XlmRoBertaSentenceEmbeddings(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } setDefault(dimension -> 768, batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> true) diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlnetEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlnetEmbeddings.scala index 5faea3cb0d20d9..a5bfac1d55159b 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlnetEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/XlnetEmbeddings.scala @@ -242,9 +242,7 @@ class XlnetEmbeddings(override val uid: String) * @group setParam */ override def setDimension(value: Int): this.type = { - if (get(dimension).isEmpty) - set(this.dimension, value) - this + set(this.dimension, value) } /** It contains TF model signatures for the laded saved model @@ -256,8 +254,7 @@ class XlnetEmbeddings(override val uid: String) /** @group setParam */ def setSignatures(value: Map[String, String]): this.type = { - if (get(signatures).isEmpty) - set(signatures, value) + set(signatures, value) this } diff --git a/src/main/scala/com/johnsnowlabs/nlp/serialization/Feature.scala b/src/main/scala/com/johnsnowlabs/nlp/serialization/Feature.scala index 85332fec330ab4..9bfc20ae97a918 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/serialization/Feature.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/serialization/Feature.scala @@ -118,14 +118,20 @@ abstract class Feature[Serializable1, Serializable2, TComplete: ClassTag]( } final def setValue(value: Option[Any]): HasFeatures = { - // TODO: make sure we log if there is any protected param is being set - // if (isProtected && isSet) - if (useBroadcast) { - if (isSet) broadcastValue.get.destroy() - broadcastValue = - value.map(v => spark.sparkContext.broadcast[TComplete](v.asInstanceOf[TComplete])) + if (isProtected && isSet) { + val warnString = + s"Warning: The parameter ${this.name} is protected and can only be set once." + + " For a pretrained model, this was done during the initialization process." + + " If you are trying to train your own model, please check the documentation." + println(warnString) } else { - rawValue = value.map(_.asInstanceOf[TComplete]) + if (useBroadcast) { + if (isSet) broadcastValue.get.destroy() + broadcastValue = + value.map(v => spark.sparkContext.broadcast[TComplete](v.asInstanceOf[TComplete])) + } else { + rawValue = value.map(_.asInstanceOf[TComplete]) + } } model } diff --git a/src/test/scala/com/johnsnowlabs/nlp/HasFeaturesTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/HasFeaturesTestSpec.scala index 866903b4096c2d..a7610e2ce5fe9a 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/HasFeaturesTestSpec.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/HasFeaturesTestSpec.scala @@ -2,6 +2,7 @@ package com.johnsnowlabs.nlp import com.johnsnowlabs.nlp.serialization.StructFeature import com.johnsnowlabs.tags.FastTest +import com.johnsnowlabs.util.TestUtils.captureOutput import org.scalatest.flatspec.AnyFlatSpec class HasFeaturesTestSpec extends AnyFlatSpec { @@ -15,19 +16,19 @@ class HasFeaturesTestSpec extends AnyFlatSpec { val model = new MockModel - behavior of "HasFeaturesModels" + behavior of "HasFeatures" - it should "set protected params only once" taggedAs FastTest in { + it should "set protected features only once" taggedAs FastTest in { model.setProtectedMockFeature("first") assert(model.getProtectedMockFeature == "first") -// assertThrows[IllegalArgumentException] { -// model.setProtectedMockFeature("second") -// } - model.setProtectedMockFeature("second") + val output = captureOutput { + model.setProtectedMockFeature("second") + } + assert(output.contains("is protected and can only be set once")) + // should stay the same as the first value - // TODO: this should be first -// assert(model.getProtectedMockFeature == "first") + assert(model.getProtectedMockFeature == "first") } } diff --git a/src/test/scala/com/johnsnowlabs/nlp/HasProtectedParamsTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/HasProtectedParamsTestSpec.scala new file mode 100644 index 00000000000000..a6d558f13ff3e0 --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/HasProtectedParamsTestSpec.scala @@ -0,0 +1,42 @@ +package com.johnsnowlabs.nlp + +import com.johnsnowlabs.nlp.serialization.StructFeature +import com.johnsnowlabs.tags.FastTest +import com.johnsnowlabs.util.TestUtils.captureOutput +import org.apache.spark.ml.param.Param +import org.scalatest.flatspec.AnyFlatSpec +class HasProtectedParamsTestSpec extends AnyFlatSpec { + class MockModel extends AnnotatorModel[MockModel] with HasProtectedParams { + override val uid: String = "MockModel" + override val outputAnnotatorType: AnnotatorType = AnnotatorType.DUMMY + override val inputAnnotatorTypes: Array[AnnotatorType] = Array(AnnotatorType.DUMMY) + + val protectedParam = + new Param[String](this, "MockString", "Mock protected Param").setProtected() + def setProtectedParam(value: String): this.type = { + set(protectedParam, value) + } + + def getProtectedParam: String = { + $(protectedParam) + } + + } + + val model = new MockModel + + behavior of "HasProtectedParams" + + it should "set protected params only once" taggedAs FastTest in { + model.setProtectedParam("first") + + assert(model.getProtectedParam == "first") + + val output = captureOutput { + model.setProtectedParam("second") + + } + assert(output.contains("is protected and can only be set once")) + assert(model.getProtectedParam == "first") + } +} diff --git a/src/test/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloaderMetaSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloaderMetaSpec.scala index 1d17ffb596699a..0af8283a589726 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloaderMetaSpec.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloaderMetaSpec.scala @@ -17,6 +17,7 @@ package com.johnsnowlabs.nlp.pretrained import com.johnsnowlabs.tags.{FastTest, SlowTest} +import com.johnsnowlabs.util.TestUtils.captureOutput import com.johnsnowlabs.util.Version import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.BeforeAndAfter @@ -42,14 +43,6 @@ class ResourceDownloaderMetaSpec extends AnyFlatSpec with BeforeAndAfter { ResourceDownloader.communityDownloader = realCommunityDownloader } - def captureOutput(thunk: => Unit): String = { - val stream = new java.io.ByteArrayOutputStream() - Console.withOut(stream) { - thunk - } - stream.toString - } - def extractTableContent(string: String): Array[String] = { val split = string.split("\n") split.slice(3, split.length - 1) diff --git a/src/test/scala/com/johnsnowlabs/util/TestUtils.scala b/src/test/scala/com/johnsnowlabs/util/TestUtils.scala index d935896a8e715d..4b87982ff74953 100644 --- a/src/test/scala/com/johnsnowlabs/util/TestUtils.scala +++ b/src/test/scala/com/johnsnowlabs/util/TestUtils.scala @@ -43,4 +43,13 @@ private[johnsnowlabs] object TestUtils { } } } + + def captureOutput(thunk: => Unit): String = { + val stream = new java.io.ByteArrayOutputStream() + Console.withOut(stream) { + thunk + } + stream.toString + } + }