Skip to content

Commit

Permalink
[pipeline] missing import regarding assisted generation (huggingface#…
Browse files Browse the repository at this point in the history
…35752)

missing import
  • Loading branch information
gante authored and elvircrn committed Feb 13, 2025
1 parent a43acb0 commit 96f7d49
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/transformers/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import torch
from torch.utils.data import DataLoader, Dataset

from ..modeling_utils import PreTrainedModel
from ..models.auto.modeling_auto import AutoModel

# Re-export for backward compatibility
Expand Down Expand Up @@ -447,7 +448,7 @@ def load_assistant_model(
if not model.can_generate() or assistant_model is None:
return None, None

if not isinstance(model, PreTrainedModel):
if getattr(model, "framework") != "pt" or not isinstance(model, PreTrainedModel):
raise ValueError(
"Assisted generation, triggered by the `assistant_model` argument, is only available for "
"`PreTrainedModel` model instances. For instance, TF or JAX models are not supported."
Expand Down

0 comments on commit 96f7d49

Please sign in to comment.