Skip to content

Commit

Permalink
Make torch_geometric models compatible with export (#123403)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#123403
Approved by: https://github.com/angelayi

Reviewed By: clee2000

Differential Revision: D55828744

Pulled By: tugsbayasgalan

fbshipit-source-id: 2ae53d8fbeb722c2c44b614bd6806d2b036ec97e
  • Loading branch information
tugsbayasgalan authored and facebook-github-bot committed Apr 6, 2024
1 parent 0044106 commit 37dc5a4
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
6 changes: 3 additions & 3 deletions userbenchmark/dynamo/dynamobench/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,12 +1135,12 @@ def load(cls, model, example_inputs, device):
example_outputs = copy.deepcopy(model)(*example_args, **example_kwargs)
_register_dataclass_output_as_pytree(example_outputs)

# TODO(angelayi): change this to predispatch
gm = torch.export._trace._export_to_torch_ir(
gm = torch.export._trace._export(
model,
example_args,
example_kwargs,
)
pre_dispatch=True,
).module()
with torch.no_grad():
so_path = torch._inductor.aot_compile(
gm, example_args, example_kwargs
Expand Down
22 changes: 22 additions & 0 deletions userbenchmark/dynamo/dynamobench/torchbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,20 @@
torch.backends.cuda.matmul.allow_tf32 = True


def _reassign_parameters(model):
# torch_geometric models register parameter as tensors due to
# https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/dense/linear.py#L158-L168
# Since it is unusual thing to do, we just reassign them to parameters
def state_dict_hook(module, destination, prefix, local_metadata):
for name, param in module.named_parameters():
if isinstance(destination[name], torch.Tensor) and not isinstance(
destination[name], torch.nn.Parameter
):
destination[name] = torch.nn.Parameter(destination[name])

model._register_state_dict_hook(state_dict_hook)


def setup_torchbench_cwd():
original_dir = abspath(os.getcwd())

Expand Down Expand Up @@ -265,6 +279,14 @@ def load_model(
extra_args=extra_args,
)
model, example_inputs = benchmark.get_module()
if model_name in [
"basic_gnn_edgecnn",
"basic_gnn_gcn",
"basic_gnn_sage",
"basic_gnn_gin",
]:
_reassign_parameters(model)

# Models that must be in train mode while training
if is_training and (
not use_eval_mode or model_name in self._config["only_training"]
Expand Down

0 comments on commit 37dc5a4

Please sign in to comment.