diff --git a/src/transformers/pipelines/zero_shot_classification.py b/src/transformers/pipelines/zero_shot_classification.py index 9a600bc8ad0f..f4aee3341e30 100644 --- a/src/transformers/pipelines/zero_shot_classification.py +++ b/src/transformers/pipelines/zero_shot_classification.py @@ -239,7 +239,10 @@ def _forward(self, inputs): def postprocess(self, model_outputs, multi_label=False): candidate_labels = [outputs["candidate_label"] for outputs in model_outputs] sequences = [outputs["sequence"] for outputs in model_outputs] - logits = np.concatenate([output["logits"].numpy() for output in model_outputs]) + if self.framework == "pt": + logits = np.concatenate([output["logits"].float().numpy() for output in model_outputs]) + else: + logits = np.concatenate([output["logits"].numpy() for output in model_outputs]) N = logits.shape[0] n = len(candidate_labels) num_sequences = N // n