Skip to content

Commit

Permalink
added validation method + doc strings + update call method
Browse files Browse the repository at this point in the history
  • Loading branch information
sambhavnoobcoder committed Jan 25, 2025
1 parent fc269f7 commit 5e45062
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions src/transformers/pipelines/zero_shot_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,26 @@ class ZeroShotClassificationArgumentHandler(ArgumentHandler):
"""
Handles arguments for zero-shot for text classification by turning each possible label into an NLI
premise/hypothesis pair.
Warning:
For security reasons, the hypothesis_template only supports simple {} placeholder for label insertion.
Complex Python string formatting patterns are not supported.
"""

def _validate_hypothesis_template(self, template: str) -> None:
"""Validates that the hypothesis template only contains simple {} placeholders."""
import re

# Check for exactly one simple {} placeholder
placeholder_pattern = r'\{[^{}]*\}'
matches = re.findall(placeholder_pattern, template)

if len(matches) != 1 or matches[0] != '{}':
raise ValueError(
'hypothesis_template must contain exactly one simple {} placeholder for security reasons. '
'Complex Python string formatting patterns (like {:>10}, {!r}, {0}) are not supported.'
)

def _parse_labels(self, labels):
if isinstance(labels, str):
labels = [label.strip() for label in labels.split(",") if label.strip()]
Expand All @@ -25,14 +43,12 @@ def _parse_labels(self, labels):
def __call__(self, sequences, labels, hypothesis_template):
if len(labels) == 0 or len(sequences) == 0:
raise ValueError("You must include at least one label and at least one sequence.")
if hypothesis_template.format(labels[0]) == hypothesis_template:
raise ValueError(
(
'The provided hypothesis_template "{}" was not able to be formatted with the target labels. '
"Make sure the passed template includes formatting syntax such as {{}} where the label should go."
).format(hypothesis_template)
)


# Validate the template before using it
self._validate_hypothesis_template(hypothesis_template)

labels = self._parse_labels(labels)

if isinstance(sequences, str):
sequences = [sequences]

Expand Down

0 comments on commit 5e45062

Please sign in to comment.