-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ray] launch multiple GPU with ray #396
Conversation
475668a
to
61101c3
Compare
PR实现了通过ray方式启动多进程。参考vllm使用RayGPUExecutor来管理多个worker,每个worker执行diffusers pipefline的逻辑。 目前这种方式和torchrun启动程序(example.py)用法差别太大。 我建议设计一个DiffusionPipeline的Ray分布式版本,RayDiffusionPipeline,然后这个类提供from_pretrained,forward等接口。 PR中hardcode了一些地方,比如对模型初始化text_encoder处理,因为目前text_encoder是没有多卡切分的,可以让每个worker都重复载入text_encoder,希望尽量保持和torchrun接口一致性。 |
examples/ray/ray_flux_example.py
Outdated
# output is a list of results from each worker, we take the last one | ||
for i, image in enumerate(output[-1].images): | ||
image.save( | ||
f"/data/results/{model_name}_result_{i}.png" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
save to a relative path ./results/xxx
@@ -188,6 +192,9 @@ class ParallelConfig: | |||
sp_config: SequenceParallelConfig | |||
pp_config: PipeFusionParallelConfig | |||
tp_config: TensorParallelConfig | |||
distributed_executor_backend: Optional[str] = None | |||
world_size: int = 1 # FIXME: remove this | |||
worker_cls: str = "xfuser.ray.worker.worker.Worker" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need distributed_executor_backend and worker_cls?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need distributed_executor_backend, but we need worker_cls for ray to initial worker by its class name
def init_worker(self, *args, **kwargs):
worker_class = resolve_obj_by_qualname(
self.worker_cls)
self.worker = worker_class(*args, **kwargs)
assert self.worker is not None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Support Ray to start the pipeline