Skip to content

Commit

Permalink
fix forward mask
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Feb 25, 2025
1 parent 80b2a9e commit c24297d
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/gfn/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,20 +1046,20 @@ def forward_masks(self) -> TensorDict:
same_graph_mask = (arange_nodes >= self.tensor["batch_ptr"][:-1, None]) & (
arange_nodes < self.tensor["batch_ptr"][1:, None]
)
# edge_index = torch.where(
# self.tensor["edge_index"][..., None] == self.tensor["node_index"]
# )[2].reshape(self.tensor["edge_index"].shape)
# i, j = edge_index[..., 0], edge_index[..., 1]
edge_index = torch.where(
self.tensor["edge_index"][..., None] == self.tensor["node_index"]
)[2].reshape(self.tensor["edge_index"].shape)
i, j = edge_index[..., 0], edge_index[..., 1]

# for _ in range(len(self.batch_shape)):
# (i, j) = ei1.unsqueeze(0), ei2.unsqueeze(0)
for _ in range(len(self.batch_shape)):
(i, j) = i.unsqueeze(0), j.unsqueeze(0)

# First allow nodes in the same graph to connect, then disable nodes with existing edges
forward_masks["edge_index"][
same_graph_mask[:, :, None] & same_graph_mask[:, None, :]
] = True
torch.diagonal(forward_masks["edge_index"], dim1=-2, dim2=-1).fill_(False)
forward_masks["edge_index"][arange[..., None], ei1, ei2] = False
forward_masks["edge_index"][arange[..., None], i, j] = False
forward_masks["action_type"][..., GraphActionType.ADD_EDGE] &= torch.any(
forward_masks["edge_index"], dim=(-1, -2)
)
Expand Down

0 comments on commit c24297d

Please sign in to comment.