Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Movie Gen] Performance optimization and other minor changes #816

Merged
merged 16 commits into from
Jan 17, 2025
Merged
1 change: 1 addition & 0 deletions examples/moviegen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ Experiments were conducted on Ascend 910* using MindSpore 2.3.1 in Graph mode.

| Model | Cards | Stage | Batch size | Resolution | Jit level | Compile time | Recompute | Gradient Acc | ZeRO | Sequence Parallel | TAE Cache | Time (s/step) | Config |
|:-----:|:-----:|:---------:|:-----------------------:|:-----------------------:|:---------:|:------------:|:-----------------------:|:------------:|:----:|:-----------------:|:---------:|:-------------:|:--------------------------------------------------------------:|
| 30B | 8 | 3 (T2V) | Video: 1 | 256x576x1024 | O1 | 7m | ON | 1 | 3 | 8 shards | Yes | 37.7 | [stage3_t2iv_768px.yaml](configs/train/stage3_t2iv_768px.yaml) |
| 30B | 8 | 2 (T2V) | Video: 1 | 256x256x455 | O1 | 6m | ON | 1 | 3 | 8 shards | Yes | 4.08 | [stage2_t2iv_256px.yaml](configs/train/stage2_t2iv_256px.yaml) |
| 5B | 8 | 1 (T2I) | 10 | 256x455 | O1 | 3m 40s | ON | 1 | No | No | Yes | 1.29 | [stage1_t2i_256px.yaml](configs/train/stage1_t2i_256px.yaml) |
| 5B | 8 | 2 (T2I/V) | Image: 1<br/>Video: 1 | 256x455<br/>256 frames | O1 | 6m | ON<br/>(Every 2 blocks) | 5 | 2 | No | Yes | 5.09 | [stage2_t2iv_256px.yaml](configs/train/stage2_t2iv_256px.yaml) |
Expand Down
2 changes: 1 addition & 1 deletion examples/moviegen/configs/train/stage2_t2iv_256px.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ train:

lr_scheduler:
name: constant
lr: 6.0e-5
lr: 1.0e-4
warmup_steps: 1000

lr_reduce_on_plateau:
Expand Down
2 changes: 1 addition & 1 deletion examples/moviegen/configs/train/stage3_t2iv_768px.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ train:

lr_scheduler:
name: constant
lr: 6.0e-5
lr: 1.0e-4
warmup_steps: 1000

