diff --git a/src/main/scala/com/johnsnowlabs/nlp/HasProtectedParams.scala b/src/main/scala/com/johnsnowlabs/nlp/HasProtectedParams.scala new file mode 100644 index 00000000000000..6a1287ec5fd30d --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/HasProtectedParams.scala @@ -0,0 +1,57 @@ +package com.johnsnowlabs.nlp + +import org.apache.spark.ml.param.{Param, Params} + +/** Enables a class to protect a parameter, which means that it can only be set once. + * + * This trait will enable a implicit conversion from Param to ProtectedParam. In addition, the + * new set for ProtectedParam will then check, whether or not the value was already set. If so, + * then a warning will be output and the value will not be set again. + */ +trait HasProtectedParams { + this: Params => + implicit class ProtectedParam[T](private val param: Param[T]) + extends Param[T](param.parent, param.name, param.doc, param.isValid) { + + var isProtected = false + + /** Sets this parameter to be protected, which means that it can only be set once. + * + * Default values do not count as a set value and can be overridden. + * + * @return + * This object + */ + def setProtected(): this.type = { + isProtected = true + this + } + + def toParam: Param[T] = this.asInstanceOf[Param[T]] + } + + /** Sets the value for a protected Param. + * + * If the parameter was already set, it will not be set again. Default values do not count as a + * set value and can be overridden. + * + * @param param + * Protected parameter to set + * @param value + * Value for the parameter + * @tparam T + * Type of the parameter + * @return + * This object + */ + def set[T](param: ProtectedParam[T], value: T): this.type = { + if (param.isProtected && get(param).isDefined) + println( + s"Warning: The parameter ${param.name} is protected and can only be set once." + + " For a pretrained model, this was done during the initialization process." + + " If you are trying to train your own model, please check the documentation.") + else + set(param.toParam, value) + this + } +} diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala index c2cb03707caa34..4514e681670df8 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/BertSentenceEmbeddings.scala @@ -155,7 +155,8 @@ class BertSentenceEmbeddings(override val uid: String) with HasEmbeddingsProperties with HasStorageRef with HasCaseSensitiveProperties - with HasEngine { + with HasEngine + with HasProtectedParams { def this() = this(Identifiable.randomUID("BERT_SENTENCE_EMBEDDINGS")) @@ -180,7 +181,7 @@ class BertSentenceEmbeddings(override val uid: String) * @group param */ val maxSentenceLength = - new IntParam(this, "maxSentenceLength", "Max sentence length to process") + new IntParam(this, "maxSentenceLength", "Max sentence length to process").setProtected() /** Use Long type instead of Int type for inputs (Default: `false`) * @@ -190,15 +191,14 @@ class BertSentenceEmbeddings(override val uid: String) parent = this, name = "isLong", "Use Long type instead of Int type for inputs buffer - Some Bert models require Long instead of Int.") + .setProtected() /** set isLong * * @group setParam */ def setIsLong(value: Boolean): this.type = { - if (get(isLong).isEmpty) - set(this.isLong, value) - this + set(this.isLong, value) } /** get isLong @@ -223,9 +223,7 @@ class BertSentenceEmbeddings(override val uid: String) * @group setParam */ override def setDimension(value: Int): this.type = { - if (get(dimension).isEmpty) - set(this.dimension, value) - this + set(this.dimension, value) } @@ -234,9 +232,7 @@ class BertSentenceEmbeddings(override val uid: String) * @group setParam */ override def setCaseSensitive(value: Boolean): this.type = { - if (get(caseSensitive).isEmpty) - set(this.caseSensitive, value) - this + set(this.caseSensitive, value) } /** Vocabulary used to encode the words to ids with WordPieceEncoder @@ -262,9 +258,7 @@ class BertSentenceEmbeddings(override val uid: String) value <= 512, "BERT models do not support sequences longer than 512 because of trainable positional embeddings") - if (get(maxSentenceLength).isEmpty) - set(maxSentenceLength, value) - this + set(maxSentenceLength, value) } /** ConfigProto from tensorflow, serialized into byte array. Get with diff --git a/src/test/scala/com/johnsnowlabs/nlp/HasFeaturesTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/HasFeaturesTestSpec.scala index 84cda1316e478e..a7610e2ce5fe9a 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/HasFeaturesTestSpec.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/HasFeaturesTestSpec.scala @@ -18,7 +18,7 @@ class HasFeaturesTestSpec extends AnyFlatSpec { behavior of "HasFeatures" - it should "set protected params only once" taggedAs FastTest in { + it should "set protected features only once" taggedAs FastTest in { model.setProtectedMockFeature("first") assert(model.getProtectedMockFeature == "first") diff --git a/src/test/scala/com/johnsnowlabs/nlp/HasProtectedParamsTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/HasProtectedParamsTestSpec.scala new file mode 100644 index 00000000000000..a6d558f13ff3e0 --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/HasProtectedParamsTestSpec.scala @@ -0,0 +1,42 @@ +package com.johnsnowlabs.nlp + +import com.johnsnowlabs.nlp.serialization.StructFeature +import com.johnsnowlabs.tags.FastTest +import com.johnsnowlabs.util.TestUtils.captureOutput +import org.apache.spark.ml.param.Param +import org.scalatest.flatspec.AnyFlatSpec +class HasProtectedParamsTestSpec extends AnyFlatSpec { + class MockModel extends AnnotatorModel[MockModel] with HasProtectedParams { + override val uid: String = "MockModel" + override val outputAnnotatorType: AnnotatorType = AnnotatorType.DUMMY + override val inputAnnotatorTypes: Array[AnnotatorType] = Array(AnnotatorType.DUMMY) + + val protectedParam = + new Param[String](this, "MockString", "Mock protected Param").setProtected() + def setProtectedParam(value: String): this.type = { + set(protectedParam, value) + } + + def getProtectedParam: String = { + $(protectedParam) + } + + } + + val model = new MockModel + + behavior of "HasProtectedParams" + + it should "set protected params only once" taggedAs FastTest in { + model.setProtectedParam("first") + + assert(model.getProtectedParam == "first") + + val output = captureOutput { + model.setProtectedParam("second") + + } + assert(output.contains("is protected and can only be set once")) + assert(model.getProtectedParam == "first") + } +}