Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
cecille committed Aug 25, 2023
1 parent e028ee4 commit 0decaaa
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
16 changes: 8 additions & 8 deletions src/python_testing/TC_DeviceBasicComposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,17 +173,17 @@ def separate_endpoint_types(endpoint_dict: dict[int, Any]) -> tuple[list[int], l
return (flat, tree)


def get_all_children(endpoint_id, endpoint_dict: dict[int, Any]):
def get_all_children(endpoint_id, endpoint_dict: dict[int, Any]) -> set[int]:
"""Returns all the children (include subchildren) of the given endpoint
This assumes we've already checked that there are no cycles, so we can do the dumb things and just trace the tree
"""
children = []
children = set()

def add_children(endpoint_id, children):
immediate_children = endpoint_dict[endpoint_id][Clusters.Descriptor][Clusters.Descriptor.Attributes.PartsList]
if len(immediate_children) == 0:
if not immediate_children:
return
children.extend(immediate_children)
children.update(set(immediate_children))
for child in immediate_children:
add_children(child, children)

Expand All @@ -193,7 +193,7 @@ def add_children(endpoint_id, children):

def find_tree_roots(tree_endpoints: list[int], endpoint_dict: dict[int, Any]) -> set[int]:
"""Returns a set of all the endpoints in tree_endpoints that are roots for a tree (not include singletons)"""
tree_roots = []
tree_roots = set()

def find_tree_root(current_id):
for endpoint_id, endpoint in endpoint_dict.items():
Expand All @@ -207,8 +207,8 @@ def find_tree_root(current_id):
for endpoint_id in tree_endpoints:
root = find_tree_root(endpoint_id)
if root != endpoint_id:
tree_roots.append(root)
return set(tree_roots)
tree_roots.add(root)
return tree_roots


def parts_list_cycles(tree_endpoints: list[int], endpoint_dict: dict[int, Any]) -> list[int]:
Expand Down Expand Up @@ -552,7 +552,7 @@ def test_TC_SM_1_2(self):
# ensure that every sub-id in the parts list is included in the parent
sub_children = []
for child in self.endpoints[endpoint_id][Clusters.Descriptor][Clusters.Descriptor.Attributes.PartsList]:
sub_children.extend(get_all_children(child))
sub_children.update(get_all_children(child))
if not all(item in sub_children for item in self.endpoints[endpoint_id][Clusters.Descriptor][Clusters.Descriptor.Attributes.PartsList]):
location = AttributePathLocation(endpoint_id=endpoint_id, cluster_id=cluster_id, attribute_id=attribute_id)
self.record_error(self.get_test_name(), location=location,
Expand Down
8 changes: 4 additions & 4 deletions src/python_testing/TestMatterTestingSupport.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,10 @@ def test_cycle_detection_and_splitting(self):

def test_get_all_children(self):
endpoints = self.create_example_topology()
asserts.assert_equal(get_all_children(2, endpoints), [1, 3, 4, 5, 9], "Child list for ep2 is incorrect")
asserts.assert_equal(get_all_children(6, endpoints), [7, 8], "Child list for ep6 is incorrect")
asserts.assert_equal(get_all_children(13, endpoints), [12, 14, 15, 16], "Child list for ep13 is incorrect")
asserts.assert_equal(get_all_children(17, endpoints), [18, 19], "Child list for ep17 is incorrect")
asserts.assert_equal(get_all_children(2, endpoints), {1, 3, 4, 5, 9}, "Child list for ep2 is incorrect")
asserts.assert_equal(get_all_children(6, endpoints), {7, 8}, "Child list for ep6 is incorrect")
asserts.assert_equal(get_all_children(13, endpoints), {12, 14, 15, 16}, "Child list for ep13 is incorrect")
asserts.assert_equal(get_all_children(17, endpoints), {18, 19}, "Child list for ep17 is incorrect")

def test_get_tree_roots(self):
endpoints = self.create_example_topology()
Expand Down

0 comments on commit 0decaaa

Please sign in to comment.