Skip to content

Commit

Permalink
Frameowrk dependant float conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
amyeroberts committed Jun 4, 2024
1 parent 489c63d commit 65c0570
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/transformers/pipelines/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,12 @@ def postprocess(self, model_outputs, function_to_apply=None, top_k=1, _legacy=Tr
function_to_apply = ClassificationFunction.NONE

outputs = model_outputs["logits"][0]
outputs = outputs.float().numpy()

if self.framework == "pt":
# To enable using fp16 and bf16
outputs = outputs.float().numpy()
else:
outputs = outputs.numpy()

if function_to_apply == ClassificationFunction.SIGMOID:
scores = sigmoid(outputs)
Expand Down

0 comments on commit 65c0570

Please sign in to comment.