Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deterministic, pure-Python roi_align implementation #7587

Merged
merged 11 commits into from
May 16, 2023
Merged

Conversation

ezyang
Copy link
Contributor

@ezyang ezyang commented May 14, 2023

@pytorch-bot
Copy link

pytorch-bot bot commented May 14, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7587

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit d7f4e80:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ezyang added 3 commits May 14, 2023 05:53
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ezyang . PR Looks good, I just have minor Qs and comments below. There's nothing critical so I'll approve now.

I haven't checked the correctness of the new implementation too closely, I'm mostly relying on our tests for that. Out of precaution I parametrized the test_forward() over 50 seeds locally and they all passed.

# https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266
@torch._dynamo.allow_in_graph
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
from functorch.dim import dims
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to lazy import? Isn't the functorch namespace always available now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is supposed to always be available but @zou3519 mentioned to me that xplat mumble mumble doesn't have working torchdims build mumble? In any case, there is not much harm in making it lazy like this, so I went ahead and did it this way.

Copy link
Contributor

@zou3519 zou3519 May 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better reason that I remembered just now is that import functorch.dim monkey-patches torch and we really do not want to monkey patch torch by default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For safety purposes, I hand-inserted all of the expands necessary to remove the first-class dims impl. We should still keep the first class dims version around for documentary purposes though.

ezyang added 7 commits May 15, 2023 06:32
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
@ezyang ezyang merged commit fc838ad into main May 16, 2023
@github-actions
Copy link

Hey @ezyang!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

facebook-github-bot pushed a commit that referenced this pull request May 16, 2023
Summary: Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Reviewed By: vmoens

Differential Revision: D45903818

fbshipit-source-id: 4d80da89da5e149f64c5fde2d31bf9c490232b91
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants