From 05a26c7e0057191c61e45db252905bdff018d0e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Malthe=20J=C3=B8rgensen?= Date: Sat, 4 Feb 2023 00:51:45 +0100 Subject: [PATCH] Add B031: Warn when using `groupby()` result multiple times --- README.rst | 3 ++ bugbear.py | 65 +++++++++++++++++++++++++++++++++++++++++++ tests/b031.py | 64 ++++++++++++++++++++++++++++++++++++++++++ tests/test_bugbear.py | 12 ++++++++ 4 files changed, 144 insertions(+) create mode 100644 tests/b031.py diff --git a/README.rst b/README.rst index 8fe5020..99d60c6 100644 --- a/README.rst +++ b/README.rst @@ -181,6 +181,9 @@ It is therefore recommended to use a stacklevel of 2 or greater to provide more **B030**: Except handlers should only be exception classes or tuples of exception classes. +**B031**: Using the generator returned from `itertools.groupby()` more than once will do nothing on the +second usage. Save the result to a list if the result is needed multiple times. + Opinionated warnings ~~~~~~~~~~~~~~~~~~~~ diff --git a/bugbear.py b/bugbear.py index 7f9f80b..503f0cc 100644 --- a/bugbear.py +++ b/bugbear.py @@ -265,6 +265,11 @@ def children_in_scope(node): yield from children_in_scope(child) +def walk_list(nodes): + for node in nodes: + yield from ast.walk(node) + + def _typesafe_issubclass(cls, class_or_tuple): try: return issubclass(cls, class_or_tuple) @@ -401,6 +406,7 @@ def visit_For(self, node): self.check_for_b007(node) self.check_for_b020(node) self.check_for_b023(node) + self.check_for_b031(node) self.generic_visit(node) def visit_AsyncFor(self, node): @@ -793,6 +799,56 @@ def check_for_b026(self, call: ast.Call): ): self.errors.append(B026(starred.lineno, starred.col_offset)) + def check_for_b031(self, loop_node): # noqa: C901 + """Check that `itertools.groupby` isn't iterated over more than once. + + We emit a warning when the generator returned by `groupby()` is used + more than once inside a loop body or when it's used in a nested loop. + """ + # for in : ... + if isinstance(loop_node.iter, ast.Call): + node = loop_node.iter + if (isinstance(node.func, ast.Name) and node.func.id in ("groupby",)) or ( + isinstance(node.func, ast.Attribute) + and node.func.attr == "groupby" + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "itertools" + ): + # We have an invocation of groupby which is a simple unpacking + if isinstance(loop_node.target, ast.Tuple) and isinstance( + loop_node.target.elts[1], ast.Name + ): + group_name = loop_node.target.elts[1].id + else: + # Ignore any `groupby()` invocation that isn't unpacked + return + + num_usages = 0 + for node in walk_list(loop_node.body): + # Handled nested loops + if isinstance(node, ast.For): + for nested_node in walk_list(node.body): + assert nested_node != node + if ( + isinstance(nested_node, ast.Name) + and nested_node.id == group_name + ): + self.errors.append( + B031( + nested_node.lineno, + nested_node.col_offset, + vars=(nested_node.id,), + ) + ) + + # Handle multiple uses + if isinstance(node, ast.Name) and node.id == group_name: + num_usages += 1 + if num_usages > 1: + self.errors.append( + B031(node.lineno, node.col_offset, vars=(node.id,)) + ) + def _get_assigned_names(self, loop_node): loop_targets = (ast.For, ast.AsyncFor, ast.comprehension) for node in children_in_scope(loop_node): @@ -1558,8 +1614,17 @@ def visit_Lambda(self, node): "anything. Add exceptions to handle." ) ) + B030 = Error(message="B030 Except handlers should only be names of exception classes") +B031 = Error( + message=( + "B031 Using the generator returned from `itertools.groupby()` more than once" + " will do nothing on the second usage. Save the result to a list, if the" + " result is needed multiple times." + ) +) + # Warnings disabled by default. B901 = Error( message=( diff --git a/tests/b031.py b/tests/b031.py new file mode 100644 index 0000000..1c090dc --- /dev/null +++ b/tests/b031.py @@ -0,0 +1,64 @@ +""" +Should emit: +B030 - on lines 29, 33, 43 +""" +import itertools +from itertools import groupby + +shoppers = ["Jane", "Joe", "Sarah"] +items = [ + ("lettuce", "greens"), + ("tomatoes", "greens"), + ("cucumber", "greens"), + ("chicken breast", "meats & fish"), + ("salmon", "meats & fish"), + ("ice cream", "frozen items"), +] + +carts = {shopper: [] for shopper in shoppers} + + +def collect_shop_items(shopper, items): + # Imagine this an expensive database query or calculation that is + # advantageous to batch. + carts[shopper] += items + + +# Group by shopping section +for _section, section_items in groupby(items, key=lambda p: p[1]): + for shopper in shoppers: + collect_shop_items(shopper, section_items) + +for _section, section_items in groupby(items, key=lambda p: p[1]): + collect_shop_items("Jane", section_items) + collect_shop_items("Joe", section_items) + + +for _section, section_items in groupby(items, key=lambda p: p[1]): + # This is ok + collect_shop_items("Jane", section_items) + +for _section, section_items in itertools.groupby(items, key=lambda p: p[1]): + for shopper in shoppers: + collect_shop_items(shopper, section_items) + +for group in groupby(items, key=lambda p: p[1]): + # This is bad, but not detected currently + collect_shop_items("Jane", group[1]) + collect_shop_items("Joe", group[1]) + + +# Make sure we ignore - but don't fail on more complicated invocations +for _key, (_value1, _value2) in groupby( + [("a", (1, 2)), ("b", (3, 4)), ("a", (5, 6))], key=lambda p: p[1] +): + collect_shop_items("Jane", group[1]) + collect_shop_items("Joe", group[1]) + +# Make sure we ignore - but don't fail on more complicated invocations +for (_key1, _key2), (_value1, _value2) in groupby( + [(("a", "a"), (1, 2)), (("b", "b"), (3, 4)), (("a", "a"), (5, 6))], + key=lambda p: p[1], +): + collect_shop_items("Jane", group[1]) + collect_shop_items("Joe", group[1]) diff --git a/tests/test_bugbear.py b/tests/test_bugbear.py index 4751e73..731f008 100644 --- a/tests/test_bugbear.py +++ b/tests/test_bugbear.py @@ -42,6 +42,7 @@ B028, B029, B030, + B031, B901, B902, B903, @@ -459,6 +460,17 @@ def test_b030(self): ) self.assertEqual(errors, expected) + def test_b031(self): + filename = Path(__file__).absolute().parent / "b031.py" + bbc = BugBearChecker(filename=str(filename)) + errors = list(bbc.run()) + expected = self.errors( + B031(30, 36, vars=("section_items",)), + B031(34, 30, vars=("section_items",)), + B031(43, 36, vars=("section_items",)), + ) + self.assertEqual(errors, expected) + @unittest.skipIf(sys.version_info < (3, 8), "not implemented for <3.8") def test_b907(self): filename = Path(__file__).absolute().parent / "b907.py"