lr_reduce_on_plateau:
Expand Down
7 changes: 4 additions & 3 deletions examples/moviegen/docs/report.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,13 @@ and [USP](https://arxiv.org/abs/2405.07719)), we implement model parallelism
using [Ulysses-SP](https://arxiv.org/abs/2309.14509) together with [ZeRO-3](https://arxiv.org/abs/1910.02054),instead of
the approach used in Movie Gen. Ulysses-SP utilizes `All2ALL` communication for segments of the QKV tensors, drastically
reducing communication costs compared to sequence parallelism implemented
in [Megatron-LM](https://arxiv.org/abs/2405.07719), [DSP](https://arxiv.org/abs/2403.10266), as well as the sequence
in [Megatron-LM](https://arxiv.org/abs/2405.07719), as well as the sequence
parallelism mentioned
in [Movie Gen](https://ai.meta.com/research/publications/movie-gen-a-cast-of-media-foundation-models/). Alongside
ZeRO-3, it achieves similar memory efficiency to [[Megatron-LM](https://arxiv.org/abs/2405.07719)]. Experimental results
ZeRO-3, it achieves similar memory efficiency to [Megatron-LM](https://arxiv.org/abs/2405.07719). Experimental results
show that using Ulysses-SP + ZeRO-3, we can train a model of similar scale compared to 3D parallelism, with over 2x
speed boost in training, corroborating the findings
in [Megatron-LM](https://arxiv.org/abs/2405.07719), [Ulysses-SP](https://arxiv.org/abs/2309.14509),
in [Ulysses-SP](https://arxiv.org/abs/2309.14509), [USP](https://arxiv.org/abs/2405.07719)
and [DSP](https://arxiv.org/abs/2403.10266).

### Training Details
Expand Down Expand Up @@ -260,6 +260,7 @@ Experiments were conducted on Ascend 910* using MindSpore 2.3.1 in graph mode.

| Model | Cards | Stage | Batch size | Resolution | Recompute | TAE Cache | Time (s/step) | Recipe |
|:-----:|:-----:|:---------:|:-----------------------:|:-----------------------:|:-----------------------:|:---------:|:-------------:|:-----------------------------------------------------------------:|
| 30B | 8 | 3 (T2V) | Video: 1 | 256x576x1024 | ON | ON | 37.7 | [stage3_t2iv_768px.yaml](../configs/train/stage3_t2iv_768px.yaml) |
| 30B | 8 | 2 (T2V) | Video: 1 | 256x256x455 | ON | ON | 4.08 | [stage2_t2iv_256px.yaml](../configs/train/stage2_t2iv_256px.yaml) |
| 5B | 8 | 1 (T2I) | 10 | 256x455 | ON | ON | 1.29 | [stage1_t2i_256px.yaml](../configs/train/stage1_t2i_256px.yaml) |
| 5B | 8 | 2 (T2I/V) | Image: 1<br/>Video: 1 | 256x455<br/>256 frames | ON<br/>(Every 2 blocks) | ON | 5.09 | [stage2_t2iv_256px.yaml](../configs/train/stage2_t2iv_256px.yaml) |
Expand Down
19 changes: 7 additions & 12 deletions examples/moviegen/mg/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
tae_scale_factor: float = 1.5305,
tae_shift_factor: float = 0.0609,
target_size: Optional[Tuple[int, int]] = None,
sample_n_frames: int = 17,
sample_n_frames: int = 16,
sample_stride: int = 1,
frames_mask_generator: Optional[Callable[[int], np.ndarray]] = None,
t_compress_func: Optional[Callable[[int], int]] = None,
Expand Down Expand Up @@ -164,18 +164,14 @@ def _get_item(self, idx: int, thw: Optional[Tuple[int, int, int]] = None) -> Tup

if self._tae_latent_folder:
tae_latent_data = np.load(data["tae_latent"])
latent_mean, latent_std = tae_latent_data["latent_mean"], tae_latent_data["latent_std"] # C T H W
if latent_mean.shape[1] < self._min_length: # TODO: add support for images and buckets
latent_mean, latent_std = tae_latent_data["latent_mean"], tae_latent_data["latent_std"] # T C H W
if 1 < len(latent_mean) < self._min_length: # TODO: add support for buckets
raise ValueError(f"Video is too short: {data['video']}")

start_pos = random.randint(0, len(latent_mean) - self._min_length)
batch_index = np.linspace(start_pos, start_pos + self._min_length - 1, num_frames, dtype=int)

batch_index = np.linspace(0, self._min_length - 1, num_frames, dtype=int)
hadipash marked this conversation as resolved.
Show resolved Hide resolved
latent_mean, latent_std = latent_mean[batch_index], latent_std[batch_index]
tae_latent = np.random.normal(latent_mean, latent_std).astype(np.float32)
tae_latent = (tae_latent - self._tae_shift_factor) * self._tae_scale_factor
# FIXME: remove unnecessary transpose
data["video"] = np.transpose(tae_latent, (1, 0, 2, 3)) # C T H W -> T C H W
data["video"] = (tae_latent - self._tae_shift_factor) * self._tae_scale_factor

else:
if data["video"].lower().endswith(IMAGE_EXT):
Expand All @@ -190,8 +186,7 @@ def _get_item(self, idx: int, thw: Optional[Tuple[int, int, int]] = None) -> Tup
min_length = (num_frames - 1) * self._stride + 1
if len(reader) < min_length:
raise ValueError(f"Video is too short: {data['video']}")
start_pos = random.randint(0, len(reader) - min_length)
data["video"] = reader.fetch_frames(num=num_frames, start_pos=start_pos, step=self._stride)
data["video"] = reader.fetch_frames(num=num_frames, start_pos=0, step=self._stride) # T H W C
data["fps"] = np.array(reader.fps / self._stride, dtype=np.float32)

data["num_frames"] = np.array(num_frames, dtype=np.float32)
Expand Down Expand Up @@ -249,7 +244,7 @@ def train_transforms(
ResizeCrop(target_size, interpolation=interpolation),
lambda x: x.astype(np.float32) / 127.5 - 1,
lambda x: x[None, ...] if x.ndim == 3 else x, # if image
lambda x: np.transpose(x, (0, 3, 1, 2)),
lambda x: np.transpose(x, (3, 0, 1, 2)), # T H W C -> C T H W
],
"input_columns": ["video"],
}
Expand Down
22 changes: 9 additions & 13 deletions examples/moviegen/mg/models/llama/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __init__(
)
num_heads = self.num_heads // self.sp_group_size if self.sp_group_size is not None else self.num_heads
self.flash_attention = FlashAttentionScore(
num_heads, keep_prob=1 - self.attention_dropout, scale_value=self.head_dim**-0.5, input_layout="BSND"
num_heads, keep_prob=1 - self.attention_dropout, scale_value=self.head_dim**-0.5, input_layout="BNSD"
)

def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tensor] = None) -> Tensor:
Expand All @@ -185,12 +185,8 @@ def construct(self, hidden_states: Tensor, encoder_hidden_states: Optional[Tenso
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

# Reshape to the expected shape and dtype for Flash Attention
query_states = mint.permute(query_states, (0, 2, 1, 3))
key_states = mint.permute(key_states, (0, 2, 1, 3))
value_states = mint.permute(value_states, (0, 2, 1, 3))

_, _, _, attn_output = self.flash_attention(query_states, key_states, value_states, None, None, None, None)
attn_output = mint.permute(attn_output, (0, 2, 1, 3))
attn_output = self.alltoall(attn_output)
attn_output = ops.reshape(attn_output, (bsz, q_len, -1))
attn_output = self.o_proj(attn_output)
Expand Down Expand Up @@ -266,6 +262,7 @@ def __init__(
hidden_size: int,
frequency_embedding_size: int = 256,
hidden_act: str = "silu",
max_period: int = 10000,
dtype: ms.Type = ms.float32,
) -> None:
super().__init__()
Expand All @@ -275,23 +272,22 @@ def __init__(
mint.nn.Linear(hidden_size, hidden_size, bias=False, dtype=dtype),
)
self.frequency_embedding_size = frequency_embedding_size
half = frequency_embedding_size // 2
self._freqs = Tensor(np.exp(-np.log(max_period) * np.arange(start=0, stop=half, dtype=np.float32) / half)[None])
self._dtype = dtype

@property
def dtype(self):
return self._dtype

@staticmethod
def timestep_embedding(t: Tensor, dim: int, max_period: int = 10000) -> Tensor:
half = dim // 2
freqs = mint.exp(-mint.log(Tensor(max_period)) * mint.arange(start=0, end=half, dtype=ms.float32) / half)
args = ops.unsqueeze(t, 1).to(ms.float32) * ops.unsqueeze(freqs, 0)
def timestep_embedding(self, t: Tensor) -> Tensor:
args = ops.unsqueeze(t, 1).to(ms.float32) * self._freqs
embedding = mint.cat([mint.cos(args), mint.sin(args)], dim=-1)
if dim % 2:
if self.frequency_embedding_size % 2:
embedding = mint.cat([embedding, mint.zeros_like(embedding[:, :1])], dim=-1)
return embedding

def construct(self, t: Tensor) -> Tensor:
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_freq = self.timestep_embedding(t)
t_emb = self.mlp(t_freq.to(self.dtype))
return t_emb
8 changes: 3 additions & 5 deletions examples/moviegen/mg/pipelines/train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,12 @@ def get_condition_embeddings(self, text_tokens: Tensor) -> Tensor:
return text_emb

def get_latents(self, video_tokens: Tensor) -> Tensor:
if self.video_emb_cached:
if self.video_emb_cached: # (B, T, C, H, W)
return video_tokens
with no_grad():
# (b c f h w) shape is expected. FIXME: remove this redundancy
video_tokens = mint.permute(video_tokens, (0, 2, 1, 3, 4))
with no_grad(): # (B, C, T, H, W)
video_emb = ops.stop_gradient(self.tae.encode(video_tokens)[0]).to(ms.float32)
video_emb = (video_emb - self.tae.shift_factor) * self.tae.scale_factor
video_emb = mint.permute(video_emb, (0, 2, 1, 3, 4)) # FIXME
video_emb = mint.permute(video_emb, (0, 2, 1, 3, 4)) # FIXME: move inside `Encoder`
return video_emb

def set_train(self, mode=True):
Expand Down
4 changes: 4 additions & 0 deletions examples/moviegen/scripts/inference_tae.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def encode(args, tae: TemporalAutoencoder, save_dir: Path, rank_id: int, device_
mean, logvar = to_numpy(mean), to_numpy(logvar)
std = np.exp(0.5 * np.clip(logvar, -30.0, 20.0))

# C T H W -> T C H W
mean = np.transpose(mean, (1, 0, 2, 3))
std = np.transpose(std, (1, 0, 2, 3))

for m, s, path in zip(mean, std, samples[1].tolist()):
out_path = save_dir / path
out_path.parent.mkdir(parents=True, exist_ok=True)
Expand Down
1 change: 1 addition & 0 deletions examples/moviegen/scripts/moviegen/30B_stage2_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ python scripts/train.py \
--dataset.tae_latent_folder TAE_LATENT_FOLDER \
--dataset.text_emb_folder.ul2 UL2_FOLDER \
--dataset.text_emb_folder.byt5 BYT5_FOLDER \
--dataset.sample_n_frames 32 \
--dataloader.batch_size 1 \
--train.ema "" \
--train.output_path "$output_dir"
31 changes: 31 additions & 0 deletions examples/moviegen/scripts/moviegen/30B_stage3_train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# plot memory usage, feature/model: 1
export MS_MEMORY_STATISTIC=0

# operation/graph fusion for dynamic shape
# export MS_DEV_ENABLE_KERNEL_PACKET=on # TODO: add dynamic shape support

# log level
export GLOG_v=2

output_dir=output/stage3_t2iv_768px/$(date +"%Y.%m.%d-%H.%M.%S")

msrun --bind_core=True --master_port=8200 --worker_num=8 --local_worker_num=8 --log_dir="$output_dir" \
python scripts/train.py \
--config configs/train/stage3_t2iv_768px.yaml \
--env.mode 0 \
--env.jit_level O1 \
--env.max_device_memory 59GB \
--env.distributed True \
--model.name=llama-30B \
--train.settings.zero_stage 3 \
--train.sequence_parallel.shards 8 \
--dataset.csv_path CSV_PATH \
--dataset.video_folder VIDEO_FOLDER \
--dataset.tae_latent_folder TAE_LATENT_FOLDER \
--dataset.text_emb_folder.ul2 UL2_FOLDER \
--dataset.text_emb_folder.byt5 BYT5_FOLDER \
--dataset.sample_n_frames 32 \
--dataloader.batch_size 1 \
--train.ema "" \
--train.output_path "$output_dir"
5 changes: 4 additions & 1 deletion examples/moviegen/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ def main(args):
# if bucketing is used in Graph mode, activate dynamic inputs
if mode == GRAPH_MODE and isinstance(args.dataloader.batch_size, dict):
bs = Symbol(unique=True)
video = Tensor(shape=[bs, None, args.model.in_channels if tae is None else 3, None, None], dtype=mstype.float32)
if tae is None:
video = Tensor(shape=[bs, None, args.model.in_channels, None, None], dtype=mstype.float32)
else: # FIXME: Align TAE to B T C H W order
video = Tensor(shape=[bs, 3, None, None, None], dtype=mstype.float32)
# FIXME: fix sequence length
ul2_emb = Tensor(shape=[bs, 300, 4096], dtype=mstype.float32)
byt5_emb = Tensor(shape=[bs, 100, 1472], dtype=mstype.float32)
Expand Down
Loading