From f22c8ac54da291495feb99fabae91fd64dcc8488 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 20 Jan 2025 11:37:50 +0800 Subject: [PATCH] [torch.compile] fix sym_tensor_indices (#12191) Signed-off-by: youkaichao Signed-off-by: Isotr0py <2037008807@qq.com> --- vllm/compilation/backends.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index d7f4dcb7a20fc..955c25f300512 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -624,9 +624,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: ] # index of tensors that have symbolic shapes (batch size) + # for weights and static buffers, they will have concrete shapes. + # symbolic shape only happens for input tensors. + from torch.fx.experimental.symbolic_shapes import is_symbolic self.sym_tensor_indices = [ i for i, x in enumerate(fake_args) - if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) + if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \ + any(is_symbolic(d) for d in x.size()) ] # compiler managed cudagraph input buffers