diff --git a/src/MLUtils.jl b/src/MLUtils.jl index ca3321b..f1e0f66 100644 --- a/src/MLUtils.jl +++ b/src/MLUtils.jl @@ -4,6 +4,8 @@ using Random using Statistics using ShowCases: ShowLimit import StatsBase: sample +using Base: @propagate_inbounds +using Random: AbstractRNG, shuffle!, GLOBAL_RNG include("observation.jl") export numobs, @@ -24,6 +26,9 @@ include("dataiterator.jl") export eachobs, eachbatch +include("dataloader.jl") +export DataLoader + include("folds.jl") export kfolds, leavepout diff --git a/src/dataloader.jl b/src/dataloader.jl new file mode 100644 index 0000000..0bfe8e3 --- /dev/null +++ b/src/dataloader.jl @@ -0,0 +1,113 @@ +# Adapted from Knet's src/data.jl (author: Deniz Yuret) + +struct DataLoader{D,R<:AbstractRNG} + data::D + batchsize::Int + nobs::Int + partial::Bool + imax::Int + indices::Vector{Int} + shuffle::Bool + rng::R +end + +""" + DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG) + +An object that iterates over mini-batches of `data`, +each mini-batch containing `batchsize` observations +(except possibly the last one). + +Takes as input a single data tensor, or a tuple (or a named tuple) of tensors. +The last dimension in each tensor is the observation dimension, i.e. the one +divided into mini-batches. + +If `shuffle=true`, it shuffles the observations each time iterations are re-started. +If `partial=false` and the number of observations is not divisible by the batchsize, +then the last mini-batch is dropped. + +The original data is preserved in the `data` field of the DataLoader. + +# Examples + +```jldoctest +julia> Xtrain = rand(10, 100); + +julia> array_loader = DataLoader(Xtrain, batchsize=2); + +julia> for x in array_loader + @assert size(x) == (10, 2) + # do something with x, 50 times + end + +julia> array_loader.data === Xtrain +true + +julia> tuple_loader = DataLoader((Xtrain,), batchsize=2); # similar, but yielding 1-element tuples + +julia> for x in tuple_loader + @assert x isa Tuple{Matrix} + @assert size(x[1]) == (10, 2) + end + +julia> Ytrain = rand('a':'z', 100); # now make a DataLoader yielding 2-element named tuples + +julia> train_loader = DataLoader((data=Xtrain, label=Ytrain), batchsize=5, shuffle=true); + +julia> for epoch in 1:100 + for (x, y) in train_loader # access via tuple destructuring + @assert size(x) == (10, 5) + @assert size(y) == (5,) + # loss += f(x, y) # etc, runs 100 * 20 times + end + end + +julia> first(train_loader).label isa Vector{Char} # access via property name +true + +julia> first(train_loader).label == Ytrain[1:5] # because of shuffle=true +false + +julia> foreach(println∘summary, DataLoader(rand(Int8, 10, 64), batchsize=30)) # partial=false would omit last +10×30 Matrix{Int8} +10×30 Matrix{Int8} +10×4 Matrix{Int8} +``` +""" +function DataLoader(data; batchsize=1, shuffle=false, partial=true, rng=GLOBAL_RNG) + batchsize > 0 || throw(ArgumentError("Need positive batchsize")) + + n = numobs(data) + if n < batchsize + @warn "Number of observations less than batchsize, decreasing the batchsize to $n" + batchsize = n + end + imax = partial ? n : n - batchsize + 1 + DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle, rng) +end + +# returns data in d.indices[i+1:i+batchsize] +@propagate_inbounds function Base.iterate(d::DataLoader, i=0) + i >= d.imax && return nothing + if d.shuffle && i == 0 + shuffle!(d.rng, d.indices) + end + nexti = min(i + d.batchsize, d.nobs) + ids = d.indices[i+1:nexti] + batch = getobs(d.data, ids) + return (batch, nexti) +end + +function Base.length(d::DataLoader) + n = d.nobs / d.batchsize + d.partial ? ceil(Int,n) : floor(Int,n) +end + + +Base.eltype(::Type{<:DataLoader{D}}) where D = batchtype(D) + +batchtype(D::Type) = Any +batchtype(D::Type{<:AbstractArray}) = D +batchtype(D::Type{<:Tuple})= Tuple{batchtype.(D.parameters)...} +batchtype(D::Type{<:NamedTuple{K,V}}) where {K,V} = NamedTuple{K, batchtype(V)} +batchtype(D::Type{<:Dict{K,V}}) where {K,V} = Dict{K, batchtype(V)} diff --git a/test/dataloader.jl b/test/dataloader.jl new file mode 100644 index 0000000..1c7519e --- /dev/null +++ b/test/dataloader.jl @@ -0,0 +1,99 @@ + +@testset "DataLoader" begin + X = reshape([1:10;], (2, 5)) + Y = [1:5;] + + d = DataLoader(X, batchsize=2) + @inferred first(d) + batches = collect(d) + @test eltype(batches) == eltype(d) == typeof(X) + @test length(batches) == 3 + @test batches[1] == X[:,1:2] + @test batches[2] == X[:,3:4] + @test batches[3] == X[:,5:5] + + d = DataLoader(X, batchsize=2, partial=false) + @inferred first(d) + batches = collect(d) + @test eltype(batches) == eltype(d) == typeof(X) + @test length(batches) == 2 + @test batches[1] == X[:,1:2] + @test batches[2] == X[:,3:4] + + d = DataLoader((X,), batchsize=2, partial=false) + @inferred first(d) + batches = collect(d) + @test eltype(batches) == eltype(d) == Tuple{typeof(X)} + @test length(batches) == 2 + @test batches[1] == (X[:,1:2],) + @test batches[2] == (X[:,3:4],) + + d = DataLoader((X, Y), batchsize=2) + @inferred first(d) + batches = collect(d) + @test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)} + @test length(batches) == 3 + @test length(batches[1]) == 2 + @test length(batches[2]) == 2 + @test length(batches[3]) == 2 + @test batches[1][1] == X[:,1:2] + @test batches[1][2] == Y[1:2] + @test batches[2][1] == X[:,3:4] + @test batches[2][2] == Y[3:4] + @test batches[3][1] == X[:,5:5] + @test batches[3][2] == Y[5:5] + + # test with NamedTuple + d = DataLoader((x=X, y=Y), batchsize=2) + @inferred first(d) + batches = collect(d) + @test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}} + @test length(batches) == 3 + @test length(batches[1]) == 2 + @test length(batches[2]) == 2 + @test length(batches[3]) == 2 + @test batches[1][1] == batches[1].x == X[:,1:2] + @test batches[1][2] == batches[1].y == Y[1:2] + @test batches[2][1] == batches[2].x == X[:,3:4] + @test batches[2][2] == batches[2].y == Y[3:4] + @test batches[3][1] == batches[3].x == X[:,5:5] + @test batches[3][2] == batches[3].y == Y[5:5] + + # test interaction with `train!` + θ = ones(2) + X = zeros(2, 10) + d = DataLoader(X) + for x in d + @test size(x) == (2,1) + end + + # test interaction with `train!` + θ = zeros(2) + X = ones(2, 10) + Y = fill(5, 10) + d = DataLoader((X, Y)) + for (x, y) in d + @test size(x) == (2,1) + @test y == [5] + end + # specify the rng + d = map(identity, DataLoader(X, batchsize=2; shuffle=true, rng=Random.seed!(Random.default_rng(), 5))) + + # numobs/getobs compatibility + d = DataLoader(CustomType(), batchsize=2) + @test first(d) == [1, 2] + @test length(collect(d)) == 8 + + @testset "Dict" begin + data = Dict("x" => rand(2,4), "y" => rand(4)) + dloader = DataLoader(data, batchsize=2) + @test eltype(dloader) == Dict{String, Array{Float64}} + c = collect(dloader) + @test c[1] == Dict("x" => data["x"][:,1:2], "y" => data["y"][1:2]) + @test c[2] == Dict("x" => data["x"][:,3:4], "y" => data["y"][3:4]) + + data = Dict("x" => rand(2,4), "y" => rand(2,4)) + dloader = DataLoader(data, batchsize=2) + @test eltype(dloader) == Dict{String, Matrix{Float64}} + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 0904a9a..24257bd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -39,6 +39,7 @@ MLUtils.getobs(::CustomType, i::AbstractVector) = collect(i) @testset "batchview" begin; include("batchview.jl"); end @testset "dataiterator" begin; include("dataiterator.jl"); end +@testset "dataloader" begin; include("dataloader.jl"); end @testset "folds" begin; include("folds.jl"); end @testset "observation" begin; include("observation.jl"); end @testset "obsview" begin; include("obsview.jl"); end