diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index bdd91c9d1541..ac03ee96eeab 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -28,6 +28,8 @@ if is_jinja_available(): import jinja2 + from jinja2.ext import Extension + from jinja2.sandbox import ImmutableSandboxedEnvironment else: jinja2 = None @@ -360,11 +362,14 @@ def _render_with_assistant_indices( @lru_cache def _compile_jinja_template(chat_template): - class AssistantTracker(jinja2.ext.Extension): + if not is_jinja_available(): + raise ImportError("apply_chat_template requires jinja2 to be installed. Please install it using `pip install jinja2`.") + + class AssistantTracker(Extension): # This extension is used to track the indices of assistant-generated tokens in the rendered chat tags = {"generation"} - def __init__(self, environment: jinja2.sandbox.ImmutableSandboxedEnvironment): + def __init__(self, environment: ImmutableSandboxedEnvironment): # The class is only initiated by jinja. super().__init__(environment) environment.extend(activate_tracker=self.activate_tracker) @@ -418,7 +423,7 @@ def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False) def strftime_now(format): return datetime.now().strftime(format) - jinja_env = jinja2.sandbox.ImmutableSandboxedEnvironment( + jinja_env = ImmutableSandboxedEnvironment( trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols] ) jinja_env.filters["tojson"] = tojson