Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ConvTranspose produces nonsense results for 3D (aka 5D) inputs #978

Closed
bcdarwin opened this issue Dec 27, 2019 · 4 comments
Closed

ConvTranspose produces nonsense results for 3D (aka 5D) inputs #978

bcdarwin opened this issue Dec 27, 2019 · 4 comments

Comments

@bcdarwin
Copy link

bcdarwin commented Dec 27, 2019

E.g.,

julia> ConvTranspose((2, 2, 2), stride=2, 4=>8)(rand(16, 16, 16, 4, 3))
# ERROR: DimensionMismatch("arrays could not be broadcast to a common size")

When the batch size is 1, the application 'succeeds' by coincidence due to broadcasting but the output is of the wrong size.

The PR (#311) that added this operation explicitly noted that only 2D support was added but @doc doesn't mention this and anyway an error should be signalled if the functionality is missing.

Examining the definition reveals that the channel and batch dims somehow end up swapped in the 2d vs 3d case:

ct2 = ConvTranspose((2, 2),    4=>5, stride=2)
ct3 = ConvTranspose((2, 2, 2), 4=>5, stride=2)
mat2 = rand(    16, 16, 4, 3);
mat3 = rand(16, 16, 16, 4, 3);
∇conv_data(mat3, ct3.weight, conv_transpose_dims(ct3, mat3)) |> size
# (32, 32, 32, 3, 5)
∇conv_data(mat2, ct2.weight, conv_transpose_dims(ct2, mat2)) |> size
# (32, 32, 5, 3)
@bcdarwin bcdarwin changed the title ConvTranspose produces nonsense results for 3D ("5D") inputs ConvTranspose produces nonsense results for 3D (aka 5D) inputs Dec 28, 2019
@bcdarwin
Copy link
Author

On some investigation, it looks like the issue is that the 2D case throws away the return value of the ∇conv_data call and thus misses the final transpose_swapbatch that the 3D case undergoes. However, the 2D result is the 'correct' one so I'm not immediately sure how to restructure the code to fix this ...

@bcdarwin
Copy link
Author

bcdarwin commented Jan 2, 2020

Closing here since this is an NNlib issue.

@bcdarwin bcdarwin closed this as completed Jan 2, 2020
@haampie
Copy link

haampie commented Jan 27, 2020

Pinging @tejank10 for this issue as well, maybe he has some insights :)

Ok, this is indeed NNlib internals. For anyone coming across this issue, it might have happened because you mixed Float64 with Float32 (e.g. Float32 weights, Float64 input), because that takes another code path. Everything Float32 does not hit this issue.

@bcdarwin
Copy link
Author

Moved here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants