-
Notifications
You must be signed in to change notification settings - Fork 359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
chore: dynamic shape support for any/sort/trunc ops #3026
Conversation
shuffle_layer.get_output(0), | ||
TRT_TOPK_MAX_ELEMENT, | ||
) | ||
ctx.net.add_assertion( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when dynamic k > 3840, there is segmentation fault rather than runtime error. Error is like below
Python error: Segmentation fault
Thread 0x00007fda56bfe640 (most recent call first):
File "/root/.pyenv/versions/3.10.14/lib/python3.10/threading.py", line 324 in wait
File "/root/.pyenv/versions/3.10.14/lib/python3.10/threading.py", line 607 in wait
File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tqdm/_monitor.py", line 60 in run
File "/root/.pyenv/versions/3.10.14/lib/python3.10/threading.py", line 1016 in _bootstrap_inner
File "/root/.pyenv/versions/3.10.14/lib/python3.10/threading.py", line 973 in _bootstrap
Current thread 0x00007fdc02f65740 (most recent call first):
File "/root/trt/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py", line 287 in forward
...
Extension modules: numpy.core._multiarray_umath, ...
Segmentation fault
I tried catch it from add_assertion() but it doesn't work while simple test with add_assertion() was ok.
Any idea to use it? Or we can with this behavior to enable dynamic k value for sort?
There is no problem if dynamic k value is in range as in test code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the issue is due to TRT's limitation. One possible solution is to fall back to pytorch if k > 3840. We don't expect to throw any errors due to TRT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for review. As we discuss, I changed to allow this ops when k value is static because we cannot identify dynamic k value at compile time.
shape=[1], | ||
stride=[1], | ||
) | ||
set_layer_name(layer, target, name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please pass in source_ir
and add suffix for the layer's name
set_layer_name(layer, target, name) | ||
|
||
# Get scalar tensor from 1d tensor | ||
shuffle_layer = ctx.net.add_shuffle(layer.get_output(0)) | ||
shuffle_layer.reshape_dims = trt.Dims() | ||
set_layer_name(shuffle_layer, target, name, source_ir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add suffix for the layer's name
shuffle_layer.get_output(0), | ||
TRT_TOPK_MAX_ELEMENT, | ||
) | ||
ctx.net.add_assertion( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the issue is due to TRT's limitation. One possible solution is to fall back to pytorch if k > 3840. We don't expect to throw any errors due to TRT
d17af0e
to
e1491b0
Compare
dim = node.args[1] | ||
dim = get_positive_dim(dim, len(shape)) | ||
k = shape[dim] | ||
if not isinstance(k, int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If dim k is static and other dims are dynamic, we can support sort ops.
This is to validate if k is static or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comment. Is it possible that k is -1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for bringing this to my attention.
I checked only dynamic dims are passed to export() in torch_tensorrt.dynamo.trace()
https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/_tracer.py#L81
So I think static dim value in meta_data will be valid dim value(>0). If you think static shape dim value from meta data could be -1, I will update to validate this case as well.
+elif isinstance(k, int) and k < 0:
- return False
Sort(), | ||
input_specs, | ||
output_dtypes=[torch.float, torch.int64], | ||
use_dynamo_tracer=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensor meta data is available when use_dynamo_tracer=True
shuffle_layer.get_output(0), | ||
TRT_TOPK_MAX_ELEMENT, | ||
) | ||
ctx.net.add_assertion( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for review. As we discuss, I changed to allow this ops when k value is static because we cannot identify dynamic k value at compile time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
dynamic shape support for any/sort/trunc ops
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: