Skip to content

Commit

Permalink
polish cogvideo example (xdit-project#266)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear committed Oct 25, 2024
1 parent 334eea1 commit 701645c
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 8 deletions.
23 changes: 21 additions & 2 deletions examples/cogvideox_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ def main():
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)

# Check if ulysses_degree is valid
num_heads = 30
if engine_args.ulysses_degree > 0 and num_heads % engine_args.ulysses_degree != 0:
raise ValueError(
f"ulysses_degree ({engine_args.ulysses_degree}) must be a divisor of the number of heads ({num_heads})"
)

engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank

Expand All @@ -30,7 +38,8 @@ def main():
pipe.enable_model_cpu_offload(gpu_id=local_rank)
pipe.vae.enable_tiling()
else:
pipe = pipe.to(f"cuda:{local_rank}")
device = torch.device(f"cuda:{local_rank}")
pipe = pipe.to(device)

torch.cuda.reset_peak_memory_stats()
start_time = time.time()
Expand All @@ -49,8 +58,18 @@ def main():
elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

parallel_info = (
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
f"tp{engine_args.tensor_parallel_degree}_"
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if is_dp_last_group():
export_to_video(output, "results/output.mp4", fps=8)
world_size = get_data_parallel_world_size()
resolution = f"{input_config.width}x{input_config.height}"
output_filename = f"results/cogvideox_{parallel_info}_{resolution}.mp4"
export_to_video(output, output_filename, fps=8)
print(f"output saved to {output_filename}")

if get_world_group().rank == get_world_group().world_size - 1:
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB")
Expand Down
15 changes: 10 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from setuptools import find_packages, setup
import subprocess


def get_cuda_version():
try:
nvcc_version = subprocess.check_output(["nvcc", "--version"]).decode("utf-8")
version_line = [line for line in nvcc_version.split('\n') if "release" in line][0]
cuda_version = version_line.split(' ')[-2].replace(',', '')
return 'cu' + cuda_version.replace('.', '')
version_line = [line for line in nvcc_version.split("\n") if "release" in line][
0
]
cuda_version = version_line.split(" ")[-2].replace(",", "")
return "cu" + cuda_version.replace(".", "")
except Exception as e:
return 'no_cuda'
return "no_cuda"


if __name__ == "__main__":
with open("README.md", "r") as f:
Expand All @@ -29,9 +33,10 @@ def get_cuda_version():
"sentencepiece>=0.1.99",
"beautifulsoup4>=4.12.3",
"distvae",
"yunchang==0.3",
"yunchang>=0.3.0",
"pytest",
"flask",
"opencv-python",
],
extras_require={
"[flash_attn]": [
Expand Down
1 change: 1 addition & 0 deletions tests/layers/attention_processor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def init_process(rank, world_size, fn, run_attn_test):

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
os.environ["LOCAL_RANK"] = str(rank)

init_distributed_environment(rank=rank, world_size=world_size)
initialize_model_parallel(
Expand Down
2 changes: 1 addition & 1 deletion xfuser/core/long_ctx_attention/ring/ring_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def ring_flash_attn_forward(
raise ValueError(
f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}"
)
elif joint_strategy is not "none" and (
elif joint_strategy != "none" and (
joint_tensor_key is None or joint_tensor_value is None
):
raise ValueError(
Expand Down

0 comments on commit 701645c

Please sign in to comment.