-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Assert RAFT input resolution is 128 x 128 or higher #7339
Conversation
Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @ChristophReich1996 , LGTM with a minor modification below
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @ChristophReich1996 LGTM. let's just wait for the checks to be green before merging.
The RAFT tests fail due to the low resolution in the tests. The issue (file size vs res.) is described here: Lines 1032 to 1044 in b030e93
Can we go to 128 x 128 in this test? |
Ah, I remember that now... That's unfortunate. I guess this was the reason we didn't hard-code a check like the one proposed in this PR. Our tests are pretty terrible overall and we've been thinking about revamping them for a long time now. I don't know when this will be done but unfortunately, I'm not sure we'll be able to change the size of the input for now. That would make the expected weights files too big, which would then increase the size of the repo. Sorry @ChristophReich1996 , unfortunately we might not be able to merge that PR right now 😞 |
Idea: Let's not check if the image is large enough but let's check if the feature maps are large enough for the correlation block... Since, as in the test, it also depents on the correlation block parameters if
|
Sure @ChristophReich1996 , that sounds good. Happy to review a proposal |
@NicolasHug the current version takes the resolution spatial resolution of the feature maps and checks if the downsampled version in the correlation block is larger or equal to two. If this is not the case vision/torchvision/models/optical_flow/raft.py Lines 484 to 487 in 930fcf7
Additionally, the error message provides the lowest resolution required for the used RAFT configuration. vision/torchvision/models/optical_flow/raft.py Lines 488 to 491 in 930fcf7
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the late review @ChristophReich1996 and thanks for the changes. I made a comment below to hopefully simplify the computations, LMK what you think. I'd also suggest to put those checks within the corr_block.build_pyramid()
method, instead of in the forward()
pass. It's functionally equivalent but it makes sense to put the pyramid-related checks within the pyramid class.
_, _, h_fmap, w_fmap = fmap1.shape | ||
if not (((h_fmap // 2**(self.corr_block.num_levels - 1))) < 2) and ( | ||
((w_fmap // 2**(self.corr_block.num_levels - 1))) < 2 | ||
): | ||
min_res = 2 * 2**(self.corr_block.num_levels - 1) * 8 | ||
raise ValueError( | ||
f"input image resolution is too small image resolution should be at least {min_res} (h) and {min_res} (w), got {h} (h) and {w} (w)" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be equivalent to simply check h_fmap
and w_fmap
against min_res
? Something like:
_, _, h_fmap, w_fmap = fmap1.shape | |
if not (((h_fmap // 2**(self.corr_block.num_levels - 1))) < 2) and ( | |
((w_fmap // 2**(self.corr_block.num_levels - 1))) < 2 | |
): | |
min_res = 2 * 2**(self.corr_block.num_levels - 1) * 8 | |
raise ValueError( | |
f"input image resolution is too small image resolution should be at least {min_res} (h) and {min_res} (w), got {h} (h) and {w} (w)" | |
) | |
_, _, h_fmap, w_fmap = fmap1.shape | |
min_res = 2 * 2**(self.corr_block.num_levels - 1) * 8 | |
if (h_fmap, w_fmap) < (min_res, min_res): | |
raise ValueError(...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @NicolasHug, thanks for the feedback. I agree that the check should be moved to corr_block.build_pyramid()
. However, checking h_fmp
and w_fmap
against min_res is not equivalent to the current check. min_res
is the minimum input image resolution whereas h_fmp
and w_fmap
are the spatial dimensions of the feature maps produced by the encoder (input resolution // 8
). But we can check for
_, _, h_fmap, w_fmap = fmap1.shape
min_res = 2 * 2**(self.corr_block.num_levels - 1) * 8
if (min_res // 8, min_res // 8) <= (h_fmap, w_fmap):
raise ValueError(...)
this is definitely cleaner than the current version. Let me draft a new version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New version is ready :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ChristophReich1996
There was still some failures in the tests, I think something was wrong with the condition. I took the liberty to push something in 3f353ec - I think this is correct. LMK what you think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me, thanks!
Hey @NicolasHug! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
Thanks a lot for the PR @ChristophReich1996 ! |
Summary: Co-authored-by: Nicolas Hug <nicolashug@meta.com> Reviewed By: vmoens Differential Revision: D45903816 fbshipit-source-id: 0b8e8250dcc5abb5b27031fd342d4bd8b75a0bea
Hi torchvision community,
As described in issue #7338 the RAFT architecture is only suitable for input resolutions of
128 x 128
or higher (and res. divisible by 8). Otherwise, when performing grid_sample a division by zero occurs in the lowest resolution stage. Thus, the full output of the RAFT model will benan
for resolutions smaller than128 x 128
.This pull request just adds a check if the input images entail a resolution of
128 x 128
or higher to avoidnan
s.