-
Notifications
You must be signed in to change notification settings - Fork 362
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add Selective ATen decompositions #2173
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
34a190e
to
bdb06d8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
bdb06d8
to
368a20e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
ENABLED_TORCH_DECOMPOSITIONS: Dict[ | ||
torch._ops.OpOverload, Callable | ||
] = get_torch_decompositions(enabled_decompositions) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, the decompositions are sourced directly from Torch's main registry (via get_torch_decompositions
) and may not exactly match with the _core_aten_decompositions
. This is because certain decompositions which we depend on (such as native_layer_norm
, may occasionally be removed from the core set).
Whenever Torch versions are upgraded, this list should be updated as well.
ENABLED_TORCH_DECOMPOSITIONS: Dict[ | ||
torch._ops.OpOverload, Callable | ||
] = get_torch_decompositions(enabled_decompositions) | ||
TORCH_TRT_DECOMPOSITIONS: Dict[torch._ops.OpOverload, Callable] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The decompositions are three dictionaries:
ENABLED_TORCH_DECOMPOSITIONS
- the enabled decompositions we've pre-selectedCORE_ATEN_DECOMPOSITIONS_FILTERED
(defined inget_decompositions
below) - the complete set of_core_aten_decompositions
Torch provides, minus the set of disabled decompositions. Note thatTORCH_DECOMPOSITIONS
may not be a subset of this setTORCH_TRT_DECOMPOSITIONS
- the decompositions we've written ourselves
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would TORCH_DECOMPOSITIONS only include decompositions from the get_torch_decompoistions set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that TORCH_DECOMPOSITIONS may not be a subset of this set
From what I understand,
- it seems like some decompositions in
_core_aten_decompositions
have been removed eg:aten.native_layer_norm
. ENABLED_TORCH_DECOMPOSITIONS is a more complete set (from previous commit maybe ). Is this correct ? - In that case, what if we move aten.native.layer_norm to TORCH_TRT_DECOMPOSITIONS since it is useful to us and maybe other useful ones instead of maintaining a
ENABLED_TORCH_DECOMPOSITIONS
which overlaps with_core_aten_decompositions
one ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@narendasan
ENABLED_TORCH_DECOMPOSITIONS
would only include decompositions from the _core_aten_decompositions
set which are not also in disabled_decompositions
- The interpretation of
ENABLED_TORCH_DECOMPOSITIONS
is correct - My initial intent for
TORCH_TRT_DECOMPOSITIONS
was that it would only store decompositions we specifically (custom) wrote, not ones sourced from Torch, aslayer_norm
would be.
def reciprocal_replacement( | ||
input_: torch.Tensor, | ||
) -> torch.Tensor: | ||
return torch.div(1, input_) | ||
|
||
|
||
def get_decompositions(): | ||
return DECOMPOSITIONS | ||
def get_decompositions( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dictionary returned by get_decompositions
is either ENABLED_TORCH_DECOMPOSITIONS
or CORE_ATEN_DECOMPOSITIONS_FILTERED
concatenated with our TORCH_TRT_DECOMPOSITIONS
Have we thought about what this might look like if its user accessible? |
Can we add a tool to monitor these decomposition sets similar to the opset coverage tool? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
8b12de5
to
806b348
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
The existing opset coverage tool is compatible with this PR, meaning that |
806b348
to
7fa036f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
7fa036f
to
d49cadb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
- Add sets to selectively enable or disable decompositions in Torch - Add new runtime argument `enable_experimental_decompositions` to enable all core aten decompositions, or a pre-selected subset thereof - Improve documentation of compilation settings overall
- Add decorator-wrapper to perform import-time checks on decompositions and alert the user if any custom decompositions conflict with existing registered or specified operators - Simplify code logic for dictionary merging in `get_decompositions` function - Add safety logic to ensure invariants about the decompositions are not violated
d49cadb
to
1e3d12e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
1e3d12e
to
2064f4f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
Description
enable_experimental_decompositions
to enable all core aten decompositions, or a pre-selected subset thereofFixes #2160
Type of change
Checklist: