Skip to content

Commit

Permalink
use gather; fix outdated docs
Browse files Browse the repository at this point in the history
Co-authored-by: Manikya <manikyabard@gmail.com>
  • Loading branch information
CarloLucibello and manikyabard committed Jul 10, 2021
1 parent 4d3944c commit 9553267
Show file tree
Hide file tree
Showing 10 changed files with 72 additions and 64 deletions.
56 changes: 28 additions & 28 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ version = "3.2.1"

[[ChainRules]]
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "720fa9a9ce61ff18842a40f501d6a1f8ba771c64"
git-tree-sha1 = "85c579fa131b5545eef874a5b413bb3b783e21c6"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.8.6"
version = "0.8.21"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
git-tree-sha1 = "8b31cc69cbc38c5c826aaa1c890c694be3622d99"
git-tree-sha1 = "dcc25ff085cf548bc8befad5ce048391a7c07d40"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.10.3"
version = "0.10.11"

[[CodecZlib]]
deps = ["TranscodingStreams", "Zlib_jll"]
Expand Down Expand Up @@ -87,18 +87,18 @@ version = "0.3.0"

[[Compat]]
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
git-tree-sha1 = "e4e2b39db08f967cc1360951f01e8a75ec441cab"
git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "3.30.0"
version = "3.31.0"

