From 65c057081d810e213bc1ca286196d7e2c44e2d64 Mon Sep 17 00:00:00 2001 From: Amy Roberts <22614925+amyeroberts@users.noreply.github.com> Date: Tue, 4 Jun 2024 12:24:51 +0100 Subject: [PATCH] Frameowrk dependant float conversion --- src/transformers/pipelines/text_classification.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/text_classification.py b/src/transformers/pipelines/text_classification.py index bc763c161487..21ca70c2ac50 100644 --- a/src/transformers/pipelines/text_classification.py +++ b/src/transformers/pipelines/text_classification.py @@ -202,7 +202,12 @@ def postprocess(self, model_outputs, function_to_apply=None, top_k=1, _legacy=Tr function_to_apply = ClassificationFunction.NONE outputs = model_outputs["logits"][0] - outputs = outputs.float().numpy() + + if self.framework == "pt": + # To enable using fp16 and bf16 + outputs = outputs.float().numpy() + else: + outputs = outputs.numpy() if function_to_apply == ClassificationFunction.SIGMOID: scores = sigmoid(outputs)