Skip to content

Commit

Permalink
show parameter memory usage, upgrad diffusers to 0.31
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear committed Oct 25, 2024
1 parent f0593b8 commit b158d97
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
5 changes: 4 additions & 1 deletion examples/pixartalpha_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def main():
engine_config=engine_config,
torch_dtype=torch.float16,
).to(f"cuda:{local_rank}")
model_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")
pipe.prepare_run(input_config)

torch.cuda.reset_peak_memory_stats()
Expand Down Expand Up @@ -62,7 +63,9 @@ def main():
print(img_file)

if get_world_group().rank == get_world_group().world_size - 1:
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB")
print(
f"{parallel_info}: epoch time: {elapsed_time:.2f} sec, model memory: {model_memory/1e9:.2f} GB, overall memory: {peak_memory/1e9:.2f} GB"
)
get_runtime_state().destory_distributed_env()


Expand Down
5 changes: 4 additions & 1 deletion examples/sd3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def main():
engine_config=engine_config,
torch_dtype=torch.float16,
).to(f"cuda:{local_rank}")

model_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

pipe.prepare_run(input_config)

torch.cuda.reset_peak_memory_stats()
Expand Down Expand Up @@ -63,7 +66,7 @@ def main():

if get_world_group().rank == get_world_group().world_size - 1:
print(
f"{parallel_info} epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB"
f"{parallel_info} epoch time: {elapsed_time:.2f} sec, overall memory: {peak_memory/1e9:.2f} GB, model memory: {model_memory/1e9:.2f} GB"
)
get_runtime_state().destory_distributed_env()

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_cuda_version():
install_requires=[
"torch>=2.1.0",
"accelerate>=0.33.0",
"diffusers@git+https://github.com/huggingface/diffusers", # NOTE: diffusers>=0.31.0.dev is necessary for CogVideoX and Flux
"diffusers>=0.31", # NOTE: diffusers>=0.31.0.dev is necessary for CogVideoX and Flux
"transformers>=4.39.1",
"sentencepiece>=0.1.99",
"beautifulsoup4>=4.12.3",
Expand Down

0 comments on commit b158d97

Please sign in to comment.