Skip to content

Commit

Permalink
Guard 2D sharding for activations and inputs (#18)
Browse files Browse the repository at this point in the history
Summary:
This pull requests fix a bug in #17 where it forgot to guard 2D sharding for activations and inputs.

Test Plan:
N/A.
  • Loading branch information
alanwaketan authored Aug 2, 2023
1 parent 813af25 commit 0a9c9e0
Showing 1 changed file with 32 additions and 30 deletions.
62 changes: 32 additions & 30 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,21 +372,22 @@ def forward(
if not output_attentions:
attn_weights = None

# Apply 2D sharding:
# activation (data,, None, model)
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
import torch_xla
num_devices = xr.global_runtime_device_count()
device_ids = torch.arange(num_devices)
print('> Sharding activations', attn_output.shape)
model = self.spmd_2d_sharding
data = num_devices // model
assert model * data == num_devices
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model))
xs.mark_sharding(attn_output, data_model_mesh, (0, 1, 2))
print(torch_xla._XLAC._get_xla_sharding_spec(attn_output))
if self.spmd_2d_sharding > 0:
# Apply 2D sharding:
# activation (data,, None, model)
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
import torch_xla
num_devices = xr.global_runtime_device_count()
device_ids = torch.arange(num_devices)
print('> Sharding activations', attn_output.shape)
model = self.spmd_2d_sharding
data = num_devices // model
assert model * data == num_devices
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model))
xs.mark_sharding(attn_output, data_model_mesh, (0, 1, 2))
print(torch_xla._XLAC._get_xla_sharding_spec(attn_output))

return attn_output, attn_weights, past_key_value

Expand Down Expand Up @@ -681,21 +682,22 @@ def forward(

# Is this the input to the model?
hidden_states = inputs_embeds
# Apply 2D sharding:
# input (data,, None, model)
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
import torch_xla
num_devices = xr.global_runtime_device_count()
device_ids = torch.arange(num_devices)
print('> Sharding hidden_states', hidden_states.shape)
model = self.spmd_2d_sharding
data = num_devices // model
assert model * data == num_devices
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model))
xs.mark_sharding(hidden_states, data_model_mesh, (0, 1, 2))
print(torch_xla._XLAC._get_xla_sharding_spec(hidden_states))
if self.spmd_2d_sharding > 0:
# Apply 2D sharding:
# input (data,, None, model)
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
import torch_xla.runtime as xr
import torch_xla
num_devices = xr.global_runtime_device_count()
device_ids = torch.arange(num_devices)
print('> Sharding hidden_states', hidden_states.shape)
model = self.spmd_2d_sharding
data = num_devices // model
assert model * data == num_devices
data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model))
xs.mark_sharding(hidden_states, data_model_mesh, (0, 1, 2))
print(torch_xla._XLAC._get_xla_sharding_spec(hidden_states))

if self.gradient_checkpointing and self.training:
if use_cache:
Expand Down

0 comments on commit 0a9c9e0

Please sign in to comment.