diff --git a/docs/src/reference.md b/docs/src/reference.md index 69dc7ed5c..f4ae5fdbf 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -71,6 +71,8 @@ ConvDims depthwiseconv DepthwiseConvDims DenseConvDims +unfold +fold ``` ## Upsampling diff --git a/src/NNlib.jl b/src/NNlib.jl index 0fd1f4cbc..19658d2be 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -61,6 +61,8 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter, include("conv_bias_act.jl") export conv_bias_act, conv_bias_act! +include("fold.jl") + include("ctc.jl") export ctc_loss diff --git a/src/fold.jl b/src/fold.jl new file mode 100644 index 000000000..dcde60e4f --- /dev/null +++ b/src/fold.jl @@ -0,0 +1,199 @@ + +""" + unfold(x, kernel_size; stride = 1, pad = 0, dilation = 0, flipped = true) + +Places sliding windows of x into a container tensor of size `(num_windows, +window_size, batchsize)`. The window size is determined by the `prod(spatial dims +of kernel)*input_channels`. The number of sliding windows will match those of +convolution (`conv`) with the same kernel_size and arguments. Note that +by default `conv` flips the spatial dimensions of its kernel (default +`flipped=false`), whereas `unfold` does not (default `flipped=true`). +Uses `NNlib.im2col!` as backend. + +See also [`fold`](@ref), the adjoint/transpose operator +and a potential inverse of `unfold`. + +# Example +The below example demonstrates that `unfold` uses the same sliding windows as `conv`. +In general [`batched_mul`](@ref) + `unfold` should not be used to achieve convolution. +```jldoctest +julia> x = reshape([100 2 3 40 5 6 700], 7, 1, 1); # 1D data, 1 channel, batch of 1 + +julia> w = reshape([1 0 -1], 3, 1, 1); # 1D conv kernel of length 3 + +julia> kws = (pad=1, stride=2, flipped=true); # use same args for conv and unfold + +julia> z = NNlib.unfold(x, size(w); kws...) +4×3×1 Array{Int64, 3}: +[:, :, 1] = + 0 100 2 + 2 3 40 + 40 5 6 + 6 700 0 + +julia> y1 = conv(x, w; kws...) +4×1×1 Array{Int64, 3}: +[:, :, 1] = + -2 + -38 + 34 + 6 + +julia> y2 = z ⊠ w # ⊠ (\\boxtimes) is NNlib.batched_mul +4×1×1 Array{Int64, 3}: +[:, :, 1] = + -2 + -38 + 34 + 6 +``` +""" +function unfold(x::AbstractArray{T, N}, kernel_size::NTuple{K}; stride = 1, pad = 0, dilation = 1, flipped = true) where {T, K, N} + stride = expand(Val(N - 2), stride) + padding = expand(Val(N - 2), pad) + dilation = expand(Val(N - 2), dilation) + cdims = DenseConvDims(size(x), kernel_size; stride, padding, dilation, flipkernel=flipped) + return unfold(x, cdims) +end + +""" + fold(y, output_size, kernel_size; stride = 1, pad = 0, dilation = 0, flipped = true) + +The adjoint/transpose operator of `unfold`. It accumulates sliding windows from +the output of `unfold` into a container tensor of size `output_size`. An inverse +to `unfold` may be obtained (in some cases) by using `fold` and accounting for scaling issues +with a divisor (see example). Uses `NNlib.col2im!` as backend. + +See also [`unfold`](@ref). + +# Example +```jldoctest +julia> x = reshape([100 2 3 40 5 6 700], 7, 1, 1); # 1D data, 1 channel, batch of 1 + +julia> y = NNlib.unfold(x, (3,1,1)) # sliding window of size 3 +5×3×1 Array{Int64, 3}: +[:, :, 1] = + 100 2 3 + 2 3 40 + 3 40 5 + 40 5 6 + 5 6 700 + +julia> z = NNlib.fold(y, size(x), (3,1,1)) # sum of contributions in y. 100 appears once, 40 three times +7×1×1 Array{Int64, 3}: +[:, :, 1] = + 100 + 4 + 9 + 120 + 15 + 12 + 700 + +julia> divisor = NNlib.fold(NNlib.unfold(ones(size(x)...), (3,1,1)), size(x), (3,1,1)) +7×1×1 Array{Float64, 3}: +[:, :, 1] = + 1.0 + 2.0 + 3.0 + 3.0 + 3.0 + 2.0 + 1.0 + +julia> z ./ divisor +7×1×1 Array{Float64, 3}: +[:, :, 1] = + 100.0 + 2.0 + 3.0 + 40.0 + 5.0 + 6.0 + 700.0 +``` +In general, an inverse to `unfold` does not exist if `divisor` contains zeros. +""" +function fold(x::AbstractArray{T, 3}, output_size::NTuple{N}, kernel_size::NTuple{K}; stride = 1, pad = 0, dilation = 1, flipped = true) where {T, K, N} + stride = expand(Val(N - 2), stride) + padding = expand(Val(N - 2), pad) + dilation = expand(Val(N - 2), dilation) + cdims = DenseConvDims(output_size, kernel_size; stride, padding, dilation, flipkernel=flipped) + return fold(x, output_size, cdims) +end + +# im2col_dims returns (numblocks, blocksize, threadnum) where thread dim is used as thread-local +# workspace for multithreaded conv. Ultimately, we want to threadnum with batchsize. +unfold_dims(cdims::DenseConvDims) = im2col_dims(cdims)[1:2] + +# auto-allocating versions +function unfold(x::AbstractArray{T, N}, cdims::DenseConvDims) where {T, N} + y = similar(x, unfold_dims(cdims)..., size(x, N)) # (numblocks, blocksize, batchsize) + return unfold!(y, x, cdims) +end + +function fold(y::AbstractArray{T, 3}, output_size::NTuple, cdims::DenseConvDims) where {T} + x = similar(y, output_size) + return fold!(x, y, cdims) +end + +# N < 5 -dimension in-place versions +function unfold!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, N}, cdims::DenseConvDims) where {yT, xT, N} + unfold!( + y, + insert_singleton_spatial_dimension(x, 5-N), + insert_singleton_spatial_dimension(cdims, 5-N), + ) + return y +end + +function fold!(x::AbstractArray{xT, N}, y::AbstractArray{yT, 3}, cdims::DenseConvDims) where {yT, xT, N} + fold!( + insert_singleton_spatial_dimension(x, 5-N), + y, + insert_singleton_spatial_dimension(cdims, 5-N), + ) + return x +end + +# 5-dimension in-place versions +function unfold!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, 5}, cdims::DenseConvDims) where {yT, xT} + @threads for batch_idx in 1:size(x, 5) + y_slice = view(y, :, :, batch_idx) + im2col!(y_slice, view(x, :, :, :, :, batch_idx), cdims) + end + return y +end + +function fold!(x::AbstractArray{xT, 5}, y::AbstractArray{yT, 3}, cdims::DenseConvDims) where {xT, yT} + @threads for batch_idx in 1:size(x, 5) + y_slice = view(y, :, :, batch_idx) + col2im!(view(x, :, :, :, :, batch_idx), y_slice, cdims) + end + return x +end + +# reverse diff rules +function rrule(::typeof(unfold), x, cdims::DenseConvDims; kw...) + function unfold_pullback(Δ) + return ( + NoTangent(), + fold(unthunk(Δ), size(x), cdims; kw...), + NoTangent(), + ) + end + return unfold(x, cdims; kw...), unfold_pullback +end + +function rrule(::typeof(fold), x, output_size, cdims::DenseConvDims; kw...) + function fold_pullback(Δ) + return ( + NoTangent(), + unfold(unthunk(Δ), cdims; kw...), + NoTangent(), + NoTangent(), + ) + end + return fold(x, output_size, cdims; kw...), fold_pullback +end + diff --git a/test/fold.jl b/test/fold.jl new file mode 100644 index 000000000..35296fb47 --- /dev/null +++ b/test/fold.jl @@ -0,0 +1,40 @@ +using NNlib, Test + +@testset "unfold wrapper" begin + x = rand(rng, 16, 16, 3, 10) + w = rand(rng, 5, 5, 3, 2) + @test size(NNlib.unfold(x, size(w))) == (144, 75, 10) + @test size(NNlib.unfold(x, size(w); pad=2)) == (256, 75, 10) + @test size(NNlib.unfold(x, size(w); stride=2)) == (36, 75, 10) + @test size(NNlib.unfold(x, size(w); dilation=2)) == (64, 75, 10) +end + +@testset "Inverses: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) + x = rand(rng, repeat([8], spatial_rank)..., 3, 2) + w = rand(rng, repeat([3], spatial_rank)..., 3, 3) + cdims = DenseConvDims(x, w; padding=1) + y = NNlib.unfold(x, cdims) + z = NNlib.fold(y, size(x), cdims) + divisor = NNlib.fold(NNlib.unfold(ones(eltype(x), size(x)...), cdims), size(x), cdims) + @test isapprox(z ./ divisor, x, rtol=1.0e-7) + + # introduce stride + cdims = DenseConvDims(x, w; padding=1, stride=2) + y = NNlib.unfold(x, cdims) + z = NNlib.fold(y, size(x), cdims) + divisor = NNlib.fold(NNlib.unfold(ones(eltype(x), size(x)...), cdims), size(x), cdims) + @test isapprox(z ./ divisor, x, rtol=1.0e-7) +end + +@testset "AutoDiff: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) + x = rand(rng, repeat([5], spatial_rank)..., 3, 2) + w = rand(rng, repeat([3], spatial_rank)..., 3, 3) + cdims = DenseConvDims(x, w) + gradtest(x -> NNlib.unfold(x, cdims), x) + test_rrule(NNlib.unfold, x, cdims) + + y = NNlib.unfold(x, cdims) + gradtest(y -> NNlib.fold(y, size(x), cdims), y) + test_rrule(NNlib.fold, y, size(x), cdims) +end + diff --git a/test/runtests.jl b/test/runtests.jl index afb2c0802..22f51779f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,6 +52,10 @@ include("test_utils.jl") include("ctc.jl") end + @testset "Fold/Unfold" begin + include("fold.jl") + end + @testset "Inference" begin include("inference.jl") end