From 994f3dd578005ec6eafff3f8fa11c25e44f5e087 Mon Sep 17 00:00:00 2001 From: Tom Kuson Date: Sun, 8 Dec 2024 14:58:51 +0000 Subject: [PATCH] Add B911: itertools.batched without strict= (#502) --- README.rst | 4 ++++ bugbear.py | 17 +++++++++++++++++ tests/b911_py313.py | 24 ++++++++++++++++++++++++ tests/test_bugbear.py | 16 ++++++++++++++++ 4 files changed, 61 insertions(+) create mode 100644 tests/b911_py313.py diff --git a/README.rst b/README.rst index ad1cc1f..38ea0f5 100644 --- a/README.rst +++ b/README.rst @@ -258,6 +258,10 @@ This is meant to be enabled by developers writing visitors using the ``ast`` mod **B910**: Use Counter() instead of defaultdict(int) to avoid excessive memory use as the default dict will record missing keys with the default value when accessed. +**B911**: ``itertools.batched()`` without an explicit `strict=` parameter set. ``strict=True`` causes the resulting iterator to raise a ``ValueError`` if the final batch is shorter than ``n``. + +The ``strict=`` argument was added in Python 3.13, so don't enable this flag for code that should work on <3.13. + **B950**: Line too long. This is a pragmatic equivalent of ``pycodestyle``'s ``E501``: it considers "max-line-length" but only triggers when the value has been exceeded by **more than 10%**. ``noqa`` and ``type: ignore`` comments are ignored. You will no diff --git a/bugbear.py b/bugbear.py index bbe7eed..04cd1ae 100644 --- a/bugbear.py +++ b/bugbear.py @@ -515,6 +515,7 @@ def visit_Call(self, node) -> None: self.check_for_b039(node) self.check_for_b905(node) self.check_for_b910(node) + self.check_for_b911(node) # no need for copying, if used in nested calls it will be set to None current_b040_caught_exception = self.b040_caught_exception @@ -1757,6 +1758,18 @@ def check_for_b910(self, node: ast.Call) -> None: ): self.errors.append(B910(node.lineno, node.col_offset)) + def check_for_b911(self, node: ast.Call) -> None: + if ( + (isinstance(node.func, ast.Name) and node.func.id == "batched") + or ( + isinstance(node.func, ast.Attribute) + and node.func.attr == "batched" + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "itertools" + ) + ) and not any(kw.arg == "strict" for kw in node.keywords): + self.errors.append(B911(node.lineno, node.col_offset)) + def compose_call_path(node): if isinstance(node, ast.Attribute): @@ -2436,6 +2449,9 @@ def visit_Lambda(self, node) -> None: B910 = Error( message="B910 Use Counter() instead of defaultdict(int) to avoid excessive memory use" ) +B911 = Error( + message="B911 `itertools.batched()` without an explicit `strict=` parameter." +) B950 = Error(message="B950 line too long ({} > {} characters)") @@ -2449,5 +2465,6 @@ def visit_Lambda(self, node) -> None: "B908", "B909", "B910", + "B911", "B950", ] diff --git a/tests/b911_py313.py b/tests/b911_py313.py new file mode 100644 index 0000000..61107c8 --- /dev/null +++ b/tests/b911_py313.py @@ -0,0 +1,24 @@ +import itertools +from itertools import batched + +# Expect B911 +batched(range(3), 2) +batched(range(3), n=2) +batched(iterable=range(3), n=2) +itertools.batched(range(3), 2) +itertools.batched(range(3), n=2) +itertools.batched(iterable=range(3), n=2) + +# OK +batched(range(3), 2, strict=True) +batched(range(3), n=2, strict=True) +batched(iterable=range(3), n=2, strict=True) +batched(range(3), 2, strict=False) +batched(range(3), n=2, strict=False) +batched(iterable=range(3), n=2, strict=False) +itertools.batched(range(3), 2, strict=True) +itertools.batched(range(3), n=2, strict=True) +itertools.batched(iterable=range(3), n=2, strict=True) +itertools.batched(range(3), 2, strict=False) +itertools.batched(range(3), n=2, strict=False) +itertools.batched(iterable=range(3), n=2, strict=False) diff --git a/tests/test_bugbear.py b/tests/test_bugbear.py index f63d721..50cb23a 100644 --- a/tests/test_bugbear.py +++ b/tests/test_bugbear.py @@ -59,6 +59,7 @@ B908, B909, B910, + B911, B950, BugBearChecker, BugBearVisitor, @@ -1074,6 +1075,21 @@ def test_b910(self): ] self.assertEqual(errors, self.errors(*expected)) + @unittest.skipIf(sys.version_info < (3, 13), "requires 3.13+") + def test_b911(self): + filename = Path(__file__).absolute().parent / "b911_py313.py" + bbc = BugBearChecker(filename=str(filename)) + errors = list(bbc.run()) + expected = [ + B911(5, 0), + B911(6, 0), + B911(7, 0), + B911(8, 0), + B911(9, 0), + B911(10, 0), + ] + self.assertEqual(errors, self.errors(*expected)) + class TestFuzz(unittest.TestCase): from hypothesis import HealthCheck, given, settings