Skip to content

Commit

Permalink
SPARKNLP-835: ProtectedParam and ProtectedFeature (#13797)
Browse files Browse the repository at this point in the history
* SPARKNLP-835: Finalize protected Features

* SPARKNLP-835: Remove redundant checks for protected Features

* SPARKNLP-835: Introduce ProtectedParam

* SPARKNLP-835: Resolve encoding/decoding issue for HasProtectedParams

* SPARKNLP-835: Make caseSensitive settable

* SPARKNLP-835: Make maxSentenceLength, batchSize settable

* SPARKNLP-835: Enable protected Params for Annotators
  • Loading branch information
DevinTDHa authored May 25, 2023
1 parent 8ed57e4 commit e586d56
Show file tree
Hide file tree
Showing 61 changed files with 263 additions and 295 deletions.
62 changes: 62 additions & 0 deletions src/main/scala/com/johnsnowlabs/nlp/HasProtectedParams.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package com.johnsnowlabs.nlp

import org.apache.spark.ml.param.{Param, Params}

/** Enables a class to protect a parameter, which means that it can only be set once.
*
* This trait will enable a implicit conversion from Param to ProtectedParam. In addition, the
* new set for ProtectedParam will then check, whether or not the value was already set. If so,
* then a warning will be output and the value will not be set again.
*/
trait HasProtectedParams {
this: Params =>
implicit class ProtectedParam[T](baseParam: Param[T])
extends Param[T](baseParam.parent, baseParam.name, baseParam.doc, baseParam.isValid) {

var isProtected = false

/** Sets this parameter to be protected, which means that it can only be set once.
*
* Default values do not count as a set value and can be overridden.
*
* @return
* This object
*/
def setProtected(): this.type = {
isProtected = true
this
}

def toParam: Param[T] = this.asInstanceOf[Param[T]]

// Overrides needed for individual Param implementation
override def jsonEncode(value: T): String = baseParam.jsonEncode(value)
override def jsonDecode(json: String): T = baseParam.jsonDecode(json)
}

/** Sets the value for a protected Param.
*
* If the parameter was already set, it will not be set again. Default values do not count as a
* set value and can be overridden.
*
* @param param
* Protected parameter to set
* @param value
* Value for the parameter
* @tparam T
* Type of the parameter
* @return
* This object
*/
def set[T](param: ProtectedParam[T], value: T): this.type = {
if (param.isProtected && get(param).isDefined)
println(
s"Warning: The parameter ${param.name} is protected and can only be set once." +
" For a pretrained model, this was done during the initialization process." +
" If you are trying to train your own model, please check the documentation." +
" If this is intentional, set the parameter directly with set(annotator.param, value).")
else
set(param.toParam, value)
this
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,7 @@ class Wav2Vec2ForCTC(override val uid: String)

/** @group setParam */
def setSignatures(value: Map[String, String]): this.type = {
if (get(signatures).isEmpty)
set(signatures, value)
set(signatures, value)
this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,7 @@ class AlbertForQuestionAnswering(override val uid: String)

/** @group setParam */
def setSignatures(value: Map[String, String]): this.type = {
if (get(signatures).isEmpty)
set(signatures, value)
set(signatures, value)
this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,7 @@ class AlbertForSequenceClassification(override val uid: String)

/** @group setParam */
def setSignatures(value: Map[String, String]): this.type = {
if (get(signatures).isEmpty)
set(signatures, value)
set(signatures, value)
this
}

Expand Down Expand Up @@ -265,9 +264,7 @@ class AlbertForSequenceClassification(override val uid: String)
* @group setParam
*/
override def setCaseSensitive(value: Boolean): this.type = {
if (get(caseSensitive).isEmpty)
set(this.caseSensitive, value)
this
set(this.caseSensitive, value)
}

setDefault(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,7 @@ class AlbertForTokenClassification(override val uid: String)

/** @group setParam */
def setSignatures(value: Map[String, String]): this.type = {
if (get(signatures).isEmpty)
set(signatures, value)
set(signatures, value)
this
}

Expand Down Expand Up @@ -242,9 +241,7 @@ class AlbertForTokenClassification(override val uid: String)
* @group setParam
*/
override def setCaseSensitive(value: Boolean): this.type = {
if (get(caseSensitive).isEmpty)
set(this.caseSensitive, value)
this
set(this.caseSensitive, value)
}

setDefault(batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,7 @@ class BertForQuestionAnswering(override val uid: String)

/** @group setParam */
def setSignatures(value: Map[String, String]): this.type = {
if (get(signatures).isEmpty)
set(signatures, value)
set(signatures, value)
this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,7 @@ class BertForSequenceClassification(override val uid: String)

/** @group setParam */
def setSignatures(value: Map[String, String]): this.type = {
if (get(signatures).isEmpty)
set(signatures, value)
set(signatures, value)
this
}

Expand Down Expand Up @@ -282,9 +281,7 @@ class BertForSequenceClassification(override val uid: String)
* @group setParam
*/
override def setCaseSensitive(value: Boolean): this.type = {
if (get(caseSensitive).isEmpty)
set(this.caseSensitive, value)
this
set(this.caseSensitive, value)
}

setDefault(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,7 @@ class BertForTokenClassification(override val uid: String)

/** @group setParam */
def setSignatures(value: Map[String, String]): this.type = {
if (get(signatures).isEmpty)
set(signatures, value)
set(signatures, value)
this
}

Expand Down Expand Up @@ -255,9 +254,7 @@ class BertForTokenClassification(override val uid: String)
* @group setParam
*/
override def setCaseSensitive(value: Boolean): this.type = {
if (get(caseSensitive).isEmpty)
set(this.caseSensitive, value)
this
set(this.caseSensitive, value)
}

setDefault(batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,7 @@ class BertForZeroShotClassification(override val uid: String)

/** @group setParam */
def setVocabulary(value: Map[String, Int]): this.type = {
if (get(vocabulary).isEmpty)
set(vocabulary, value)
set(vocabulary, value)
this
}

Expand Down Expand Up @@ -256,8 +255,7 @@ class BertForZeroShotClassification(override val uid: String)

/** @group setParam */
def setSignatures(value: Map[String, String]): this.type = {
if (get(signatures).isEmpty)
set(signatures, value)
set(signatures, value)
this
}

Expand Down Expand Up @@ -295,9 +293,7 @@ class BertForZeroShotClassification(override val uid: String)
* @group setParam
*/
override def setCaseSensitive(value: Boolean): this.type = {
if (get(caseSensitive).isEmpty)
set(this.caseSensitive, value)
this
set(this.caseSensitive, value)
}

setDefault(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,7 @@ class CamemBertForQuestionAnswering(override val uid: String)

/** @group setParam */
def setSignatures(value: Map[String, String]): this.type = {
if (get(signatures).isEmpty)
set(signatures, value)
set(signatures, value)
this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,7 @@ class CamemBertForSequenceClassification(override val uid: String)

/** @group setParam */
def setSignatures(value: Map[String, String]): this.type = {
if (get(signatures).isEmpty)
set(signatures, value)
set(signatures, value)
this
}

Expand Down Expand Up @@ -265,9 +264,7 @@ class CamemBertForSequenceClassification(override val uid: String)
* @group setParam
*/
override def setCaseSensitive(value: Boolean): this.type = {
if (get(caseSensitive).isEmpty)
set(this.caseSensitive, value)
this
set(this.caseSensitive, value)
}

setDefault(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,7 @@ class CamemBertForTokenClassification(override val uid: String)

/** @group setParam */
def setSignatures(value: Map[String, String]): this.type = {
if (get(signatures).isEmpty)
set(signatures, value)
set(signatures, value)
this
}

Expand Down Expand Up @@ -242,9 +241,7 @@ class CamemBertForTokenClassification(override val uid: String)
* @group setParam
*/
override def setCaseSensitive(value: Boolean): this.type = {
if (get(caseSensitive).isEmpty)
set(this.caseSensitive, value)
this
set(this.caseSensitive, value)
}

setDefault(batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,7 @@ class DeBertaForQuestionAnswering(override val uid: String)

/** @group setParam */
def setSignatures(value: Map[String, String]): this.type = {
if (get(signatures).isEmpty)
set(signatures, value)
set(signatures, value)
this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ class DeBertaForSequenceClassification(override val uid: String)

/** @group setParam */
def setSignatures(value: Map[String, String]): this.type = {
if (get(signatures).isEmpty)
set(signatures, value)
set(signatures, value)
this
}

Expand Down Expand Up @@ -264,9 +263,7 @@ class DeBertaForSequenceClassification(override val uid: String)
* @group setParam
*/
override def setCaseSensitive(value: Boolean): this.type = {
if (get(caseSensitive).isEmpty)
set(this.caseSensitive, value)
this
set(this.caseSensitive, value)
}

setDefault(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,7 @@ class DeBertaForTokenClassification(override val uid: String)

/** @group setParam */
def setSignatures(value: Map[String, String]): this.type = {
if (get(signatures).isEmpty)
set(signatures, value)
set(signatures, value)
this
}

Expand Down Expand Up @@ -243,9 +242,7 @@ class DeBertaForTokenClassification(override val uid: String)
* @group setParam
*/
override def setCaseSensitive(value: Boolean): this.type = {
if (get(caseSensitive).isEmpty)
set(this.caseSensitive, value)
this
set(this.caseSensitive, value)
}

setDefault(batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,7 @@ class DistilBertForQuestionAnswering(override val uid: String)

/** @group setParam */
def setSignatures(value: Map[String, String]): this.type = {
if (get(signatures).isEmpty)
set(signatures, value)
set(signatures, value)
this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,7 @@ class DistilBertForSequenceClassification(override val uid: String)

/** @group setParam */
def setSignatures(value: Map[String, String]): this.type = {
if (get(signatures).isEmpty)
set(signatures, value)
set(signatures, value)
this
}

Expand Down Expand Up @@ -278,9 +277,7 @@ class DistilBertForSequenceClassification(override val uid: String)
* @group setParam
*/
override def setCaseSensitive(value: Boolean): this.type = {
if (get(caseSensitive).isEmpty)
set(this.caseSensitive, value)
this
set(this.caseSensitive, value)
}

setDefault(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,7 @@ class DistilBertForTokenClassification(override val uid: String)

/** @group setParam */
def setSignatures(value: Map[String, String]): this.type = {
if (get(signatures).isEmpty)
set(signatures, value)
set(signatures, value)
this
}

Expand Down Expand Up @@ -255,9 +254,7 @@ class DistilBertForTokenClassification(override val uid: String)
* @group setParam
*/
override def setCaseSensitive(value: Boolean): this.type = {
if (get(caseSensitive).isEmpty)
set(this.caseSensitive, value)
this
set(this.caseSensitive, value)
}

setDefault(batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,7 @@ class DistilBertForZeroShotClassification(override val uid: String)

/** @group setParam */
def setVocabulary(value: Map[String, Int]): this.type = {
if (get(vocabulary).isEmpty)
set(vocabulary, value)
set(vocabulary, value)
this
}

Expand Down Expand Up @@ -253,8 +252,7 @@ class DistilBertForZeroShotClassification(override val uid: String)

/** @group setParam */
def setSignatures(value: Map[String, String]): this.type = {
if (get(signatures).isEmpty)
set(signatures, value)
set(signatures, value)
this
}

Expand Down Expand Up @@ -291,9 +289,7 @@ class DistilBertForZeroShotClassification(override val uid: String)
* @group setParam
*/
override def setCaseSensitive(value: Boolean): this.type = {
if (get(caseSensitive).isEmpty)
set(this.caseSensitive, value)
this
set(this.caseSensitive, value)
}

setDefault(
Expand Down
Loading

0 comments on commit e586d56

Please sign in to comment.