-
Notifications
You must be signed in to change notification settings - Fork 358
/
Copy pathconverter_registry.py
327 lines (263 loc) · 12.6 KB
/
converter_registry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Optional, Sequence, Union
from enum import Enum, auto
from torch.fx.node import Target, Node, _get_qualified_name
from torch_tensorrt.fx.converter_registry import CONVERTERS
class ConverterPriority(Enum):
"""Enum to set a converter's priority in the registry"""
STANDARD = auto()
HIGH = auto()
@dataclass(frozen=True)
class ConverterSupport:
"""Class representing a converter implementation and support function
Args:
converter_implementation: Function which converts said node to a TRT equivalent
capability_validator: Function which takes in a Node and returns a bool indicating
whether that node can be supported by its companion converter. Note that
this function must not modify the node or its graph
"""
converter_implementation: Callable
capability_validator: Callable[[Node], bool] = field(default=lambda node: True)
# Dictionary representing Dynamo aten-only converters
# Each converter maps to a sequence of at least one ConverterSupport object(s)
DYNAMO_ATEN_CONVERTERS: Dict[Target, Sequence[ConverterSupport]] = {}
def dynamo_tensorrt_converter(
key: Target,
enabled: bool = True,
capability_validator: Optional[Callable[[Node], bool]] = None,
priority: ConverterPriority = ConverterPriority.STANDARD,
) -> Callable[[Any], Any]:
"""Decorator for Dynamo TensorRT Converter
Registers the decorated function in the DYNAMO_ATEN_CONVERTERS registry
Args:
key: Node target for which the converter is implemented for
(for example, torch.ops.add.Tensor)
enabled: Whether the converter should be enabled/cached or not
capability_validator: Function which evaluates whether a node is valid for conversion
by the decorated converter. See ConverterSupport for more details.
Defaults to None, implying the capability_validator function is always true -
this means all nodes of "key" kind can be supported by this converter
priority: Converter's level of priority relative to other converters with the
same target
Returns:
The converter being decorated
"""
def register_converter(converter):
"""Helper function to register the converter, then return it"""
assert callable(converter), "Converter function must be callable"
# If no capability_validator function is specified, use the default function - always return true
if capability_validator is None:
converter_support = ConverterSupport(converter_implementation=converter)
else:
assert callable(
capability_validator
), "Argument checking function must be callable"
converter_support = ConverterSupport(
converter_implementation=converter,
capability_validator=capability_validator,
)
# If a converter for this operator already exists, append the new converter to the list
# Otherwise, start a new list
if key in DYNAMO_ATEN_CONVERTERS:
# High priority converters are inserted at the front of the list,
# so they can be checked first by the registry
if priority is ConverterPriority.HIGH:
DYNAMO_ATEN_CONVERTERS[key].insert(0, converter_support)
else:
DYNAMO_ATEN_CONVERTERS[key].append(converter_support)
else:
DYNAMO_ATEN_CONVERTERS[key] = [converter_support]
return converter
def disable_converter(converter):
return converter
# Select whether to cache/enable the converter
if enabled:
return register_converter
else:
return disable_converter
class ConverterRegistry:
"""Registry for storing multiple converter dictionaries
Capable of storing dictionaries with the following signature:
Dict[Target, Union[Callable, Sequence[ConverterSupport]]]
Also able to validate converter implementations against user-provided
argument-checking functions
Args:
registries: List of dictionaries representing converter registries.
The order of the provided dictionaries is the order in which they
will be traversed. This is only significant when using non-validated
methods.
"""
def __init__(
self,
registries: Sequence[Dict[Target, Union[Callable, Sequence[ConverterSupport]]]],
registry_names: Optional[Sequence[str]] = None,
):
# Copy reference to each dictionary object into attribute list
self.registries = [registry for registry in registries]
if registry_names is not None:
assert len(self.registries) == len(registry_names)
self.registry_names = [name for name in registry_names]
else:
self.registry_names = [
f"Registry {i + 1}" for i in range(len(self.registries))
]
self.validate_invariants()
def validate_invariants(self):
"""Validates the invariants required of the dictionaries in the registries
Raises AssertionError if any invariants have been violated
"""
# All registries must be dictionaries
assert all(isinstance(elt, dict) for elt in self.registries)
# Every dictionary in the registry must have one of two signatures:
# Dict[Target, Callable] or Dict[Target, Sequence[ConverterSupport]]
# Where, for the latter, the sequence must be non-empty
for registry in self.registries:
for converters in registry.values():
if isinstance(converters, (list, tuple)):
assert (
all(isinstance(c, ConverterSupport) for c in converters)
and len(converters) > 0
)
else:
assert callable(converters), "Converter function must be callable"
def __getitem_without_validation__(self, key: Target):
"""Get the first-found converter in any registry
Searches all registries in order and returns the first converter encountered
"""
if isinstance(key, Node):
raise KeyError(
"Unvalidated accesses to the Converter registry can only be "
+ "made with node targets. Try accessing the registry with node.target"
)
self.validate_invariants()
# Iterate over all registries and return the first converter found
for registry in self.registries:
if key in registry:
converters = registry[key]
if isinstance(converters, (list, tuple)):
return converters[0].converter_implementation
else:
return converters
raise KeyError(f"None of the converter registries have an entry for {key}")
def __getitem__(self, node: Node):
"""Get the first-found validated converter in any registry
Searches all registries in order and returns the first converter
which passes validation on the input node
"""
if not isinstance(node, Node):
raise KeyError(
"Validated accesses to the Converter registry can only be "
+ "made with node inputs. Try accessing the registry with a node "
+ "or use get_unvalidated to access without node validation."
)
self.validate_invariants()
key = node.target
# Iterate over all registries, validating the converter on the input node
# If no capability_validator function is found, assume full coverage
for registry in self.registries:
if key in registry:
converters = registry[key]
if isinstance(converters, (list, tuple)):
for candidate in converters:
if candidate.capability_validator(node):
return candidate.converter_implementation
else:
return converters
raise KeyError(
f"None of the converter registries have a validated entry for {key}, with node {node}"
)
def keys(self):
"""Get all unique targets across all dictionaries"""
return self.unique_targets()
def get_unvalidated(self, key: Target, value=None):
"""Get unvalidated converter for input target with a default return"""
try:
return self.__getitem_without_validation__(key)
except KeyError:
return value
def get(self, node: Node, value=None):
"""Get validated converter for input node with a default return"""
try:
return self.__getitem__(node)
except KeyError:
return value
def __contains__(self, key: Union[Target, Node]):
"""Check whether a converter for an input node or target exists"""
try:
# Attempt to access the item in the registry
if isinstance(key, Node):
self.__getitem__(key)
else:
self.__getitem_without_validation__(key)
return True
except KeyError:
return False
def get_all_converters_with_target(
self, key: Target, return_registry_info: bool = False
):
"""Get all converters across all registries for the target
Returns a list of all converterts having the specified target
"""
self.validate_invariants()
converters_with_target = []
# Store count of number of registered converters per registry
if return_registry_info:
registry_data = {name: 0 for name in self.registry_names}
for index, registry in enumerate(self.registries):
if key in registry:
converters = registry[key]
if isinstance(converters, (list, tuple)):
converters_with_target.extend(
[c.converter_implementation for c in converters]
)
# Add converter count to registry name storage
if return_registry_info:
registry_data[self.registry_names[index]] += len(converters)
else:
converters_with_target.append(converters)
# Add converter count to registry name storage
if return_registry_info:
registry_data[self.registry_names[index]] += 1
if return_registry_info:
return converters_with_target, registry_data
else:
return converters_with_target
def __setitem__(self, key, value):
raise AssertionError(
f"Do not set registry members directly through the ConverterRegistry object. "
+ f"Attempted to set {key}: {value} via direct assignment to ConverterRegistry."
)
def __delitem__(self, key):
raise AssertionError(
f"Do not delete registry members directly through the ConverterRegistry object. "
+ f"Attempted to delete {key} via direct del on ConverterRegistry."
)
def __len__(self):
"""Returns the sum of lengths of all registries stored"""
return sum(len(registry) for registry in self.registries)
def unique_targets(self):
"""Returns the set of unique converter targets stored across all registries"""
return set.union(*[set(registry.keys()) for registry in self.registries])
def qualified_name_or_str(self, target: Target) -> str:
"""Returns string representation of an FX Node target"""
if isinstance(target, str):
return target
else:
return _get_qualified_name(target)
def display_all_available_converters(self) -> str:
"""Returns a string with all converters and their source, separated by newlines"""
available_converters = "Available converters in ATen registries with counts:\n"
for target in sorted(
self.unique_targets(), key=lambda target: self.qualified_name_or_str(target)
):
_, registry_data = self.get_all_converters_with_target(
target, return_registry_info=True
)
available_converters += f"Node: {self.qualified_name_or_str(target)} - Registry Presence Counts: {registry_data}\n"
return available_converters
# Initialize dynamo converter registry with the FX and Dynamo aten registries
# Note the Dynamo registry is listed first, for precedence
DYNAMO_CONVERTERS: ConverterRegistry = ConverterRegistry(
[DYNAMO_ATEN_CONVERTERS, CONVERTERS],
["Dynamo ATen Converters Registry", "FX ATen Converters Registry"],
)