Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] ZeRO-1 hangs in optimizer step when used in Pipeline #1522

Open
stas00 opened this issue Nov 5, 2021 · 0 comments
Open

[BUG] ZeRO-1 hangs in optimizer step when used in Pipeline #1522

stas00 opened this issue Nov 5, 2021 · 0 comments
Labels
bug Something isn't working

Comments

@stas00
Copy link
Collaborator

stas00 commented Nov 5, 2021

Describe the bug

TLDR: Pipeline works w/o ZeRO-1 but hangs w/ ZeRO-1

At BigScience we have a very large embed layer for the ml training and using Megatron-Deepspeed GPT we are trying to give it a whole pipe stage, since sharing with the transformer layer is too much and we get OOM.

To exemplify the problem let's just slice on embed layer:

        super().__init__(layers=self.specs,
                         loss_fn=CrossEntropy,
                         topology=topo,
                         activation_checkpoint_interval=interval,
                         partition_method='type:embed')

instead of partition_method='type:embed|transformer' as it hides the problem on a small setup.

This is called here: https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/2d9744f23df1a67b4cc1523e3bbdcaca738eb391/megatron/model/gpt_model.py#L279-L283

When ZeRO-1 is used this hangs. W/o ZERO-1 it works.

The hanging is in (py-spy dump):

Thread 1084205 (active): "MainThread"
    get_grad_norm_direct (deepspeed/runtime/zero/stage2.py:1512)
    step (deepspeed/runtime/zero/stage2.py:1645)
    _take_model_step (deepspeed/runtime/engine.py:1538)
    _exec_optimizer_step (deepspeed/runtime/pipe/engine.py:1124)
    _exec_schedule (deepspeed/runtime/pipe/engine.py:1335)
    train_batch (deepspeed/runtime/pipe/engine.py:329)
    train_step (megatron/training.py:405)
    train (megatron/training.py:737)
    pretrain (megatron/training.py:165)
    <module> (pretrain_gpt.py:237)

I'm testing on just 2 gpus with 2 transformer layers and a tied embed before and after the transformer layers.

Let's look at partitioning weights:

  • with type:transformer [0, 0, 0, 1, 1, 0, 0, 0, 0]
  • with type:embed|transformer [0, 1, 0, 1, 1, 0, 0, 1, 0]

so the partitioning is identical:[0, 0, 0, 1], [1, 0, 0, 0, 0]

but with type:embed [0, 1, 0, 0, 0, 0, 0, 1, 0]
it splits [0, 1, 0, 0, 0, 0, 0], [1, 0]
and so it doesn't know how to handle a boundary that is not a transformer layer

If I trace the schedule cmds on the 2 gpus, I get:

0 LoadMicroBatch(buffer_id=0)
1 RecvActivation(buffer_id=0)
0 ForwardPass(buffer_id=0)
0 SendActivation(buffer_id=0)
0 LoadMicroBatch(buffer_id=1)
1 LoadMicroBatch(buffer_id=0)
0 ForwardPass(buffer_id=1)
1 ForwardPass(buffer_id=0)
0 SendActivation(buffer_id=1)
0 RecvGrad(buffer_id=0)
0 BackwardPass(buffer_id=0)
1 BackwardPass(buffer_id=0)
1 RecvActivation(buffer_id=1)
1 SendGrad(buffer_id=0)
1 LoadMicroBatch(buffer_id=1)
1 ForwardPass(buffer_id=1)
1 BackwardPass(buffer_id=1)
1 SendGrad(buffer_id=1)
1 ReduceTiedGrads()
1 ReduceGrads()
1 OptimizerStep()
0 RecvGrad(buffer_id=1)
0 BackwardPass(buffer_id=1)
0 ReduceTiedGrads()
0 ReduceGrads()
0 OptimizerStep()
1 done with schedule

The log is before the cmd is executed.

So gpu 0 hangs in OptimizerStep

I checked that with type:transformer, it's the same trace, but it gets gpu 0 to complete:

0 done with schedule
1 done with schedule

I tried to remap the above calls to the actual parallel sequence, does it look more or less correct? This indeed looks like an interleaved schedule:

0 LoadMicroBatch(buffer_id=0)
0 SendActivation(buffer_id=0) -> 1 RecvActivation(buffer_id=0)
0 ForwardPass(buffer_id=0)

0 LoadMicroBatch(buffer_id=1)    1 LoadMicroBatch(buffer_id=0)
0 ForwardPass(buffer_id=1)       1 ForwardPass(buffer_id=0)

0 SendActivation(buffer_id=1) -> 1 RecvActivation(buffer_id=1)
                                 1 BackwardPass(buffer_id=0)
0 RecvGrad(buffer_id=0)       -> 1 SendGrad(buffer_id=0)  

                                 1 LoadMicroBatch(buffer_id=1)
0 BackwardPass(buffer_id=0)      1 ForwardPass(buffer_id=1)

                                 1 BackwardPass(buffer_id=1)
0 RecvGrad(buffer_id=1)       <- 1 SendGrad(buffer_id=1)
                                 1 ReduceTiedGrads()
                                 1 ReduceGrads()
                                 1 OptimizerStep()

0 BackwardPass(buffer_id=1)
0 ReduceTiedGrads()
0 ReduceGrads()
0 OptimizerStep()

0 done with schedule
1 done with schedule

with a bunch of prints and CUDA_LAUNCH_BLOCKING=1 I was able to trace it to this:

self._model_parallel_all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.SUM)

so gpu0 is blocking on syncing with the gpu 1, but the latter has gone already.

The culprit seems to be this:

for i, group in enumerate(self.bit16_groups):

gpu 0 has 2 items in self.bit16_groups , whereas gpu 1 only 1 item - and so they fail to sync.

In the type:transformer case, it has 2 items in both gpus and so they sync

So it looks like the partitioning didn't take care of ensuring the zero optimizer has the same optimizer groups on each gpu.

the last pipe partition has a different set of optimizer groups from the first one.

With type:transformer there is at least one transformer layer in each pipe stage, which ensures that all self.bit16_groups have 2 groups.

With type:embed we end up with stage 0 of all transformer layers and one embed, and stage 2 with one embed and no transformer layers. So the stages are different.

Specifically it hangs here:

torch.distributed.all_reduce(total_norm_cuda,
op=torch.distributed.ReduceOp.SUM,
group=self.dp_process_group)
self._model_parallel_all_reduce(tensor=total_norm_cuda,
op=torch.distributed.ReduceOp.SUM)

since the other gpu isn't there to do this.

As mentioned earlier turning Z1 off solves the problem, but then we need more hardware to compensate for non-sharded optimizer states.

Thank you!

@ShadenSmith, @tjruwase

@stas00 stas00 added the bug Something isn't working label Nov 5, 2021
@stas00 stas00 changed the title [BUG] Pipeline not able to partition not on transformer layer w/ ZeRO-1 [BUG] ZeRO-1 hangs in optimizer step when used in Pipeline Nov 6, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant