Skip to content

Commit

Permalink
[example] print parameter memory usage
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear committed Oct 14, 2024
1 parent 80d439e commit 1fe2439
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 22 deletions.
6 changes: 5 additions & 1 deletion examples/flux_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def main():
else:
pipe = pipe.to(f"cuda:{local_rank}")

parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

pipe.prepare_run(input_config, steps=1)

torch.cuda.reset_peak_memory_stats()
Expand Down Expand Up @@ -69,7 +71,9 @@ def main():
print(f"image {i} saved to ./results/{image_name}")

if get_world_group().rank == get_world_group().world_size - 1:
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB")
print(
f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9:.2f} GB"
)
get_runtime_state().destory_distributed_env()


Expand Down
7 changes: 6 additions & 1 deletion examples/hunyuandit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def main():
engine_config=engine_config,
torch_dtype=torch.float16,
).to(f"cuda:{local_rank}")

parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

pipe.prepare_run(input_config)

torch.cuda.reset_peak_memory_stats()
Expand Down Expand Up @@ -63,7 +66,9 @@ def main():
)

if get_world_group().rank == get_world_group().world_size - 1:
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB")
print(
f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9:.2f} GB"
)
get_runtime_state().destory_distributed_env()


Expand Down
24 changes: 5 additions & 19 deletions examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ set -x
export PYTHONPATH=$PWD:$PYTHONPATH

# Select the model type
export MODEL_TYPE="Pixart-alpha"
export MODEL_TYPE="Flux"
# Configuration for different model types
# script, model_id, inference_step
declare -A MODEL_CONFIGS=(
["Pixart-alpha"]="pixartalpha_example.py /cfs/dit/PixArt-XL-2-1024-MS 20"
["Pixart-sigma"]="pixartsigma_example.py /cfs/dit/PixArt-Sigma-XL-2-2K-MS 20"
["Sd3"]="sd3_example.py /cfs/dit/stable-diffusion-3-medium-diffusers 20"
["Flux"]="flux_example.py /cfs/dit/FLUX.1-schnell 4"
["Flux"]="flux_example.py /cfs/dit/FLUX.1-dev 28"
["HunyuanDiT"]="hunyuandit_example.py /cfs/dit/HunyuanDiT-v1.2-Diffusers 50"
)

Expand All @@ -27,24 +27,10 @@ mkdir -p ./results
# task args
TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"

# Flux only supports SP. Do not set the pipefusion degree.
if [ "$MODEL_TYPE" = "Flux" ]; then
N_GPUS=8
PARALLEL_ARGS="--ulysses_degree $N_GPUS"
CFG_ARGS=""
PARALLEL_ARGS="--ulysses_degree 1 --ring_degree 1 --pipefusion_parallel_degree 8"

# HunyuanDiT asserts sp_degree == ulysses_degree*ring_degree <= 2, or the output will be incorrect.
elif [ "$MODEL_TYPE" = "HunyuanDiT" ]; then
N_GPUS=8
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 1"
CFG_ARGS="--use_cfg_parallel"

else
# On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch)
N_GPUS=8
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 1"
CFG_ARGS="--use_cfg_parallel"
fi
# CFG_ARGS="--use_cfg_parallel"

# By default, num_pipeline_patch = pipefusion_degree, and you can tune this parameter to achieve optimal performance.
# PIPEFUSION_ARGS="--num_pipeline_patch 8 "
Expand All @@ -65,7 +51,7 @@ $PIPEFUSION_ARGS \
$OUTPUT_ARGS \
--num_inference_steps $INFERENCE_STEP \
--warmup_steps 0 \
--prompt "A small dog" \
--prompt "brown dog laying on the ground with a metal bowl in front of him." \
$CFG_ARGS \
$PARALLLEL_VAE \
$COMPILE_FLAG
5 changes: 4 additions & 1 deletion examples/sd3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def main():
engine_config=engine_config,
torch_dtype=torch.float16,
).to(f"cuda:{local_rank}")

parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

pipe.prepare_run(input_config)

torch.cuda.reset_peak_memory_stats()
Expand Down Expand Up @@ -63,7 +66,7 @@ def main():

if get_world_group().rank == get_world_group().world_size - 1:
print(
f"{parallel_info} epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB"
f"{parallel_info} epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, peak memory: {peak_memory/1e9:.2f} GB"
)
get_runtime_state().destory_distributed_env()

Expand Down

0 comments on commit 1fe2439

Please sign in to comment.