diff --git a/Project.toml b/Project.toml index 763e4d921a..c444870dcc 100644 --- a/Project.toml +++ b/Project.toml @@ -54,6 +54,7 @@ MPI = "0.20.19" MacroTools = "0.5" NCCL = "0.1.1" NNlib = "0.9.22" +Metal = "0.5, 1" OneHotArrays = "0.2.4" Optimisers = "0.3.3" Preferences = "1" diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 254f06db0c..c0ba62ed22 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -170,9 +170,8 @@ end function (a::Dense)(x::AbstractVecOrMat) _size_check(a, x, 1 => size(a.weight, 2)) - σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc xT = _match_eltype(a, x) # fixes Float64 input, etc. - return σ.(a.weight * xT .+ a.bias) + NNlib.bias_act!(a.σ, a.weight * xT, a.bias) # does σ.(W*x .+ b), with fast paths end function (a::Dense)(x::AbstractArray) @@ -450,7 +449,7 @@ function (a::Bilinear)(x::AbstractMatrix, y::AbstractMatrix) Z = reshape(Wyx, (d_z, :)) # @einsum out[o,s] := σ(Z[o,i] + b[o]) - σ.(Z .+ b) + NNlib.bias_act!(σ, Z, b) # σ.(Z .+ b) end (a::Bilinear)(x::AbstractVecOrMat) = a(x, x) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 8ba07b95a8..d09d0bcd69 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -196,10 +196,9 @@ ChainRulesCore.@non_differentiable conv_dims(::Any, ::Any) function (c::Conv)(x::AbstractArray) _conv_size_check(c, x) - σ = NNlib.fast_act(c.σ, x) cdims = conv_dims(c, x) xT = _match_eltype(c, x) - σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c)) + NNlib.bias_act!(c.σ, conv(xT, c.weight, cdims), conv_reshape_bias(c)) end _channels_in(l::Conv) = size(l.weight, ndims(l.weight)-1) * l.groups @@ -350,10 +349,9 @@ ChainRulesCore.@non_differentiable conv_transpose_dims(::Any, ::Any) function (c::ConvTranspose)(x::AbstractArray) _conv_size_check(c, x) - σ = NNlib.fast_act(c.σ, x) cdims = conv_transpose_dims(c, x) xT = _match_eltype(c, x) - σ.(∇conv_data(xT, c.weight, cdims) .+ conv_reshape_bias(c)) + NNlib.bias_act!(c.σ, ∇conv_data(xT, c.weight, cdims), conv_reshape_bias(c)) end function Base.show(io::IO, l::ConvTranspose) @@ -493,10 +491,9 @@ ChainRulesCore.@non_differentiable crosscor_dims(::Any, ::Any) function (c::CrossCor)(x::AbstractArray) _conv_size_check(c, x) - σ = NNlib.fast_act(c.σ, x) cdims = crosscor_dims(c, x) xT = _match_eltype(c, x) - σ.(crosscor(xT, c.weight, cdims) .+ conv_reshape_bias(c)) + NNlib.bias_act!(c.σ, crosscor(xT, c.weight, cdims), conv_reshape_bias(c)) end function Base.show(io::IO, l::CrossCor) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 99092f9756..9d294e3e6e 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -245,7 +245,7 @@ function _norm_layer_forward( β = reshape(l.β, affine_shape) scale = γ ./ sqrt.(σ² .+ eps) - bias = -scale .* μ .+ β + bias = .-scale .* μ .+ β l.λ.(scale .* x .+ bias) end