Skip to content

Commit

Permalink
Update single_node flag back to False
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjoolee95 committed Jul 31, 2023
1 parent b6bf058 commit c046f24
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
6 changes: 3 additions & 3 deletions test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,21 +217,21 @@ def fn_fallback(t):
xla_dynamo_res = dynamo_fn(t_xla)
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 4)
self.assertEqual(met.metric_data('ExecuteTime')[0], 7)
self.assertEqual(met.metric_data('ExecuteTime')[0], 6)

# Second tracing
met.clear_counters()
xla_dynamo_res_2 = dynamo_fn(t_xla)
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res_2.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 4)
self.assertEqual(met.metric_data('ExecuteTime')[0], 9)
self.assertEqual(met.metric_data('ExecuteTime')[0], 8)

# Verify that dynamo can handle different inputs
xla_dynamo_res_3 = dynamo_fn(t_xla * 3)
cpu_res_3 = fn_fallback(t * 3)
self.assertTrue(torch.allclose(cpu_res_3, xla_dynamo_res_3.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 5)
self.assertEqual(met.metric_data('ExecuteTime')[0], 12)
self.assertEqual(met.metric_data('ExecuteTime')[0], 10)


class DynamoTrainingBasicTest(unittest.TestCase):
Expand Down
11 changes: 4 additions & 7 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def is_xla_tensor(tensor: torch.Tensor) -> bool:
return tensor.device.type == "xla"


def extract_internal(xla_model: torch.fx.GraphModule, self_args):
def extract_internal(xla_model: torch.fx.GraphModule):
xla_args = xla_model.xla_args
assert all(
map(
Expand Down Expand Up @@ -242,8 +242,6 @@ def extract_internal(xla_model: torch.fx.GraphModule, self_args):
xla_out_ids = {id(x) for x in xla_out}

# If a arg is being in place updated by model, we need to include arg as part of the graph result.
self_args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
self_args)
xla_args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
xla_args)
xla_args_need_update = []
Expand Down Expand Up @@ -293,7 +291,7 @@ def extract_internal(xla_model: torch.fx.GraphModule, self_args):
# in place update will replace the underlying DeviceData of the `xla_args`.
# Note that this needs to happens before `_clear_pending_irs` otherwise
# the additional IR generated by `copy_` won't be cleared.
if xla_args_need_update_bool or xla_args_need_update_bool:
if xla_args_need_update_bool:
for xla_arg, cloned_xla_arg in zip(xla_args, cloned_xla_args):
if isinstance(xla_arg, torch.Tensor):
xla_arg.copy_(cloned_xla_arg)
Expand Down Expand Up @@ -428,8 +426,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:

# partition the model and exectue to collect inputs
supported_ops = XlaOperatorSupport()
partitioner = CapabilityBasedPartitioner(
xla_model, supported_ops, allows_single_node_partition=True)
partitioner = CapabilityBasedPartitioner(xla_model, supported_ops)
partitions = partitioner.propose_partitions()
partitioned_graph = partitioner.fuse_partitions(partitions)
InputCollector(partitioned_graph).run(*xla_args)
Expand All @@ -450,7 +447,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
partitioned_graph.delete_submodule(node.target)
with partitioned_graph.graph.inserting_after(node):
new_node = partitioned_graph.graph.call_function(
extract_internal(fused_module, self_args), node.args, None)
extract_internal(fused_module), node.args, None)
node.replace_all_uses_with(new_node)
partitioned_graph.graph.erase_node(node)

Expand Down

0 comments on commit c046f24

Please sign in to comment.