diff --git a/src/python_testing/TC_DeviceBasicComposition.py b/src/python_testing/TC_DeviceBasicComposition.py index 4dfb3c8914ced8..40ae82be4bad70 100644 --- a/src/python_testing/TC_DeviceBasicComposition.py +++ b/src/python_testing/TC_DeviceBasicComposition.py @@ -173,17 +173,17 @@ def separate_endpoint_types(endpoint_dict: dict[int, Any]) -> tuple[list[int], l return (flat, tree) -def get_all_children(endpoint_id, endpoint_dict: dict[int, Any]): +def get_all_children(endpoint_id, endpoint_dict: dict[int, Any]) -> set[int]: """Returns all the children (include subchildren) of the given endpoint This assumes we've already checked that there are no cycles, so we can do the dumb things and just trace the tree """ - children = [] + children = set() def add_children(endpoint_id, children): immediate_children = endpoint_dict[endpoint_id][Clusters.Descriptor][Clusters.Descriptor.Attributes.PartsList] - if len(immediate_children) == 0: + if not immediate_children: return - children.extend(immediate_children) + children.update(set(immediate_children)) for child in immediate_children: add_children(child, children) @@ -193,7 +193,7 @@ def add_children(endpoint_id, children): def find_tree_roots(tree_endpoints: list[int], endpoint_dict: dict[int, Any]) -> set[int]: """Returns a set of all the endpoints in tree_endpoints that are roots for a tree (not include singletons)""" - tree_roots = [] + tree_roots = set() def find_tree_root(current_id): for endpoint_id, endpoint in endpoint_dict.items(): @@ -207,8 +207,8 @@ def find_tree_root(current_id): for endpoint_id in tree_endpoints: root = find_tree_root(endpoint_id) if root != endpoint_id: - tree_roots.append(root) - return set(tree_roots) + tree_roots.add(root) + return tree_roots def parts_list_cycles(tree_endpoints: list[int], endpoint_dict: dict[int, Any]) -> list[int]: @@ -552,7 +552,7 @@ def test_TC_SM_1_2(self): # ensure that every sub-id in the parts list is included in the parent sub_children = [] for child in self.endpoints[endpoint_id][Clusters.Descriptor][Clusters.Descriptor.Attributes.PartsList]: - sub_children.extend(get_all_children(child)) + sub_children.update(get_all_children(child)) if not all(item in sub_children for item in self.endpoints[endpoint_id][Clusters.Descriptor][Clusters.Descriptor.Attributes.PartsList]): location = AttributePathLocation(endpoint_id=endpoint_id, cluster_id=cluster_id, attribute_id=attribute_id) self.record_error(self.get_test_name(), location=location, diff --git a/src/python_testing/TestMatterTestingSupport.py b/src/python_testing/TestMatterTestingSupport.py index 010ccf3d834b15..9e63cb9ae58206 100644 --- a/src/python_testing/TestMatterTestingSupport.py +++ b/src/python_testing/TestMatterTestingSupport.py @@ -302,10 +302,10 @@ def test_cycle_detection_and_splitting(self): def test_get_all_children(self): endpoints = self.create_example_topology() - asserts.assert_equal(get_all_children(2, endpoints), [1, 3, 4, 5, 9], "Child list for ep2 is incorrect") - asserts.assert_equal(get_all_children(6, endpoints), [7, 8], "Child list for ep6 is incorrect") - asserts.assert_equal(get_all_children(13, endpoints), [12, 14, 15, 16], "Child list for ep13 is incorrect") - asserts.assert_equal(get_all_children(17, endpoints), [18, 19], "Child list for ep17 is incorrect") + asserts.assert_equal(get_all_children(2, endpoints), {1, 3, 4, 5, 9}, "Child list for ep2 is incorrect") + asserts.assert_equal(get_all_children(6, endpoints), {7, 8}, "Child list for ep6 is incorrect") + asserts.assert_equal(get_all_children(13, endpoints), {12, 14, 15, 16}, "Child list for ep13 is incorrect") + asserts.assert_equal(get_all_children(17, endpoints), {18, 19}, "Child list for ep17 is incorrect") def test_get_tree_roots(self): endpoints = self.create_example_topology()