From 56117dcacc8185bb72f92eeb0fb4c49ac6588abd Mon Sep 17 00:00:00 2001 From: Jan Kwakkel Date: Sat, 21 Sep 2024 15:30:18 +0200 Subject: [PATCH] Update to CellCollection.select (#2307) --- mesa/experimental/cell_space/cell.py | 4 +-- .../cell_space/cell_collection.py | 33 ++++++++++++------- tests/test_cell_space.py | 12 +++++++ 3 files changed, 36 insertions(+), 13 deletions(-) diff --git a/mesa/experimental/cell_space/cell.py b/mesa/experimental/cell_space/cell.py index 381d0f6cccb..08a37102ebb 100644 --- a/mesa/experimental/cell_space/cell.py +++ b/mesa/experimental/cell_space/cell.py @@ -48,7 +48,7 @@ class Cell: def __init__( self, coordinate: Coordinate, - capacity: float | None = None, + capacity: int | None = None, random: Random | None = None, ) -> None: """Initialise the cell. @@ -65,7 +65,7 @@ def __init__( self.agents: list[ Agent ] = [] # TODO:: change to AgentSet or weakrefs? (neither is very performant, ) - self.capacity = capacity + self.capacity: int = capacity self.properties: dict[Coordinate, object] = {} self.random = random diff --git a/mesa/experimental/cell_space/cell_collection.py b/mesa/experimental/cell_space/cell_collection.py index 14832d511be..a8c36f5198f 100644 --- a/mesa/experimental/cell_space/cell_collection.py +++ b/mesa/experimental/cell_space/cell_collection.py @@ -83,25 +83,36 @@ def select_random_agent(self) -> CellAgent: """ return self.random.choice(list(self.agents)) - def select(self, filter_func: Callable[[T], bool] | None = None, n=0): + def select( + self, + filter_func: Callable[[T], bool] | None = None, + at_most: int | float = float("inf"), + ): """Select cells based on filter function. Args: filter_func: filter function - n: number of cells to select + at_most: The maximum amount of cells to select. Defaults to infinity. + - If an integer, at most the first number of matching cells is selected. + - If a float between 0 and 1, at most that fraction of original number of cells Returns: CellCollection """ - # FIXME: n is not considered - if filter_func is None and n == 0: + if filter_func is None and at_most == float("inf"): return self - return CellCollection( - { - cell: agents - for cell, agents in self._cells.items() - if filter_func is None or filter_func(cell) - } - ) + if at_most <= 1.0 and isinstance(at_most, float): + at_most = int(len(self) * at_most) # Note that it rounds down (floor) + + def cell_generator(filter_func, at_most): + count = 0 + for cell in self: + if count >= at_most: + break + if not filter_func or filter_func(cell): + yield cell + count += 1 + + return CellCollection(cell_generator(filter_func, at_most)) diff --git a/tests/test_cell_space.py b/tests/test_cell_space.py index 5ef4c44e21d..64d275663d2 100644 --- a/tests/test_cell_space.py +++ b/tests/test_cell_space.py @@ -500,3 +500,15 @@ def test_cell_collection(): agents = collection[cells[0]] assert agents == cells[0].agents + + cell = collection.select(at_most=1) + assert len(cell) == 1 + + cells = collection.select(at_most=2) + assert len(cells) == 2 + + cells = collection.select(at_most=0.5) + assert len(cells) == 5 + + cells = collection.select() + assert len(cells) == len(collection)