From 4de70eb7044dd1e7bdb91a6eb1bab0fd3e9923a0 Mon Sep 17 00:00:00 2001 From: Wonjoo Lee Date: Sat, 22 Jul 2023 00:23:01 +0000 Subject: [PATCH] Clean up some code Surround in an if-statement Update metrics for fallback related dynamo tests Update cloned args logic Revert "Update metrics for fallback related dynamo tests" This reverts commit 3855f4374d479d53e2f431cc4b95afd7e291cdcc. --- test/dynamo/test_dynamo.py | 5 ++-- torch_xla/core/dynamo_bridge.py | 49 +++++++++++++++++---------------- 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/test/dynamo/test_dynamo.py b/test/dynamo/test_dynamo.py index 1a0146c6a7a8..e09fba0815b7 100644 --- a/test/dynamo/test_dynamo.py +++ b/test/dynamo/test_dynamo.py @@ -95,8 +95,9 @@ def forward(self, index, copy_tensor, input_tensor): res_cpu = cpu_model.forward(index, copy_tensor, input_tensor) xla_model = TestModel(device).to(device) - res_xla_dynamo = xla_model.forward(xla_index, xla_copy_tensor, - xla_input_tensor) + compiled_model = torch.compile(xla_model, backend='torchxla_trace_once') + res_xla_dynamo = compiled_model.forward(xla_index, xla_copy_tensor, + xla_input_tensor) self.assertIn('xla::index_copy', met.counter_names()) self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu())) diff --git a/torch_xla/core/dynamo_bridge.py b/torch_xla/core/dynamo_bridge.py index eec3101b58ee..0da5e9ef6bd7 100644 --- a/torch_xla/core/dynamo_bridge.py +++ b/torch_xla/core/dynamo_bridge.py @@ -196,7 +196,7 @@ def is_xla_tensor(tensor: torch.Tensor) -> bool: return tensor.device.type == "xla" -def extract_internal(xla_model: torch.fx.GraphModule): +def extract_internal(xla_model: torch.fx.GraphModule, self_args): xla_args = xla_model.xla_args assert all( map( @@ -215,9 +215,10 @@ def extract_internal(xla_model: torch.fx.GraphModule): # TensorID in `_get_tensors_xla_device_data_node` to create the mapping, the wrong Tensor ID # will be returned. # TODO(JackCaoG): fix the cloned tensor can't be used to warm up the cache. - cloned_xla_args = [ + all_xla_args = list(xla_args) + self_args + cloned_args = [ torch.clone(xla_arg) if isinstance(xla_arg, torch.Tensor) else xla_arg - for xla_arg in xla_args + for xla_arg in all_xla_args ] args_tensor_ids = [ @@ -242,6 +243,8 @@ def extract_internal(xla_model: torch.fx.GraphModule): xla_out_ids = {id(x) for x in xla_out} # If a arg is being in place updated by model, we need to include arg as part of the graph result. + self_args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization( + self_args) xla_args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization( xla_args) xla_args_need_update = [] @@ -291,10 +294,10 @@ def extract_internal(xla_model: torch.fx.GraphModule): # in place update will replace the underlying DeviceData of the `xla_args`. # Note that this needs to happens before `_clear_pending_irs` otherwise # the additional IR generated by `copy_` won't be cleared. - if xla_args_need_update_bool: - for xla_arg, cloned_xla_arg in zip(xla_args, cloned_xla_args): - if isinstance(xla_arg, torch.Tensor): - xla_arg.copy_(cloned_xla_arg) + if xla_args_need_update_bool or self_args_need_update_bool: + for arg, cloned_arg in zip(all_xla_args, cloned_args): + if isinstance(arg, torch.Tensor): + arg.copy_(cloned_arg) # Remove all of the pending IR from all live tensors. The assumptions are # 1. With the `xm.mark_step` in the beginning of this call, every XLATensor @@ -387,26 +390,29 @@ def call_module(self, target, args, kwargs): return super().call_module(target, args, kwargs) -def extract_compiled_graph(xla_model, xla_args): +def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): # This call is critical to make sure xla_args' tensor id show up in graph_input_tensor_ids xm.mark_step() - self_tensors = [] + # If a model's `forward` function has an in-place op that acts on its `self.tensor`, the + # `self.tensor` is not included as a part of the `xla_args` and does not get materialized. + # This explicitly fetches the `self.tensor`s if they exist. + self_args = [] for name, buffer in xla_model.named_buffers(): if "self" in name: - self_tensors.append(buffer) - self_tensors = tuple(self_tensors) + self_args.append(buffer) + all_xla_args = list(xla_args) + self_args # This logic, needed for supporting in-place operations, is a duplicate of # the one in the main `extract_internal` function above. We need to do this # check for fetching fallback ops as well. # TODO (@wonjoo): Make this duplicate code a bit cleaner. - xla_args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization( - xla_args + self_tensors) + args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization( + all_xla_args) - cloned_xla_args = [ + cloned_args = [ torch.clone(xla_arg) if isinstance(xla_arg, torch.Tensor) else xla_arg - for xla_arg in xla_args + self_tensors + for xla_arg in all_xla_args ] # execute model once to collect fallback ops @@ -429,15 +435,12 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: partitioned_graph = partitioner.fuse_partitions(partitions) InputCollector(partitioned_graph).run(*xla_args) - # TODO (@wonjoo) Add some comment - torch_xla._XLAC._xla_sync_multi(self_tensors, devices=[], wait=True) - # Again, same logic in the `extract_internal` above to support in-place operations. # TODO (@wonjoo): Make this duplicate code a bit cleaner. - if xla_args_need_update_bool: - for xla_arg, cloned_xla_arg in zip(xla_args, cloned_xla_args): - if isinstance(xla_arg, torch.Tensor): - xla_arg.copy_(cloned_xla_arg) + if args_need_update_bool: + for arg, cloned_arg in zip(all_xla_args, cloned_args): + if isinstance(arg, torch.Tensor): + arg.copy_(cloned_arg) torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) @@ -448,7 +451,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: partitioned_graph.delete_submodule(node.target) with partitioned_graph.graph.inserting_after(node): new_node = partitioned_graph.graph.call_function( - extract_internal(fused_module), node.args, None) + extract_internal(fused_module, self_args), node.args, None) node.replace_all_uses_with(new_node) partitioned_graph.graph.erase_node(node)