-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CI] Check test if the GenerationTesterMixin
inheritance is correct 🐛 🔫
#36180
Changes from 8 commits
fc5ef0c
4ee061e
286270a
fc63465
92d65cf
d2c2859
34d4a1e
e971049
e90c389
8c4f298
129a5a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
import datetime | ||
import gc | ||
import inspect | ||
import random | ||
import tempfile | ||
import unittest | ||
import warnings | ||
|
@@ -48,7 +49,6 @@ | |
) | ||
from transformers.utils import is_ipex_available | ||
|
||
from ..test_modeling_common import floats_tensor, ids_tensor | ||
from .test_framework_agnostic import GenerationIntegrationTestsMixin | ||
|
||
|
||
|
@@ -2753,6 +2753,41 @@ def test_speculative_sampling_target_distribution(self): | |
self.assertTrue(last_token_counts[8] > last_token_counts[3]) | ||
|
||
|
||
global_rng = random.Random() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. copied from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a comment with "copied from" can be added i think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are these used in a lot of places in this file, or just inside one method? If so, we can probably avoid circular dependencies by importing them within that (single) method ..? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good idea, moving to an internal import to prevent code bloat There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. uhmmm local imports would be needed in many places, will go with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one another possible approach is not to use
but check the Up to you. p.s. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's also okay to have a pure copy of the short functions :P It's just a handful of lines, I don't think it's worth the extra work for now -- I will have to refactor these lines when we remove TF (i.e. very soon) 👀 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK 👍 |
||
|
||
|
||
def ids_tensor(shape, vocab_size, rng=None, name=None): | ||
# Creates a random int32 tensor of the shape within the vocab size | ||
if rng is None: | ||
rng = global_rng | ||
|
||
total_dims = 1 | ||
for dim in shape: | ||
total_dims *= dim | ||
|
||
values = [] | ||
for _ in range(total_dims): | ||
values.append(rng.randint(0, vocab_size - 1)) | ||
|
||
return torch.tensor(data=values, dtype=torch.long, device=torch_device).view(shape).contiguous() | ||
|
||
|
||
def floats_tensor(shape, scale=1.0, rng=None, name=None): | ||
"""Creates a random float32 tensor""" | ||
if rng is None: | ||
rng = global_rng | ||
|
||
total_dims = 1 | ||
for dim in shape: | ||
total_dims *= dim | ||
|
||
values = [] | ||
for _ in range(total_dims): | ||
values.append(rng.random() * scale) | ||
|
||
return torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous() | ||
|
||
|
||
@pytest.mark.generate | ||
@require_torch | ||
class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMixin): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -451,6 +451,8 @@ class BigBirdModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) | |
if is_torch_available() | ||
else () | ||
) | ||
# Doesn't run generation tests. There are interface mismatches when using `generate` -- TODO @gante | ||
all_generative_model_classes = () | ||
Comment on lines
+454
to
+455
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for my understanding: do we need to have empty There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If a model inherits
option 2 is intentionally annoying (we are forced to overwrite a property), so we are very explicit about skipping tests. We don't want skips to happen unless we're very intentional about it. |
||
pipeline_model_mapping = ( | ||
{ | ||
"feature-extraction": BigBirdModel, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can_generate()
is only used inGenerationMixin
-related code. Let's remove time series model from this function.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it completely different or uses part of
generate()
, like some audio models?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's completely different ☠️