From 0a9c9e03aa71e49d0b04713a6a9d88086a9b9356 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Wed, 2 Aug 2023 10:05:37 -0700 Subject: [PATCH] Guard 2D sharding for activations and inputs (#18) Summary: This pull requests fix a bug in https://github.com/pytorch-tpu/transformers/pull/17 where it forgot to guard 2D sharding for activations and inputs. Test Plan: N/A. --- .../models/llama/modeling_llama.py | 62 ++++++++++--------- 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index ca01a0e3e8b..bf3f6ed3400 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -372,21 +372,22 @@ def forward( if not output_attentions: attn_weights = None - # Apply 2D sharding: - # activation (data,, None, model) - import torch_xla.core.xla_model as xm - import torch_xla.experimental.xla_sharding as xs - import torch_xla.runtime as xr - import torch_xla - num_devices = xr.global_runtime_device_count() - device_ids = torch.arange(num_devices) - print('> Sharding activations', attn_output.shape) - model = self.spmd_2d_sharding - data = num_devices // model - assert model * data == num_devices - data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model)) - xs.mark_sharding(attn_output, data_model_mesh, (0, 1, 2)) - print(torch_xla._XLAC._get_xla_sharding_spec(attn_output)) + if self.spmd_2d_sharding > 0: + # Apply 2D sharding: + # activation (data,, None, model) + import torch_xla.core.xla_model as xm + import torch_xla.experimental.xla_sharding as xs + import torch_xla.runtime as xr + import torch_xla + num_devices = xr.global_runtime_device_count() + device_ids = torch.arange(num_devices) + print('> Sharding activations', attn_output.shape) + model = self.spmd_2d_sharding + data = num_devices // model + assert model * data == num_devices + data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model)) + xs.mark_sharding(attn_output, data_model_mesh, (0, 1, 2)) + print(torch_xla._XLAC._get_xla_sharding_spec(attn_output)) return attn_output, attn_weights, past_key_value @@ -681,21 +682,22 @@ def forward( # Is this the input to the model? hidden_states = inputs_embeds - # Apply 2D sharding: - # input (data,, None, model) - import torch_xla.core.xla_model as xm - import torch_xla.experimental.xla_sharding as xs - import torch_xla.runtime as xr - import torch_xla - num_devices = xr.global_runtime_device_count() - device_ids = torch.arange(num_devices) - print('> Sharding hidden_states', hidden_states.shape) - model = self.spmd_2d_sharding - data = num_devices // model - assert model * data == num_devices - data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model)) - xs.mark_sharding(hidden_states, data_model_mesh, (0, 1, 2)) - print(torch_xla._XLAC._get_xla_sharding_spec(hidden_states)) + if self.spmd_2d_sharding > 0: + # Apply 2D sharding: + # input (data,, None, model) + import torch_xla.core.xla_model as xm + import torch_xla.experimental.xla_sharding as xs + import torch_xla.runtime as xr + import torch_xla + num_devices = xr.global_runtime_device_count() + device_ids = torch.arange(num_devices) + print('> Sharding hidden_states', hidden_states.shape) + model = self.spmd_2d_sharding + data = num_devices // model + assert model * data == num_devices + data_model_mesh = xs.HybridMesh(ici_mesh_shape=(data, 1, model)) + xs.mark_sharding(hidden_states, data_model_mesh, (0, 1, 2)) + print(torch_xla._XLAC._get_xla_sharding_spec(hidden_states)) if self.gradient_checkpointing and self.training: if use_cache: