Skip to content

Commit

Permalink
SPARKNLP-835: Introduce ProtectedParam
Browse files Browse the repository at this point in the history
  • Loading branch information
DevinTDHa committed May 11, 2023
1 parent eb772ae commit 0fdab1d
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 15 deletions.
57 changes: 57 additions & 0 deletions src/main/scala/com/johnsnowlabs/nlp/HasProtectedParams.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package com.johnsnowlabs.nlp

import org.apache.spark.ml.param.{Param, Params}

/** Enables a class to protect a parameter, which means that it can only be set once.
*
* This trait will enable a implicit conversion from Param to ProtectedParam. In addition, the
* new set for ProtectedParam will then check, whether or not the value was already set. If so,
* then a warning will be output and the value will not be set again.
*/
trait HasProtectedParams {
this: Params =>
implicit class ProtectedParam[T](private val param: Param[T])
extends Param[T](param.parent, param.name, param.doc, param.isValid) {

var isProtected = false

/** Sets this parameter to be protected, which means that it can only be set once.
*
* Default values do not count as a set value and can be overridden.
*
* @return
* This object
*/
def setProtected(): this.type = {
isProtected = true
this
}

def toParam: Param[T] = this.asInstanceOf[Param[T]]
}

/** Sets the value for a protected Param.
*
* If the parameter was already set, it will not be set again. Default values do not count as a
* set value and can be overridden.
*
* @param param
* Protected parameter to set
* @param value
* Value for the parameter
* @tparam T
* Type of the parameter
* @return
* This object
*/
def set[T](param: ProtectedParam[T], value: T): this.type = {
if (param.isProtected && get(param).isDefined)
println(
s"Warning: The parameter ${param.name} is protected and can only be set once." +
" For a pretrained model, this was done during the initialization process." +
" If you are trying to train your own model, please check the documentation.")
else
set(param.toParam, value)
this
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ class BertSentenceEmbeddings(override val uid: String)
with HasEmbeddingsProperties
with HasStorageRef
with HasCaseSensitiveProperties
with HasEngine {
with HasEngine
with HasProtectedParams {

def this() = this(Identifiable.randomUID("BERT_SENTENCE_EMBEDDINGS"))

Expand All @@ -180,7 +181,7 @@ class BertSentenceEmbeddings(override val uid: String)
* @group param
*/
val maxSentenceLength =
new IntParam(this, "maxSentenceLength", "Max sentence length to process")
new IntParam(this, "maxSentenceLength", "Max sentence length to process").setProtected()

/** Use Long type instead of Int type for inputs (Default: `false`)
*
Expand All @@ -190,15 +191,14 @@ class BertSentenceEmbeddings(override val uid: String)
parent = this,
name = "isLong",
"Use Long type instead of Int type for inputs buffer - Some Bert models require Long instead of Int.")
.setProtected()

/** set isLong
*
* @group setParam
*/
def setIsLong(value: Boolean): this.type = {
if (get(isLong).isEmpty)
set(this.isLong, value)
this
set(this.isLong, value)
}

/** get isLong
Expand All @@ -223,9 +223,7 @@ class BertSentenceEmbeddings(override val uid: String)
* @group setParam
*/
override def setDimension(value: Int): this.type = {
if (get(dimension).isEmpty)
set(this.dimension, value)
this
set(this.dimension, value)

}

Expand All @@ -234,9 +232,7 @@ class BertSentenceEmbeddings(override val uid: String)
* @group setParam
*/
override def setCaseSensitive(value: Boolean): this.type = {
if (get(caseSensitive).isEmpty)
set(this.caseSensitive, value)
this
set(this.caseSensitive, value)
}

/** Vocabulary used to encode the words to ids with WordPieceEncoder
Expand All @@ -262,9 +258,7 @@ class BertSentenceEmbeddings(override val uid: String)
value <= 512,
"BERT models do not support sequences longer than 512 because of trainable positional embeddings")

if (get(maxSentenceLength).isEmpty)
set(maxSentenceLength, value)
this
set(maxSentenceLength, value)
}

/** ConfigProto from tensorflow, serialized into byte array. Get with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class HasFeaturesTestSpec extends AnyFlatSpec {

behavior of "HasFeatures"

it should "set protected params only once" taggedAs FastTest in {
it should "set protected features only once" taggedAs FastTest in {
model.setProtectedMockFeature("first")
assert(model.getProtectedMockFeature == "first")

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package com.johnsnowlabs.nlp

import com.johnsnowlabs.nlp.serialization.StructFeature
import com.johnsnowlabs.tags.FastTest
import com.johnsnowlabs.util.TestUtils.captureOutput
import org.apache.spark.ml.param.Param
import org.scalatest.flatspec.AnyFlatSpec
class HasProtectedParamsTestSpec extends AnyFlatSpec {
class MockModel extends AnnotatorModel[MockModel] with HasProtectedParams {
override val uid: String = "MockModel"
override val outputAnnotatorType: AnnotatorType = AnnotatorType.DUMMY
override val inputAnnotatorTypes: Array[AnnotatorType] = Array(AnnotatorType.DUMMY)

val protectedParam =
new Param[String](this, "MockString", "Mock protected Param").setProtected()
def setProtectedParam(value: String): this.type = {
set(protectedParam, value)
}

def getProtectedParam: String = {
$(protectedParam)
}

}

val model = new MockModel

behavior of "HasProtectedParams"

it should "set protected params only once" taggedAs FastTest in {
model.setProtectedParam("first")

assert(model.getProtectedParam == "first")

val output = captureOutput {
model.setProtectedParam("second")

}
assert(output.contains("is protected and can only be set once"))
assert(model.getProtectedParam == "first")
}
}

0 comments on commit 0fdab1d

Please sign in to comment.