Skip to content

Commit

Permalink
use gather with onehot input
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Jul 13, 2021
1 parent e54440b commit a9618af
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
9 changes: 7 additions & 2 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -473,11 +473,16 @@ end

Embedding(in::Integer, out::Integer; init = randn32) = Embedding(init(out, in))

(m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
(m::Embedding)(x::Integer) = m([x])

(m::Embedding)(x::Integer) = m.weight[:, x]
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)

function (m::Embedding)(x::Union{OneHotVector{T,L}, OneHotMatrix{T,L}}) where {T,L}
size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L"))
return m(onecold(x))
end

function Base.show(io::IO, m::Embedding)
print(io, "Embedding($(size(m.weight, 2)), $(size(m.weight, 1)))")
end
22 changes: 15 additions & 7 deletions test/cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,21 @@ function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; te
# test
if test_cpu
@test y_gpu y_cpu rtol=1f-3 atol=1f-3
@test Array(xg_gpu) xg_cpu rtol=1f-3 atol=1f-3
if isnothing(xg_cpu)
@test isnothing(xg_gpu)
else
@test Array(xg_gpu) xg_cpu rtol=1f-3 atol=1f-3
end
end
@test gs_gpu isa Flux.Zygote.Grads
for (p_cpu, p_gpu) in zip(ps_cpu, ps_gpu)
@test gs_gpu[p_gpu] isa Flux.CUDA.CuArray
if test_cpu
@test Array(gs_gpu[p_gpu]) gs_cpu[p_cpu] rtol=1f-3 atol=1f-3
if isnothing(xg_cpu)
@test isnothing(xg_gpu)
else
@test gs_gpu[p_gpu] isa Flux.CUDA.CuArray
if test_cpu
@test Array(gs_gpu[p_gpu]) gs_cpu[p_cpu] rtol=1f-3 atol=1f-3
end
end
end
end
Expand Down Expand Up @@ -114,14 +122,14 @@ pixelshuffle = [PixelShuffle]
gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3)
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)

embedding = [Embedding]
embedding = [Flux.Embedding]
gpu_gradtest("Embedding", embedding, [1,3,5], 5, 2)
gpu_gradtest("Embedding repeated indices", embedding, [1,3,5,3], 5, 2)
gpu_gradtest("Embedding integer index", embedding, 1, 5, 2)
gpu_gradtest("Embedding 2d index", embedding, [1 2; 3 4], 5, 2)
gpu_gradtest("Embedding OneHotVec index", embedding, OneHotVector(1, 5), 5, 2)
gpu_gradtest("Embedding OneHotMatrix index", embedding, OneHotMatrix([1,2,3], 5), 5, 2)
gpu_gradtest("Embedding OneHotMatrix repeated indices", OneHotMatrix([1,2,2], 5), 5, 2)
gpu_gradtest("Embedding OneHotMatrix repeated indices", embedding, OneHotMatrix([1,2,2], 5), 5, 2)

@testset "function layers" begin
x = rand(Float32, 3,3)
Expand All @@ -144,7 +152,7 @@ end
end

@testset "Dense with Zeros bias" begin
l = Dense(ones(Float32, 4,3), Flux.Zeros()) |> gpu
l = Dense(ones(Float32, 4, 3), Flux.Zeros()) |> gpu
ip = zeros(Float32, 3, 7) |> gpu

@test sum(l(ip)) 0.f0
Expand Down

0 comments on commit a9618af

Please sign in to comment.