Skip to content

Commit

Permalink
remove_nonterminal_branches (#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
sidnarayanan authored Jan 15, 2025
1 parent e748ba2 commit 67a55cc
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion ldp/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def add_transition(
"""Add a transition to the tree.
Args:
step_id: A unique identifier for the root node of the tree.
step_id: A unique identifier for this node in the tree.
The expected form of the step ID is "{parent step ID}:{step index}".
step: The transition to add.
weight: Weight of the transition. Defaults to 1.0.
Expand Down Expand Up @@ -336,6 +336,25 @@ def compute_advantages(self) -> None:
# See docstring for explanation.
# step.metadata["advantage"] = step.value - state_values[parent_id]

def remove_nonterminal_branches(self) -> TransitionTree:
"""Creates a new tree with only branches that end in terminal states (done=True)."""
new_tree = TransitionTree(self.root_id)
for trajectory in self.get_trajectories():
if not trajectory.done:
continue

traj_id_parts = cast(str, trajectory.traj_id).split(":")

for step in trajectory.steps:
step_id = ":".join(traj_id_parts[: step.timestep + 2])
new_tree.add_transition(
step_id=step_id,
step=step,
weight=self.get_weight(step_id),
)

return new_tree

def merge_identical_nodes(
self,
agent_state_hash_fn: Callable[[Any], Hashable],
Expand Down

0 comments on commit 67a55cc

Please sign in to comment.