From 206a8689e02a8ca06daa293d7bdd29a5239f021c Mon Sep 17 00:00:00 2001 From: Umar Butler Date: Wed, 18 Sep 2024 18:16:15 +1000 Subject: [PATCH 1/3] Added support for bfloat16 to zero-shot classification pipeline --- src/transformers/pipelines/zero_shot_classification.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py index 9a600bc8ad0f..520327ac60ec 100644 --- a/src/transformers/pipelines/zero_shot_classification.py +++ b/src/transformers/pipelines/zero_shot_classification.py @@ -1,6 +1,7 @@ import inspect from typing import List, Union +import torch import numpy as np from ..tokenization_utils import TruncationStrategy @@ -239,7 +240,7 @@ def _forward(self, inputs): def postprocess(self, model_outputs, multi_label=False): candidate_labels = [outputs["candidate_label"] for outputs in model_outputs] sequences = [outputs["sequence"] for outputs in model_outputs] - logits = np.concatenate([output["logits"].numpy() for output in model_outputs]) + logits = np.concatenate([output["logits"].astype(torch.float32).numpy() for output in model_outputs]) N = logits.shape[0] n = len(candidate_labels) num_sequences = N // n From e9eb9075f0f5387a5a3397be66c6fe704fa69900 Mon Sep 17 00:00:00 2001 From: Umar Butler Date: Thu, 19 Sep 2024 00:29:46 +1000 Subject: [PATCH 2/3] Ensure support for TF. Co-authored-by: Matt --- src/transformers/pipelines/zero_shot_classification.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py index 520327ac60ec..ef2018ed245b 100644 --- a/src/transformers/pipelines/zero_shot_classification.py +++ b/src/transformers/pipelines/zero_shot_classification.py @@ -240,7 +240,10 @@ def _forward(self, inputs): def postprocess(self, model_outputs, multi_label=False): candidate_labels = [outputs["candidate_label"] for outputs in model_outputs] sequences = [outputs["sequence"] for outputs in model_outputs] - logits = np.concatenate([output["logits"].astype(torch.float32).numpy() for output in model_outputs]) + if self.framework == "pt": + logits = np.concatenate([output["logits"].float().numpy() for output in model_outputs]) + else: + logits = np.concatenate([output["logits"].numpy() for output in model_outputs]) N = logits.shape[0] n = len(candidate_labels) num_sequences = N // n From 782ceb747b5d7710e3fe99b2453522f9581c29f9 Mon Sep 17 00:00:00 2001 From: Umar Butler Date: Thu, 19 Sep 2024 00:30:02 +1000 Subject: [PATCH 3/3] Remove dependency on `torch`. Co-authored-by: Matt --- src/transformers/pipelines/zero_shot_classification.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py index ef2018ed245b..f4aee3341e30 100644 --- a/src/transformers/pipelines/zero_shot_classification.py +++ b/src/transformers/pipelines/zero_shot_classification.py @@ -1,7 +1,6 @@ import inspect from typing import List, Union -import torch import numpy as np from ..tokenization_utils import TruncationStrategy