Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add non-TS'able _resize_image_and_masks variant with less tensor ops #7592

Merged
merged 3 commits into from
May 20, 2023

Conversation

ezyang
Copy link
Contributor

@ezyang ezyang commented May 16, 2023

We did some horrible things to _resize_image_and_masks
to make it TorchScriptable, and those horrible things cause
weird divergences when you send the float computation
to a real compiler that is willing to do fastmath optimizations
to floating point, see pytorch/pytorch#93598

This PR adds a non TS-goopified version of the operator which doesn't
have this problem, since it does the size compute the "normal way"
(and consequently, doesn't get fastmath'ified).

Signed-off-by: Edward Z. Yang ezyang@meta.com

We did some horrible things to _resize_image_and_masks
to make it TorchScriptable, and those horrible things cause
weird divergences when you send the float computation
to a real compiler that is willing to do fastmath optimizations
to floating point, see pytorch/pytorch#93598

This PR adds a non TS-goopified version of the operator which doesn't
have this problem, since it does the size compute the "normal way"
(and consequently, doesn't get fastmath'ified).

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
@pytorch-bot
Copy link

pytorch-bot bot commented May 16, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7592

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit d8043a9:

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@ezyang ezyang requested a review from NicolasHug May 16, 2023 00:52
ezyang added a commit to pytorch/pytorch that referenced this pull request May 16, 2023
The bulk of the heavy lifting is happening in
pytorch/vision#7592

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

[ghstack-poisoned]
ezyang added a commit to pytorch/pytorch that referenced this pull request May 16, 2023
The bulk of the heavy lifting is happening in
pytorch/vision#7592

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: d8f8c954c3d7c45f595847c642d56f97e3322b6f
Pull Request resolved: #101477
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @ezyang . I am growing wary of the maintenance cost we're facing with recent PRs related to PT 2.0 support (#7587, now this PR). Both PRs add a separate implementation for an existing function.

Do we expect to have a lot more of those PT 2.0 support issues in the future? And if yes, is there a alternate solution to what we're currently doing which is to duplicate all implementations?

Supporting the cross product of JIT x ONNX x all_platorms x <insert your preferred tech here> has been a massive challenge in torchvision (and I'm being diplomatic), and I fear that adding yet another factor to that is going to be... err... difficult.

if self.training:
if self._skip_resize:
return image, target
size = random.choice(self.min_size)
Copy link
Member

@NicolasHug NicolasHug May 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests failing because this needs an import. But please use the RNG from torch instead of Python, so that users can fully control randomness by just calling torch.manual_seed() and similar mechanisms.

@ezyang
Copy link
Contributor Author

ezyang commented May 16, 2023

PT 2.0 support (#7587)

I want to treat this and #7587 separately. #7587 has nothing to do with PT2 support per se, it's entirely around support for deterministic algorithms (which I ran into while working on PT2, sure, but it stands on its own.) On the subject of deterministic algorithms, there really are two main approaches for how you can add a deterministic version of a CUDA kernel that uses gpuAtomicAdd: (1) you can write a decomposition (which is what the PR does) or (2) you can write another copy of the CUDA kernel by hand that doesn't have atomic adds in it (the easiest approach is to change the iteration space from grad_output to grad_input, so that the summation happens from a single thread). Writing the decomposition has the added benefit that you can use it to compile things in PT2 (though I don't actually do this in the PR) and it can be a nice, pure-Python reference implementation that you can use for testing / experimentation. So it seems preferable over banging out another CUDA kernel. This seems... like a fair trade for "code duplication"? In the limit, we'd be applying this treatment to every custom operator in torchvision. A long term vision for PT2 is that you wouldn't need to write hand-written CUDA code at all; you could write the pure Python code and generate a kernel automatically from it, but we're still a little bit away from it.

Supporting the cross product of JIT x ONNX x all_platorms x has been a massive challenge in torchvision (and I'm being diplomatic), and I fear that adding yet another factor to that is going to be... err... difficult.

I think the question I would ask you is, if you didn't have to support TorchScript JIT / ONNX, what would this code look like? I argue that your code would look like the new version I've posted: why would you intentionally create a tensor just to do shape computation and convert it back out again? The version of the code here is the clear, idiomatic PyTorch eager implementation of the function.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ezyang added a commit to pytorch/pytorch that referenced this pull request May 16, 2023
The bulk of the heavy lifting is happening in
pytorch/vision#7592

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit to pytorch/pytorch that referenced this pull request May 16, 2023
The bulk of the heavy lifting is happening in
pytorch/vision#7592

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
@ezyang
Copy link
Contributor Author

ezyang commented May 16, 2023

Attempted simplifying the duplication

ezyang added a commit to pytorch/pytorch that referenced this pull request May 16, 2023
The bulk of the heavy lifting is happening in
pytorch/vision#7592

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit to pytorch/pytorch that referenced this pull request May 16, 2023
The bulk of the heavy lifting is happening in
pytorch/vision#7592

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit to pytorch/pytorch that referenced this pull request May 16, 2023
The bulk of the heavy lifting is happening in
pytorch/vision#7592

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit to pytorch/pytorch that referenced this pull request May 16, 2023
The bulk of the heavy lifting is happening in
pytorch/vision#7592

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit to pytorch/pytorch that referenced this pull request May 16, 2023
The bulk of the heavy lifting is happening in
pytorch/vision#7592

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 28e7047b3ed1058cf3e8009cd29e7146cacc4426
Pull Request resolved: #101477
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for trying to minimize the code duplication @ezyang . As discussed offline the ONNX tests are red, but I'll stamp to unblock.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
ezyang added a commit to pytorch/pytorch that referenced this pull request May 20, 2023
The bulk of the heavy lifting is happening in
pytorch/vision#7592

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit to pytorch/pytorch that referenced this pull request May 20, 2023
The bulk of the heavy lifting is happening in
pytorch/vision#7592

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit to pytorch/pytorch that referenced this pull request May 20, 2023
The bulk of the heavy lifting is happening in
pytorch/vision#7592

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: 3a058371e45a21cf35cdb91ce1d313295f2d80ff
Pull Request resolved: #101477
@ezyang ezyang merged commit 300a909 into main May 20, 2023
@github-actions
Copy link

Hey @ezyang!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

ezyang added a commit to pytorch/pytorch that referenced this pull request May 20, 2023
…krcnn"

The bulk of the heavy lifting is happening in
pytorch/vision#7592

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit to pytorch/pytorch that referenced this pull request May 20, 2023
The bulk of the heavy lifting is happening in
pytorch/vision#7592

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx desertfire

[ghstack-poisoned]
ezyang added a commit to pytorch/pytorch that referenced this pull request May 20, 2023
The bulk of the heavy lifting is happening in
pytorch/vision#7592

Signed-off-by: Edward Z. Yang <ezyangmeta.com>

ghstack-source-id: e4b86e6f0b9aad0142e227f01b3bb3561cd272cd
Pull Request resolved: #101477
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request May 21, 2023
The bulk of the heavy lifting is happening in
pytorch/vision#7592

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #101477
Approved by: https://github.com/voznesenskym
facebook-github-bot pushed a commit that referenced this pull request May 23, 2023
…nsor ops (#7592)

Summary: Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Reviewed By: vmoens

Differential Revision: D46071415

fbshipit-source-id: bd1575ba95700e29f5565b720d9d1be070736fe8
ezyang added a commit to ezyang/vision that referenced this pull request Sep 7, 2023
This is a small follow up on pytorch#7592
that makes this Dynamo exportable.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
@fmassa fmassa deleted the maskrcnn-descale branch September 26, 2023 08:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants