-
Notifications
You must be signed in to change notification settings - Fork 359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Device casting issues with certain aten
operators
#1416
Conversation
%false: bool = prim::Constant[value=0]() | ||
%mask_cuda: Tensor = aten::to(%mask, %device, %dtype, %false, %false) | ||
%self_cuda: Tensor = aten::to(%self, %device, %dtype, %false, %false) | ||
%out: Tensor = aten::masked_fill_(%self_cuda, %mask_cuda, %value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be replaced by aten::masked_fill
, for which a converter exists, however there are a few bugs which arise in edge cases when %value
is a float but %self_cuda
is an int, and similar such scenarios which are not directly handled by the converter and can cause errors in TRT. As a result, I opted for the unimplemented aten::masked_fill_
version with casted tensors, in the meantime.
// should be casted to CUDA to avoid device mismatch errors | ||
std::string unpacked_pattern = R"IR( | ||
graph(%self, %mask, %value): | ||
%device: Device = prim::Constant[value="cuda"]() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens in the case of multi-gpu systems?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potentially could add an argument to take the target device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Take a look at snprintf and modifying core::CompileSpec::lower_info
to add a device field which replicates the device info from the external API, then you should able to determine the target device at lower time.
%false: bool = prim::Constant[value=0]() | ||
%mask_cuda: Tensor = aten::to(%mask, %device, %dtype, %false, %false) | ||
%self_cuda: Tensor = aten::to(%self, %device, %dtype, %false, %false) | ||
%out: Tensor = aten::masked_fill_(%self_cuda, %mask_cuda, %value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now this would handle the in place case. Can handle the functional case here too?
nvinfer1::DataType promote_types(nvinfer1::DataType type_a, nvinfer1::DataType type_b) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion - this actually turns out to be a bug in the aten::masked_fill
converter, as it behaves differently than Torch when the types of the input and value are different. Specifically, the converter throws an error whereas Torch just inherits the type of the first argument. I will make a new PR + Test Cases for this, as it is an unrelated bug.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix in PR #1430
std::string num_to_tensor_clean_pattern = R"IR( | ||
graph(%1: Scalar): | ||
%2: Tensor = prim::NumToTensor(%1) | ||
%device: Device = prim::Constant[value="cuda"]() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above about correct device
// to avoid device mismatch issues | ||
std::string full_clean_pattern = R"IR( | ||
graph(%1, %2, %3, %4, %5, %6): | ||
%cuda: Device = prim::Constant[value="cuda"]() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
""
std::string clean_pattern_part_1 = R"IR( | ||
graph(%1: Scalar): | ||
%2: Tensor = prim::NumToTensor(%1) | ||
%device: Device = prim::Constant[value=")IR"; | ||
|
||
std::string clean_pattern_part_2 = R"IR("]() | ||
%dtype: NoneType = prim::Constant() | ||
%false: bool = prim::Constant[value=0]() | ||
%3: Tensor = aten::to(%2, %device, %dtype, %false, %false) | ||
return (%3))IR"; | ||
|
||
auto num_to_tensor_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had to use this paradigm instead of snprintf
because the %
symbols in the IR are registered as formatting for snprintf
, which made it difficult to insert the device string
decd728
to
cebc58f
Compare
|
||
for (auto& in : inputs) { | ||
in = in.to(torch::Device(target_device)); | ||
} | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't need to be an else, could just be a second check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the else block to just assign the cuda
target device name, and now the runtime device check is applied as a second check
fixes: #1446 |
- Investigated issue arising with BART-base model (https://huggingface.co/facebook/bart-base) where certain tensor inputs to TensorRT were on the cpu, despite users explicitly casting all inputs properly - Traced issue to internally-generated 0D tensors, mask tensors, and operations returning CPU tensors passed between Torch and Torch-TensorRT engines - Added lowering passes to ensure function edge cases are appropriately dealt with, tensors are located on the proper device at runtime, and added validation check in runtime to avoid models crashing at runtime due to device mismatches - Added testing for lowering passes to ensure output values are accurate
…evice - Adde field to LowerInfo to hold device information - Update internal Device struct location to allow streamlined imports - Update BUILD files - Build strings in lowering phase using user-specified target device - Update CMakeLists to reflect IR dependency in lowering - Update runtime device location code to run regardless of whether a switch is required or not.
9e7822a
to
8583a4c
Compare
Description
Type of change
aten
functions and castingChecklist: