Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assert RAFT input resolution is 128 x 128 or higher #7339

Merged
merged 9 commits into from
May 11, 2023
Merged

Assert RAFT input resolution is 128 x 128 or higher #7339

merged 9 commits into from
May 11, 2023

Conversation

ChristophReich1996
Copy link
Contributor

Hi torchvision community,

As described in issue #7338 the RAFT architecture is only suitable for input resolutions of 128 x 128 or higher (and res. divisible by 8). Otherwise, when performing grid_sample a division by zero occurs in the lowest resolution stage. Thus, the full output of the RAFT model will be nan for resolutions smaller than 128 x 128.
This pull request just adds a check if the input images entail a resolution of 128 x 128 or higher to avoid nans.

@facebook-github-bot
Copy link

Hi @ChristophReich1996!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @ChristophReich1996 , LGTM with a minor modification below

Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
@NicolasHug NicolasHug changed the title Assume RAFT input resolution is 128 x 128 or higher Assert RAFT input resolution is 128 x 128 or higher Feb 27, 2023
NicolasHug
NicolasHug previously approved these changes Feb 27, 2023
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @ChristophReich1996 LGTM. let's just wait for the checks to be green before merging.

@ChristophReich1996
Copy link
Contributor Author

The RAFT tests fail due to the low resolution in the tests. The issue (file size vs res.) is described here:

vision/test/test_models.py

Lines 1032 to 1044 in b030e93

# We need very small images, otherwise the pickle size would exceed the 50KB
# As a resut we need to override the correlation pyramid to not downsample
# too much, otherwise we would get nan values (effective H and W would be
# reduced to 1)
corr_block = models.optical_flow.raft.CorrBlock(num_levels=2, radius=2)
model = model_fn(corr_block=corr_block).eval().to("cuda")
if scripted:
model = torch.jit.script(model)
bs = 1
img1 = torch.rand(bs, 3, 80, 72).cuda()
img2 = torch.rand(bs, 3, 80, 72).cuda()

Can we go to 128 x 128 in this test?

@NicolasHug
Copy link
Member

Ah, I remember that now... That's unfortunate. I guess this was the reason we didn't hard-code a check like the one proposed in this PR.

Our tests are pretty terrible overall and we've been thinking about revamping them for a long time now. I don't know when this will be done but unfortunately, I'm not sure we'll be able to change the size of the input for now. That would make the expected weights files too big, which would then increase the size of the repo.

Sorry @ChristophReich1996 , unfortunately we might not be able to merge that PR right now 😞

@NicolasHug NicolasHug dismissed their stale review February 27, 2023 14:13

Broken test

@ChristophReich1996
Copy link
Contributor Author

Idea: Let's not check if the image is large enough but let's check if the feature maps are large enough for the correlation block... Since, as in the test, it also depents on the correlation block parameters if nans occur. Let me know what you think!

self.corr_block.build_pyramid(fmap1, fmap2)

@NicolasHug
Copy link
Member

Sure @ChristophReich1996 , that sounds good. Happy to review a proposal

@ChristophReich1996
Copy link
Contributor Author

@NicolasHug the current version takes the resolution spatial resolution of the feature maps and checks if the downsampled version in the correlation block is larger or equal to two. If this is not the case grid_sample will produce nans.

_, _, h_fmap, w_fmap = fmap1.shape
if not (((h_fmap // 2**(self.corr_block.num_levels - 1))) < 2) and (
((w_fmap // 2**(self.corr_block.num_levels - 1))) < 2
):

Additionally, the error message provides the lowest resolution required for the used RAFT configuration.
min_res = 2 * 2**(self.corr_block.num_levels - 1) * 8
raise ValueError(
f"input image resolution is too small image resolution should be at least {min_res} (h) and {min_res} (w), got {h} (h) and {w} (w)"
)

@pytorch-bot
Copy link

pytorch-bot bot commented May 10, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/7339

Note: Links to docs will display an error until the docs builds have been completed.

❌ 25 New Failures

As of commit 949c0b0:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the late review @ChristophReich1996 and thanks for the changes. I made a comment below to hopefully simplify the computations, LMK what you think. I'd also suggest to put those checks within the corr_block.build_pyramid() method, instead of in the forward() pass. It's functionally equivalent but it makes sense to put the pyramid-related checks within the pyramid class.

Comment on lines 484 to 491
_, _, h_fmap, w_fmap = fmap1.shape
if not (((h_fmap // 2**(self.corr_block.num_levels - 1))) < 2) and (
((w_fmap // 2**(self.corr_block.num_levels - 1))) < 2
):
min_res = 2 * 2**(self.corr_block.num_levels - 1) * 8
raise ValueError(
f"input image resolution is too small image resolution should be at least {min_res} (h) and {min_res} (w), got {h} (h) and {w} (w)"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be equivalent to simply check h_fmap and w_fmap against min_res? Something like:

Suggested change
_, _, h_fmap, w_fmap = fmap1.shape
if not (((h_fmap // 2**(self.corr_block.num_levels - 1))) < 2) and (
((w_fmap // 2**(self.corr_block.num_levels - 1))) < 2
):
min_res = 2 * 2**(self.corr_block.num_levels - 1) * 8
raise ValueError(
f"input image resolution is too small image resolution should be at least {min_res} (h) and {min_res} (w), got {h} (h) and {w} (w)"
)
_, _, h_fmap, w_fmap = fmap1.shape
min_res = 2 * 2**(self.corr_block.num_levels - 1) * 8
if (h_fmap, w_fmap) < (min_res, min_res):
raise ValueError(...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @NicolasHug, thanks for the feedback. I agree that the check should be moved to corr_block.build_pyramid(). However, checking h_fmp and w_fmap against min_res is not equivalent to the current check. min_res is the minimum input image resolution whereas h_fmp and w_fmap are the spatial dimensions of the feature maps produced by the encoder (input resolution // 8). But we can check for

_, _, h_fmap, w_fmap = fmap1.shape
        min_res = 2 * 2**(self.corr_block.num_levels - 1) * 8
        if (min_res // 8, min_res // 8) <= (h_fmap, w_fmap):
            raise ValueError(...)

this is definitely cleaner than the current version. Let me draft a new version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New version is ready :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ChristophReich1996

There was still some failures in the tests, I think something was wrong with the condition. I took the liberty to push something in 3f353ec - I think this is correct. LMK what you think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, thanks!

@NicolasHug NicolasHug merged commit b06ea39 into pytorch:main May 11, 2023
@github-actions
Copy link

Hey @NicolasHug!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

@NicolasHug
Copy link
Member

Thanks a lot for the PR @ChristophReich1996 !

facebook-github-bot pushed a commit that referenced this pull request May 16, 2023
Summary: Co-authored-by: Nicolas Hug <nicolashug@meta.com>

Reviewed By: vmoens

Differential Revision: D45903816

fbshipit-source-id: 0b8e8250dcc5abb5b27031fd342d4bd8b75a0bea
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants