diff --git a/ldp/data_structures.py b/ldp/data_structures.py index 2c626a6..661146e 100644 --- a/ldp/data_structures.py +++ b/ldp/data_structures.py @@ -192,7 +192,7 @@ def add_transition( """Add a transition to the tree. Args: - step_id: A unique identifier for the root node of the tree. + step_id: A unique identifier for this node in the tree. The expected form of the step ID is "{parent step ID}:{step index}". step: The transition to add. weight: Weight of the transition. Defaults to 1.0. @@ -336,6 +336,25 @@ def compute_advantages(self) -> None: # See docstring for explanation. # step.metadata["advantage"] = step.value - state_values[parent_id] + def remove_nonterminal_branches(self) -> TransitionTree: + """Creates a new tree with only branches that end in terminal states (done=True).""" + new_tree = TransitionTree(self.root_id) + for trajectory in self.get_trajectories(): + if not trajectory.done: + continue + + traj_id_parts = cast(str, trajectory.traj_id).split(":") + + for step in trajectory.steps: + step_id = ":".join(traj_id_parts[: step.timestep + 2]) + new_tree.add_transition( + step_id=step_id, + step=step, + weight=self.get_weight(step_id), + ) + + return new_tree + def merge_identical_nodes( self, agent_state_hash_fn: Callable[[Any], Hashable],