diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index d7f4dcb7a20fc..955c25f300512 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -624,9 +624,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: ] # index of tensors that have symbolic shapes (batch size) + # for weights and static buffers, they will have concrete shapes. + # symbolic shape only happens for input tensors. + from torch.fx.experimental.symbolic_shapes import is_symbolic self.sym_tensor_indices = [ i for i, x in enumerate(fake_args) - if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) + if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \ + any(is_symbolic(d) for d in x.size()) ] # compiler managed cudagraph input buffers