From 9b4d0e01387ef6ccbf4a79fcfd976be77df97b24 Mon Sep 17 00:00:00 2001 From: Umar Butler Date: Thu, 19 Sep 2024 00:41:50 +1000 Subject: [PATCH] Added support for bfloat16 to zero-shot classification pipeline (#33554) * Added support for bfloat16 to zero-shot classification pipeline * Ensure support for TF. Co-authored-by: Matt * Remove dependency on `torch`. Co-authored-by: Matt --------- Co-authored-by: Matt --- src/transformers/pipelines/zero_shot_classification.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py index 9a600bc8ad0f..f4aee3341e30 100644 --- a/src/transformers/pipelines/zero_shot_classification.py +++ b/src/transformers/pipelines/zero_shot_classification.py @@ -239,7 +239,10 @@ def _forward(self, inputs): def postprocess(self, model_outputs, multi_label=False): candidate_labels = [outputs["candidate_label"] for outputs in model_outputs] sequences = [outputs["sequence"] for outputs in model_outputs] - logits = np.concatenate([output["logits"].numpy() for output in model_outputs]) + if self.framework == "pt": + logits = np.concatenate([output["logits"].float().numpy() for output in model_outputs]) + else: + logits = np.concatenate([output["logits"].numpy() for output in model_outputs]) N = logits.shape[0] n = len(candidate_labels) num_sequences = N // n