Skip to content

Commit

Permalink
add an example of aten2trt, fix batch norm pass (#1685)
Browse files Browse the repository at this point in the history
Co-authored-by: Wei Wei <wwei6@fb.com>
  • Loading branch information
frank-wei and Wei Wei authored Feb 22, 2023
1 parent deda87b commit cefb2f2
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 2 deletions.
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ commands:
parameters:
torch-build:
type: string
default: "2.0.0.dev20230129+cu117"
default: "2.0.0.dev20230219+cu117"
torch-build-index:
type: string
default: "https://download.pytorch.org/whl/nightly/cu117"
Expand Down Expand Up @@ -1026,7 +1026,7 @@ parameters:
# Nightly platform config
torch-build:
type: string
default: "2.0.0.dev20230129+cu117"
default: "2.0.0.dev20230219+cu117"
torch-build-index:
type: string
default: "https://download.pytorch.org/whl/nightly/cu117"
Expand Down
1 change: 1 addition & 0 deletions examples/fx/lower_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def run_configuration_benchmark(
input,
max_batch_size=conf.batch_size,
lower_precision=LowerPrecision.FP16 if conf.fp16 else LowerPrecision.FP32,
explicit_batch_dimension=True,
)
time = benchmark_torch_function(conf.batch_iter, lambda: lowered_module(*input))
else:
Expand Down
196 changes: 196 additions & 0 deletions examples/fx/lower_example_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import typing as t
from copy import deepcopy
from dataclasses import dataclass, field, replace

import torch
import torchvision
from torch_tensorrt.fx import compile
from torch_tensorrt.fx.utils import LowerPrecision


"""
The purpose of this example is to demostrate the onverall flow of lowering a PyTorch model
to TensorRT conveniently with lower.py.
"""


@dataclass
class Configuration:
"""
Specify the configuration used for fx2trt lowering and benchmark.
To extend, add a new configuration field to this class, and modify the
lowering or benchmark behavior in `run_configuration_benchmark()`
correspondingly.
It automatically prints all its values thanks to being a dataclass.
"""

# number of inferences to run
batch_iter: int

# Input batch size
batch_size: int

# Friendly name of the configuration
name: str = ""

# Whether to apply TRT lowering to the model before benchmarking
trt: bool = False

# Whether to apply engine holder to the lowered model
jit: bool = False

# Whether to enable FP16 mode for TRT lowering
fp16: bool = False

# Relative tolerance for accuracy check after lowering. -1 means do not
# check accuracy.
accuracy_rtol: float = -1 # disable


@dataclass
class Result:
"""Holds and computes the benchmark results.
Holds raw essential benchmark result values like duration.
Also computes results that can be derived from the raw essential values
(QPS), in the form of auto properties.
"""

module: torch.nn.Module = field(repr=False)
input: t.Any = field(repr=False)
conf: Configuration
time_sec: float
accuracy_res: t.Optional[bool] = None

@property
def time_per_iter_ms(self) -> float:
return self.time_sec * 1.0e3

@property
def qps(self) -> float:
return self.conf.batch_size / self.time_sec

def format(self) -> str:
return (
f"== Benchmark Result for: {self.conf}\n"
f"BS: {self.conf.batch_size}, "
f"Time per iter: {self.time_per_iter_ms:.2f}ms, "
f"QPS: {self.qps:.2f}, "
f"Accuracy: {self.accuracy_res} (rtol={self.conf.accuracy_rtol})"
)


@torch.inference_mode()
def benchmark(
model,
inputs,
batch_iter: int,
batch_size: int,
) -> None:
"""
Run fx2trt lowering and benchmark the given model according to the
specified benchmark configuration. Prints the benchmark result for each
configuration at the end of the run.
"""

model = model.cuda().eval()
inputs = [x.cuda() for x in inputs]

# benchmark base configuration
conf = Configuration(batch_iter=batch_iter, batch_size=batch_size)

configurations = [
# Baseline
replace(conf, name="CUDA Eager", trt=False),
# FP16
replace(
conf,
name="TRT FP16 Eager",
trt=True,
jit=False,
fp16=True,
accuracy_rtol=1e-2,
),
]

results = [
run_configuration_benchmark(deepcopy(model), inputs, conf_)
for conf_ in configurations
]

for res in results:
print(res.format())


def benchmark_torch_function(iters: int, f, *args) -> float:
"""Estimates the average time duration for a single inference call in second
If the input is batched, then the estimation is for the batches inference call.
Args:
iters: number of inference iterations to run
f: a function to perform a single inference call
Returns:
estimated average time duration in second for a single inference call
"""
with torch.inference_mode():
f(*args)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
print("== Start benchmark iterations")
with torch.inference_mode():
start_event.record()
for _ in range(iters):
f(*args)
end_event.record()
torch.cuda.synchronize()
print("== End benchmark iterations")
return (start_event.elapsed_time(end_event) * 1.0e-3) / iters


def run_configuration_benchmark(
module,
input,
conf: Configuration,
) -> Result:
"""
Runs `module` through lowering logic and benchmark the module before and
after lowering.
"""
print(f"=== Running benchmark for: {conf}", "green")
time = -1.0

if conf.fp16:
module = module.half()
input = [i.half() for i in input]

if not conf.trt:
# Run eager mode benchmark
time = benchmark_torch_function(conf.batch_iter, lambda: module(*input))
elif not conf.jit:
# Run lowering eager mode benchmark
lowered_module = compile(
module,
input,
max_batch_size=conf.batch_size,
lower_precision=LowerPrecision.FP16 if conf.fp16 else LowerPrecision.FP32,
explicit_batch_dimension=True,
is_aten=True,
)
time = benchmark_torch_function(conf.batch_iter, lambda: lowered_module(*input))
else:
print("Lowering with JIT is not available!", "red")

result = Result(module=module, input=input, conf=conf, time_sec=time)
return result


if __name__ == "__main__":
test_model = torchvision.models.resnet18(pretrained=True)
input = [torch.rand(128, 3, 224, 224)] # type: ignore[attr-defined]
benchmark(test_model, input, 50, 128)
2 changes: 2 additions & 0 deletions py/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,13 +353,15 @@ def run(self):
"torch_tensorrt.fx.passes",
"torch_tensorrt.fx.tools",
"torch_tensorrt.fx.tracer.acc_tracer",
"torch_tensorrt.fx.tracer.dispatch_tracer",
]
package_dir = {
"torch_tensorrt.fx": "torch_tensorrt/fx",
"torch_tensorrt.fx.converters": "torch_tensorrt/fx/converters",
"torch_tensorrt.fx.passes": "torch_tensorrt/fx/passes",
"torch_tensorrt.fx.tools": "torch_tensorrt/fx/tools",
"torch_tensorrt.fx.tracer.acc_tracer": "torch_tensorrt/fx/tracer/acc_tracer",
"torch_tensorrt.fx.tracer.dispatch_tracer": "torch_tensorrt/fx/tracer/dispatch_tracer",
}

