Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

port DataLoader from Flux #22

Merged
merged 4 commits into from
Jan 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/MLUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ using Random
using Statistics
using ShowCases: ShowLimit
import StatsBase: sample
using Base: @propagate_inbounds
using Random: AbstractRNG, shuffle!, GLOBAL_RNG

include("observation.jl")
export numobs,
Expand All @@ -24,6 +26,9 @@ include("dataiterator.jl")
export eachobs,
eachbatch

include("dataloader.jl")
export DataLoader

include("folds.jl")
export kfolds,
leavepout
Expand Down
113 changes: 113 additions & 0 deletions src/dataloader.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Adapted from Knet's src/data.jl (author: Deniz Yuret)

struct DataLoader{D,R<:AbstractRNG}
data::D
batchsize::Int
nobs::Int
partial::Bool
imax::Int
indices::Vector{Int}
shuffle::Bool
rng::R
end

"""
DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)

An object that iterates over mini-batches of `data`,
each mini-batch containing `batchsize` observations
(except possibly the last one).

Takes as input a single data tensor, or a tuple (or a named tuple) of tensors.
The last dimension in each tensor is the observation dimension, i.e. the one
divided into mini-batches.

If `shuffle=true`, it shuffles the observations each time iterations are re-started.
If `partial=false` and the number of observations is not divisible by the batchsize,
then the last mini-batch is dropped.

The original data is preserved in the `data` field of the DataLoader.

# Examples

```jldoctest
julia> Xtrain = rand(10, 100);

julia> array_loader = DataLoader(Xtrain, batchsize=2);

julia> for x in array_loader
@assert size(x) == (10, 2)
# do something with x, 50 times
end

julia> array_loader.data === Xtrain
true

julia> tuple_loader = DataLoader((Xtrain,), batchsize=2); # similar, but yielding 1-element tuples

julia> for x in tuple_loader
@assert x isa Tuple{Matrix}
@assert size(x[1]) == (10, 2)
end

julia> Ytrain = rand('a':'z', 100); # now make a DataLoader yielding 2-element named tuples

julia> train_loader = DataLoader((data=Xtrain, label=Ytrain), batchsize=5, shuffle=true);

julia> for epoch in 1:100
for (x, y) in train_loader # access via tuple destructuring
@assert size(x) == (10, 5)
@assert size(y) == (5,)
# loss += f(x, y) # etc, runs 100 * 20 times
end
end

julia> first(train_loader).label isa Vector{Char} # access via property name
true

julia> first(train_loader).label == Ytrain[1:5] # because of shuffle=true
false

julia> foreach(println∘summary, DataLoader(rand(Int8, 10, 64), batchsize=30)) # partial=false would omit last
10×30 Matrix{Int8}
10×30 Matrix{Int8}
10×4 Matrix{Int8}
```
"""
function DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG)
batchsize > 0 || throw(ArgumentError("Need positive batchsize"))

n = numobs(data)
if n < batchsize
@warn "Number of observations less than batchsize, decreasing the batchsize to $n"
batchsize = n
end
imax = partial ? n : n - batchsize + 1
DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle, rng)
end

# returns data in d.indices[i+1:i+batchsize]
@propagate_inbounds function Base.iterate(d::DataLoader, i=0)
i >= d.imax && return nothing
if d.shuffle && i == 0
shuffle!(d.rng, d.indices)
end
nexti = min(i + d.batchsize, d.nobs)
ids = d.indices[i+1:nexti]
batch = getobs(d.data, ids)
return (batch, nexti)
end

function Base.length(d::DataLoader)
n = d.nobs / d.batchsize
d.partial ? ceil(Int,n) : floor(Int,n)
end


Base.eltype(::Type{<:DataLoader{D}}) where D = batchtype(D)

batchtype(D::Type) = Any
batchtype(D::Type{<:AbstractArray}) = D
batchtype(D::Type{<:Tuple})= Tuple{batchtype.(D.parameters)...}
batchtype(D::Type{<:NamedTuple{K,V}}) where {K,V} = NamedTuple{K, batchtype(V)}
batchtype(D::Type{<:Dict{K,V}}) where {K,V} = Dict{K, batchtype(V)}
99 changes: 99 additions & 0 deletions test/dataloader.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@

@testset "DataLoader" begin
X = reshape([1:10;], (2, 5))
Y = [1:5;]

d = DataLoader(X, batchsize=2)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == typeof(X)
@test length(batches) == 3
@test batches[1] == X[:,1:2]
@test batches[2] == X[:,3:4]
@test batches[3] == X[:,5:5]

d = DataLoader(X, batchsize=2, partial=false)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == typeof(X)
@test length(batches) == 2
@test batches[1] == X[:,1:2]
@test batches[2] == X[:,3:4]

d = DataLoader((X,), batchsize=2, partial=false)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == Tuple{typeof(X)}
@test length(batches) == 2
@test batches[1] == (X[:,1:2],)
@test batches[2] == (X[:,3:4],)

d = DataLoader((X, Y), batchsize=2)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)}
@test length(batches) == 3
@test length(batches[1]) == 2
@test length(batches[2]) == 2
@test length(batches[3]) == 2
@test batches[1][1] == X[:,1:2]
@test batches[1][2] == Y[1:2]
@test batches[2][1] == X[:,3:4]
@test batches[2][2] == Y[3:4]
@test batches[3][1] == X[:,5:5]
@test batches[3][2] == Y[5:5]

# test with NamedTuple
d = DataLoader((x=X, y=Y), batchsize=2)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}}
@test length(batches) == 3
@test length(batches[1]) == 2
@test length(batches[2]) == 2
@test length(batches[3]) == 2
@test batches[1][1] == batches[1].x == X[:,1:2]
@test batches[1][2] == batches[1].y == Y[1:2]
@test batches[2][1] == batches[2].x == X[:,3:4]
@test batches[2][2] == batches[2].y == Y[3:4]
@test batches[3][1] == batches[3].x == X[:,5:5]
@test batches[3][2] == batches[3].y == Y[5:5]

# test interaction with `train!`
θ = ones(2)
X = zeros(2, 10)
d = DataLoader(X)
for x in d
@test size(x) == (2,1)
end

# test interaction with `train!`
θ = zeros(2)
X = ones(2, 10)
Y = fill(5, 10)
d = DataLoader((X, Y))
for (x, y) in d
@test size(x) == (2,1)
@test y == [5]
end
# specify the rng
d = map(identity, DataLoader(X, batchsize=2; shuffle=true, rng=Random.seed!(Random.default_rng(), 5)))

# numobs/getobs compatibility
d = DataLoader(CustomType(), batchsize=2)
@test first(d) == [1, 2]
@test length(collect(d)) == 8

@testset "Dict" begin
data = Dict("x" => rand(2,4), "y" => rand(4))
dloader = DataLoader(data, batchsize=2)
@test eltype(dloader) == Dict{String, Array{Float64}}
c = collect(dloader)
@test c[1] == Dict("x" => data["x"][:,1:2], "y" => data["y"][1:2])
@test c[2] == Dict("x" => data["x"][:,3:4], "y" => data["y"][3:4])

data = Dict("x" => rand(2,4), "y" => rand(2,4))
dloader = DataLoader(data, batchsize=2)
@test eltype(dloader) == Dict{String, Matrix{Float64}}
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ MLUtils.getobs(::CustomType, i::AbstractVector) = collect(i)

@testset "batchview" begin; include("batchview.jl"); end
@testset "dataiterator" begin; include("dataiterator.jl"); end
@testset "dataloader" begin; include("dataloader.jl"); end
@testset "folds" begin; include("folds.jl"); end
@testset "observation" begin; include("observation.jl"); end
@testset "obsview" begin; include("obsview.jl"); end
Expand Down