diff --git a/vllm/config.py b/vllm/config.py index be796ff7d0940..268eca605c169 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -873,6 +873,13 @@ def __init__( f"distributed executor backend " f"'{self.distributed_executor_backend}'.") + if current_platform.is_tpu() and self.world_size > 1: + if self.distributed_executor_backend is None: + self.distributed_executor_backend = "ray" + if self.distributed_executor_backend != "ray": + raise ValueError( + "TPU backend only supports Ray for distributed inference.") + if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group.