Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated ruff to the latest version #31926

Merged
merged 4 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def test_finetune_lr_schedulers(self):
with CaptureStdout() as cs:
args = parser.parse_args(args)
assert False, "--help is expected to sys.exit"
assert excinfo.type == SystemExit
assert excinfo.type is SystemExit
expected = lightning_base.arg_to_scheduler_metavar
assert expected in cs.out, "--help is expected to list the supported schedulers"

Expand All @@ -429,7 +429,7 @@ def test_finetune_lr_schedulers(self):
with CaptureStderr() as cs:
args = parser.parse_args(args)
assert False, "invalid argument is expected to sys.exit"
assert excinfo.type == SystemExit
assert excinfo.type is SystemExit
expected = f"invalid choice: '{unsupported_param}'"
assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@
"rhoknp>=1.1.0,<1.3.1",
"rjieba",
"rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
"ruff==0.4.4",
"ruff==0.5.1",
"sacrebleu>=1.4.12,<2.0.0",
"sacremoses",
"safetensors>=0.4.1",
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"rhoknp": "rhoknp>=1.1.0,<1.3.1",
"rjieba": "rjieba",
"rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
"ruff": "ruff==0.4.4",
"ruff": "ruff==0.5.1",
"sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
"sacremoses": "sacremoses",
"safetensors": "safetensors>=0.4.1",
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
)
if type(None) not in field.type.__args__:
# filter `str` in Union
field.type = field.type.__args__[0] if field.type.__args__[1] == str else field.type.__args__[1]
field.type = field.type.__args__[0] if field.type.__args__[1] is str else field.type.__args__[1]
origin_type = getattr(field.type, "__origin__", field.type)
elif bool not in field.type.__args__:
# filter `NoneType` in Union (except for `Union[bool, NoneType]`)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def dtype_byte_size(dtype):
4
```
"""
if dtype == bool:
if dtype is bool:
return 1 / 8
bit_search = re.search(r"[^\d](\d+)$", dtype.name)
if bit_search is None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def forward(
if output_attentions:
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
if type(sa_output) != tuple:
if type(sa_output) is not tuple:
raise TypeError(f"sa_output must be a tuple but it is {type(sa_output)} type")

sa_output = sa_output[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def __call__(
if output_attentions:
sa_output, sa_weights = sa_output
else:
assert type(sa_output) == tuple
assert type(sa_output) is tuple
sa_output = sa_output[0]
sa_output = self.sa_layer_norm(sa_output + hidden_states)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/esm/openfold_utils/rigid_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def __getitem__(self, index: Any) -> Rotation:
Returns:
The indexed rotation
"""
if type(index) != tuple:
if type(index) is not tuple:
index = (index,)

if self._rot_mats is not None:
Expand Down Expand Up @@ -827,7 +827,7 @@ def __getitem__(self, index: Any) -> Rigid:
Returns:
The indexed tensor
"""
if type(index) != tuple:
if type(index) is not tuple:
index = (index,)

return Rigid(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_three_from_single(self, html_string):

for element in html_code.descendants:
if isinstance(element, bs4.element.NavigableString):
if type(element.parent) != bs4.element.Tag:
if type(element.parent) is not bs4.element.Tag:
continue

text_in_this_tag = html.unescape(element).strip()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/musicgen/modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2605,7 +2605,7 @@ def generate(
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())

if model_kwargs.get("encoder_outputs") is not None and type(model_kwargs["encoder_outputs"]) == tuple:
if model_kwargs.get("encoder_outputs") is not None and type(model_kwargs["encoder_outputs"]) is tuple:
# wrap the unconditional outputs as a BaseModelOutput for compatibility with the rest of generate
model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=model_kwargs["encoder_outputs"][0])

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def reissue_pt_warnings(caught_warnings):
# Reissue warnings that are not the SAVE_STATE_WARNING
if len(caught_warnings) > 1:
for w in caught_warnings:
if w.category != UserWarning or w.message != SAVE_STATE_WARNING:
if w.category is not UserWarning or w.message != SAVE_STATE_WARNING:
warnings.warn(w.message, w.category)


Expand Down
2 changes: 1 addition & 1 deletion tests/agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def test_react_fails_max_iterations(self):
)
agent.run("What is 2 multiplied by 3.6452?")
assert len(agent.logs) == 7
assert type(agent.logs[-1]["error"]) == AgentMaxIterationsError
assert type(agent.logs[-1]["error"]) is AgentMaxIterationsError

@require_torch
def test_init_agent_with_different_toolsets(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/agents/test_python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def test_evaluate_slicing(self):
def test_access_attributes(self):
code = "integer = 1\nobj_class = integer.__class__\nobj_class"
result = evaluate_python_code(code, {}, state={})
assert result == int
assert result is int

def test_list_comprehension(self):
code = "sentence = 'THESEAGULL43'\nmeaningful_sentence = '-'.join([char.lower() for char in sentence if char.isalpha()])"
Expand Down Expand Up @@ -529,7 +529,7 @@ def test_types_as_objects(self):
code = "type_a = float(2); type_b = str; type_c = int"
state = {}
result = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state)
assert result == int
assert result is int

def test_tuple_id(self):
code = """
Expand Down
2 changes: 1 addition & 1 deletion tests/models/roformer/test_tokenization_roformer.py
Copy link
Contributor Author

@Sai-Suraj-27 Sai-Suraj-27 Jul 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this file, the name of this method test_rust_tokenizer is already defined as an attribute here:


So, the rule F811 of ruff was failing. I am not quite sure what to do here so, i just added the noqa comment as a temporary fix.

Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_tokenizer(self):
exp_tokens = [22943, 21332, 34431, 45904, 117, 306, 1231, 1231, 2653, 33994, 1266, 100]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), exp_tokens)

def test_rust_tokenizer(self):
def test_rust_tokenizer(self): # noqa: F811
tokenizer = self.get_rust_tokenizer()
input_text, output_text = self.get_chinese_input_output_texts()
tokens = tokenizer.tokenize(input_text)
Expand Down
Loading