with open("README.md", "r", encoding="utf-8") as fh:
Expand Down
11 changes: 11 additions & 0 deletions py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def replace_aten_op_with_indices(module: torch.fx.GraphModule) -> torch.fx.Graph
torch.ops.aten.max_pool3d_with_indices.default,
torch.ops.aten.native_batch_norm.default,
torch.ops.aten._native_batch_norm_legit.default,
torch.ops.aten._native_batch_norm_legit_no_training.default,
):
modified = True
if len(n.users) != 1:
Expand All @@ -185,6 +186,16 @@ def replace_aten_op_with_indices(module: torch.fx.GraphModule) -> torch.fx.Graph
new_args = list(n.args)
new_args.append(False)
new_args = tuple(new_args)
elif (
n.target == torch.ops.aten._native_batch_norm_legit_no_training.default
):
new_op = torch.ops.aten.batch_norm
new_args = list(n.args)
new_args.append(False)
# _native_batch_norm_legit_no_training doesn't take in a training arg (assumed to be false)
# but batchnorm takes in a training arg at position 5.
new_args.insert(5, False)
new_args = tuple(new_args)

getitem_node = next(iter(n.users))
with module.graph.inserting_after(getitem_node):
Expand Down

0 comments on commit cefb2f2

Please sign in to comment.