fix F.interpolate()
for large batch sizes
#1006
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes #984.
It seems that
F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
breaks for batch sizes > 64 whenhidden_states
uses channels last format. See pytorch/pytorch#81665 and #984This PR proposes to force a contiguous format for hidden states when
bsz > 64
. Credits to @pcuenca for the find.The following now works (after applying the memory efficient PR + this PR)
cc @pcuenca @patrickvonplaten @patil-suraj