[[CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"

[[DataAPI]]
git-tree-sha1 = "dfb3b7e89e395be1e25c2ad6d7690dc29cc53b1d"
git-tree-sha1 = "ee400abb2298bd13bfc3df1c412ed228061a2385"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.6.0"
version = "1.7.0"

[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
Expand Down Expand Up @@ -141,15 +141,15 @@ deps = ["ArgTools", "LibCURL", "NetworkOptions"]
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"

[[ExprTools]]
git-tree-sha1 = "10407a39b87f29d47ebaca8edbc75d7c302ff93e"
git-tree-sha1 = "b7e3d17636b348f005f11040025ae8c6f645fe92"
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
version = "0.1.3"
version = "0.1.6"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "31939159aeb8ffad1d4d8ee44d07f8558273120a"
git-tree-sha1 = "693210145367e7685d8604aee33d9bfb85db8b31"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.11.7"
version = "0.11.9"

[[FixedPointNumbers]]
deps = ["Statistics"]
Expand Down Expand Up @@ -183,9 +183,9 @@ version = "0.11.5"

[[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
git-tree-sha1 = "c67e7515a11f726f44083e74f218d134396d6510"
git-tree-sha1 = "95215cd0076a150ef46ff7928892bc341864c73c"
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
version = "0.4.2"
version = "0.4.3"

[[IfElse]]
git-tree-sha1 = "28e837ff3e7a6c3cdb252ce49fb412c8eb3caeef"
Expand All @@ -210,9 +210,9 @@ version = "0.8.4"

[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "b499c68a45249b0385585c62f4a9b62b5db8e691"
git-tree-sha1 = "f57ac3fd2045b50d3db081663837ac5b4096947e"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "3.7.1"
version = "3.9.0"

[[LazyArtifacts]]
deps = ["Artifacts", "Pkg"]
Expand Down Expand Up @@ -243,9 +243,9 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[[LogExpFunctions]]
deps = ["DocStringExtensions", "LinearAlgebra"]
git-tree-sha1 = "1ba664552f1ef15325e68dc4c05c3ef8c2d5d885"
git-tree-sha1 = "7bd5f6565d80b6bf753738d2bc40a5dfea072070"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.2.4"
version = "0.2.5"

[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
Expand Down Expand Up @@ -290,9 +290,9 @@ uuid = "14a3606d-f60d-562e-9121-12d972cd8159"

[[NNlib]]
deps = ["Adapt", "ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
git-tree-sha1 = "0bf1fbb9dc557f2af9fb7e1337366d69de0dc78c"
git-tree-sha1 = "7e6f31cfa39b1ff1c541cc8580b14b0ff4ba22d0"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.7.21"
version = "0.7.23"

[[NNlibCUDA]]
deps = ["CUDA", "LinearAlgebra", "NNlib", "Random", "Statistics"]
Expand Down Expand Up @@ -347,9 +347,9 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[Random123]]
deps = ["Libdl", "Random", "RandomNumbers"]
git-tree-sha1 = "7c6710c8198fd4444b5eb6a3840b7d47bd3593c5"
git-tree-sha1 = "0e8b146557ad1c6deb1367655e052276690e71a3"
uuid = "74087812-796a-5b5d-8853-05524746bad3"
version = "1.3.1"
version = "1.4.2"

[[RandomNumbers]]
deps = ["Random", "Requires"]
Expand Down Expand Up @@ -389,9 +389,9 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

[[SortingAlgorithms]]
deps = ["DataStructures"]
git-tree-sha1 = "2ec1962eba973f383239da22e75218565c390a96"
git-tree-sha1 = "b3363d7460f7d098ca0912c69b082f75625d7508"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
version = "1.0.0"
version = "1.0.1"

[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
Expand All @@ -411,9 +411,9 @@ version = "0.2.5"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "42378d3bab8b4f57aa1ca443821b752850592668"
git-tree-sha1 = "a43a7b58a6e7dc933b2fa2e0ca653ccf8bb8fd0e"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.2.2"
version = "1.2.6"

[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
Expand Down Expand Up @@ -444,9 +444,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[[TimerOutputs]]
deps = ["ExprTools", "Printf"]
git-tree-sha1 = "bf8aacc899a1bd16522d0350e1e2310510d77236"
git-tree-sha1 = "209a8326c4f955e2442c07b56029e88bb48299c7"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
version = "0.5.9"
version = "0.5.12"

[[TranscodingStreams]]
deps = ["Random", "Test"]
Expand Down
2 changes: 1 addition & 1 deletion docs/src/gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ If you define a structured model, like a `Dense` layer or `Chain`, you just need
```julia
d = Dense(10, 5, σ)
d = fmap(cu, d)
d.W # CuArray
d.weight # CuArray
d(cu(rand(10))) # CuArray output

m = Chain(Dense(10, 5, σ), Dense(5, 2), softmax)
Expand Down
2 changes: 1 addition & 1 deletion docs/src/models/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ by simply deleting it from `ps`:

```julia
ps = params(m)
delete!(ps, m[2].b)
delete!(ps, m[2].bias)
```

## Custom multiple input or output layer
Expand Down
7 changes: 7 additions & 0 deletions docs/src/models/nnlib.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,10 @@ NNlib.batched_mul!
NNlib.batched_adjoint
NNlib.batched_transpose
```

## Gather and Scatter

```@docs
NNlib.gather
NNlib.scatter
```
38 changes: 19 additions & 19 deletions docs/src/models/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Here's how you'd use Flux to build and train the most basic of models, step by s

This example will predict the output of the function `4x + 2`. First, import `Flux` and define the function we want to simulate:

```
```julia
julia> using Flux

julia> actual(x) = 4x + 2
Expand All @@ -28,7 +28,7 @@ This example will build a model to approximate the `actual` function.

Use the `actual` function to build sets of data for training and verification:

```
```julia
julia> x_train, x_test = hcat(0:5...), hcat(6:10...)
([0 1 4 5], [6 7 9 10])

Expand All @@ -42,38 +42,38 @@ Normally, your training and test data come from real world observations, but thi

Now, build a model to make predictions with `1` input and `1` output:

```
```julia
julia> model = Dense(1, 1)
Dense(1, 1)

julia> model.W
1-element Array{Float64,1}:
-0.99009055
julia> model.weight
1×1 Matrix{Float32}:
-1.4925033

julia> model.b
1-element Array{Float64,1}:
julia> model.bias
1-element Vector{Float32}:
0.0
```

Under the hood, a dense layer is a struct with fields `W` and `b`. `W` represents a weight and `b` represents a bias. There's another way to think about a model. In Flux, *models are conceptually predictive functions*:
Under the hood, a dense layer is a struct with fields `weight` and `bias`. `weight` represents a weights' matrix and `bias` represents a bias vector. There's another way to think about a model. In Flux, *models are conceptually predictive functions*:

```
```julia
julia> predict = Dense(1, 1)
```

`Dense(1, 1)` also implements the function `σ(Wx+b)` where `W` and `b` are the weights and biases. `σ` is an activation function (more on activations later). Our model has one weight and one bias, but typical models will have many more. Think of weights and biases as knobs and levers Flux can use to tune predictions. Activation functions are transformations that tailor models to your needs.

This model will already make predictions, though not accurate ones yet:

```
```julia
julia> predict(x_train)
1×6 Array{Float32,2}:
-1.98018 -5.94054 -9.90091 -13.8613 -17.8216 -21.782
1×6 Matrix{Float32}:
0.0 -1.4925 -2.98501 -4.47751 -5.97001 -7.46252
```

In order to make better predictions, you'll need to provide a *loss function* to tell Flux how to objectively *evaluate* the quality of a prediction. Loss functions compute the cumulative distance between actual values and predictions.

```
```julia
julia> loss(x, y) = Flux.Losses.mse(predict(x), y)
loss (generic function with 1 method)

Expand All @@ -87,7 +87,7 @@ More accurate predictions will yield a lower loss. You can write your own loss f

Under the hood, the Flux [`train!`](@ref) function uses *a loss function* and *training data* to improve the *parameters* of your model based on a pluggable [`optimiser`](../training/optimisers.md):

```
```julia
julia> using Flux: train!

julia> opt = Descent()
Expand All @@ -100,12 +100,12 @@ julia> data = [(x_train, y_train)]

Now, we have the optimiser and data we'll pass to `train!`. All that remains are the parameters of the model. Remember, each model is a Julia struct with a function and configurable parameters. Remember, the dense layer has weights and biases that depend on the dimensions of the inputs and outputs:

```
julia> predict.W
```julia
julia> predict.weight
1-element Array{Float64,1}:
-0.99009055

julia> predict.b
julia> predict.bias
1-element Array{Float64,1}:
0.0
```
Expand All @@ -120,7 +120,7 @@ Params([[-0.99009055], [0.0]])
These are the parameters Flux will change, one step at a time, to improve predictions. Each of the parameters comes from the `predict` model:

```
julia> predict.W in parameters, predict.b in parameters
julia> predict.weight in parameters, predict.bias in parameters
(true, true)
```
Expand Down
4 changes: 2 additions & 2 deletions docs/src/models/regularisation.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ m = Dense(10, 5)
loss(x, y) = logitcrossentropy(m(x), y)
```

We can apply L2 regularisation by taking the squared norm of the parameters , `m.W` and `m.b`.
We can apply L2 regularisation by taking the squared norm of the parameters , `m.weight` and `m.bias`.

```julia
penalty() = sum(abs2, m.W) + sum(abs2, m.b)
penalty() = sum(abs2, m.weight) + sum(abs2, m.bias)
loss(x, y) = logitcrossentropy(m(x), y) + penalty()
```

Expand Down
3 changes: 2 additions & 1 deletion src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,8 @@ function Embedding(in::Integer, out::Integer;
end

(m::Embedding)(x::Union{OneHotVector, OneHotMatrix}) = m.weight * x # equivalent to m.weight[:,onecold(x)]
(m::Embedding)(x::Union{Int,AbstractVector}) = m.weight[:, x]
(m::Embedding)(x::Integer) = m([x])
(m::Embedding)(x::AbstractVector) = NNlib.gather(m.weight, x)
(m::Embedding)(x::AbstractArray) = reshape(m(vec(x)), :, size(x)...)

function Base.show(io::IO, m::Embedding)
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ This function is mainly used by weight initializers, e.g., [`kaiming_normal`](@r
julia> layer = Dense(10, 20)
Dense(10, 20)
julia> Flux.nfan(size(layer.W))
julia> Flux.nfan(size(layer.weight))
(10, 20)
julia> layer = Conv((3, 3), 2=>10)
Expand Down
2 changes: 1 addition & 1 deletion test/cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ end

@test sum(l(ip)) 0.f0
gs = gradient(() -> sum(l(ip)), Flux.params(l))
@test l.b gs.params
@test l.bias gs.params
end

@testset "Extended BatchNorm" begin
Expand Down
20 changes: 10 additions & 10 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,19 +226,19 @@ end
m = Chain(Dense(10, 5, relu), Dense(5, 2))
x64 = rand(Float64, 10)
x32 = rand(Float32, 10)
@test eltype(m[1].W) == Float32
@test eltype(m[1].weight) == Float32
@test eltype(m(x32)) == Float32
@test eltype(m(x64)) == Float64
@test eltype(f64(m)(x32)) == Float64
@test eltype(f64(m)(x64)) == Float64
@test eltype(f64(m)[1].W) == Float64
@test eltype(f32(f64(m))[1].W) == Float32
@test eltype(f64(m)[1].weight) == Float64
@test eltype(f32(f64(m))[1].weight) == Float32
end

@testset "Zeros" begin
m = Dense(3,2; bias=false)
@test f64(m).b === m.b === Zeros()
@test f32(m).b === m.b === Zeros()
@test f64(m).bias === m.bias === Zeros()
@test f32(m).bias === m.bias === Zeros()

@testset "Gradients for broadcasted $op with sizes $s" for op in (+,-,*), s in ((1,), (2,3))
o = ones(s)
Expand Down Expand Up @@ -340,19 +340,19 @@ end

nobias(n) = Zeros()
testdense(m, bt) = @testset "Check layer $i" for (i, (l1, l2)) in enumerate(zip(m, dm(bt)))
@test l1.W == l2.W
@test l1.b == l2.b
@test_skip typeof(l1.b) === typeof(l2.b)
@test l1.weight == l2.weight
@test l1.bias == l2.bias
@test_skip typeof(l1.bias) === typeof(l2.bias)
end

@testset "loadparams!" begin
import Flux: loadparams!
pars(w, b) = [w, b]
import Flux: loadparams!, Zeros
pars(w, b::Zeros) = [w, Flux.zeros(size(w,1))]
pars(l) = pars(l.W, l.b)
pars(l) = pars(l.weight, l.bias)
pararray(m) = mapreduce(pars, vcat, m)
weights(m) = mapreduce(l -> [l.W], vcat, m)
weights(m) = mapreduce(l -> [l.weight], vcat, m)
@testset "Bias type $bt" for bt in (Flux.zeros, nobias)
m = dm(bt)
loadparams!(m, params(m))
Expand Down

0 comments on commit 9553267

Please sign in to comment.