diff --git a/src/containers/hdf5dataset.jl b/src/containers/hdf5dataset.jl index e604dbe6..2bdc6ac9 100644 --- a/src/containers/hdf5dataset.jl +++ b/src/containers/hdf5dataset.jl @@ -6,9 +6,9 @@ end """ HDF5Dataset(file::Union{AbstractString, AbstractPath}, paths) - HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}) - HDF5Dataset(fid::HDF5.File, paths::Vector{<:AbstractString}) - HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}, shapes) + HDF5Dataset(fid::HDF5.File, paths::Union{HDF5.Dataset, Vector{HDF5.Dataset}}) + HDF5Dataset(fid::HDF5.File, paths::Union{AbstractString, Vector{<:AbstractString}}) + HDF5Dataset(fid::HDF5.File, paths::Union{HDF5.Dataset, Vector{HDF5.Dataset}}, shapes) Wrap several HDF5 datasets (`paths`) as a single dataset container. Each dataset `p` in `paths` should be accessible as `fid[p]`. @@ -19,34 +19,38 @@ See [`close(::HDF5Dataset)`](@ref) for closing the underlying HDF5 file pointer. For array datasets, the last dimension is assumed to be the observation dimension. For scalar datasets, the stored value is returned by `getobs` for any index. """ -struct HDF5Dataset <: AbstractDataContainer +struct HDF5Dataset{T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}}} <: AbstractDataContainer fid::HDF5.File - paths::Vector{HDF5.Dataset} + paths::T shapes::Vector{Tuple} - function HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}, shapes::Vector) + function HDF5Dataset(fid::HDF5.File, paths::T, shapes::Vector) where T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}} _check_hdf5_shapes(shapes) || throw(ArgumentError("Cannot create HDF5Dataset for datasets with mismatched number of observations.")) - new(fid, paths, shapes) + new{T}(fid, paths, shapes) end end +HDF5Dataset(fid::HDF5.File, path::HDF5.Dataset) = HDF5Dataset(fid, path, [size(path)]) HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}) = HDF5Dataset(fid, paths, map(size, paths)) +HDF5Dataset(fid::HDF5.File, path::AbstractString) = HDF5Dataset(fid, fid[path]) HDF5Dataset(fid::HDF5.File, paths::Vector{<:AbstractString}) = HDF5Dataset(fid, map(p -> fid[p], paths)) HDF5Dataset(file::Union{AbstractString, AbstractPath}, paths) = HDF5Dataset(h5open(file, "r"), paths) -MLUtils.getobs(dataset::HDF5Dataset, i) = Tuple(map(dataset.paths, dataset.shapes) do path, shape - if isempty(shape) - return read(path) - else - I = map(s -> 1:s, shape[1:(end - 1)]) - return path[I..., i] - end -end) +_getobs_hdf5(dataset::HDF5.Dataset, ::Tuple{}, i) = read(dataset) +function _getobs_hdf5(dataset::HDF5.Dataset, shape, i) + I = map(s -> 1:s, shape[1:(end - 1)]) + + return dataset[I..., i] +end +MLUtils.getobs(dataset::HDF5Dataset{HDF5.Dataset}, i) = + _getobs_hdf5(dataset.paths, only(dataset.shapes), i) +MLUtils.getobs(dataset::HDF5Dataset{<:Vector}, i) = + Tuple(map((p, s) -> _getobs_hdf5(p, s, i), dataset.paths, dataset.shapes)) MLUtils.numobs(dataset::HDF5Dataset) = last(first(filter(!isempty, dataset.shapes))) """ diff --git a/src/containers/jld2dataset.jl b/src/containers/jld2dataset.jl index edbd9384..9fb91150 100644 --- a/src/containers/jld2dataset.jl +++ b/src/containers/jld2dataset.jl @@ -2,7 +2,7 @@ _check_jld2_nobs(nobs) = all(==(first(nobs)), nobs[2:end]) """ JLD2Dataset(file::Union{AbstractString, AbstractPath}, paths) - JLD2Dataset(fid::JLD2.JLDFile, paths) + JLD2Dataset(fid::JLD2.JLDFile, paths::Union{String, Vector{String}}) Wrap several JLD2 datasets (`paths`) as a single dataset container. Each dataset `p` in `paths` should be accessible as `fid[p]`. @@ -10,24 +10,27 @@ Calling `getobs` on a `JLD2Dataset` is equivalent to mapping `getobs` on each dataset in `paths`. See [`close(::JLD2Dataset)`](@ref) for closing the underlying JLD2 file pointer. """ -struct JLD2Dataset{T<:JLD2.JLDFile} <: AbstractDataContainer +struct JLD2Dataset{T<:JLD2.JLDFile, S<:Tuple} <: AbstractDataContainer fid::T - paths::Vector{String} + paths::S - function JLD2Dataset(fid::JLD2.JLDFile, paths::Vector{String}) - nobs = map(p -> numobs(fid[p]), paths) + function JLD2Dataset(fid::JLD2.JLDFile, paths) + _paths = Tuple(map(p -> fid[p], paths)) + nobs = map(numobs, _paths) _check_jld2_nobs(nobs) || - throw(ArgumentError("Cannot create JLD2Dataset for datasets with mismatched number of observations.")) + throw(ArgumentError("Cannot create JLD2Dataset for datasets with mismatched number of observations (got $nobs).")) - new{typeof(fid)}(fid, paths) + new{typeof(fid), typeof(_paths)}(fid, _paths) end end +JLD2Dataset(file::JLD2.JLDFile, path::String) = JLD2Dataset(file, (path,)) JLD2Dataset(file::Union{AbstractString, AbstractPath}, paths) = JLD2Dataset(jldopen(file, "r"), paths) -MLUtils.getobs(dataset::JLD2Dataset, i) = Tuple(map(p -> getobs(dataset.fid[p], i), dataset.paths)) -MLUtils.numobs(dataset::JLD2Dataset) = numobs(dataset.fid[dataset.paths[1]]) +MLUtils.getobs(dataset::JLD2Dataset{<:JLD2.JLDFile, <:NTuple{1}}, i) = getobs(only(dataset.paths), i) +MLUtils.getobs(dataset::JLD2Dataset, i) = map(Base.Fix2(getobs, i), dataset.paths) +MLUtils.numobs(dataset::JLD2Dataset) = numobs(dataset.paths[1]) """ close(dataset::JLD2Dataset) diff --git a/src/containers/tabledataset.jl b/src/containers/tabledataset.jl index 56cbcfed..f2756ae7 100644 --- a/src/containers/tabledataset.jl +++ b/src/containers/tabledataset.jl @@ -23,14 +23,18 @@ TableDataset(path::Union{AbstractPath, AbstractString}) = TableDataset(DataFrame(CSV.File(path))) # slow accesses based on Tables.jl +_getobs_row(x, i) = first(Iterators.peel(Iterators.drop(x, i - 1))) +function _getobs_column(x, i) + colnames = Tuple(Tables.columnnames(x)) + rowvals = ntuple(j -> Tables.getcolumn(x, j)[i], length(colnames)) + + return NamedTuple{colnames}(rowvals) +end function MLUtils.getobs(dataset::TableDataset, i) if Tables.rowaccess(dataset.table) - row, _ = Iterators.peel(Iterators.drop(Tables.rows(dataset.table), i - 1)) - return row + return _getobs_row(Tables.rows(dataset.table), i) elseif Tables.columnaccess(dataset.table) - colnames = Tables.columnnames(dataset.table) - rowvals = [Tables.getcolumn(dataset.table, j)[i] for j in 1:length(colnames)] - return (; zip(colnames, rowvals)...) + return _getobs_column(dataset.table, i) else error("The Tables.jl implementation used should have either rowaccess or columnaccess.") end diff --git a/test/containers/cacheddataset.jl b/test/containers/cacheddataset.jl index ff70aab4..c9726133 100644 --- a/test/containers/cacheddataset.jl +++ b/test/containers/cacheddataset.jl @@ -13,13 +13,29 @@ @testset "CachedDataset(::HDF5Dataset)" begin paths, datas = setup_hdf5dataset_test() - hdataset = HDF5Dataset("test.h5", ["d1"]) + hdataset = HDF5Dataset("test.h5", "d1") cdataset = CachedDataset(hdataset, 5) @test numobs(cdataset) == numobs(hdataset) + @test cdataset.cache isa Matrix{Float64} @test cdataset.cache == getobs(hdataset, 1:5) @test all(getobs(cdataset, i) == getobs(hdataset, i) for i in 1:10) + close(hdataset) cleanup_hdf5dataset_test() end + + @testset "CachedDataset(::JLD2Dataset)" begin + paths, datas = setup_jld2dataset_test() + jdataset = JLD2Dataset("test.jld2", "d1") + cdataset = CachedDataset(jdataset, 5) + + @test numobs(cdataset) == numobs(jdataset) + @test cdataset.cache isa Matrix{Float64} + @test cdataset.cache == getobs(jdataset, 1:5) + @test all(@inferred(getobs(cdataset, i)) == getobs(jdataset, i) for i in 1:10) + + close(jdataset) + cleanup_jld2dataset_test() + end end diff --git a/test/containers/hdf5dataset.jl b/test/containers/hdf5dataset.jl index 878cc262..8d79e033 100644 --- a/test/containers/hdf5dataset.jl +++ b/test/containers/hdf5dataset.jl @@ -18,12 +18,22 @@ cleanup_hdf5dataset_test() = rm("test.h5") @testset "HDF5Dataset" begin paths, datas = setup_hdf5dataset_test() - dataset = HDF5Dataset("test.h5", paths) - for i in 1:10 - data = Tuple(map(x -> (x isa String) ? x : getobs(x, i), datas)) - @test getobs(dataset, i) == data + @testset "Single path" begin + dataset = HDF5Dataset("test.h5", "d1") + for i in 1:10 + @test getobs(dataset, i) == getobs(datas[1], i) + end + @test numobs(dataset) == 10 + close(dataset) + end + @testset "Multiple paths" begin + dataset = HDF5Dataset("test.h5", paths) + for i in 1:10 + data = Tuple(map(x -> (x isa String) ? x : getobs(x, i), datas)) + @test @inferred(Tuple, getobs(dataset, i)) == data + end + @test numobs(dataset) == 10 + close(dataset) end - @test numobs(dataset) == 10 - close(dataset) cleanup_hdf5dataset_test() end diff --git a/test/containers/jld2dataset.jl b/test/containers/jld2dataset.jl index d849d8fa..d1656b91 100644 --- a/test/containers/jld2dataset.jl +++ b/test/containers/jld2dataset.jl @@ -18,11 +18,21 @@ cleanup_jld2dataset_test() = rm("test.jld2") @testset "JLD2Dataset" begin paths, datas = setup_jld2dataset_test() - dataset = JLD2Dataset("test.jld2", paths) - for i in 1:10 - @test getobs(dataset, i) == getobs(datas, i) + @testset "Single path" begin + dataset = JLD2Dataset("test.jld2", "d1") + for i in 1:10 + @test @inferred(getobs(dataset, i)) == getobs(datas[1], i) + end + @test numobs(dataset) == 10 + close(dataset) + end + @testset "Multiple paths" begin + dataset = JLD2Dataset("test.jld2", paths) + for i in 1:10 + @test @inferred(getobs(dataset, i)) == getobs(datas, i) + end + @test numobs(dataset) == 10 + close(dataset) end - @test numobs(dataset) == 10 - close(dataset) cleanup_jld2dataset_test() end diff --git a/test/containers/tabledataset.jl b/test/containers/tabledataset.jl index 92d76847..9154484c 100644 --- a/test/containers/tabledataset.jl +++ b/test/containers/tabledataset.jl @@ -6,7 +6,7 @@ testtable = Tables.table([1 4.0 "7"; 2 5.0 "8"; 3 6.0 "9"]) td = TableDataset(testtable) - @test all(getobs(td, 1) .== [1, 4.0, "7"]) + @test collect(@inferred(getobs(td, 1))) == [1, 4.0, "7"] @test numobs(td) == 3 end @@ -17,7 +17,7 @@ testtable = Tables.table([1 4.0 "7"; 2 5.0 "8"; 3 6.0 "9"]) td = TableDataset(testtable) - @test [data for data in getobs(td, 2)] == [2, 5.0, "8"] + @test collect(@inferred(NamedTuple, getobs(td, 2))) == [2, 5.0, "8"] @test numobs(td) == 3 @test getobs(td, 1) isa NamedTuple @@ -35,7 +35,7 @@ td = TableDataset(testtable) @test td isa TableDataset{<:DataFrame} - @test [data for data in getobs(td, 1)] == [1, "a", 10, "A", 100.0, "train"] + @test collect(@inferred(getobs(td, 1))) == [1, "a", 10, "A", 100.0, "train"] @test numobs(td) == 5 end @@ -46,7 +46,7 @@ testtable = CSV.File("test.csv") td = TableDataset(testtable) @test td isa TableDataset{<:CSV.File} - @test [data for data in getobs(td, 1)] == [1, "a", 10, "A", 100.0, "train"] + @test collect(@inferred(getobs(td, 1))) == [1, "a", 10, "A", 100.0, "train"] @test numobs(td) == 1 rm("test.csv") end