-
Notifications
You must be signed in to change notification settings - Fork 360
/
Copy pathengine_caching_example.py
288 lines (240 loc) · 10.7 KB
/
engine_caching_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
"""
.. _engine_caching_example:
Engine Caching
=======================
As model sizes increase, the cost of compilation will as well. With AOT methods
like ``torch.dynamo.compile``, this cost is paid upfront. However if the weights
change, the session ends or you are using JIT methods like ``torch.compile``, as
graphs get invalidated they get re-compiled, this cost will get paid repeatedly.
Engine caching is a way to mitigate this cost by saving constructed engines to disk
and re-using them when possible. This tutorial demonstrates how to use engine caching
with TensorRT in PyTorch. Engine caching can significantly speed up subsequent model
compilations reusing previously built TensorRT engines.
We'll explore two approaches:
1. Using torch_tensorrt.dynamo.compile
2. Using torch.compile with the TensorRT backend
The example uses a pre-trained ResNet18 model and shows the
differences between compilation without caching, with caching enabled,
and when reusing cached engines.
"""
import os
from typing import Dict, Optional
import numpy as np
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models
from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
np.random.seed(0)
torch.manual_seed(0)
model = models.resnet18(pretrained=True).eval().to("cuda")
enabled_precisions = {torch.float}
debug = False
min_block_size = 1
use_python_runtime = False
def remove_timing_cache(path=TIMING_CACHE_PATH):
if os.path.exists(path):
os.remove(path)
# %%
# Engine Caching for JIT Compilation
# ----------------------------------
#
# The primary goal of engine caching is to help speed up JIT workflows. ``torch.compile``
# provides a great deal of flexibility in model construction which makes it a good
# first tool to try when looking to speed up your workflow. However, historically
# the cost of compilation and in particular recompilation has been a barrier to entry
# for many users. If for some reason a subgraph gets invalidated, that graph is reconstructed
# scratch prior to the addition of engine caching. Now as engines are constructed, with ``cache_built_engines=True``,
# engines are saved to disk tied to a hash of their corresponding PyTorch subgraph. If
# in a subsequent compilation, either as part of this session or a new session, the cache will
# pull the built engine and **refit** the weights which can reduce compilation times by orders of magnitude.
# As such, in order to insert a new engine into the cache (i.e. ``cache_built_engines=True``),
# the engine must be refittable (``immutable_weights=False``). See :ref:`refit_engine_example` for more details.
def torch_compile(iterations=3):
times = []
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
# The 1st iteration is to measure the compilation time without engine caching
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
for i in range(iterations):
inputs = [torch.rand((100, 3, 224, 224)).to("cuda")]
# remove timing cache and reset dynamo just for engine caching messurement
remove_timing_cache()
torch._dynamo.reset()
if i == 0:
cache_built_engines = False
reuse_cached_engines = False
else:
cache_built_engines = True
reuse_cached_engines = True
start.record()
compiled_model = torch.compile(
model,
backend="tensorrt",
options={
"use_python_runtime": True,
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
"immutable_weights": False,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
},
)
compiled_model(*inputs) # trigger the compilation
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
print("----------------torch_compile----------------")
print("disable engine caching, used:", times[0], "ms")
print("enable engine caching to cache engines, used:", times[1], "ms")
print("enable engine caching to reuse engines, used:", times[2], "ms")
torch_compile()
# %%
# Engine Caching for AOT Compilation
# ----------------------------------
# Similarly to the JIT workflow, AOT workflows can benefit from engine caching.
# As the same architecture or common subgraphs get recompiled, the cache will pull
# previously built engines and refit the weights.
def dynamo_compile(iterations=3):
times = []
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
# Mark the dim0 of inputs as dynamic
batch = torch.export.Dim("batch", min=1, max=200)
exp_program = torch.export.export(
model, args=example_inputs, dynamic_shapes={"x": {0: batch}}
)
# The 1st iteration is to measure the compilation time without engine caching
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
for i in range(iterations):
inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")]
remove_timing_cache() # remove timing cache just for engine caching messurement
if i == 0:
cache_built_engines = False
reuse_cached_engines = False
else:
cache_built_engines = True
reuse_cached_engines = True
start.record()
trt_gm = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
immutable_weights=False,
cache_built_engines=cache_built_engines,
reuse_cached_engines=reuse_cached_engines,
engine_cache_size=1 << 30, # 1GB
)
# output = trt_gm(*inputs)
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
print("----------------dynamo_compile----------------")
print("disable engine caching, used:", times[0], "ms")
print("enable engine caching to cache engines, used:", times[1], "ms")
print("enable engine caching to reuse engines, used:", times[2], "ms")
dynamo_compile()
# %%
# Custom Engine Cache
# ----------------------
#
# By default, the engine cache is stored in the system's temporary directory. Both the cache directory and
# size limit can be customized by passing ``engine_cache_dir`` and ``engine_cache_size``.
# Users can also define their own engine cache implementation by extending the ``BaseEngineCache`` class.
# This allows for remote or shared caching if so desired.
#
# The custom engine cache should implement the following methods:
# - ``save``: Save the engine blob to the cache.
# - ``load``: Load the engine blob from the cache.
#
# The hash provided by the cache systen is a weight agnostic hash of the originating PyTorch subgraph (post lowering).
# The blob contains a serialized engine, calling spec data, and weight map information in the pickle format
#
# Below is an example of a custom engine cache implementation that implents a ``RAMEngineCache``.
class RAMEngineCache(BaseEngineCache):
def __init__(
self,
) -> None:
"""
Constructs a user held engine cache in memory.
"""
self.engine_cache: Dict[str, bytes] = {}
def save(
self,
hash: str,
blob: bytes,
):
"""
Insert the engine blob to the cache.
Args:
hash (str): The hash key to associate with the engine blob.
blob (bytes): The engine blob to be saved.
Returns:
None
"""
self.engine_cache[hash] = blob
def load(self, hash: str) -> Optional[bytes]:
"""
Load the engine blob from the cache.
Args:
hash (str): The hash key of the engine to load.
Returns:
Optional[bytes]: The engine blob if found, None otherwise.
"""
if hash in self.engine_cache:
return self.engine_cache[hash]
else:
return None
def torch_compile_my_cache(iterations=3):
times = []
engine_cache = RAMEngineCache()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
# The 1st iteration is to measure the compilation time without engine caching
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
for i in range(iterations):
inputs = [torch.rand((100, 3, 224, 224)).to("cuda")]
# remove timing cache and reset dynamo just for engine caching messurement
remove_timing_cache()
torch._dynamo.reset()
if i == 0:
cache_built_engines = False
reuse_cached_engines = False
else:
cache_built_engines = True
reuse_cached_engines = True
start.record()
compiled_model = torch.compile(
model,
backend="tensorrt",
options={
"use_python_runtime": True,
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
"immutable_weights": False,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"custom_engine_cache": engine_cache,
},
)
compiled_model(*inputs) # trigger the compilation
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
print("----------------torch_compile----------------")
print("disable engine caching, used:", times[0], "ms")
print("enable engine caching to cache engines, used:", times[1], "ms")
print("enable engine caching to reuse engines, used:", times[2], "ms")
torch_compile_my_cache()