diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 9d633ad259b13..143cb49697f5b 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -92,7 +92,7 @@ def test_simple_piecewise_compile(): num_graphs_seen=1, # one graph for the model num_piecewise_graphs_seen=5, # 2 * num_layers + 1 num_piecewise_capturable_graphs_seen=3, # 1 + num_layers - num_inductor_compilations=3, # num_piecewise_capturable_graphs_seen + num_backend_compilations=3, # num_piecewise_capturable_graphs_seen num_cudagraph_caputured= 6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 0404722bab891..021bd4cc46356 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -322,7 +322,7 @@ def test_toy_llama(): num_graphs_seen=0, num_piecewise_graphs_seen=0, num_piecewise_capturable_graphs_seen=0, - num_inductor_compilations=0, + num_backend_compilations=0, num_cudagraph_caputured=0, ): outputs.append(run_model(llama_config, use_compile=False)) @@ -332,7 +332,7 @@ def test_toy_llama(): num_graphs_seen=1, # one graph for the model num_piecewise_graphs_seen=1, num_piecewise_capturable_graphs_seen=1, - num_inductor_compilations=1, # num_piecewise_capturable_graphs_seen + num_backend_compilations=1, # num_piecewise_capturable_graphs_seen num_cudagraph_caputured= 2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): @@ -345,7 +345,7 @@ def test_toy_llama(): 1, # 2 * num_layers + 1 num_piecewise_capturable_graphs_seen=1 + llama_config.num_layers, # 1 + num_layers - num_inductor_compilations=1 + + num_backend_compilations=1 + llama_config.num_layers, # num_piecewise_capturable_graphs_seen num_cudagraph_caputured=2 * (1 + llama_config.num_layers diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 979890170c16b..b972f03c9685b 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -1,12 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 import ast -import copy import dataclasses import os import pprint import time -from collections import defaultdict from contextlib import ExitStack from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from unittest.mock import patch @@ -19,6 +17,7 @@ from vllm.logger import init_logger from vllm.utils import weak_ref_tensors +from .compiler_interface import EagerAdaptor, InductorAdaptor from .counter import compilation_counter from .inductor_pass import InductorPass from .monitor import end_monitoring_torch_compile @@ -27,306 +26,128 @@ logger = init_logger(__name__) -@dataclasses.dataclass -class InductorArtifact: - hash_str: str = "" - file_path: str = "" +class CompilerManager: + """ + A manager to manage the compilation process, including + caching the compiled graph, loading the compiled graph, + and compiling the graph. + The cache is a dict mapping + `(runtime_shape, graph_index, backend_name)` + to `any_data` returned from the compiler. -class InductorHashCache: + When serializing the cache, we save it to a Python file + for readability. We don't use json here because json doesn't + support int as key. """ - Disk format: a Python list of tuples, each tuple is - (runtime_shape, graph_index, hash_str, file_path) - We use list of tuple for readability. - In-memory format: a defaultdict of dict, where the key is - runtime_shape, and the value is a dict of graph_index to hash_str. + def __init__(self, use_inductor: bool): + self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict() + cls = InductorAdaptor if use_inductor else EagerAdaptor + self.compiler = cls() - The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`, - we don't use json here because json doesn't support int as key. - - TODO: better off-the-shelf solution to serialize the data? - """ + def compute_hash(self, vllm_config: VllmConfig) -> str: + return self.compiler.compute_hash(vllm_config) - def __init__(self, cache_dir: str, disabled: bool = False): - self.cache: Dict[Optional[int], - Dict[int, InductorArtifact]] = defaultdict(dict) - self.disabled = disabled + def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + self.disable_cache = disable_cache self.cache_dir = cache_dir - self.cache_file_path = os.path.join(cache_dir, - "inductor_hash_cache.py") - if disabled: - return - # set flags so that Inductor and Triton store their cache - # in the cache_dir, then users only need to copy the cache_dir - # to another machine to reuse the cache. - inductor_cache = os.path.join(cache_dir, "inductor_cache") - os.makedirs(inductor_cache, exist_ok=True) - os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache - triton_cache = os.path.join(cache_dir, "triton_cache") - os.makedirs(triton_cache, exist_ok=True) - os.environ["TRITON_CACHE_DIR"] = triton_cache - if os.path.exists(self.cache_file_path): + self.cache_file_path = os.path.join(cache_dir, "vllm_compile_cache.py") + + if not disable_cache and os.path.exists(self.cache_file_path): + # load the cache from the file with open(self.cache_file_path) as f: - self.deserialize(f.read()) - - def deserialize(self, data: str): - # we use ast.literal_eval to parse the data - # because it is a safe way to parse Python literals. - # do not use eval(), it is unsafe. - list_data = ast.literal_eval(data) - for item in list_data: - runtime_shape = item[0] - graph_index = item[1] - hash_str = item[2] - # for compatibility of old version, - # where we don't have file_path. - # NOTE: after running the new code, the file_path - # will be updated. - file_path = "" if len(item) == 3 else item[3] - self.cache[runtime_shape][graph_index] = InductorArtifact( - hash_str=hash_str, file_path=file_path) - - def serialize(self) -> str: - data = [] - for runtime_shape, value in self.cache.items(): - for graph_index, inductor_artifact in value.items(): - data.append( - (runtime_shape, graph_index, inductor_artifact.hash_str, - inductor_artifact.file_path)) - printer = pprint.PrettyPrinter(indent=4) - return printer.pformat(data) + # we use ast.literal_eval to parse the data + # because it is a safe way to parse Python literals. + # do not use eval(), it is unsafe. + self.cache = ast.literal_eval(f.read()) + + self.compiler.initialize_cache(cache_dir=cache_dir, + disable_cache=disable_cache) def save_to_file(self): - if self.disabled: + if self.disable_cache: return with open(self.cache_file_path, "w") as f: - f.write(self.serialize()) - - def __contains__(self, key: Tuple[Optional[int], int]) -> bool: - if self.disabled: - return False - runtime_shape, graph_index = key - return runtime_shape in self.cache and graph_index in self.cache[ - runtime_shape] - - def __getitem__(self, key: Tuple[Optional[int], int]) -> InductorArtifact: - if self.disabled: - raise KeyError("cannot read from disabled cache") - runtime_shape, graph_index = key - return self.cache[runtime_shape][graph_index] - - def __setitem__(self, key: Tuple[Optional[int], int], - value: InductorArtifact): - # setitem for disabled cache is fine, because we - # don't actually write to the disk - runtime_shape, graph_index = key - self.cache[runtime_shape][graph_index] = value - - -class AlwaysHitShapeEnv: - """ - Why do we need this class: - - For normal `torch.compile` usage, every compilation will have - one Dynamo bytecode compilation and one Inductor compilation. - The Inductor compilation happens under the context of the - Dynamo bytecode compilation, and that context is used to - determine the dynamic shape information, etc. - - For our use case, we only run Dynamo bytecode compilation once, - and run Inductor compilation multiple times with different shapes - plus a general shape. The compilation for specific shapes happens - outside of the context of the Dynamo bytecode compilation. At that - time, we don't have shape environment to provide to Inductor, and - it will fail the Inductor code cache lookup. - - By providing a dummy shape environment that always hits, we can - make the Inductor code cache lookup always hit, and we can - compile the graph for different shapes as needed. - - The following dummy methods are obtained by trial-and-error - until it works. - """ - - def __init__(self) -> None: - self.guards: List[Any] = [] - - def evaluate_guards_expression(self, *args, **kwargs): - return True - - def get_pruned_guards(self, *args, **kwargs): - return [] - - def produce_guards_expression(self, *args, **kwargs): - return "" - - -def wrap_inductor(graph: fx.GraphModule, - example_inputs, - additional_inductor_config, - compilation_config: CompilationConfig, - vllm_backend: "VllmBackend", - graph_index: int = 0, - num_graphs: int = 1, - runtime_shape: Optional[int] = None, - use_inductor: bool = True) -> Any: - if graph_index == 0: - # before compiling the first graph, record the start time - global compilation_start_time - compilation_start_time = time.time() - - if not use_inductor: - return graph - - compilation_counter.num_inductor_compilations += 1 - - from torch._inductor import config - current_config = config.get_config_copy() - from torch._inductor.compile_fx import compile_fx - - if additional_inductor_config is not None: - current_config.update(additional_inductor_config) - - if isinstance(runtime_shape, int): - # for a specific batchsize, tuning triton kernel parameters - # can be beneficial - current_config["max_autotune"] = True - current_config["coordinate_descent_tuning"] = True - - # inductor can inplace modify the graph, so we need to copy it - # see https://github.com/pytorch/pytorch/issues/138980 - graph = copy.deepcopy(graph) - - cache_data = vllm_backend.inductor_hash_cache - if (runtime_shape, graph_index) in cache_data: - # we compiled this graph before - # so we can directly lookup the compiled graph via hash - inductor_artifact = cache_data[(runtime_shape, graph_index)] - hash_str = inductor_artifact.hash_str - if graph_index == 0: - # adds some info logging for the first graph - logger.info( - "Directly lookup the graph for shape %s from the cache", - str(runtime_shape)) # noqa + printer = pprint.PrettyPrinter(indent=4) + data = printer.pformat(self.cache) + f.write(data) + + def load(self, + graph: fx.GraphModule, + example_inputs: List[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Optional[Callable]: + if (runtime_shape, graph_index, self.compiler.name) not in self.cache: + return None + handle = self.cache[(runtime_shape, graph_index, self.compiler.name)] + compiled_graph = self.compiler.load(handle, graph, example_inputs, + graph_index, runtime_shape) logger.debug( - "directly lookup the %s-th graph for shape %s via hash %s", - graph_index, str(runtime_shape), hash_str) - from torch._inductor.codecache import FxGraphCache - with patch("torch._inductor.codecache.FxGraphCache._get_shape_env", - lambda *args, **kwargs: AlwaysHitShapeEnv()): - inductor_compiled_graph = FxGraphCache._lookup_graph( - hash_str, example_inputs, True, False) - assert inductor_compiled_graph is not None, ( - "Inductor cache lookup failed. Please remove" - f"the cache file {cache_data.cache_file_path} and try again." # noqa - ) - inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa - - # Inductor calling convention (function signature): - # f(list) -> tuple - # Dynamo calling convention (function signature): - # f(*args) -> Any - - # need to know if the graph returns a tuple - from torch._inductor.compile_fx import graph_returns_tuple - returns_tuple = graph_returns_tuple(graph) - - # this is the callable we return to Dynamo to run - def compiled_graph(*args): - # convert args to list - list_args = list(args) - graph_output = inductor_compiled_graph(list_args) - # unpack the tuple if needed - if returns_tuple: - return graph_output - else: - return graph_output[0] - else: - # it's the first time we compile this graph - # the assumption is that we don't have nested Inductor compilation. - # compiled_fx_graph_hash will only be called once, and we can hook - # it to get the hash of the compiled graph directly. - - inductor_artifact = InductorArtifact() - from torch._inductor.codecache import (FxGraphCache, - compiled_fx_graph_hash) - original_load = FxGraphCache.load - - def hijack_load(*args, **kwargs): - inductor_compiled_graph = original_load(*args, **kwargs) - inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa - return inductor_compiled_graph - - def hijack_compiled_fx_graph_hash(*args, **kwargs): - out = compiled_fx_graph_hash(*args, **kwargs) - inductor_artifact.hash_str = out[0] - return out - - def _check_can_cache(*args, **kwargs): - # no error means it can be cached. - # Inductor refuses to cache the graph outside of Dynamo - # tracing context, and also disables caching for graphs - # with high-order ops. - # For vLLM, in either case, we want to cache the graph. - # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa - return - - def _get_shape_env() -> AlwaysHitShapeEnv: - return AlwaysHitShapeEnv() - - with ExitStack() as stack: - if not cache_data.disabled: - # compilation cache is enabled, patch several functions - - # hijack to get the compiled graph itself - stack.enter_context( - patch("torch._inductor.codecache.FxGraphCache.load", - hijack_load)) - - # for hijacking the hash of the compiled graph - stack.enter_context( - patch("torch._inductor.codecache.compiled_fx_graph_hash", - hijack_compiled_fx_graph_hash)) - - # for providing a dummy shape environment - stack.enter_context( - patch( - "torch._inductor.codecache.FxGraphCache._get_shape_env", - _get_shape_env)) - - # for forcing the graph to be cached - stack.enter_context( - patch( - "torch._inductor.codecache.FxGraphCache._check_can_cache", - _check_can_cache)) - - compiled_graph = compile_fx(graph, - example_inputs, - config_patches=current_config) - # store the inductor_artifact in the cache - cache_data[(runtime_shape, graph_index)] = inductor_artifact + "Directly load the %s-th graph for shape %s from %s via " + "handle %s", graph_index, str(runtime_shape), self.compiler.name, + handle) + return compiled_graph + + def compile(self, + graph: fx.GraphModule, + example_inputs, + additional_inductor_config, + compilation_config: CompilationConfig, + graph_index: int = 0, + num_graphs: int = 1, + runtime_shape: Optional[int] = None) -> Any: if graph_index == 0: - # adds some info logging for the first graph - logger.info("Cache the graph of shape %s for later use", - str(runtime_shape)) - logger.debug( - "store the %s-th graph for shape %s via hash %s from file %s", - graph_index, str(runtime_shape), inductor_artifact.hash_str, - inductor_artifact.file_path) - # after compiling the last graph, record the end time - if graph_index == num_graphs - 1: - now = time.time() - elapsed = now - compilation_start_time - compilation_config.compilation_time += elapsed - if runtime_shape is None: - logger.info("Compiling a graph for general shape takes %.2f s", - elapsed) - else: - logger.info("Compiling a graph for shape %s takes %.2f s", - runtime_shape, elapsed) + # before compiling the first graph, record the start time + global compilation_start_time + compilation_start_time = time.time() + + compilation_counter.num_backend_compilations += 1 + + compiled_graph = None + + # try to load from the cache + compiled_graph = self.load(graph, example_inputs, graph_index, + runtime_shape) + if compiled_graph is not None: + if graph_index == 0: + # adds some info logging for the first graph + logger.info("Directly load the compiled graph for shape %s " + "from the cache", str(runtime_shape)) # noqa + return compiled_graph + + # no compiler cached the graph, or the cache is disabled, + # we need to compile it + compiled_graph, handle = self.compiler.compile( + graph, example_inputs, additional_inductor_config, runtime_shape) + + assert compiled_graph is not None, "Failed to compile the graph" + + # store the artifact in the cache + if handle is not None: + self.cache[(runtime_shape, graph_index, + self.compiler.name)] = handle + if graph_index == 0: + # adds some info logging for the first graph + logger.info("Cache the graph of shape %s for later use", + str(runtime_shape)) + logger.debug( + "store the %s-th graph for shape %s from %s via handle %s", + graph_index, str(runtime_shape), self.compiler.name, handle) + + # after compiling the last graph, record the end time + if graph_index == num_graphs - 1: + now = time.time() + elapsed = now - compilation_start_time + compilation_config.compilation_time += elapsed + if runtime_shape is None: + logger.info("Compiling a graph for general shape takes %.2f s", + elapsed) + else: + logger.info("Compiling a graph for shape %s takes %.2f s", + runtime_shape, elapsed) - return compiled_graph + return compiled_graph @dataclasses.dataclass @@ -436,16 +257,15 @@ def call_module(self, target: torch.fx.node.Target, i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] global compilation_start_time - compiled_graph_for_general_shape = wrap_inductor( + compiled_graph_for_general_shape = self.vllm_backend.\ + compiler_manager.compile( submod, args, self.compilation_config.inductor_compile_config, self.compilation_config, - self.vllm_backend, graph_index=index, num_graphs=len(self.compile_submod_names), - runtime_shape=None, - use_inductor=self.compilation_config.use_inductor) + runtime_shape=None) self.module.__dict__[target] = PiecewiseBackend( submod, self.vllm_config, self.graph_pool, index, @@ -483,7 +303,7 @@ class VllmBackend: post_grad_passes: Sequence[Callable] sym_tensor_indices: List[int] input_buffers: List[torch.Tensor] - inductor_hash_cache: InductorHashCache + compiler_manager: CompilerManager def __init__( self, @@ -507,6 +327,9 @@ def __init__( self.vllm_config = vllm_config self.compilation_config = vllm_config.compilation_config + self.compiler_manager: CompilerManager = CompilerManager( + self.compilation_config.use_inductor) + # `torch.compile` is JIT compiled, so we don't need to # do anything here @@ -533,9 +356,11 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: # the cache dir will be the same so that we can reuse the compiled # graph. + factors = [] # 1. factors come from the vllm_config (it mainly summarizes how the # model is created) config_hash = vllm_config.compute_hash() + factors.append(config_hash) # 2. factors come from the code files that are traced by Dynamo ( # it mainly summarizes how the model is used in forward pass) @@ -553,10 +378,15 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: import hashlib code_hash = hashlib.md5( "\n".join(hash_content).encode()).hexdigest() + factors.append(code_hash) + + # 3. compiler hash + compiler_hash = self.compiler_manager.compute_hash(vllm_config) + factors.append(compiler_hash) + + # combine all factors to generate the cache dir + hash_key = hashlib.md5(str(factors).encode()).hexdigest()[:10] - # combine the two hashes to generate the cache dir - hash_key = hashlib.md5( - f"{config_hash}_{code_hash}".encode()).hexdigest()[:10] cache_dir = os.path.join( envs.VLLM_CACHE_ROOT, "torch_compile_cache", @@ -570,15 +400,16 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: cache_dir, f"rank_{vllm_config.parallel_config.rank}") self.compilation_config.local_cache_dir = local_cache_dir - disabled = envs.VLLM_DISABLE_COMPILE_CACHE - self.inductor_hash_cache: InductorHashCache = InductorHashCache( - local_cache_dir, disabled=disabled) - if disabled: + disable_cache = envs.VLLM_DISABLE_COMPILE_CACHE + + if disable_cache: logger.info("vLLM's torch.compile cache is disabled.") else: logger.info("Using cache directory: %s for vLLM's torch.compile", local_cache_dir) + self.compiler_manager.initialize_cache(local_cache_dir, disable_cache) + # when dynamo calls the backend, it means the bytecode # transform and analysis are done compilation_counter.num_graphs_seen += 1 @@ -759,7 +590,7 @@ def check_for_ending_compilation(self): if self.is_last_graph and not self.to_be_compiled_sizes: # no specific sizes to compile # save the hash of the inductor graph for the next run - self.vllm_backend.inductor_hash_cache.save_to_file() + self.vllm_backend.compiler_manager.save_to_file() end_monitoring_torch_compile(self.vllm_config) def __call__(self, *args) -> Any: @@ -782,16 +613,14 @@ def __call__(self, *args) -> Any: entry.compiled = True self.to_be_compiled_sizes.remove(runtime_shape) # args are real arguments - entry.runnable = wrap_inductor( + entry.runnable = self.vllm_backend.compiler_manager.compile( self.graph, args, self.compilation_config.inductor_compile_config, self.compilation_config, - self.vllm_backend, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, - runtime_shape=runtime_shape, - use_inductor=self.compilation_config.use_inductor) + runtime_shape=runtime_shape) # finished compilations for all required shapes if self.is_last_graph and not self.to_be_compiled_sizes: diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py new file mode 100644 index 0000000000000..ac0544ad64037 --- /dev/null +++ b/vllm/compilation/compiler_interface.py @@ -0,0 +1,340 @@ +# SPDX-License-Identifier: Apache-2.0 +import copy +import hashlib +import os +from contextlib import ExitStack +from typing import Any, Callable, Dict, List, Optional, Tuple +from unittest.mock import patch + +import torch +import torch._inductor.compile_fx +import torch.fx as fx + +from vllm.config import VllmConfig + + +class CompilerInterface: + """ + The interface for a compiler that can be used by vLLM. + """ + # The name of the compiler, e.g. inductor. + # This is a class-level attribute. + name: str + + def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + """ + when the vLLM process uses `cache_dir` as the cache directory, + the compiler should initialize itself with the cache directory, + e.g. by re-directing its own cache directory to a sub-directory. + """ + pass + + def compute_hash(self, vllm_config: VllmConfig) -> str: + """ + Gather all the relevant information from the VLLM config, + to compute a hash so that we can cache the compiled model. + + See :meth:`VllmConfig.compute_hash` to check what information + is already considered by default. This function should only + consider the information that is specific to the compiler. + """ + return "" + + def compile( + self, + graph: fx.GraphModule, + example_inputs: List[Any], + compiler_config: Dict[str, Any], + runtime_shape: Optional[int] = None + ) -> Tuple[Optional[Callable], Optional[Any]]: + """ + Compile the graph with the given example inputs and compiler config, + with a runtime shape. If the `runtime_shape` is None, it means + the `example_inputs` have a dynamic shape. Otherwise, the + `runtime_shape` specifies the shape of the inputs. Right now we only + support one variable shape for all inputs, which is the batchsize + (number of tokens) during inference. + + Dynamo will make sure `graph(*example_inputs)` is valid. + + The function should return a compiled callable function, as well as + a handle that can be used to directly load the compiled function. + + The handle should be a plain Python object, preferably a string or a + file path for readability. + + If the compiler doesn't support caching, it should return None for the + handle. If the compiler fails to compile the graph, it should return + None for the compiled function as well. + """ + return None, None + + def load(self, + handle: Any, + graph: fx.GraphModule, + example_inputs: List[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Callable: + """ + Load the compiled function from the handle. + Raises an error if the handle is invalid. + + The handle is the second return value of the `compile` function. + """ + raise NotImplementedError("caching is not supported") + + +class AlwaysHitShapeEnv: + """ + Why do we need this class: + + For normal `torch.compile` usage, every compilation will have + one Dynamo bytecode compilation and one Inductor compilation. + The Inductor compilation happens under the context of the + Dynamo bytecode compilation, and that context is used to + determine the dynamic shape information, etc. + + For our use case, we only run Dynamo bytecode compilation once, + and run Inductor compilation multiple times with different shapes + plus a general shape. The compilation for specific shapes happens + outside of the context of the Dynamo bytecode compilation. At that + time, we don't have shape environment to provide to Inductor, and + it will fail the Inductor code cache lookup. + + By providing a dummy shape environment that always hits, we can + make the Inductor code cache lookup always hit, and we can + compile the graph for different shapes as needed. + + The following dummy methods are obtained by trial-and-error + until it works. + """ + + def __init__(self) -> None: + self.guards: List[Any] = [] + + def evaluate_guards_expression(self, *args, **kwargs): + return True + + def get_pruned_guards(self, *args, **kwargs): + return [] + + def produce_guards_expression(self, *args, **kwargs): + return "" + + +class InductorAdaptor(CompilerInterface): + """ + The adaptor for the Inductor compiler, version 2.5 and 2.6. + """ + name = "inductor" + + def compute_hash(self, vllm_config: VllmConfig) -> str: + factors: List[Any] = [] + # summarize system state + from torch._inductor.codecache import CacheBase + system_factors = CacheBase.get_system() + factors.append(system_factors) + + # summarize pytorch state + from torch._inductor.codecache import torch_key + torch_factors = torch_key() + factors.append(torch_factors) + hash_str = hashlib.md5(str(factors).encode()).hexdigest()[:10] + return hash_str + + def initialize_cache(self, cache_dir: str, disable_cache: bool = False): + if disable_cache: + return + # redirect the cache directory to a sub-directory + # set flags so that Inductor and Triton store their cache + # in the cache_dir, then users only need to copy the cache_dir + # to another machine to reuse the cache. + inductor_cache = os.path.join(cache_dir, "inductor_cache") + os.makedirs(inductor_cache, exist_ok=True) + os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache + triton_cache = os.path.join(cache_dir, "triton_cache") + os.makedirs(triton_cache, exist_ok=True) + os.environ["TRITON_CACHE_DIR"] = triton_cache + + def compile( + self, + graph: fx.GraphModule, + example_inputs: List[Any], + compiler_config: Dict[str, Any], + runtime_shape: Optional[int] = None + ) -> Tuple[Optional[Callable], Optional[Any]]: + from torch._inductor import config + current_config = config.get_config_copy() + from torch._inductor.compile_fx import compile_fx + + # disable remote cache + current_config["fx_graph_cache"] = True + current_config["fx_graph_remote_cache"] = False + + if compiler_config is not None: + current_config.update(compiler_config) + + if isinstance(runtime_shape, int): + # for a specific batchsize, tuning triton kernel parameters + # can be beneficial + current_config["max_autotune"] = True + current_config["coordinate_descent_tuning"] = True + + # inductor can inplace modify the graph, so we need to copy it + # see https://github.com/pytorch/pytorch/issues/138980 + graph = copy.deepcopy(graph) + + # it's the first time we compile this graph + # the assumption is that we don't have nested Inductor compilation. + # compiled_fx_graph_hash will only be called once, and we can hook + # it to get the hash of the compiled graph directly. + + hash_str, file_path = None, None + from torch._inductor.codecache import (FxGraphCache, + compiled_fx_graph_hash) + + if torch.__version__.startswith("2.5"): + original_load = FxGraphCache.load + original_load_name = "torch._inductor.codecache.FxGraphCache.load" + + def hijack_load(*args, **kwargs): + inductor_compiled_graph = original_load(*args, **kwargs) + nonlocal file_path + file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa + return inductor_compiled_graph + + hijacked_compile_fx_inner = torch._inductor.compile_fx.compile_fx_inner # noqa + elif torch.__version__ >= "2.6": + # function renamed in 2.6 + original_load_name = None + + def hijacked_compile_fx_inner(*args, **kwargs): + output = torch._inductor.compile_fx.compile_fx_inner( + *args, **kwargs) + nonlocal hash_str + inductor_compiled_graph = output + if inductor_compiled_graph is not None: + nonlocal file_path + file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa + hash_str = inductor_compiled_graph._fx_graph_cache_key + return output + + def hijack_compiled_fx_graph_hash(*args, **kwargs): + out = compiled_fx_graph_hash(*args, **kwargs) + nonlocal hash_str + hash_str = out[0] + return out + + def _check_can_cache(*args, **kwargs): + # no error means it can be cached. + # Inductor refuses to cache the graph outside of Dynamo + # tracing context, and also disables caching for graphs + # with high-order ops. + # For vLLM, in either case, we want to cache the graph. + # see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa + return + + def _get_shape_env() -> AlwaysHitShapeEnv: + return AlwaysHitShapeEnv() + + with ExitStack() as stack: + # hijack to get the compiled graph itself + if original_load_name is not None: + stack.enter_context(patch(original_load_name, hijack_load)) + + # for hijacking the hash of the compiled graph + stack.enter_context( + patch("torch._inductor.codecache.compiled_fx_graph_hash", + hijack_compiled_fx_graph_hash)) + + # for providing a dummy shape environment + stack.enter_context( + patch("torch._inductor.codecache.FxGraphCache._get_shape_env", + _get_shape_env)) + + # for forcing the graph to be cached + stack.enter_context( + patch( + "torch._inductor.codecache.FxGraphCache._check_can_cache", + _check_can_cache)) + + compiled_graph = compile_fx( + graph, + example_inputs, + inner_compile=hijacked_compile_fx_inner, + config_patches=current_config) + + assert hash_str is not None, ( + "failed to get the hash of the compiled graph") + assert file_path is not None, ( + "failed to get the file path of the compiled graph") + return compiled_graph, (hash_str, file_path) + + def load(self, + handle: Any, + graph: fx.GraphModule, + example_inputs: List[Any], + graph_index: int, + runtime_shape: Optional[int] = None) -> Callable: + assert isinstance(handle, tuple) + assert isinstance(handle[0], str) + assert isinstance(handle[1], str) + hash_str = handle[0] + + from torch._inductor.codecache import FxGraphCache + with patch("torch._inductor.codecache.FxGraphCache._get_shape_env", + lambda *args, **kwargs: AlwaysHitShapeEnv()): + if torch.__version__.startswith("2.5"): + inductor_compiled_graph = FxGraphCache._lookup_graph( + hash_str, example_inputs, True, False) + assert inductor_compiled_graph is not None, ( + "Inductor cache lookup failed. Please remove" + f"the cache directory and try again." # noqa + ) + elif torch.__version__ >= "2.6": + from torch._inductor.output_code import ( + CompiledFxGraphConstantsWithGm) + constants = CompiledFxGraphConstantsWithGm(graph) + inductor_compiled_graph, _ = FxGraphCache._lookup_graph( + hash_str, example_inputs, True, None, constants) + assert inductor_compiled_graph is not None, ( + "Inductor cache lookup failed. Please remove" + f"the cache directory and try again." # noqa + ) + + # Inductor calling convention (function signature): + # f(list) -> tuple + # Dynamo calling convention (function signature): + # f(*args) -> Any + + # need to know if the graph returns a tuple + from torch._inductor.compile_fx import graph_returns_tuple + returns_tuple = graph_returns_tuple(graph) + + # this is the callable we return to Dynamo to run + def compiled_graph(*args): + # convert args to list + list_args = list(args) + graph_output = inductor_compiled_graph(list_args) + # unpack the tuple if needed + if returns_tuple: + return graph_output + else: + return graph_output[0] + + return compiled_graph + + +class EagerAdaptor(CompilerInterface): + name = "eager" + + def compile( + self, + graph: fx.GraphModule, + example_inputs: List[Any], + compiler_config: Dict[str, Any], + runtime_shape: Optional[int] = None + ) -> Tuple[Optional[Callable], Optional[Any]]: + # we don't need to compile the graph, just return the graph itself. + # It does not support caching, return None for the handle. + return graph, None diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py index a6f11a3af4d4c..5be452593c620 100644 --- a/vllm/compilation/counter.py +++ b/vllm/compilation/counter.py @@ -13,7 +13,7 @@ class CompilationCounter: num_piecewise_graphs_seen: int = 0 # not including the splitting ops num_piecewise_capturable_graphs_seen: int = 0 - num_inductor_compilations: int = 0 + num_backend_compilations: int = 0 num_cudagraph_caputured: int = 0 def clone(self) -> "CompilationCounter": diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index be663946f4d81..1fea927aac31f 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -13,7 +13,6 @@ class InductorPass(ABC): """ General custom inductor pass interface. - TODO(torch==2.6) use torch._inductor.custom_graph_pass.CustomGraphPass """ @abstractmethod diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index c7387fb7c2db9..52f8c3b1ec15a 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List +import torch from torch import fx as fx from vllm.config import CompilationConfig @@ -15,7 +16,17 @@ logger = init_logger(__name__) -class PostGradPassManager: +class PlaceHolder: + pass + + +if torch.__version__ < "2.6": + Parent = PlaceHolder # type: ignore +else: + Parent = torch._inductor.custom_graph_pass.CustomGraphPass # type: ignore + + +class PostGradPassManager(Parent): """ The pass manager for post-grad passes. It handles configuration, adding custom passes, and running passes. @@ -55,6 +66,9 @@ def add(self, pass_: InductorPass): assert isinstance(pass_, InductorPass) self.passes.append(pass_) + def uuid(self): + return self.__getstate__() + def __getstate__(self) -> Dict[str, List[Any]]: """ Custom pickling for the pass manager, as some passes cannot be pickled. diff --git a/vllm/config.py b/vllm/config.py index 9ba4975761245..5579d6936d105 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3072,15 +3072,6 @@ def compute_hash(self) -> str: the final hidden states. """ factors: List[Any] = [] - # summarize system state - from torch._inductor.codecache import CacheBase - system_factors = CacheBase.get_system() - factors.append(system_factors) - - # summarize pytorch state - from torch._inductor.codecache import torch_key - torch_factors = torch_key() - factors.append(torch_factors) # summarize vllm config vllm_factors: List[Any] = []