Skip to content

Commit

Permalink
fix bf16 issue in text classification pipeline (#30996)
Browse files Browse the repository at this point in the history
* fix logits dtype

* Add bf16/fp16 tests for text_classification pipeline

* Update test_pipelines_text_classification.py

* fix

* fix
  • Loading branch information
chujiezheng authored Jun 4, 2024
1 parent de460e2 commit 6b22a8f
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/transformers/pipelines/text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def postprocess(self, model_outputs, function_to_apply=None, top_k=1, _legacy=Tr
function_to_apply = ClassificationFunction.NONE

outputs = model_outputs["logits"][0]
outputs = outputs.numpy()
outputs = outputs.float().numpy()

if function_to_apply == ClassificationFunction.SIGMOID:
scores = sigmoid(outputs)
Expand Down
39 changes: 38 additions & 1 deletion tests/pipelines/test_pipelines_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,24 @@

import unittest

import torch

from transformers import (
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TextClassificationPipeline,
pipeline,
)
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow, torch_device
from transformers.testing_utils import (
is_pipeline_test,
nested_simplify,
require_tf,
require_torch,
require_torch_bf16,
require_torch_fp16,
slow,
torch_device,
)

from .test_pipelines_common import ANY

Expand Down Expand Up @@ -106,6 +117,32 @@ def test_accepts_torch_device(self):
outputs = text_classifier("This is great !")
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])

@require_torch_fp16
def test_accepts_torch_fp16(self):
text_classifier = pipeline(
task="text-classification",
model="hf-internal-testing/tiny-random-distilbert",
framework="pt",
device=torch_device,
torch_dtype=torch.float16,
)

outputs = text_classifier("This is great !")
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])

@require_torch_bf16
def test_accepts_torch_bf16(self):
text_classifier = pipeline(
task="text-classification",
model="hf-internal-testing/tiny-random-distilbert",
framework="pt",
device=torch_device,
torch_dtype=torch.bfloat16,
)

outputs = text_classifier("This is great !")
self.assertEqual(nested_simplify(outputs), [{"label": "LABEL_0", "score": 0.504}])

@require_tf
def test_small_model_tf(self):
text_classifier = pipeline(
Expand Down

0 comments on commit 6b22a8f

Please sign in to comment.