diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 87655530cead4..157e3f7f39c9c 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -251,15 +251,27 @@ def _check_can_cache(*args, **kwargs): def _get_shape_env() -> AlwaysHitShapeEnv: return AlwaysHitShapeEnv() - with patch(# for hijacking the hash of the compiled graph - "torch._inductor.codecache.compiled_fx_graph_hash", - hijack_compiled_fx_graph_hash), \ - patch(# for providing a dummy shape environment - "torch._inductor.codecache.FxGraphCache._get_shape_env", - _get_shape_env), \ - patch(# for forcing the graph to be cached - "torch._inductor.codecache.FxGraphCache._check_can_cache", - _check_can_cache): + with ExitStack() as stack: + if not cache_data.disabled: + # compilation cache is enabled, patch several functions + + # for hijacking the hash of the compiled graph + stack.enter_context( + patch("torch._inductor.codecache.compiled_fx_graph_hash", + hijack_compiled_fx_graph_hash)) + + # for providing a dummy shape environment + stack.enter_context( + patch( + "torch._inductor.codecache.FxGraphCache._get_shape_env", + _get_shape_env)) + + # for forcing the graph to be cached + stack.enter_context( + patch( + "torch._inductor.codecache.FxGraphCache._check_can_cache", + _check_can_cache)) + compiled_graph = compile_fx(graph, example_inputs, config_patches=current_config)