Skip to content

Commit

Permalink
Add per node profiler (#54)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: https://github.com/pytorch/fx2trt/pull/54

Given a graph module, it does node by node profiling.

Reviewed By: frank-wei

Differential Revision: D35636500

fbshipit-source-id: 71cabb3239892d36884d8d96ed3c3bc94ed1f565
  • Loading branch information
Yinghai Lu authored and Wei Wei committed Jun 4, 2022
1 parent fea0469 commit 2e3b265
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions fx/tools/node_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
from torch import fx
from typing import Any


class NodeProfiler(fx.Interpreter):
"""
This is basically a variant of shape prop in
https://github.com/pytorch/pytorch/blob/74849d9188de30d93f7c523d4eeceeef044147a9/torch/fx/passes/shape_prop.py#L65.
Instead of propagating just the shape, we record all the intermediate node Tensor values.
This is useful to debug some of lowering pass issue where we want to check a specific
tensor value. Note that output value can be tuple(Tensor) as well as Tensor.
"""

def __init__(self, module: fx.GraphModule):
super().__init__(module)
self.execution_time = {}
self.node_map = {}
self.iter = 100

def run_node(self, n: fx.Node) -> Any:
result = super().run_node(n)
if n.op not in {"call_function", "call_method", "call_module"}:
return result

torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()

for _ in range(self.iter):
result = super().run_node(n)

end_event.record()
torch.cuda.synchronize()

self.execution_time[f"{n.name}"] = start_event.elapsed_time(end_event) / self.iter
self.node_map[n.name] = n
return result

def propagate(self, *args):
"""
Run `module` via interpretation and return the result and
record the shape and type of each node.
Args:
*args (Tensor): the sample input.
Returns:
Any: The value returned from executing the Module
"""
return super().run(*args)

0 comments on commit 2e3b265

Please sign in to comment.