Skip to content

Commit

Permalink
Updated ruff to the latest version (#31926)
Browse files Browse the repository at this point in the history
* Updated ruff version and fixed the required code accorindg to the latest version.

* Updated ruff version and fixed the required code accorindg to the latest version.

* Added noqa directive to ignore 1 error shown by ruff
  • Loading branch information
Sai-Suraj-27 authored Jul 23, 2024
1 parent 9cf4f2a commit d2c687b
Show file tree
Hide file tree
Showing 14 changed files with 17 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def test_finetune_lr_schedulers(self):
with CaptureStdout() as cs:
args = parser.parse_args(args)
assert False, "--help is expected to sys.exit"
assert excinfo.type == SystemExit
assert excinfo.type is SystemExit
expected = lightning_base.arg_to_scheduler_metavar
assert expected in cs.out, "--help is expected to list the supported schedulers"

Expand All @@ -429,7 +429,7 @@ def test_finetune_lr_schedulers(self):
with CaptureStderr() as cs:
args = parser.parse_args(args)
assert False, "invalid argument is expected to sys.exit"
assert excinfo.type == SystemExit
assert excinfo.type is SystemExit
expected = f"invalid choice: '{unsupported_param}'"
assert expected in cs.err, f"should have bailed on invalid choice of scheduler {unsupported_param}"

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@
"rhoknp>=1.1.0,<1.3.1",
"rjieba",
"rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
"ruff==0.4.4",
"ruff==0.5.1",
"sacrebleu>=1.4.12,<2.0.0",
"sacremoses",
"safetensors>=0.4.1",
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"rhoknp": "rhoknp>=1.1.0,<1.3.1",
"rjieba": "rjieba",
"rouge-score": "rouge-score!=0.0.7,!=0.0.8,!=0.1,!=0.1.1",
"ruff": "ruff==0.4.4",
"ruff": "ruff==0.5.1",
"sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
"sacremoses": "sacremoses",
"safetensors": "safetensors>=0.4.1",
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/hf_argparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _parse_dataclass_field(parser: ArgumentParser, field: dataclasses.Field):
)
if type(None) not in field.type.__args__:
# filter `str` in Union
field.type = field.type.__args__[0] if field.type.__args__[1] == str else field.type.__args__[1]
field.type = field.type.__args__[0] if field.type.__args__[1] is str else field.type.__args__[1]
origin_type = getattr(field.type, "__origin__", field.type)
elif bool not in field.type.__args__:
# filter `NoneType` in Union (except for `Union[bool, NoneType]`)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def dtype_byte_size(dtype):
4
```
"""
if dtype == bool:
if dtype is bool:
return 1 / 8
bit_search = re.search(r"[^\d](\d+)$", dtype.name)
if bit_search is None:
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def forward(
if output_attentions:
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
if type(sa_output) != tuple:
if type(sa_output) is not tuple:
raise TypeError(f"sa_output must be a tuple but it is {type(sa_output)} type")

sa_output = sa_output[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def __call__(
if output_attentions:
sa_output, sa_weights = sa_output
else:
assert type(sa_output) == tuple
assert type(sa_output) is tuple
sa_output = sa_output[0]
sa_output = self.sa_layer_norm(sa_output + hidden_states)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/esm/openfold_utils/rigid_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def __getitem__(self, index: Any) -> Rotation:
Returns:
The indexed rotation
"""
if type(index) != tuple:
if type(index) is not tuple:
index = (index,)

if self._rot_mats is not None:
Expand Down Expand Up @@ -827,7 +827,7 @@ def __getitem__(self, index: Any) -> Rigid:
Returns:
The indexed tensor
"""
if type(index) != tuple:
if type(index) is not tuple:
index = (index,)

return Rigid(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_three_from_single(self, html_string):

for element in html_code.descendants:
if isinstance(element, bs4.element.NavigableString):
if type(element.parent) != bs4.element.Tag:
if type(element.parent) is not bs4.element.Tag:
continue

text_in_this_tag = html.unescape(element).strip()
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/musicgen/modeling_musicgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2550,7 +2550,7 @@ def generate(
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())

if model_kwargs.get("encoder_outputs") is not None and type(model_kwargs["encoder_outputs"]) == tuple:
if model_kwargs.get("encoder_outputs") is not None and type(model_kwargs["encoder_outputs"]) is tuple:
# wrap the unconditional outputs as a BaseModelOutput for compatibility with the rest of generate
model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=model_kwargs["encoder_outputs"][0])

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def reissue_pt_warnings(caught_warnings):
# Reissue warnings that are not the SAVE_STATE_WARNING
if len(caught_warnings) > 1:
for w in caught_warnings:
if w.category != UserWarning or w.message != SAVE_STATE_WARNING:
if w.category is not UserWarning or w.message != SAVE_STATE_WARNING:
warnings.warn(w.message, w.category)


Expand Down
2 changes: 1 addition & 1 deletion tests/agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def test_react_fails_max_iterations(self):
)
agent.run("What is 2 multiplied by 3.6452?")
assert len(agent.logs) == 7
assert type(agent.logs[-1]["error"]) == AgentMaxIterationsError
assert type(agent.logs[-1]["error"]) is AgentMaxIterationsError

@require_torch
def test_init_agent_with_different_toolsets(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/agents/test_python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def test_evaluate_slicing(self):
def test_access_attributes(self):
code = "integer = 1\nobj_class = integer.__class__\nobj_class"
result = evaluate_python_code(code, {}, state={})
assert result == int
assert result is int

def test_list_comprehension(self):
code = "sentence = 'THESEAGULL43'\nmeaningful_sentence = '-'.join([char.lower() for char in sentence if char.isalpha()])"
Expand Down Expand Up @@ -591,7 +591,7 @@ def test_types_as_objects(self):
code = "type_a = float(2); type_b = str; type_c = int"
state = {}
result = evaluate_python_code(code, {"float": float, "str": str, "int": int}, state=state)
assert result == int
assert result is int

def test_tuple_id(self):
code = """
Expand Down
2 changes: 1 addition & 1 deletion tests/models/roformer/test_tokenization_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_tokenizer(self):
exp_tokens = [22943, 21332, 34431, 45904, 117, 306, 1231, 1231, 2653, 33994, 1266, 100]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), exp_tokens)

def test_rust_tokenizer(self):
def test_rust_tokenizer(self): # noqa: F811
tokenizer = self.get_rust_tokenizer()
input_text, output_text = self.get_chinese_input_output_texts()
tokens = tokenizer.tokenize(input_text)
Expand Down

0 comments on commit d2c687b

Please sign in to comment.