Skip to content

Commit

Permalink
[torch.compile] store inductor compiled Python file (#12182)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <youkaichao@gmail.com>
  • Loading branch information
youkaichao authored Jan 19, 2025
1 parent 630eb5b commit e66faf4
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 33 deletions.
80 changes: 58 additions & 22 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,30 @@
logger = init_logger(__name__)


@dataclasses.dataclass
class InductorArtifact:
hash_str: str = ""
file_path: str = ""


class InductorHashCache:
"""
Disk format: a Python list of tuples, each tuple is
(runtime_shape, graph_index, hash_str)
(runtime_shape, graph_index, hash_str, file_path)
We use list of tuple for readability.
In-memory format: a defaultdict of dict, where the key is
runtime_shape, and the value is a dict of graph_index to hash_str.
The data is essentially `Dict[Optional[int], Dict[int, str]]`,
The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`,
we don't use json here because json doesn't support int as key.
TODO: better off-the-shelf solution to serialize the data?
"""

def __init__(self, cache_dir: str, disabled: bool = False):
self.cache: defaultdict = defaultdict(dict)
self.cache: Dict[Optional[int],
Dict[int, InductorArtifact]] = defaultdict(dict)
self.disabled = disabled
self.cache_dir = cache_dir
self.cache_file_path = os.path.join(cache_dir,
Expand All @@ -66,14 +73,25 @@ def deserialize(self, data: str):
# because it is a safe way to parse Python literals.
# do not use eval(), it is unsafe.
list_data = ast.literal_eval(data)
for runtime_shape, graph_index, hash_str in list_data:
self.cache[runtime_shape][graph_index] = hash_str
for item in list_data:
runtime_shape = item[0]
graph_index = item[1]
hash_str = item[2]
# for compatibility of old version,
# where we don't have file_path.
# NOTE: after running the new code, the file_path
# will be updated.
file_path = "" if len(item) == 3 else item[3]
self.cache[runtime_shape][graph_index] = InductorArtifact(
hash_str=hash_str, file_path=file_path)

def serialize(self) -> str:
data = []
for runtime_shape, graph_index_to_hash_str in self.cache.items():
for graph_index, hash_str in graph_index_to_hash_str.items():
data.append((runtime_shape, graph_index, hash_str))
for runtime_shape, value in self.cache.items():
for graph_index, inductor_artifact in value.items():
data.append(
(runtime_shape, graph_index, inductor_artifact.hash_str,
inductor_artifact.file_path))
printer = pprint.PrettyPrinter(indent=4)
return printer.pformat(data)

Expand All @@ -90,13 +108,14 @@ def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
return runtime_shape in self.cache and graph_index in self.cache[
runtime_shape]

def __getitem__(self, key: Tuple[Optional[int], int]) -> str:
def __getitem__(self, key: Tuple[Optional[int], int]) -> InductorArtifact:
if self.disabled:
raise KeyError("cannot read from disabled cache")
runtime_shape, graph_index = key
return self.cache[runtime_shape][graph_index]

def __setitem__(self, key: Tuple[Optional[int], int], value: str):
def __setitem__(self, key: Tuple[Optional[int], int],
value: InductorArtifact):
# setitem for disabled cache is fine, because we
# don't actually write to the disk
runtime_shape, graph_index = key
Expand Down Expand Up @@ -181,7 +200,8 @@ def wrap_inductor(graph: fx.GraphModule,
if (runtime_shape, graph_index) in cache_data:
# we compiled this graph before
# so we can directly lookup the compiled graph via hash
hash_str = cache_data[(runtime_shape, graph_index)]
inductor_artifact = cache_data[(runtime_shape, graph_index)]
hash_str = inductor_artifact.hash_str
if graph_index == 0:
# adds some info logging for the first graph
logger.info(
Expand All @@ -199,6 +219,7 @@ def wrap_inductor(graph: fx.GraphModule,
"Inductor cache lookup failed. Please remove"
f"the cache file {cache_data.cache_file_path} and try again." # noqa
)
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa

# Inductor calling convention (function signature):
# f(list) -> tuple
Expand All @@ -224,19 +245,20 @@ def compiled_graph(*args):
# the assumption is that we don't have nested Inductor compilation.
# compiled_fx_graph_hash will only be called once, and we can hook
# it to get the hash of the compiled graph directly.
from torch._inductor.codecache import compiled_fx_graph_hash

inductor_artifact = InductorArtifact()
from torch._inductor.codecache import (FxGraphCache,
compiled_fx_graph_hash)
original_load = FxGraphCache.load

def hijack_load(*args, **kwargs):
inductor_compiled_graph = original_load(*args, **kwargs)
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
return inductor_compiled_graph

def hijack_compiled_fx_graph_hash(*args, **kwargs):
out = compiled_fx_graph_hash(*args, **kwargs)
# store the hash in the cache
nonlocal cache_data
cache_data[(runtime_shape, graph_index)] = out[0]
if graph_index == 0:
# adds some info logging for the first graph
logger.info("Cache the graph of shape %s for later use",
str(runtime_shape))
logger.debug("store the %s-th graph for shape %s via hash %s",
graph_index, str(runtime_shape), out[0])
inductor_artifact.hash_str = out[0]
return out

def _check_can_cache(*args, **kwargs):
Expand All @@ -255,6 +277,11 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
if not cache_data.disabled:
# compilation cache is enabled, patch several functions

# hijack to get the compiled graph itself
stack.enter_context(
patch("torch._inductor.codecache.FxGraphCache.load",
hijack_load))

# for hijacking the hash of the compiled graph
stack.enter_context(
patch("torch._inductor.codecache.compiled_fx_graph_hash",
Expand All @@ -275,7 +302,16 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
compiled_graph = compile_fx(graph,
example_inputs,
config_patches=current_config)

# store the inductor_artifact in the cache
cache_data[(runtime_shape, graph_index)] = inductor_artifact
if graph_index == 0:
# adds some info logging for the first graph
logger.info("Cache the graph of shape %s for later use",
str(runtime_shape))
logger.debug(
"store the %s-th graph for shape %s via hash %s from file %s",
graph_index, str(runtime_shape), inductor_artifact.hash_str,
inductor_artifact.file_path)
# after compiling the last graph, record the end time
if graph_index == num_graphs - 1:
now = time.time()
Expand Down
13 changes: 2 additions & 11 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2862,17 +2862,8 @@ def model_post_init(self, __context: Any) -> None:
"vllm.unified_attention_with_output",
]
else:
# v0 can use full graph compilation without splitting,
# splitting is optional.
# right now we still need it. kv cache shape
# will be included in the graph if we don't split
# the graph.
# TODO: hide kv cache in static forward context
# so that inductor does not see it.
self.splitting_ops = [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
]
# v0 uses full graph compilation
self.splitting_ops = []

for k, v in self.inductor_passes.items():
if not isinstance(v, str):
Expand Down

0 comments on commit e66faf4

Please sign in to comment.