From 234e813db5ce564e844848a9d61fb6c5b7c4eaf0 Mon Sep 17 00:00:00 2001 From: Siddharth Narayanan Date: Tue, 14 Jan 2025 16:42:50 -0800 Subject: [PATCH 1/2] remove_nonterminal_branches --- ldp/data_structures.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/ldp/data_structures.py b/ldp/data_structures.py index 2c626a6..8f5d7c4 100644 --- a/ldp/data_structures.py +++ b/ldp/data_structures.py @@ -192,7 +192,7 @@ def add_transition( """Add a transition to the tree. Args: - step_id: A unique identifier for the root node of the tree. + step_id: A unique identifier for this node in the tree. The expected form of the step ID is "{parent step ID}:{step index}". step: The transition to add. weight: Weight of the transition. Defaults to 1.0. @@ -336,6 +336,25 @@ def compute_advantages(self) -> None: # See docstring for explanation. # step.metadata["advantage"] = step.value - state_values[parent_id] + def remove_nonterminal_branches(self) -> TransitionTree: + """Creates a new tree with only branches that end in terminal states.""" + new_tree = TransitionTree(self.root_id) + for trajectory in self.get_trajectories(): + if not trajectory.done: + continue + + traj_id_parts = cast(str, trajectory.traj_id).split(":") + + for step in trajectory.steps: + step_id = ":".join(traj_id_parts[: step.timestep + 2]) + new_tree.add_transition( + step_id=step_id, + step=step, + weight=self.get_weight(step_id), + ) + + return new_tree + def merge_identical_nodes( self, agent_state_hash_fn: Callable[[Any], Hashable], From 3771815feb115981553cf687718d4b168f31f7e7 Mon Sep 17 00:00:00 2001 From: Siddharth Narayanan Date: Tue, 14 Jan 2025 16:57:15 -0800 Subject: [PATCH 2/2] update comment --- ldp/data_structures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ldp/data_structures.py b/ldp/data_structures.py index 8f5d7c4..661146e 100644 --- a/ldp/data_structures.py +++ b/ldp/data_structures.py @@ -337,7 +337,7 @@ def compute_advantages(self) -> None: # step.metadata["advantage"] = step.value - state_values[parent_id] def remove_nonterminal_branches(self) -> TransitionTree: - """Creates a new tree with only branches that end in terminal states.""" + """Creates a new tree with only branches that end in terminal states (done=True).""" new_tree = TransitionTree(self.root_id) for trajectory in self.get_trajectories(): if not trajectory.done: