-
-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC]: Refactor Worker and ModelRunner to consolidate control plane communication #5552
Comments
Great RFC! Have you had the chance to verify this with PP? #4412 |
Yes, definitely. Actually ideally we should design these together. One possibility is to merge #4412 in two parts, one for changes to worker-local execution and then a second PR to add the control plane to glue different workers together. That way, we can follow the interfaces that I'm proposing here and keep the control plane separate from model execution. Concretely, the first PR would then contain these changes:
The proposed This way, it will be easier to try different control plane methods. We can use the approach you have in #4412. Another option is a new backend in Ray that that we've been developing to improve performance for static task graphs. I wrote an integration for this based off of an earlier version of #4412: graph definition and how to call it. |
Yes agreed, we should chat more about this, what you're suggesting make sense to me. There are 3-4 optimizations that I know of that we can do on top of #4412, but my current plan is to have #4412 merged as a base on top of the current logic in order to have basic PP fully functional as soon as possible before moving on to performance refactoring. |
Does it make sense: class WorkerBase:
class WorkerInput:
def get_broadcastable_data(self):
pass
@staticmethod
def from_broadcast_data(data):
pass
def prepare_worker_input(self, seq_group_metadata_list) -> WorkerInput:
pass
def execute_worker(self, input_data: WorkerInput):
pass
def execute_model(self, seq_group_metadata_list):
if self.is_driver_worker:
worker_input = self.prepare_worker_input(seq_group_metadata_list)
self.execute_worker(worker_input)
model_input = self.model_runner.prepare_model_input(seq_group_metadata_list)
data_to_broadcast = worker_input.get_broadcastable_data()
data_to_broadcast.update(model_input.get_broadcastable_data())
broadcast_tensor_dict(data_to_broadcast, src=0)
else:
data_to_broadcast = broadcast_tensor_dict(src=0)
worker_input = self.WorkerInput.from_broadcast_data(data_to_broadcast)
self.execute_worker(worker_input)
model_input = self.ModelRunnerInput.from_broadcast_data(data_to_broadcast)
self.model_runner.execute_model(model_input)
class ModelRunnerBase:
class ModelRunnerInput:
def get_broadcastable_data(self):
pass
@staticmethod
def from_broadcast_data(data):
pass
def prepare_model_input(self, seq_group_metadata_list) -> ModelRunnerInput:
pass
def execute_model(self, input_data: ModelRunnerInput):
pass Then, say we want to add GPU worker and GPUModelRunner: class GPUWorker(WorkerBase):
class WorkerInput(WorkerBase.WorkerInput):
def get_broadcastable_data(self):
pass
@staticmethod
def from_broadcast_data(data):
pass
def prepare_worker_input(self, seq_group_metadata_list) -> WorkerInput:
pass
def execute_worker(self, input_data: WorkerInput):
pass
class GPUModelRunner(ModelRunnerBase):
class ModelRunnerInput(ModelRunnerBase.ModelRunnerInput):
def get_broadcastable_data(self):
pass
@staticmethod
def from_broadcast_data(data):
pass
def prepare_model_input(self, seq_group_metadata_list) -> ModelRunnerInput:
pass
def execute_model(self, input_data: ModelRunnerInput):
pass The control-plane communication is centralized in |
Great! Can you say a bit more about what optimizations you were thinking of? The reason I suggested splitting #4412 is that I think it will be easier to introduce some optimizations for PP if we can merge in this refactor first. |
@youkaichao that sounds good to me. I can make those changes in #5408. |
Off the top of my head, the following optimizations are possible:
In general, I agree that it would potentially make optimizations easier (particularly 2 above). My concern is that we are prioritizing optimization prematurely here and further delaying the already delayed PP feature. IMO we should have the functionality available fully first, and then implement this refactoring on top of that. cc: @zhuohan123 - this is what we chatted about last week |
Sounds good. Yes, I also don't want to block PP; I just think that it may actually be faster long-term to merge a version that's compatible with this refactor. If you do not want to split #4412, at least I think we need to move the p2p communication out of the model definitions and into the |
Makes sense! I'm not opposed to splitting #4412, but just think that if we are going to do so it's a good idea to coordinate more closely on the details so we have a good plan to get everything in with low friction. We can meet next week to talk in more detail if you are free.
To give more context, I had thought about this initially but decided against it at first since different models send/recv different numbers of tensors. For example, Given what you proposed above it also gets a little more complicated in the context of this PR if I understand correctly:
We would have to bubble up hidden_states and residuals through the model definition to |
Yes, will reach out to find some time to chat!
For this, I don't think we need to bubble up the sampling procedure to IntermediateOutput = Dict[str, torch.Tensor]
class ModelRunner:
def execute_model(self, model_input: ModelInput) -> Union[List[SamplingOutput], IntermediateOutput]:
pass |
I think this RFC makes a lot of sense. It's a great idea to put all of the communication logic in the same place. I previously misunderstood this as a bigger scope change that changes how we do control-plane communication. Some smaller questions about this RFC:
|
Yes! For now this is just proposing a refactor that would allow changing the control plane more easily. This RFC doesn't propose any behavior changes (except for squashing some broadcasts).
Ah, yes, I actually updated the RFC since @youkaichao suggested something similar. Now there is only
This would make it easier to integrate Ray DAG / any other control plane method that broadcasts the ExecuteModelRequest to all workers instead of broadcasting the ModelInput. If we want to support Ray DAG right now, we need to update the control flow in the worker and model runner, e.g., to skip the tensor broadcasts. We can do that for the main Worker codepath, but it makes the code pretty messy and we'd have to do the same thing for every other worker and model runner that we want to support. With the new APIs, we can just override class RayDAGWorkerBase(LocalOrDistributedWorkerBase):
def execute_model(self, execute_model_req: ExecuteModelRequest):
worker_input = self.prepare_worker_input(execute_model_req)
self.execute_worker(worker_input)
model_input = self.model_runner.prepare_model_input(execute_model_req)
return self.model_runner.execute_model(model_input) |
Merged |
Motivation.
Currently, both the Worker and the ModelRunner classes contain multi-GPU control plane communication code, i.e.
broadcast_tensor_dict
calls. They look something like this:Because the ModelRunner class contains both model execution code and multi-GPU control plane communication code, it makes it difficult to improve upon the performance:
Proposed Change.
Refactor
Worker
andModelRunner
classes to consolidate all control plane communication to theWorker
class. Both theWorker
andModelRunner
classes should implement new worker-local methods to prepare inputs and execute the model. There should be no control plane communication within these methods; this could be enforced using runtime checks.Here is the new proposed interface. The
Worker
andModelRunner
create aWorkerInput
and aModelInput
respectively from theExecuteModelRequest
. The contract is thatExecuteModelRequest
contains CPU-only metadata, while any tensors inWorkerInput
andModelInput
should already be on the correct device. Now, the ModelRunnerBase class looks approximately like this:This interface allows for cleaner separation between control plane communication vs. single-GPU logic. Each ModelRunner needs to explicitly state what inputs it requires by defining a ModelInput (subclass). This requires a bit more developer effort but should make it easier to introduce the optimizations discussed above.
We also add a new
LocalOrDistributedWorkerBase
. The idea behind this class is that as long as the developer implements this interface plus a ModelRunnerBase, they will get support out-of-the-box for both local and distributed execution. This class has a default implementation forexecute_model
that contains all of the control plane communication needed for distributed execution.Custom model runners: For workers / model runners that need some custom logic, they can inherit directly from the generic
WorkerBase
and do not need to follow these interfaces. In that case, they are responsible for implementing their own control plane communication too.Speculative decoding: One complication is that the speculative decoding code goes back and forth between ExecuteModelRequest and ModelInput, whereas other workers only convert from ExecuteModelRequest to ModelInput. Thus, for the speculative decoding path, it's easier for now to keep the per-step broadcast. These extra k broadcasts could also be consolidated in the future, by either supporting ModelInput -> ExecuteModelRequest, or by making it possible to modify a ModelInput. Happily, the latter should be compatible with the solutions proposed in #5561.
Pipeline parallelism: In pipeline parallelism, workers before the last PP rank will return some intermediate tensor(s) instead of a
SamplerOutput
. To support this case, we should define anIntermediateOutput
type for models that support PP. Then, we extendModelRunnerBase.execute_model
to return aUnion[SamplerOutput, IntermediateOutput]
instead of just aSamplerOutput
.Feedback Period.
One week. See #5408 for code.
CC List.
@youkaichao @zhuohan123 @zhisbug @cadedaniel @rkooo567
Any Other Things.
Checklist:
The text was updated successfully, but these errors were encountered: