You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
TLDR: Pipeline works w/o ZeRO-1 but hangs w/ ZeRO-1
At BigScience we have a very large embed layer for the ml training and using Megatron-Deepspeed GPT we are trying to give it a whole pipe stage, since sharing with the transformer layer is too much and we get OOM.
To exemplify the problem let's just slice on embed layer:
I'm testing on just 2 gpus with 2 transformer layers and a tied embed before and after the transformer layers.
Let's look at partitioning weights:
with type:transformer[0, 0, 0, 1, 1, 0, 0, 0, 0]
with type:embed|transformer[0, 1, 0, 1, 1, 0, 0, 1, 0]
so the partitioning is identical:[0, 0, 0, 1], [1, 0, 0, 0, 0]
but with type:embed[0, 1, 0, 0, 0, 0, 0, 1, 0]
it splits [0, 1, 0, 0, 0, 0, 0], [1, 0]
and so it doesn't know how to handle a boundary that is not a transformer layer
If I trace the schedule cmds on the 2 gpus, I get:
gpu 0 has 2 items in self.bit16_groups , whereas gpu 1 only 1 item - and so they fail to sync.
In the type:transformer case, it has 2 items in both gpus and so they sync
So it looks like the partitioning didn't take care of ensuring the zero optimizer has the same optimizer groups on each gpu.
the last pipe partition has a different set of optimizer groups from the first one.
With type:transformer there is at least one transformer layer in each pipe stage, which ensures that all self.bit16_groups have 2 groups.
With type:embed we end up with stage 0 of all transformer layers and one embed, and stage 2 with one embed and no transformer layers. So the stages are different.
stas00
changed the title
[BUG] Pipeline not able to partition not on transformer layer w/ ZeRO-1
[BUG] ZeRO-1 hangs in optimizer step when used in Pipeline
Nov 6, 2021
Describe the bug
TLDR: Pipeline works w/o ZeRO-1 but hangs w/ ZeRO-1
At BigScience we have a very large embed layer for the ml training and using Megatron-Deepspeed GPT we are trying to give it a whole pipe stage, since sharing with the transformer layer is too much and we get OOM.
To exemplify the problem let's just slice on embed layer:
instead of
partition_method='type:embed|transformer'
as it hides the problem on a small setup.This is called here: https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/2d9744f23df1a67b4cc1523e3bbdcaca738eb391/megatron/model/gpt_model.py#L279-L283
When ZeRO-1 is used this hangs. W/o ZERO-1 it works.
The hanging is in (
py-spy dump
):I'm testing on just 2 gpus with 2 transformer layers and a tied embed before and after the transformer layers.
Let's look at partitioning weights:
type:transformer
[0, 0, 0, 1, 1, 0, 0, 0, 0]
type:embed|transformer
[0, 1, 0, 1, 1, 0, 0, 1, 0]
so the partitioning is identical:
[0, 0, 0, 1], [1, 0, 0, 0, 0]
but with
type:embed
[0, 1, 0, 0, 0, 0, 0, 1, 0]
it splits
[0, 1, 0, 0, 0, 0, 0], [1, 0]
and so it doesn't know how to handle a boundary that is not a transformer layer
If I trace the schedule cmds on the 2 gpus, I get:
The log is before the cmd is executed.
So gpu 0 hangs in
OptimizerStep
I checked that with
type:transformer
, it's the same trace, but it gets gpu 0 to complete:I tried to remap the above calls to the actual parallel sequence, does it look more or less correct? This indeed looks like an interleaved schedule:
with a bunch of prints and CUDA_LAUNCH_BLOCKING=1 I was able to trace it to this:
DeepSpeed/deepspeed/runtime/zero/stage2.py
Lines 1509 to 1510 in 85ce85d
so gpu0 is blocking on syncing with the gpu 1, but the latter has gone already.
The culprit seems to be this:
DeepSpeed/deepspeed/runtime/zero/stage2.py
Line 1636 in 85ce85d
gpu 0 has 2 items in
self.bit16_groups
, whereas gpu 1 only 1 item - and so they fail to sync.In the
type:transformer
case, it has 2 items in both gpus and so they syncSo it looks like the partitioning didn't take care of ensuring the zero optimizer has the same optimizer groups on each gpu.
the last pipe partition has a different set of optimizer groups from the first one.
With
type:transformer
there is at least one transformer layer in each pipe stage, which ensures that all self.bit16_groups have 2 groups.With
type:embed
we end up with stage 0 of all transformer layers and one embed, and stage 2 with one embed and no transformer layers. So the stages are different.Specifically it hangs here:
DeepSpeed/deepspeed/runtime/zero/stage2.py
Lines 1505 to 1510 in 85ce85d
since the other gpu isn't there to do this.
As mentioned earlier turning Z1 off solves the problem, but then we need more hardware to compensate for non-sharded optimizer states.
Thank you!
@ShadenSmith, @tjruwase
The text was updated successfully, but these errors were encountered: