Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for in-place ops with self tensors in dynamo bridge #5309

Merged
merged 4 commits into from
Aug 1, 2023

Conversation

wonjoolee95
Copy link
Collaborator

@wonjoolee95 wonjoolee95 commented Jul 14, 2023

Add support for in-place ops with self tensors in dynamo bridge

In models where there there is an in-place op on self.tensor, the self.tensor is not part of the xla_args. To show an example, consider the following model:

import torch
import torch.nn as nn

import torch_xla.core.xla_model as xm

device = xm.xla_device()
print(f'deivce={device}')

class TestModel(nn.Module):
    def __init__(self, device=None):
        super().__init__()
        self.b = torch.zeros((3, 3), device=device)

    def forward(self, indexes, buffer_update, input):
        self.b.index_copy_(0, indexes, buffer_update)
        output = input + self.b
        return output

Currently, running this model with dynamo will error out with a Check failed: HasValue() error:

RuntimeError: ./torch_xla/csrc/runtime/pjrt_computation_client.h:166 : Check failed: HasValue() 

And looking at xla_args in our xla/torch_xla/core/dynamo_bridge.py at function extract_compiled_graph, we can see that xla_args passed from dynamo is empty:

[WONJOO | dynamo_bridge.py] xla_args=()

And as a result, the self.tensor is not materialized throughout the dynamo code patch, eventually causing the Check failed: HasValue() later when it's accessed.

As for a fix, we manually include self.tensor as part of the xla_args by calling the xla_model.named_buffers() that returns the the self.tensor as such:

Code:
for name, buffer in xla_model.named_buffers():
  print(f'[WONJOO | dynamo_bridge.py] name={name}')
  print(f'[WONJOO | dynamo_bridge.py] buffer={buffer}')

Output:
[WONJOO | dynamo_bridge.py] name=L__self___b
[WONJOO | dynamo_bridge.py] buffer=tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], device='xla:0')

@wonjoolee95 wonjoolee95 force-pushed the wonjoo/dynamo-inplace branch 2 times, most recently from 1131e7f to 374d967 Compare July 22, 2023 00:00
@wonjoolee95
Copy link
Collaborator Author

wonjoolee95 commented Jul 22, 2023

With the changes now, I can see that the model in the PR description with dynamo is now passing:

deivce=xla:0
-----CPU-----
before: cpu_model.b=tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])
output=tensor([[0.9724, 1.4463, 1.0837],
        [0.4915, 0.1398, 0.0953],
        [0.1789, 0.6370, 0.4601]])
after: cpu_model.b=tensor([[0.2868, 0.7033, 0.8447],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000]])
-----XLA-----
before: model.b=tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], device='xla:0')
output_xla=tensor([[0.9724, 1.4463, 1.0837],
        [0.4915, 0.1398, 0.0953],
        [0.1789, 0.6370, 0.4601]], device='xla:0')
after: model.b=tensor([[0.2868, 0.7033, 0.8447],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000]], device='xla:0')

Also added a unit test for in-place ops.

@wonjoolee95 wonjoolee95 changed the title [WIP] Add more support for in-place ops in dynamo bridge [WIP] Add support for in-place ops with self tensors in dynamo bridge Jul 22, 2023
for name, buffer in xla_model.named_buffers():
if "self" in name:
self_tensors.append(buffer)
torch_xla._XLAC._xla_sync_multi(self_tensors, devices=[], wait=True)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need a way to explicitly materialize the tensor, hence the _xla_sync_multi call.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this? Didn't we do mark_step on the entrance of the extract graph?

@@ -417,11 +417,19 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:

# partition the model and exectue to collect inputs
supported_ops = XlaOperatorSupport()
partitioner = CapabilityBasedPartitioner(xla_model, supported_ops)
partitioner = CapabilityBasedPartitioner(
xla_model, supported_ops, allows_single_node_partition=True)
Copy link
Collaborator Author

@wonjoolee95 wonjoolee95 Jul 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The allows_single_node_partition=True flag ensures that the partitioned module (even when it is one single large partition) goes into our own extract_internal call.

@wonjoolee95
Copy link
Collaborator Author

Some metrics failures with Dynamo, most likely due to the new _xla_sync_multi call. Let me make sure that only the metric tests are failing and that the other correctness tests are still passing.

@wonjoolee95 wonjoolee95 changed the title [WIP] Add support for in-place ops with self tensors in dynamo bridge Add support for in-place ops with self tensors in dynamo bridge Jul 24, 2023
@wonjoolee95 wonjoolee95 marked this pull request as ready for review July 24, 2023 18:10
@wonjoolee95 wonjoolee95 force-pushed the wonjoo/dynamo-inplace branch 3 times, most recently from 69cb83b to e296db0 Compare July 24, 2023 18:15
@wonjoolee95
Copy link
Collaborator Author

Some metrics failures with Dynamo, most likely due to the new _xla_sync_multi call. Let me make sure that only the metric tests are failing and that the other correctness tests are still passing.

Surrounded the _xla_sync_multi in an if-statement so it only runs if there is something in the named_buffers (i.e. there is a self.tensor). So the metrics shouldn't be impacted anymore, I'll let CI verify.

@wonjoolee95
Copy link
Collaborator Author

Seems like the allows_single_node_partition=True affected the CompileTime for CpuFallback related tests only (became one less). While I make sense of it, I updated the tests for now to let the CI continue to run.

@wonjoolee95
Copy link
Collaborator Author

Hmm, putting some debugging lines into the xm.mark_step() call (https://github.com/pytorch/xla/blob/master/torch_xla/core/dynamo_bridge.py#L392), I can see that the SyncLiveTensorsGraph actually does seem to materialize all the tensors, including the self.tensor. The following prints have been added at https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_graph_executor.cpp#L374, for-looping through the tensors:

[WONJOO | xla_graph_executor.cpp] tensors.size()=4
[WONJOO | xla_graph_executor.cpp] tensor->GetUniqueId()=1
[WONJOO | xla_graph_executor.cpp] tensor->shape().get().ToString()=f32[3]
[WONJOO | xla_graph_executor.cpp] tensor->GetUniqueId()=2
[WONJOO | xla_graph_executor.cpp] tensor->shape().get().ToString()=f32[3,3]
[WONJOO | xla_graph_executor.cpp] tensor->GetUniqueId()=3
[WONJOO | xla_graph_executor.cpp] tensor->shape().get().ToString()=s64[3]
[WONJOO | xla_graph_executor.cpp] tensor->GetUniqueId()=5
[WONJOO | xla_graph_executor.cpp] tensor->shape().get().ToString()=f32[5,3]

This is logs from running the newly added unit test DynamoInferenceBasicTest.test_simple_model_with_in_place_ops. And as we can see, the f32[5,3] with id 5 is the self.tensor and it does seem to get materialized.

However, with just this mark_step() and even with setting the wait=True flag, it still seems to fail with the buffer with shape f32[5,3] on device CPU:0 is null error if we remove that explicit xla_sync_multi call.

@wonjoolee95
Copy link
Collaborator Author

Okay, seems like the comment at #5309 (comment) may not be true. In the initial mark_step() call, I can see the following logs:

[WONJOO | init_python_bindings.cpp] at StepMarker
[WONJOO | xla_graph_executor.cpp] at SyncLiveTensorsGraph
[WONJOO | xla_graph_executor.cpp] tensors.size()=4
[WONJOO | xla_graph_executor.cpp] tensor->GetUniqueId()=1
[WONJOO | xla_graph_executor.cpp] tensor->shape().get().ToString()=f32[3]
[WONJOO | xla_graph_executor.cpp] tensor->GetUniqueId()=2
[WONJOO | xla_graph_executor.cpp] tensor->shape().get().ToString()=f32[3,3]
[WONJOO | xla_graph_executor.cpp] tensor->GetUniqueId()=3
[WONJOO | xla_graph_executor.cpp] tensor->shape().get().ToString()=s64[3]
[WONJOO | xla_graph_executor.cpp] tensor->GetUniqueId()=5
[WONJOO | xla_graph_executor.cpp] tensor->shape().get().ToString()=f32[5,3]

I thought f32[5,3] was the self.tensor. However, when I just try to add back the `_xla_sync_multi(self_tensors) call, I can see:

[WONJOO | init_python_bindings.cpp] SyncTensors, wait=0, sync_xla_data=1, warm_up_cache_only=0
[WONJOO | xla_graph_executor.cpp] at SyncTensorsGraph
[WONJOO | xla_graph_executor.cpp] wait=0
[WONJOO | xla_graph_executor.cpp] sync_ltc_data=1
[WONJOO | xla_graph_executor.cpp] warm_up_cache_only=0
[WONJOO | xla_graph_executor.cpp] tensors->size()=1
[WONJOO | xla_graph_executor.cpp] tensor->GetUniqueId()=12
[WONJOO | xla_graph_executor.cpp] tensor->shape().get().ToString()=f32[5,3]

It actually looks like the self_tensor is has GetUniqueId()=12. And this tensor with GetUniqueId()=12 is not included in the part of the initial mark_step call.

@JackCaoG
Copy link
Collaborator

Hmm, let me try to repo

@JackCaoG
Copy link
Collaborator

I can repo, I added a debug message

  for name, buffer in xla_model.named_buffers():
    if "self" in name:
      self_tensors.append(buffer)
      print(torch_xla._XLAC._get_xla_tensor_debug_info(buffer))

and in the

@@ -87,7 +89,9 @@ std::vector<XLATensorPtr> XLAGraphExecutor::DeviceContextArena::GetLiveTensors(
       auto data =
           std::dynamic_pointer_cast<XLATensor::Data>(uid_wptr.second.lock());
       if (data != nullptr) {
-        tensors.push_back(XLATensor::Create(std::move(data)));
+        auto t = XLATensor::Create(std::move(data));
+        std::cerr << "add tensor with id " << t->GetUniqueId() << "\n";
+        tensors.push_back(t);
       }

I saw

add tensor with id 1
add tensor with id 2
add tensor with id 3
add tensor with id 5
XLATensor {
TensorID: 11
Device: CPU:0
XLA Shape: f32[5,3]
IR: [] aten::index_add_, xla_shape=f32[5,3]{1,0}
XLAData: None
Tensor on host: None
}

add tensor with id 1
add tensor with id 2
add tensor with id 3
add tensor with id 6
add tensor with id 7
add tensor with id 8
add tensor with id 10
add tensor with id 11

What this means is that we do include the self.x in the mark_step somewhere. Then the next step is for me to figure out which mark_step. I added a pdb to mark_step call

--- a/torch_xla/core/xla_model.py
+++ b/torch_xla/core/xla_model.py
@@ -797,6 +797,7 @@ def _run_step_closures():
 
 
 def mark_step(wait=False):
+  import pdb; pdb.set_trace()
   if xu.getenv_as('XLA_EMIT_STEPLOG', bool, False):
     print(
         'torch_xla.core.xla_model::mark_step\n',

and found that first mark_step is in

391  	def extract_compiled_graph(xla_model, xla_args):
392  	  # This call is critical to make sure xla_args' tensor id show up in graph_input_tensor_ids
393  ->	  xm.mark_step()

which result in

add tensor with id 1
add tensor with id 2
add tensor with id 3
add tensor with id 5

Then I didn't see the second makr_step before seeing

add tensor with id 1
add tensor with id 2
add tensor with id 3
add tensor with id 6
add tensor with id 7
add tensor with id 8
add tensor with id 10
add tensor with id 11

which suggested that second execution is not triggered by a mark_step. If I just keep stepping the python function I can see it is actually _clear_pending_ir triggers the clean up. so

> /src/pytorch/xla/torch_xla/core/dynamo_bridge.py(444)extract_compiled_graph()
-> torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))
(Pdb) 
add tensor with id 1
add tensor with id 2
add tensor with id 3
add tensor with id 6
add tensor with id 7
add tensor with id 8
add tensor with id 10
add tensor with id 12

This make sense since if you look at how _clear_pending_irs works, it pretty much just collect all tensors and remove pending ir. Now the question goes back to why _clear_pending_irs will clean the self.x buffer

@JackCaoG
Copy link
Collaborator

The answer to above question is clear, because self.self_tensor actually has a IR, which is IR: [] aten::index_add_, xla_shape=f32[5,3]{1,0}. I think this is caused by tracing the model.

I think the issue is FallBackNodeCollector and CapabilityBasedPartitioner perform the tracing on self.self_tensor. For all of the xla_args we avoid this issue by

  cloned_xla_args = [
      torch.clone(xla_arg) if isinstance(xla_arg, torch.Tensor) else xla_arg
      for xla_arg in xla_args
  ]

and replace them after tracing. We need to do the same thing for the named_buffer

@wonjoolee95 wonjoolee95 force-pushed the wonjoo/dynamo-inplace branch 2 times, most recently from 4350e9f to 4de70eb Compare July 27, 2023 18:53
Remove debugging lines

Update unit tests to a model
@wonjoolee95 wonjoolee95 force-pushed the wonjoo/dynamo-inplace branch from 4de70eb to 77d8869 Compare July 27, 2023 20:25
@JackCaoG
Copy link
Collaborator

@wonjoolee95 Is this ready for review?

Surround  in an if-statement

Update metrics for fallback related dynamo tests

Update cloned args logic

Revert "Update metrics for fallback related dynamo tests"

This reverts commit 3855f43.
@wonjoolee95 wonjoolee95 force-pushed the wonjoo/dynamo-inplace branch from 77d8869 to b6bf058 Compare July 27, 2023 21:28
@@ -180,21 +217,21 @@ def fn_fallback(t):
xla_dynamo_res = dynamo_fn(t_xla)
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 4)
self.assertEqual(met.metric_data('ExecuteTime')[0], 6)
self.assertEqual(met.metric_data('ExecuteTime')[0], 7)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Execution numbers for fallback-related tests only changed because of the allows_single_node_partition=True mentioned below.

@wonjoolee95
Copy link
Collaborator Author

@wonjoolee95 Is this ready for review?

Yep, should be ready to review now. All dynamo unit tests passing locally:

(base) wonjoo@wonjoo-cpu:~/pytorch/xla$ python test/dynamo/test_dynamo.py
/home/wonjoo/miniconda3/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: 'libc10_cuda.so: cannot open shared object file: No such file or directory'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
...........
----------------------------------------------------------------------
Ran 11 tests in 28.067s

OK

@wonjoolee95
Copy link
Collaborator Author

Hmm, the failing CI seems to succeed locally:

(base) wonjoo@wonjoo-cpu:~/pytorch/xla$ python test/dynamo/test_dynamo.py DynamoInferenceBasicTest.test_simple_model_with_in_place_ops
/home/wonjoo/miniconda3/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: 'libc10_cuda.so: cannot open shared object file: No such file or directory'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
.
----------------------------------------------------------------------
Ran 1 test in 0.112s

OK

Looking into it.

@JackCaoG
Copy link
Collaborator

seems like cpu test still failed?

@wonjoolee95
Copy link
Collaborator Author

wonjoolee95 commented Jul 28, 2023

Hmm, the behavior of the newly added test test_simple_model_with_in_place_ops is odd. Trying to locally reproduce, I can see it succeeds. But if I keep running it consecutively, I'll eventually see a fail (5 consecutive runs will result in 4 successes and 1 failure):

(base) wonjoo@wonjoo-cpu:~/pytorch/xla$ python test/dynamo/test_dynamo.py 
.
----------------------------------------------------------------------
Ran 1 test in 0.120s

OK
(base) wonjoo@wonjoo-cpu:~/pytorch/xla$ python test/dynamo/test_dynamo.py 
.
----------------------------------------------------------------------
Ran 1 test in 0.120s

OK
(base) wonjoo@wonjoo-cpu:~/pytorch/xla$ python test/dynamo/test_dynamo.py 
(base) wonjoo@wonjoo-cpu:~/pytorch/xla$ python test/dynamo/test_dynamo.py 
.
----------------------------------------------------------------------
Ran 1 test in 0.120s

OK
(base) wonjoo@wonjoo-cpu:~/pytorch/xla$ python test/dynamo/test_dynamo.py DynamoInferenceBasicTest.test_simple_model_with_in_place_ops
F
======================================================================
FAIL: test_simple_model_with_in_place_ops (__main__.DynamoInferenceBasicTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/wonjoo/pytorch/xla/test/dynamo/test_dynamo.py", line 112, in test_simple_model_with_in_place_ops
    self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu()))
AssertionError: False is not true

----------------------------------------------------------------------
Ran 1 test in 0.122s

FAILED (failures=1)
(base) wonjoo@wonjoo-cpu:~/pytorch/xla$ 

@JackCaoG
Copy link
Collaborator

can you print the res_cpu and res_xla_dynamo?

@wonjoolee95
Copy link
Collaborator Author

wonjoolee95 commented Jul 28, 2023

So it seems like sometimes that res_xla_dynamo isn't updated properly:

(base) wonjoo@wonjoo-cpu:~/pytorch/xla$ python test/dynamo/test_dynamo.py DynamoInferenceBasicTest.test_simple_model_with_in_place_ops
[WONJOO] res_cpu=tensor([[ 2.,  3.,  4.],
        [ 1.,  1.,  1.],
        [ 8.,  9., 10.],
        [ 1.,  1.,  1.],
        [ 5.,  6.,  7.]])
[WONJOO] res_xla_dynamo.cpu()=tensor([[ 2.,  3.,  4.],
        [ 1.,  1.,  1.],
        [ 8.,  9., 10.],
        [ 1.,  1.,  1.],
        [ 5.,  6.,  7.]])
.
----------------------------------------------------------------------
Ran 1 test in 0.121s

OK
(base) wonjoo@wonjoo-cpu:~/pytorch/xla$ python test/dynamo/test_dynamo.py DynamoInferenceBasicTest.test_simple_model_with_in_place_ops
[WONJOO] res_cpu=tensor([[ 2.,  3.,  4.],
        [ 1.,  1.,  1.],
        [ 8.,  9., 10.],
        [ 1.,  1.,  1.],
        [ 5.,  6.,  7.]])
[WONJOO] res_xla_dynamo.cpu()=tensor([[ 2.,  3.,  4.],
        [ 1.,  1.,  1.],
        [ 8.,  9., 10.],
        [ 1.,  1.,  1.],
        [ 5.,  6.,  7.]])
.
----------------------------------------------------------------------
Ran 1 test in 0.120s

OK
(base) wonjoo@wonjoo-cpu:~/pytorch/xla$ python test/dynamo/test_dynamo.py DynamoInferenceBasicTest.test_simple_model_with_in_place_ops
[WONJOO] res_cpu=tensor([[ 2.,  3.,  4.],
        [ 1.,  1.,  1.],
        [ 8.,  9., 10.],
        [ 1.,  1.,  1.],
        [ 5.,  6.,  7.]])
[WONJOO] res_xla_dynamo.cpu()=tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
F

Thinking how this fails only once in a while, I thought it may be related to the mark_step(), so tried updating all the calls to mark_step(wait=True) but still seeing the occasional failure.

@JackCaoG
Copy link
Collaborator

Let me take a look

@JackCaoG
Copy link
Collaborator

I am able to repo the random failure, looking into it

@JackCaoG
Copy link
Collaborator

I noticed something weird, if I tried to dump the info about xla_arg and self_arg in the beginning of the extract internal with

def extract_internal(xla_model: torch.fx.GraphModule, self_args):
  xla_args = xla_model.xla_args
  print('xla arg\n')
  for xla_arg in xla_args:
    print(torch_xla._XLAC._get_xla_tensor_debug_info(xla_arg))
  print('self arg\n')
  for self_arg in self_args:
    print(torch_xla._XLAC._get_xla_tensor_debug_info(self_arg))  
xla arg

XLATensor {
TensorID: 1
Device: CPU:0
XLA Shape: f32[3]
ShardingSpec: None
IR: None
XLAData: 
  Data Device: CPU:0
  Data Shape: f32[3]
  Data Handle: 140350739976128
Tensor on host: None
}

XLATensor {
TensorID: 13
Device: CPU:0
XLA Shape: f32[5,3]
ShardingSpec: None
IR: None
XLAData: 
  Data Device: CPU:0
  Data Shape: f32[5,3]
  Data Handle: 140350739978224
Tensor on host: None
}

XLATensor {
TensorID: 3
Device: CPU:0
XLA Shape: s64[3]
ShardingSpec: None
IR: None
XLAData: 
  Data Device: CPU:0
  Data Shape: s64[3]
  Data Handle: 140350739977440
Tensor on host: None
}

XLATensor {
TensorID: 2
Device: CPU:0
XLA Shape: f32[3,3]
ShardingSpec: None
IR: None
XLAData: 
  Data Device: CPU:0
  Data Shape: f32[3,3]
  Data Handle: 140350739976688
Tensor on host: None
}

self arg

XLATensor {
TensorID: 13
Device: CPU:0
XLA Shape: f32[5,3]
ShardingSpec: None
IR: None
XLAData: 
  Data Device: CPU:0
  Data Shape: f32[5,3]
  Data Handle: 140350739978224
Tensor on host: None
}

Note that Tensor with tensor ID 13 showed up both in xla_arg and self_arg. This kind of suggest we don't need to special casing the self_arg.. Need to look into how much of that is caused by the change in this pr.

@JackCaoG
Copy link
Collaborator

ok I am a bit confuse by the additional layer of extract_compiled_graph and extract_internal, and we need to do this arg clone twice is confusing.. need to spend some time reading it again.

@JackCaoG
Copy link
Collaborator

hmm I am confuse, if I print in the extract_compiled_graph, I see

(Pdb) print(len(self_args))
1
(Pdb) print(len(xla_args))
3

if I look at extract_internal

(Pdb) len(self_args)
1
(Pdb) len(xla_args)
4

for some reason after the partitioner, self_arg becomes part of the xla_arg(or xla_model.xla_args)

@JackCaoG
Copy link
Collaborator

I dump the graph by adding

  torch_xla._XLAC._xla_warm_up_cache(args_and_out, [])
 + print(torch_xla._XLAC._get_xla_tensors_hlo(args_and_out))

and for the passing one I see

HloModule IrToHlo.18, entry_computation_layout={(f32[3,3]{1,0}, s64[3]{0}, f32[5,3]{1,0}, f32[3]{0})->(f32[5,3]{1,0}, f32[5,3]{1,0})}

%ScatterCombiner.4 (p0.5: f32[], p1.6: f32[]) -> f32[] {
  %p0.5 = f32[] parameter(0)
  ROOT %p1.6 = f32[] parameter(1)
}

ENTRY %IrToHlo.18 (p0.1: f32[3,3], p1.2: s64[3], p2.3: f32[5,3], p3.14: f32[3]) -> (f32[5,3], f32[5,3]) {
  %p2.3 = f32[5,3]{1,0} parameter(2)
  %p1.2 = s64[3]{0} parameter(1)
  %p0.1 = f32[3,3]{1,0} parameter(0)
  %scatter.7 = f32[5,3]{1,0} scatter(f32[5,3]{1,0} %p2.3, s64[3]{0} %p1.2, f32[3,3]{1,0} %p0.1), update_window_dims={1}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%ScatterCombiner.4
  %p3.14 = f32[3]{0} parameter(3)
  %broadcast.15 = f32[5,3]{1,0} broadcast(f32[3]{0} %p3.14), dimensions={1}
  %constant.8 = f32[] constant(1)
  %reshape.9 = f32[1,1]{1,0} reshape(f32[] %constant.8)
  %broadcast.10 = f32[1,1]{1,0} broadcast(f32[1,1]{1,0} %reshape.9), dimensions={0,1}
  %reshape.11 = f32[] reshape(f32[1,1]{1,0} %broadcast.10)
  %broadcast.12 = f32[5,3]{1,0} broadcast(f32[] %reshape.11), dimensions={}
  %multiply.13 = f32[5,3]{1,0} multiply(f32[5,3]{1,0} %scatter.7, f32[5,3]{1,0} %broadcast.12)
  %add.16 = f32[5,3]{1,0} add(f32[5,3]{1,0} %broadcast.15, f32[5,3]{1,0} %multiply.13)
  ROOT %tuple.17 = (f32[5,3]{1,0}, f32[5,3]{1,0}) tuple(f32[5,3]{1,0} %scatter.7, f32[5,3]{1,0} %add.16)
}

and for the failing one I see

HloModule IrToHlo.18, entry_computation_layout={(f32[3,3]{1,0}, s64[3]{0}, f32[5,3]{1,0}, f32[3]{0})->(f32[5,3]{1,0}, f32[5,3]{1,0})}

%ScatterCombiner.4 (p0.5: f32[], p1.6: f32[]) -> f32[] {
  %p0.5 = f32[] parameter(0)
  ROOT %p1.6 = f32[] parameter(1)
}

ENTRY %IrToHlo.18 (p0.1: f32[3,3], p1.2: s64[3], p2.3: f32[5,3], p3.14: f32[3]) -> (f32[5,3], f32[5,3]) {
  %p2.3 = f32[5,3]{1,0} parameter(2)
  %p1.2 = s64[3]{0} parameter(1)
  %p0.1 = f32[3,3]{1,0} parameter(0)
  %scatter.7 = f32[5,3]{1,0} scatter(f32[5,3]{1,0} %p2.3, s64[3]{0} %p1.2, f32[3,3]{1,0} %p0.1), update_window_dims={1}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%ScatterCombiner.4
  %p3.14 = f32[3]{0} parameter(3)
  %broadcast.15 = f32[5,3]{1,0} broadcast(f32[3]{0} %p3.14), dimensions={1}
  %constant.8 = f32[] constant(1)
  %reshape.9 = f32[1,1]{1,0} reshape(f32[] %constant.8)
  %broadcast.10 = f32[1,1]{1,0} broadcast(f32[1,1]{1,0} %reshape.9), dimensions={0,1}
  %reshape.11 = f32[] reshape(f32[1,1]{1,0} %broadcast.10)
  %broadcast.12 = f32[5,3]{1,0} broadcast(f32[] %reshape.11), dimensions={}
  %multiply.13 = f32[5,3]{1,0} multiply(f32[5,3]{1,0} %p2.3, f32[5,3]{1,0} %broadcast.12)
  %add.16 = f32[5,3]{1,0} add(f32[5,3]{1,0} %broadcast.15, f32[5,3]{1,0} %multiply.13)
  ROOT %tuple.17 = (f32[5,3]{1,0}, f32[5,3]{1,0}) tuple(f32[5,3]{1,0} %scatter.7, f32[5,3]{1,0} %add.16)
}

the difference is in this line, the one produce the correct output uses

  %multiply.13 = f32[5,3]{1,0} multiply(f32[5,3]{1,0} %scatter.7, f32[5,3]{1,0} %broadcast.12)

the one gives incorrect output does

  %multiply.13 = f32[5,3]{1,0} multiply(f32[5,3]{1,0} %p2.3, f32[5,3]{1,0} %broadcast.12)

scatter is pretty much just the result of the self.self_tensor.index_copy_(0, index, copy_tensor). To conclude, the difference in result depends on whether add happens on the result of the inde_copy or the origional self.self_tensor. In either case, we can see that scatter.7 is presented in the HLO, which means self.self_tensor should be updated.

I can confirm that by

-> self.assertTrue(torch.allclose(res_cpu, res_xla_dynamo.cpu()))
(Pdb) compiled_model.self_tensor
tensor([[1., 2., 3.],
        [0., 0., 0.],
        [7., 8., 9.],
        [0., 0., 0.],
        [4., 5., 6.]], device='xla:0')
(Pdb) cpu_model.self_tensor
tensor([[1., 2., 3.],
        [0., 0., 0.],
        [7., 8., 9.],
        [0., 0., 0.],
        [4., 5., 6.]])

@JackCaoG
Copy link
Collaborator

Ok I think I found the problem, it is actually not part of the pytorch/xla. I enabled the TORCH_XLA_DEBUG=1 which will dump the fx graph, in the passing case I see

def forward(self, l__self___self_tensor, l_index_ : torch.Tensor, l_copy_tensor_ : torch.Tensor, l_input_tensor_ : torch.Tensor):
    index_copy_ = l__self___self_tensor.index_copy_(0, l_index_, l_copy_tensor_);  l_index_ = l_copy_tensor_ = None
    add = l_input_tensor_ + l__self___self_tensor;  l_input_tensor_ = l__self___self_tensor = None
    return add

in the failing case I see

def forward(self, l_input_tensor_ : torch.Tensor, l__self___self_tensor, l_index_ : torch.Tensor, l_copy_tensor_ : torch.Tensor):
    add = l_input_tensor_ + l__self___self_tensor;  l_input_tensor_ = None
    index_copy_ = l__self___self_tensor.index_copy_(0, l_index_, l_copy_tensor_);  l__self___self_tensor = l_index_ = l_copy_tensor_ = None
    return add

In the failing, add happens before index_copy. Trying to figure out whether this is related to partitioner.

@JackCaoG
Copy link
Collaborator

If I dump the fx graph before the partitioner , I see

def forward(self, L_index_ : torch.Tensor, L_copy_tensor_ : torch.Tensor, L_input_tensor_ : torch.Tensor):
    l_index_ = L_index_
    l_copy_tensor_ = L_copy_tensor_
    l_input_tensor_ = L_input_tensor_
    l__self___self_tensor = self.L__self___self_tensor
    index_copy_ = l__self___self_tensor.index_copy_(0, l_index_, l_copy_tensor_);  l_index_ = l_copy_tensor_ = None
    add = l_input_tensor_ + l__self___self_tensor;  l_input_tensor_ = l__self___self_tensor = None
    return (add,)

after the partitioner

def forward(self, l_input_tensor_ : torch.Tensor, l__self___self_tensor, l_index_ : torch.Tensor, l_copy_tensor_ : torch.Tensor):
    add = l_input_tensor_ + l__self___self_tensor;  l_input_tensor_ = None
    index_copy_ = l__self___self_tensor.index_copy_(0, l_index_, l_copy_tensor_);  l__self___self_tensor = l_index_ = l_copy_tensor_ = None
    return add

so the issue is in the partitioner messed up the ordering.

@JackCaoG
Copy link
Collaborator

OK I think the issue is here

  partitions = partitioner.propose_partitions()
  print(partitions)

partitions doesn't seems to guarantee the order. Sometimes I see

[{add, index_copy_}]

other times I see

[{index_copy_, add}].   --> this one will produce the correct output

@JackCaoG
Copy link
Collaborator

If set allows_single_node_partition back to False, the test seems to consistently passing. @wonjoolee95 what's the reason you want to change that to true?

@wonjoolee95
Copy link
Collaborator Author

Thanks for the investigations, Jack. I saw that when allows_single_node_partition=False, the code path of this test case didn't actually fall through to the external_internal function. Only with this flag set to True, it reached the extract_internal function. Let me quickly double check that right now and get back.

@wonjoolee95 wonjoolee95 force-pushed the wonjoo/dynamo-inplace branch from e80b0a6 to c046f24 Compare July 31, 2023 20:42
@wonjoolee95
Copy link
Collaborator Author

Ok, I've set allows_single_node_partition back to False and the test/metrics look correct. And the code path does seem to properly enter into the extract_internal function. Now I can see the unit test consistently passing, and I can run and see that all the dynamo unit tests are passing:

/home/wonjoo/miniconda3/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: 'libc10_cuda.so: cannot open shared object file: No such file or directory'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
.
----------------------------------------------------------------------
Ran 11 tests in 27.713s

OK

This should be ready for review now.

@wonjoolee95 wonjoolee95 requested a review from JackCaoG July 31, 2023 20:46
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's hold on this pr until tmr, this pr also touches dynamo and I don't want it to break tmr's whl by accident, otherwise lGTM.

I am still concern about the partitoner might return ops with incorrect order, we should open a gihtub issue and follow up with Sherlock.

@JackCaoG JackCaoG added DO_NOT_MERGE_YET For PRs which cannot be merged, despite tests passing and removed DO_NOT_MERGE_YET For PRs which cannot be merged, despite tests passing labels Aug 1, 2023
@JackCaoG JackCaoG merged commit 67ff7ea into master Aug 1, 2023
will-cromar pushed a commit that referenced this pull request Sep 14, 2023
* Add more support for in-place ops in dynamo bridge

Run linter

* Add check to explicitly sync self tensors

Remove debugging lines

Update unit tests to a model

* Clean up some code

Surround  in an if-statement

Update metrics for fallback related dynamo tests

Update cloned args logic

Revert "Update metrics for fallback related dynamo tests"

This reverts commit 3855f43.

* Update single_node flag back to False
will-cromar added a commit that referenced this pull request Sep 15, 2023
* Sharding should be per output of IR Node, instead of per IR Node (#5330)

* sharding should be per output of IR Node, instead of per IR Node

* Update sharding_hash method

* Add test for sharding on IR with multiple output

* fix cpu test

* Fix a bug in getSharding

* Update Python device API for SPMD (#5129)

* Make python Api to respect the virtual device when SPMD is enabled

* fix typo

* Check out the release branch instead of origin/master in ansible (#5344)

* Also dump output sharding on HLO file (#5339)

* Also dump output sharding on HLO file

* only dump output sharding if dump format is HLO

* add test

* fix typo

* Make all-reduce a no-op when world size is 1 (#5342)

* Make all-reduce a no-op when world size is 1

* Fix torch.distributed test

* add fs linker flag (#5347)

* Add py3.10 whl path to doc, refactor whl table (#5354)

* fix amp dtype setting for GPU (#5337)

* fix amp dtype setting for GPU.

* fix ut

* fix lint.

* minor.

* Add python test for SPMD+Runtime Python API (#5349)

* Add python test for SPMD+Runtime Python API

* replace test name

* Update test_xla_spmd_python_api_interaction.py

* Check the actual device instead of query env var for virtual device (#5352)

* Check the actual device instead of query env var for virtual device

* revert unneeded change

* minor changes

* [BE] use self.assertEquals instead of str equality in test_zero1.py (#5364)

* Revert "[BE] use self.assertEquals instead of str equality in test_zero1.py (#5364)" (#5366)

This reverts commit 8ada333.

* [Dynamo|TPU] Tweak `atol` and `rtol` for `test_dynamo.py` (#5363)

* tweak `atol` and `rtol`

* [Dynamo|TPU] Skip`DynamoTrainingBasicTest.test_resnet18` on TPU (#5362)

*  Skip`DynamoTrainingBasicTest.test_resnet18` on TPU

* Add a script for running stablehlo tests. (#5360)

* Add kokoro presubmit for stablehlo tests

* Don't rewrite index hints in global save planning (#5348)

* [Dynamo|TPU] Skip `DynamoInferenceBasicTest.test_resnet18` on TPU (#5361)


* Skip `DynamoInferenceBasicTest.test_resnet18` on TPU

* [BE] use self.assertEquals instead of str equality in test_zero1.py (#5367)

* [BE] use self.assertEquals instead of str equality in test_zero1.py

* Use our own assertEqual

* Remove print statements

* Fix ReplicateShardedData for int type (#5374)

* Fix ReplicateShardedData for int type

* add test

* Update dynamo.md (#5378)

Update dynamo.md to remove note about fallback ops since they're supported now

* Revert "Fix ReplicateShardedData for int type (#5374)" (#5380)

This reverts commit 7fb7dfe.

* Remove the mention of XRT_TPU_CONFIG in the CONTRIBUTING.md (#5379)

* [Dynamo|TPU] Tweak `atol` and `rtol` for `test_simple_model_with_different_input_shape` on TPU (#5373)

* tweak `atol` and `rtol` for `test_simple_model_with_different_input_shape` on TPU

* Rectify test_zero1.py once optim.load_state_dict doesn't guarantee immutability (#5382)

* [TEST ONLY] print statements for test_zero1.py to debug

* Try fix

* Rectify test_zero1.py to account for state_dict modification

* Fix lint

* Add gpu doc for how to build PyTorch/XLA from source with GPU support. (#5384)

* Add gpu doc for how to build PyTorch/XLA from source with GPU support.

* fix typo

* fix comments

* fix comments

* clear pending ir should also clear the cc op tokens (#5385)

* Port resnet data loading optimizations to SPMD test script (#5386)

* Add support for in-place ops with self tensors in dynamo bridge (#5309)

* Add more support for in-place ops in dynamo bridge

Run linter

* Add check to explicitly sync self tensors

Remove debugging lines

Update unit tests to a model

* Clean up some code

Surround  in an if-statement

Update metrics for fallback related dynamo tests

Update cloned args logic

Revert "Update metrics for fallback related dynamo tests"

This reverts commit 3855f43.

* Update single_node flag back to False

* Add dynamo test in TPU CI (#5381)

Add dynamo test in TPU CI

* Add manual seed in multihost checkpoint (#5392)

* Fix change_id type in coverage uploading (#5394)

* Update dynamo cpu fallback op to aten::_foobar (#5393)

* Run single host multi GPU tests in the CI. (#5387)

* Add gpu doc for how to build PyTorch/XLA from source with GPU support.

* Run single host multi GPU tests.

* fix linter

* fix linter

* fix error

* fix test

* [PJRT] Separate collective ops test from TPU runtime test. (#5396)

* [PJRT] Separate collective ops test from TPU runtime test.

* formatting

* Fix ReplicateShardedData for int type (#5404)

* Update the dynamo backend name to `openxla` (#5402)

* Replace aot backend with openxla

* Update the inference backend except the fallback tests

* handle the fallback tests

* update remaining test

* update doc

* add torch pin

* Delete .torcch_pin

* linter

* [SPMD] Multi-host batch sharded data loading (#5331)

* Refactor to share code between export_torch_model and save_as_stablehlo (#5388)

* Refactor to share code between export_torch_model and save_as_stablehlo

* Fix TPU collective ops test for multi-host TPUs (#5408)

* Fix TPU collective ops test for multi-host TPUs

* formatting

* Partially replicate lower-rank tensors (#5409)

* Partially replicate lower-rank tensors

* Fix unit test

* Remove unnecessary device count check

* Fix unordered partition spec test

* yapf

* Revert "Partially replicate lower-rank tensors (#5409)" (#5412)

This reverts commit 56a6a02.

* SPMD cross slice-replication using partial_replication sharding (#5411)

* Revert "Support unordered sharding spec for partial replication (#5316)"
* Update test_2d_tensor_3d_mesh unit test to surface a bug
* Use partial replication for 2D tensor over 3D mesh sharding

* Fix the incorect clone arg condition in dynamo bridge (#5414)

* [SPMD] named partition spec support (#5415)

[SPMD] named partition spec

* [PJRT|TPU] Update `test_xla_devices_single_process_all_chips` for expected device number (#5421)

Update `test_xla_devices_single_process_all_chips` for expected device number

* Add repo for libcudnn8=8.7.0.84 and CUDA 11.8 (#5425)

* Update fix_includes.sh (#5441)

Without this patch I cannot get torch_xla to build outside of the docker. This should fix it.

* [PJRT] Support `torchrun` with `pjrt://` `init_method` (#5438)

* Support torchrun with `pjrt://` `init_method`

* move import

* fix error

* Fix NameError

* Fix path

* Remove from TPU CI

* Bugfix + add more test for llama (#5439)

Bugfix details:
1. When the graph have mutations the exported graph will have additional
   inputs. For now we are dropping them.
2. We should trace with args instead of final_args.

* Move the C++ test build to CI build job instead of test job (#5442)

* Update gcc to 10. (#5445)

* Update gcc to 10,

And use unversioned clang-format (so it's installation will succeed)
in both debian bullseye and buster

* gcc10 to ansible

* Update the random seed for every dynamo execution (#5444)

* Revert "Update gcc to 10. (#5445)" (#5449)

This reverts commit 454e916.

Co-authored-by: JackCaoG <59073027+JackCaoG@users.noreply.github.com>

* Install gcc-10 (#5450)

* Revert "Install gcc-10 (#5450)" (#5452)

This reverts commit 65b7639.

* parallelize SPMD inputhandler and GetDataShards (#5447)

* parallelize SPMD inputhandler and GetDataShards

* add output handler trace

* Remove base image override from TPU CI build (#5453)

* Update to GCC 10 (#5451)

* Cache sharded placeholder for dynamo execution (#5446)

* Cache the output sharding spec for dynamo

* address review comments

* add test

* remove dead code

* add missing wait deivce ops

* Update xla_graph_executor.cpp

* linter

* Remove Docker image override from dev image (#5456)

* hack: implement (unimplement?) GetDataShard for XRT

* skip flaky test (#5459)

* Neuron import hook (#5429)

* Enable Neuron import hook for calling initialization functions if using AWS Neuron

* removing copy/paste error

* moving aws init call and removing comment

* Add missing includes (#5434)

* Add missing includes

Currently this is included indirectly through PyTorch includes, but when I remove
the include from PyTorch's headers, the xla build fails.

* [TESTING] Pin PyTorch PR

* Retrigger CI after timeout

* Remove .torch_pin

* [GPU]Update README.md with wheel/docker for CUDA12.0 and deprecate CUDA11.7 (#5443)

* [GPU]Update README.md with wheel and docker support CUDA12.0 and deprecate CUDA 11.7

* Update README.md with docker support CUDA 12.0 and python 3.8

* Update README.md

* Update README.md

* update remote cache key in ansible (#5463)

* Fix data type in Pow with Scalar base and Tensor exponent (#5467)

* fix dtype inference

* fix linter

* bump the timeout for CI (#5470)

* Fix the input sharding for dynamo (#5469)

* Enabling sharding device data IR (#5475)

* Allow shard device data IR

* Handle XLATensor that is DeviceData IR and does not have XLAData

* fix typo

* Introduce `torch_xla.runtime.use_spmd()` (#5474)

Introduce torch_xla.runtime.use_spmd() and torch_xla.runtime.is_spmd()

* Enable PJRT C API Client and other changes for Neuron (#5428)

* Enable PJRT C API Client and other changes for Neuron

* keeping quotes consistent

* fixing device type call

* refactoring neuron initialization with spawn

* updating replication setting only for torchrun

* removing set replication in xla backed was added to rendezvous handler

* removing workaround for world_size/master_port for neuron

* fixing linter issues

* Don't move full tensor to device in deferred_init (#4819)

* [SPMD] Fix HybridMesh ordering (#5478)

Summary:
In xs.HybridMesh, it assumes the xr.global_runtime_device_attributes() will return the attributes according to the PyTorch/XLA's logical global ordinals. However, it turns out not to be the case.

To fix this, we pass the logical global ordinal as one of the attributes and xs.HybridMesh will sort the attributes according to this new attribute before using the array.

Test Plan:
PJRT_DEVICE=TPU USE_XLA_SPMD=1 python test/spmd/test_xla_sharding.py -v -k test_hybrid_mesh

* [SPMD] Properly skip tests on TPU V2 (#5479)

Summary:
Some of the tests only fail on TPU V2 but were skipped for all TPUs.
Let's fix that.

Test Plan:
PJRT_DEVICE=TPU USE_XLA_SPMD=1 python test/spmd/test_xla_sharding.py

* Add @yeounoh to .github CODEOWNERS (#5482)

* Add Python API to execute StableHLO bytecode (#5476)

* [SPMD] Fix TPU CI after #5478 (#5487)

* [SPMD] Fix TPU CI after #5478

Summary:
Let's fix all TPU CI failures after #5478.

Test Plan:
TPU CI

* Fix linters

* [SPMD] Fix XLA_DUMP_POST_OPTIMIZATIONS test (#5485)

Summary:
XLA_DUMP_POST_OPTIMIZATIONS was set as static which means that the value will be fixed during the whole test run for a particular test suite.

Therefore, let's make a separate file.

Test Plan:
PJRT_DEVICE=TPU USE_XLA_SPMD=1 python test/spmd/test_xla_sharding.py
PJRT_DEVICE=TPU USE_XLA_SPMD=1 python test/spmd/test_xla_sharding_hlo.py

* [Dist] Refactor ZeRO-1 (#5145)

* refactor

* fix

* fix

* add padding

* more robust save/load

* Update artifacts.auto.tfvars for 2.1 release (#5483)

* Update artifacts.auto.tfvars for 2.1 release

Update artifacts.auto.tfvars for 2.1 release

* Remove cuda version 11.7 and add 12.0 for 2.1 triggers

* Add 3.10 tpu version

* Add ShardingSpec to XLATensor when it is created with a PJRTShardedData (#5489)

* Add ShardingSpec to XLATensor when it is created with a PJRTShardedData

* add test

* Add topological sorting to dynamo partitions (#5472)

* Add topological sorting to dynamo partitions

* Run linter

* Update unit tests to include more in-place ops

* [SPMD] Patch nn.Linear (#5491)

Summary:
This pull request introduces a patched version of torch.nn.functional.linear that uses einsum instead of torch.matmul which will flatten the tensors to 2D and collide the sharded dimensions. The torch.matmul default behavior makes it very hard for XLA compiler to propagate the sharding annotation.

Test Plan:
PJRT_DEVICE=CPU python test/test_operations.py -v -k test_patched_linear

* [original author: mrnikwaws] Neuron operator support (#5471)

* adding glu operator support

* adding glu operator

* fixing yaml

* fixing linter issues

* fixing linter issues

* fixing spacing

* fixing spacing

* fixing spacing

* fixing spacing

* fixing shape helper

* fixing spacing

* [SPMD] Make IR sharding custom sharding op (#5433)

Summary:
This pull request changes the syntax of IR sharding by making it a new node instead of just attaching the sharding spec to the tensor. On the same time, we will still attach a sharding spec to the newly created XLATensor which will hold the new IR node.

This new IR node will be a CustomSharding node and in hlo:
%annotate = f32[6,3] custom-call(%copy), custom_call_target="Sharding", sharding={devices=[2,1]0,1}

Test Plan:
PJRT_DEVICE=TPU XLA_USE_SPMD=1 python test/spmd/test_xla_sharding.py -v -k test_mark_sharding_ir
PJRT_DEVICE=TPU XLA_USE_SPMD=1 python test/spmd/test_xla_sharding.py -v -k test_inplace_add_with_sharding

* Support input sharding changed after first dynamo tracing (#5477)

* Support input sharding changed after first dynamo tracing

* fix linter

* Handle the different input for dynamo sharding change

* update counter

* only get sharding specs when spmd is enabled

* add option to skip checking input sharding after x runs

* handle the cpu test

* make XLA_DYNAMO_INPUT_SHARDING_CHECK_THREASHOLD configable

* fix review comments

* Always use ExecuteReplicated with SPMD (#5494)

* Always use ExecuteReplicated with SPMD

* Add unit test

* Skip a couple tests on TPU due to precision issue (#5496)

* Refactor stablehlo API and put them in official location. (#5493)

Changes include:

* make end point in torch_xla/init.py for exposed APIs torch_xla.save_as_stablehlo and torch_xla.save_torch_model_as_stablehlo.
* All tf related integration to its own file.
* Remove args as argument (because it will spear inside of ExportedProgram) but allow user to override it (which we use for now.

* Support tuples in partition spec (#5488)

* Support tuples in partition spec

* Add unit test for partial replication

* yapf

* Support higher-rank tensors over lower-rank mesh

* Fix test & yapf

* Don't use partition_spec when creating group assignment

* Update documentation

* More documentation

* Translate named specs in ShardingSpec

* Add a API to explictly init runtime (#5500)

* Add explict error message when tensor is on CPU for dynamo backend (#5499)

* remove torchvision in stablehlo.py (#5501)

* Fix tupled partition spec test on v3 (#5503)

* Update dynamo doc (#5506)

* Update dynamo.md (#5509)

fixing typo

* Get original_traced_args as example_inputs. (#5511)

Change due to changing name in pytorch/pytorch#107978

* mark_sharding over a replicated tensor is allowed. (#5513)

* [SPMD] Propagate replicated output (#5508)

Summary:
During the LLaMA2 experiements, I disovered that manually marking 1D tensors to be replicated can greatly save a lot of memory. Then I disocvered that explicitly replicated spec will get dropped after mark_step. That is caused by PrepareOutputShardingPropagation where it explicitly clear the sharding spec for replicated output. So, I went ahead and fix that.

Further, I did some experiements of propogating replicated output and that drop the requirements of manually replicating 1D tensors. Hence, I made this change.

I'm still not quite sure why, will follow up later.

Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py

* Disable cxx abi in ansible when building pt/xla for branch r2.0 (#5332)

* Update pytorch git tag for r2.1 (#5529)

Update more places

Add torch_pin

* Enable megacore_dense by default (#5520) (#5531)

Summary:
This change enables megacore_dense by default to allow asynchorous cc
ops especailly for GSPMD.

Test Plan:
CI

Co-authored-by: Jiewen Tan <jwtan@google.com>

* Add option to unbundle libtpu (#5534) (#5536)

* Add optiona to unbundle libtpu

* Add clarifying note

* Revert 2.1 terraform changes (#5537)

* Fix FSDP for Models with Frozen Weights (#5484) (#5539)

* Fix fsdp not freeing forzen full params

* add test

* formatting

* remove unnecessary env var in test

Co-authored-by: Liyang90 <liyanglu@google.com>

* Update r2.1 wheel to be compatible with PyPI (#5550)

* Update project metadata and remove useless files

* Update README

* Add manylinux platform tag

* formatting

* Add resnet50-weight-quant colab notebook (#5407) (#5556)

* Add resnet50-weight-only-quant colab notebook

* update notebook with llama blog link

Co-authored-by: Siyuan Liu <lsiyuan@google.com>

* hack: add placeholders for `HasSharding` and `GetSharding` to XRT

* formatting

* hack: always return false from `HasSharding`

* Update torch pin to current RC for CI testing

* Cherry pick `pjrt://` init method rename and doc updates (#5562)

* Change `pjrt://` init method to `xla://` (#5560)

* Update PJRT documentation for the 2.1 release (#5557)

* Update PJRT documentation for the 2.1 release

* clarify plugins

* clarify PJRT doc

* Update `pjrt://` to `xla://`

* Use new cache silo and skip test build

* hack: disable missing test

* hack: alter cache silo name

* formatting

---------

Co-authored-by: JackCaoG <59073027+JackCaoG@users.noreply.github.com>
Co-authored-by: iefgnoix <isaacwxf23@gmail.com>
Co-authored-by: Siyuan Liu <lsiyuan@google.com>
Co-authored-by: Baole Ai <baoleai01@gmail.com>
Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
Co-authored-by: Manfei <41607353+ManfeiBai@users.noreply.github.com>
Co-authored-by: qihqi <hanq@google.com>
Co-authored-by: jonb377 <jonbolin@google.com>
Co-authored-by: Wonjoo Lee <wonjoo@google.com>
Co-authored-by: Mohit Khatwani <118776932+khatwanimohit@users.noreply.github.com>
Co-authored-by: Yeounoh Chung <yeounoh@google.com>
Co-authored-by: Mateusz Lewko <mateusz.lewko@gmail.com>
Co-authored-by: Alisson Azzolini <37222419+aazzolini@users.noreply.github.com>
Co-authored-by: aws-kingrj <78175353+aws-kingrj@users.noreply.github.com>
Co-authored-by: peterbell10 <peterbell10@live.co.uk>
Co-authored-by: Zach Zheng <zczheng@amazon.com>
Co-authored-by: Jiewen Tan <jwtan@google.com>
Co-authored-by: Huang, Guangtai <guangtai@amazon.com>
Co-authored-by: Shauheen <shauheen@users.noreply.github.com>
Co-authored-by: Liyang90 <liyanglu@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants