Skip to content

Commit

Permalink
Clean up some code
Browse files Browse the repository at this point in the history
Surround  in an if-statement

Update metrics for fallback related dynamo tests

Update cloned args logic

Revert "Update metrics for fallback related dynamo tests"

This reverts commit 3855f43.
  • Loading branch information
wonjoolee95 committed Jul 27, 2023
1 parent ff88b6e commit 4de70eb
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 25 deletions.
5 changes: 3 additions & 2 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ def forward(self, index, copy_tensor, input_tensor):
res_cpu = cpu_model.forward(index, copy_tensor, input_tensor)

xla_model = TestModel(device).to(device)
res_xla_dynamo = xla_model.forward(xla_index, xla_copy_tensor,
xla_input_tensor)
compiled_model = torch.compile(xla_model, backend='torchxla_trace_once')
res_xla_dynamo = compiled_model.forward(xla_index, xla_copy_tensor,
xla_input_tensor)

self.assertIn('xla::index_copy', met.counter_names())
self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu()))
Expand Down
49 changes: 26 additions & 23 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def is_xla_tensor(tensor: torch.Tensor) -> bool:
return tensor.device.type == "xla"


def extract_internal(xla_model: torch.fx.GraphModule):
def extract_internal(xla_model: torch.fx.GraphModule, self_args):
xla_args = xla_model.xla_args
assert all(
map(
Expand All @@ -215,9 +215,10 @@ def extract_internal(xla_model: torch.fx.GraphModule):
# TensorID in `_get_tensors_xla_device_data_node` to create the mapping, the wrong Tensor ID
# will be returned.
# TODO(JackCaoG): fix the cloned tensor can't be used to warm up the cache.
cloned_xla_args = [
all_xla_args = list(xla_args) + self_args
cloned_args = [
torch.clone(xla_arg) if isinstance(xla_arg, torch.Tensor) else xla_arg
for xla_arg in xla_args
for xla_arg in all_xla_args
]

args_tensor_ids = [
Expand All @@ -242,6 +243,8 @@ def extract_internal(xla_model: torch.fx.GraphModule):
xla_out_ids = {id(x) for x in xla_out}

# If a arg is being in place updated by model, we need to include arg as part of the graph result.
self_args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
self_args)
xla_args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
xla_args)
xla_args_need_update = []
Expand Down Expand Up @@ -291,10 +294,10 @@ def extract_internal(xla_model: torch.fx.GraphModule):
# in place update will replace the underlying DeviceData of the `xla_args`.
# Note that this needs to happens before `_clear_pending_irs` otherwise
# the additional IR generated by `copy_` won't be cleared.
if xla_args_need_update_bool:
for xla_arg, cloned_xla_arg in zip(xla_args, cloned_xla_args):
if isinstance(xla_arg, torch.Tensor):
xla_arg.copy_(cloned_xla_arg)
if xla_args_need_update_bool or self_args_need_update_bool:
for arg, cloned_arg in zip(all_xla_args, cloned_args):
if isinstance(arg, torch.Tensor):
arg.copy_(cloned_arg)

# Remove all of the pending IR from all live tensors. The assumptions are
# 1. With the `xm.mark_step` in the beginning of this call, every XLATensor
Expand Down Expand Up @@ -387,26 +390,29 @@ def call_module(self, target, args, kwargs):
return super().call_module(target, args, kwargs)


def extract_compiled_graph(xla_model, xla_args):
def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
# This call is critical to make sure xla_args' tensor id show up in graph_input_tensor_ids
xm.mark_step()

self_tensors = []
# If a model's `forward` function has an in-place op that acts on its `self.tensor`, the
# `self.tensor` is not included as a part of the `xla_args` and does not get materialized.
# This explicitly fetches the `self.tensor`s if they exist.
self_args = []
for name, buffer in xla_model.named_buffers():
if "self" in name:
self_tensors.append(buffer)
self_tensors = tuple(self_tensors)
self_args.append(buffer)
all_xla_args = list(xla_args) + self_args

# This logic, needed for supporting in-place operations, is a duplicate of
# the one in the main `extract_internal` function above. We need to do this
# check for fetching fallback ops as well.
# TODO (@wonjoo): Make this duplicate code a bit cleaner.
xla_args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
xla_args + self_tensors)
args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
all_xla_args)

cloned_xla_args = [
cloned_args = [
torch.clone(xla_arg) if isinstance(xla_arg, torch.Tensor) else xla_arg
for xla_arg in xla_args + self_tensors
for xla_arg in all_xla_args
]

# execute model once to collect fallback ops
Expand All @@ -429,15 +435,12 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
partitioned_graph = partitioner.fuse_partitions(partitions)
InputCollector(partitioned_graph).run(*xla_args)

# TODO (@wonjoo) Add some comment
torch_xla._XLAC._xla_sync_multi(self_tensors, devices=[], wait=True)

# Again, same logic in the `extract_internal` above to support in-place operations.
# TODO (@wonjoo): Make this duplicate code a bit cleaner.
if xla_args_need_update_bool:
for xla_arg, cloned_xla_arg in zip(xla_args, cloned_xla_args):
if isinstance(xla_arg, torch.Tensor):
xla_arg.copy_(cloned_xla_arg)
if args_need_update_bool:
for arg, cloned_arg in zip(all_xla_args, cloned_args):
if isinstance(arg, torch.Tensor):
arg.copy_(cloned_arg)

torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))

Expand All @@ -448,7 +451,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
partitioned_graph.delete_submodule(node.target)
with partitioned_graph.graph.inserting_after(node):
new_node = partitioned_graph.graph.call_function(
extract_internal(fused_module), node.args, None)
extract_internal(fused_module, self_args), node.args, None)
node.replace_all_uses_with(new_node)
partitioned_graph.graph.erase_node(node)

Expand Down

0 comments on commit 4de70eb

Please sign in to comment.