From 0f3e0eb1cdb7760be00a59a8bcdce03f4b144ec5 Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Thu, 19 Dec 2024 17:02:35 +0800 Subject: [PATCH] [platform] support pytorch custom op pluggable Signed-off-by: wangxiyuan --- vllm/model_executor/custom_op.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index fddc8bad09ef5..1b6be2efa3df4 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -20,6 +20,16 @@ def __init__(self): super().__init__() self._forward_method = self.dispatch_forward() + @classmethod + def set_foward_method(cls, method): + """Provide a way to register a custom forward method for a specific + backend.""" + if getattr(cls, f"forward_{current_platform.device_name}", None): + raise ValueError( + f"Custom op {cls.__class__.__name__} already has a " + f"forward_{current_platform.device_name} method") + setattr(cls, f"forward_{current_platform.device_name}", method) + def forward(self, *args, **kwargs): return self._forward_method(*args, **kwargs) @@ -72,18 +82,15 @@ def dispatch_forward(self): if not enabled: return self.forward_native - if current_platform.is_rocm(): - return self.forward_hip - elif current_platform.is_cpu(): - return self.forward_cpu - elif current_platform.is_hpu(): - return self.forward_hpu - elif current_platform.is_tpu(): - return self.forward_tpu - elif current_platform.is_xpu(): - return self.forward_xpu - else: - return self.forward_cuda + custom_forward_func = \ + getattr(self, f"forward_{current_platform.device_name}", None) + if not custom_forward_func: + logger.warning( + "Custom op %s is not supported on %s, falling back " + "to native.", self.__class__.__name__, + current_platform.device_name) + return self.forward_native + return custom_forward_func @classmethod def enabled(cls) -> bool: