Skip to content

Commit

Permalink
Special case single path HDF5 and JLD2 datasets and add @inferred t…
Browse files Browse the repository at this point in the history
…ests
  • Loading branch information
darsnack committed Feb 22, 2022
1 parent a30bfab commit b7ba9c4
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 45 deletions.
34 changes: 19 additions & 15 deletions src/containers/hdf5dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ end

"""
HDF5Dataset(file::Union{AbstractString, AbstractPath}, paths)
HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset})
HDF5Dataset(fid::HDF5.File, paths::Vector{<:AbstractString})
HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}, shapes)
HDF5Dataset(fid::HDF5.File, paths::Union{HDF5.Dataset, Vector{HDF5.Dataset}})
HDF5Dataset(fid::HDF5.File, paths::Union{AbstractString, Vector{<:AbstractString}})
HDF5Dataset(fid::HDF5.File, paths::Union{HDF5.Dataset, Vector{HDF5.Dataset}}, shapes)
Wrap several HDF5 datasets (`paths`) as a single dataset container.
Each dataset `p` in `paths` should be accessible as `fid[p]`.
Expand All @@ -19,34 +19,38 @@ See [`close(::HDF5Dataset)`](@ref) for closing the underlying HDF5 file pointer.
For array datasets, the last dimension is assumed to be the observation dimension.
For scalar datasets, the stored value is returned by `getobs` for any index.
"""
struct HDF5Dataset <: AbstractDataContainer
struct HDF5Dataset{T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}}} <: AbstractDataContainer
fid::HDF5.File
paths::Vector{HDF5.Dataset}
paths::T
shapes::Vector{Tuple}

function HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}, shapes::Vector)
function HDF5Dataset(fid::HDF5.File, paths::T, shapes::Vector) where T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}}
_check_hdf5_shapes(shapes) ||
throw(ArgumentError("Cannot create HDF5Dataset for datasets with mismatched number of observations."))

new(fid, paths, shapes)
new{T}(fid, paths, shapes)
end
end

HDF5Dataset(fid::HDF5.File, path::HDF5.Dataset) = HDF5Dataset(fid, path, [size(path)])
HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}) =
HDF5Dataset(fid, paths, map(size, paths))
HDF5Dataset(fid::HDF5.File, path::AbstractString) = HDF5Dataset(fid, fid[path])
HDF5Dataset(fid::HDF5.File, paths::Vector{<:AbstractString}) =
HDF5Dataset(fid, map(p -> fid[p], paths))
HDF5Dataset(file::Union{AbstractString, AbstractPath}, paths) =
HDF5Dataset(h5open(file, "r"), paths)

MLUtils.getobs(dataset::HDF5Dataset, i) = Tuple(map(dataset.paths, dataset.shapes) do path, shape
if isempty(shape)
return read(path)
else
I = map(s -> 1:s, shape[1:(end - 1)])
return path[I..., i]
end
end)
_getobs_hdf5(dataset::HDF5.Dataset, ::Tuple{}, i) = read(dataset)
function _getobs_hdf5(dataset::HDF5.Dataset, shape, i)
I = map(s -> 1:s, shape[1:(end - 1)])

return dataset[I..., i]
end
MLUtils.getobs(dataset::HDF5Dataset{HDF5.Dataset}, i) =
_getobs_hdf5(dataset.paths, only(dataset.shapes), i)
MLUtils.getobs(dataset::HDF5Dataset{<:Vector}, i) =
Tuple(map((p, s) -> _getobs_hdf5(p, s, i), dataset.paths, dataset.shapes))
MLUtils.numobs(dataset::HDF5Dataset) = last(first(filter(!isempty, dataset.shapes)))

"""
Expand Down
21 changes: 12 additions & 9 deletions src/containers/jld2dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,35 @@ _check_jld2_nobs(nobs) = all(==(first(nobs)), nobs[2:end])

"""
JLD2Dataset(file::Union{AbstractString, AbstractPath}, paths)
JLD2Dataset(fid::JLD2.JLDFile, paths)
JLD2Dataset(fid::JLD2.JLDFile, paths::Union{String, Vector{String}})
Wrap several JLD2 datasets (`paths`) as a single dataset container.
Each dataset `p` in `paths` should be accessible as `fid[p]`.
Calling `getobs` on a `JLD2Dataset` is equivalent to mapping `getobs` on
each dataset in `paths`.
See [`close(::JLD2Dataset)`](@ref) for closing the underlying JLD2 file pointer.
"""
struct JLD2Dataset{T<:JLD2.JLDFile} <: AbstractDataContainer
struct JLD2Dataset{T<:JLD2.JLDFile, S<:Tuple} <: AbstractDataContainer
fid::T
paths::Vector{String}
paths::S

function JLD2Dataset(fid::JLD2.JLDFile, paths::Vector{String})
nobs = map(p -> numobs(fid[p]), paths)
function JLD2Dataset(fid::JLD2.JLDFile, paths)
_paths = Tuple(map(p -> fid[p], paths))
nobs = map(numobs, _paths)
_check_jld2_nobs(nobs) ||
throw(ArgumentError("Cannot create JLD2Dataset for datasets with mismatched number of observations."))
throw(ArgumentError("Cannot create JLD2Dataset for datasets with mismatched number of observations (got $nobs)."))

new{typeof(fid)}(fid, paths)
new{typeof(fid), typeof(_paths)}(fid, _paths)
end
end

JLD2Dataset(file::JLD2.JLDFile, path::String) = JLD2Dataset(file, (path,))
JLD2Dataset(file::Union{AbstractString, AbstractPath}, paths) =
JLD2Dataset(jldopen(file, "r"), paths)

MLUtils.getobs(dataset::JLD2Dataset, i) = Tuple(map(p -> getobs(dataset.fid[p], i), dataset.paths))
MLUtils.numobs(dataset::JLD2Dataset) = numobs(dataset.fid[dataset.paths[1]])
MLUtils.getobs(dataset::JLD2Dataset{<:JLD2.JLDFile, <:NTuple{1}}, i) = getobs(only(dataset.paths), i)
MLUtils.getobs(dataset::JLD2Dataset, i) = map(Base.Fix2(getobs, i), dataset.paths)
MLUtils.numobs(dataset::JLD2Dataset) = numobs(dataset.paths[1])

"""
close(dataset::JLD2Dataset)
Expand Down
14 changes: 9 additions & 5 deletions src/containers/tabledataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,18 @@ TableDataset(path::Union{AbstractPath, AbstractString}) =
TableDataset(DataFrame(CSV.File(path)))

# slow accesses based on Tables.jl
_getobs_row(x, i) = first(Iterators.peel(Iterators.drop(x, i - 1)))
function _getobs_column(x, i)
colnames = Tuple(Tables.columnnames(x))
rowvals = ntuple(j -> Tables.getcolumn(x, j)[i], length(colnames))

return NamedTuple{colnames}(rowvals)
end
function MLUtils.getobs(dataset::TableDataset, i)
if Tables.rowaccess(dataset.table)
row, _ = Iterators.peel(Iterators.drop(Tables.rows(dataset.table), i - 1))
return row
return _getobs_row(Tables.rows(dataset.table), i)
elseif Tables.columnaccess(dataset.table)
colnames = Tables.columnnames(dataset.table)
rowvals = [Tables.getcolumn(dataset.table, j)[i] for j in 1:length(colnames)]
return (; zip(colnames, rowvals)...)
return _getobs_column(dataset.table, i)
else
error("The Tables.jl implementation used should have either rowaccess or columnaccess.")
end
Expand Down
18 changes: 17 additions & 1 deletion test/containers/cacheddataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,29 @@

@testset "CachedDataset(::HDF5Dataset)" begin
paths, datas = setup_hdf5dataset_test()
hdataset = HDF5Dataset("test.h5", ["d1"])
hdataset = HDF5Dataset("test.h5", "d1")
cdataset = CachedDataset(hdataset, 5)

@test numobs(cdataset) == numobs(hdataset)
@test cdataset.cache isa Matrix{Float64}
@test cdataset.cache == getobs(hdataset, 1:5)
@test all(getobs(cdataset, i) == getobs(hdataset, i) for i in 1:10)

close(hdataset)
cleanup_hdf5dataset_test()
end

@testset "CachedDataset(::JLD2Dataset)" begin
paths, datas = setup_jld2dataset_test()
jdataset = JLD2Dataset("test.jld2", "d1")
cdataset = CachedDataset(jdataset, 5)

@test numobs(cdataset) == numobs(jdataset)
@test cdataset.cache isa Matrix{Float64}
@test cdataset.cache == getobs(jdataset, 1:5)
@test all(@inferred(getobs(cdataset, i)) == getobs(jdataset, i) for i in 1:10)

close(jdataset)
cleanup_jld2dataset_test()
end
end
22 changes: 16 additions & 6 deletions test/containers/hdf5dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,22 @@ cleanup_hdf5dataset_test() = rm("test.h5")

@testset "HDF5Dataset" begin
paths, datas = setup_hdf5dataset_test()
dataset = HDF5Dataset("test.h5", paths)
for i in 1:10
data = Tuple(map(x -> (x isa String) ? x : getobs(x, i), datas))
@test getobs(dataset, i) == data
@testset "Single path" begin
dataset = HDF5Dataset("test.h5", "d1")
for i in 1:10
@test getobs(dataset, i) == getobs(datas[1], i)
end
@test numobs(dataset) == 10
close(dataset)
end
@testset "Multiple paths" begin
dataset = HDF5Dataset("test.h5", paths)
for i in 1:10
data = Tuple(map(x -> (x isa String) ? x : getobs(x, i), datas))
@test @inferred(Tuple, getobs(dataset, i)) == data
end
@test numobs(dataset) == 10
close(dataset)
end
@test numobs(dataset) == 10
close(dataset)
cleanup_hdf5dataset_test()
end
20 changes: 15 additions & 5 deletions test/containers/jld2dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,21 @@ cleanup_jld2dataset_test() = rm("test.jld2")

@testset "JLD2Dataset" begin
paths, datas = setup_jld2dataset_test()
dataset = JLD2Dataset("test.jld2", paths)
for i in 1:10
@test getobs(dataset, i) == getobs(datas, i)
@testset "Single path" begin
dataset = JLD2Dataset("test.jld2", "d1")
for i in 1:10
@test @inferred(getobs(dataset, i)) == getobs(datas[1], i)
end
@test numobs(dataset) == 10
close(dataset)
end
@testset "Multiple paths" begin
dataset = JLD2Dataset("test.jld2", paths)
for i in 1:10
@test @inferred(getobs(dataset, i)) == getobs(datas, i)
end
@test numobs(dataset) == 10
close(dataset)
end
@test numobs(dataset) == 10
close(dataset)
cleanup_jld2dataset_test()
end
8 changes: 4 additions & 4 deletions test/containers/tabledataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
testtable = Tables.table([1 4.0 "7"; 2 5.0 "8"; 3 6.0 "9"])
td = TableDataset(testtable)

@test all(getobs(td, 1) .== [1, 4.0, "7"])
@test collect(@inferred(getobs(td, 1))) == [1, 4.0, "7"]
@test numobs(td) == 3
end

Expand All @@ -17,7 +17,7 @@
testtable = Tables.table([1 4.0 "7"; 2 5.0 "8"; 3 6.0 "9"])
td = TableDataset(testtable)

@test [data for data in getobs(td, 2)] == [2, 5.0, "8"]
@test collect(@inferred(NamedTuple, getobs(td, 2))) == [2, 5.0, "8"]
@test numobs(td) == 3

@test getobs(td, 1) isa NamedTuple
Expand All @@ -35,7 +35,7 @@
td = TableDataset(testtable)
@test td isa TableDataset{<:DataFrame}

@test [data for data in getobs(td, 1)] == [1, "a", 10, "A", 100.0, "train"]
@test collect(@inferred(getobs(td, 1))) == [1, "a", 10, "A", 100.0, "train"]
@test numobs(td) == 5
end

Expand All @@ -46,7 +46,7 @@
testtable = CSV.File("test.csv")
td = TableDataset(testtable)
@test td isa TableDataset{<:CSV.File}
@test [data for data in getobs(td, 1)] == [1, "a", 10, "A", 100.0, "train"]
@test collect(@inferred(getobs(td, 1))) == [1, "a", 10, "A", 100.0, "train"]
@test numobs(td) == 1
rm("test.csv")
end
Expand Down

0 comments on commit b7ba9c4

Please sign in to comment.