Skip to content

Commit

Permalink
Log individual Triton kernel compilation times to dynamo_compile (#14…
Browse files Browse the repository at this point in the history
…7022)

Summary:
Gather the compilation time of individual triton kernels and log them to dynamo_compile:
* Time compilation in `_worker_compile_triton` and pass back to the main process and logged from `get_result()`.
* Added a way to track the "top N" (or N most-expensive compiles) in the metrics_context. I did this because I doubt we really care to capture potentially thousands of kernel compile times. That would be problematic for scuba logging anyway, so let's limit the number we track from the beginning. Arbitrarily chose 25 for now.
* Format the list of compile times as a json string before logging.

X-link: pytorch/pytorch#147022
Approved by: https://github.com/jamesjwu

Reviewed By: wdvr

Differential Revision: D70512505

fbshipit-source-id: b0b26cea64a4d3f34e3386bf42ea203de46a6e3b
  • Loading branch information
masnesral authored and facebook-github-bot committed Mar 4, 2025
1 parent 182024e commit 0dba3d9
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,7 @@ class CompilationMetrics:
guard_latency_us: Optional[float] = None
recompile_reason: Optional[str] = None
num_graph_breaks: Optional[int] = None
triton_kernel_compile_times_us: Optional[str] = None

@classmethod
def create(cls, metrics: dict[str, Any]):
Expand Down Expand Up @@ -1249,6 +1250,14 @@ def safe_str(item: Any) -> str:

return ",".join(safe_str(item) for item in sorted(metric))

def collection_to_json_str(metric: Optional[Any]) -> Optional[str]:
if metric is None:
return None
try:
return json.dumps(list(metric))
except Exception:
return "<unknown>"

# TODO: The following are legacy fields, populated from the fields that replace
# them. Remove these when we decide we can really deprecate them.
legacy_metrics = {
Expand Down Expand Up @@ -1288,6 +1297,9 @@ def safe_str(item: Any) -> str:
all_metrics["inductor_fx_remote_cache_miss_keys"] = collection_to_str(
all_metrics.get("inductor_fx_remote_cache_miss_keys")
)
all_metrics["triton_kernel_compile_times_us"] = collection_to_json_str(
all_metrics.get("triton_kernel_compile_times_us")
)
compile_id = all_metrics.get("compile_id")
all_metrics["compile_id"] = str(compile_id) if compile_id else None

Expand Down

0 comments on commit 0dba3d9

Please sign in to comment.