Skip to content

Commit

Permalink
[SPARKNLP-856] Adding CamemBertForZeroShotClassification to ResourceD…
Browse files Browse the repository at this point in the history
…ownloader
  • Loading branch information
danilojsl committed Aug 7, 2024
1 parent 79643e9 commit d57d13f
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class CamemBertForZeroShotClassification(AnnotatorModel,
>>> sequenceClassifier = CamemBertForZeroShotClassification.pretrained() \\
... .setInputCols(["token", "document"]) \\
... .setOutputCol("label")
The default model is ``"deberta_base_zero_shot_classifier_mnli_anli_v3"``, if no name is
The default model is ``"camembert_zero_shot_classifier_xnli_onnx"``, if no name is
provided.
For available pretrained models please see the `Models Hub
<https://sparknlp.orgtask=Text+Classification>`__.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ trait ReadPretrainedCamemBertForZeroShotClassification
extends ParamsAndFeaturesReadable[CamemBertForZeroShotClassification]
with HasPretrained[CamemBertForZeroShotClassification] {
override val defaultModelName: Some[String] = Some(
"camembert-zero-shot-classifier-xnli-onnx"
"camembert_zero_shot_classifier_xnli_onnx"
)
override val defaultLang: String = "en"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,8 @@ object PythonResourceDownloader {
"MPNetForQuestionAnswering" -> MPNetForQuestionAnswering,
"LLAMA2Transformer" -> LLAMA2Transformer,
"M2M100Transformer" -> M2M100Transformer,
"UAEEmbeddings" -> UAEEmbeddings)
"UAEEmbeddings" -> UAEEmbeddings,
"CamemBertForZeroShotClassification" -> CamemBertForZeroShotClassification)

// List pairs of types such as the one with key type can load a pretrained model from the value type
val typeMapper: Map[String, String] = Map("ZeroShotNerModel" -> "RoBertaForQuestionAnswering")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,27 +109,6 @@ object ResourceMetadata {
candidates: List[ResourceMetadata],
request: ResourceRequest): Option[ResourceMetadata] = {

val compatibleCandidatesName = candidates
.filter(item =>
item.readyToUse && item.libVersion.isDefined && item.sparkVersion.isDefined
&& item.name == request.name)

val compatibleCandidatesLanguage = candidates
.filter(item =>
item.readyToUse && item.libVersion.isDefined && item.sparkVersion.isDefined
&& item.name == request.name
&& (request.language.isEmpty || item.language.isEmpty || request.language.get == item.language.get)
)

val compatibleCandidatesSparkNLPVersion =
candidates
.filter(item =>
item.readyToUse && item.libVersion.isDefined && item.sparkVersion.isDefined
&& item.name == request.name
&& (request.language.isEmpty || item.language.isEmpty || request.language.get == item.language.get)
&& Version.isCompatible(request.libVersion, item.libVersion))

println("")
val compatibleCandidates = candidates
.filter(item =>
item.readyToUse && item.libVersion.isDefined && item.sparkVersion.isDefined
Expand Down

0 comments on commit d57d13f

Please sign in to comment.