-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Redesign package to be built on top of reusable dataset containers #96
Merged
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
834d9b8
Initial port of FastAI dataset containers
darsnack 83452ed
Drop Julia < 1.6 support and add bounds
darsnack 757b706
Add some docstrings and test for FileDataset
darsnack 4e6807f
Add HDF5 dataset
darsnack 9a45312
Close HDF5 file before deleting in tests
darsnack 3ecc35b
Update HDF5 docstrings
darsnack 7bba084
Fix broken HDF5 string tests
darsnack 92e0d06
Add JLD2Dataset
darsnack 3d48893
Support custom loading function in FileDataset
darsnack a30bfab
Add CachedDataset
darsnack b7ba9c4
Special case single path HDF5 and JLD2 datasets and add `@inferred` t…
darsnack 7db9468
Add more compat entries
darsnack 950e111
Switch to MLUtils v0.2
darsnack ebc6938
Get rid of cruft from FastAI.jl
darsnack 15fa280
Remove some outdated deps
darsnack 860fc14
Remove reference to AbstractPath
darsnack 5a882bd
Add FileIO into usings
darsnack df48e5e
Fix tests and add Tables.jl interface
darsnack 2ff0d29
Fix doc errors
darsnack File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Dataset Containers | ||
|
||
MLDatasets.jl contains several reusable data containers for accessing datasets in common storage formats. This feature is a work-in-progress and subject to change. | ||
|
||
```@docs | ||
FileDataset | ||
TableDataset | ||
HDF5Dataset | ||
Base.close(::HDF5Dataset) | ||
JLD2Dataset | ||
Base.close(::JLD2Dataset) | ||
CachedDataset | ||
MLDatasets.make_cache | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
""" | ||
make_cache(source, cacheidx) | ||
|
||
Return a in-memory copy of `source` at observation indices `cacheidx`. | ||
Defaults to `getobs(source, cacheidx)`. | ||
""" | ||
make_cache(source, cacheidx) = getobs(source, cacheidx) | ||
|
||
""" | ||
CachedDataset(source, cachesize = numbobs(source)) | ||
CachedDataset(source, cacheidx = 1:numbobs(source)) | ||
CachedDataset(source, cacheidx, cache) | ||
|
||
Wrap a `source` data container and cache `cachesize` samples in memory. | ||
This can be useful for improving read speeds when `source` is a lazy data container, | ||
but your system memory is large enough to store a sizeable chunk of it. | ||
|
||
By default the observation indices `1:cachesize` are cached. | ||
You can manually pass in a set of `cacheidx` as well. | ||
|
||
See also [`make_cache`](@ref) for customizing the default cache creation for `source`. | ||
""" | ||
struct CachedDataset{T, S} | ||
source::T | ||
cacheidx::Vector{Int} | ||
cache::S | ||
end | ||
|
||
CachedDataset(source, cacheidx::AbstractVector{<:Integer} = 1:numobs(source)) = | ||
CachedDataset(source, collect(cacheidx), make_cache(source, cacheidx)) | ||
CachedDataset(source, cachesize::Int = numobs(source)) = CachedDataset(source, 1:cachesize) | ||
|
||
function Base.getindex(dataset::CachedDataset, i::Integer) | ||
_i = findfirst(==(i), dataset.cacheidx) | ||
|
||
return isnothing(_i) ? getobs(dataset.source, i) : getobs(dataset.cache, _i) | ||
end | ||
Base.length(dataset::CachedDataset) = numobs(dataset.source) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
""" | ||
rglob(filepattern, dir = pwd(), depth = 4) | ||
|
||
Recursive glob up to `depth` layers deep within `dir`. | ||
""" | ||
function rglob(filepattern = "*", dir = pwd(), depth = 4) | ||
patterns = [repeat("*/", i) * filepattern for i in 0:(depth - 1)] | ||
darsnack marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return vcat([glob(pattern, dir) for pattern in patterns]...) | ||
end | ||
|
||
""" | ||
FileDataset([loadfn = FileIO.load,] paths) | ||
FileDataset([loadfn = FileIO.load,] dir, pattern = "*", depth = 4) | ||
|
||
Wrap a set of file `paths` as a dataset (traversed in the same order as `paths`). | ||
Alternatively, specify a `dir` and collect all paths that match a glob `pattern` | ||
(recursively globbing by `depth`). The glob order determines the traversal order. | ||
""" | ||
struct FileDataset{F, T<:AbstractString} <: AbstractDataContainer | ||
loadfn::F | ||
paths::Vector{T} | ||
end | ||
|
||
FileDataset(paths) = FileDataset(FileIO.load, paths) | ||
FileDataset(loadfn, | ||
dir::AbstractString, | ||
pattern::AbstractString = "*", | ||
depth = 4) = FileDataset(loadfn, rglob(pattern, string(dir), depth)) | ||
FileDataset(dir::AbstractString, pattern::AbstractString = "*", depth = 4) = | ||
FileDataset(FileIO.load, dir, pattern, depth) | ||
|
||
Base.getindex(dataset::FileDataset, i::Integer) = dataset.loadfn(dataset.paths[i]) | ||
Base.getindex(dataset::FileDataset, is::AbstractVector) = map(Base.Fix1(getobs, dataset), is) | ||
Base.length(dataset::FileDataset) = length(dataset.paths) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
function _check_hdf5_shapes(shapes) | ||
nobs = map(last, filter(!isempty, shapes)) | ||
|
||
return all(==(first(nobs)), nobs[2:end]) | ||
end | ||
|
||
""" | ||
HDF5Dataset(file::AbstractString, paths) | ||
HDF5Dataset(fid::HDF5.File, paths::Union{HDF5.Dataset, Vector{HDF5.Dataset}}) | ||
HDF5Dataset(fid::HDF5.File, paths::Union{AbstractString, Vector{<:AbstractString}}) | ||
HDF5Dataset(fid::HDF5.File, paths::Union{HDF5.Dataset, Vector{HDF5.Dataset}}, shapes) | ||
|
||
Wrap several HDF5 datasets (`paths`) as a single dataset container. | ||
Each dataset `p` in `paths` should be accessible as `fid[p]`. | ||
Calling `getobs` on a `HDF5Dataset` returns a tuple with each element corresponding | ||
to the observation from each dataset in `paths`. | ||
See [`close(::HDF5Dataset)`](@ref) for closing the underlying HDF5 file pointer. | ||
|
||
For array datasets, the last dimension is assumed to be the observation dimension. | ||
For scalar datasets, the stored value is returned by `getobs` for any index. | ||
""" | ||
struct HDF5Dataset{T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}}} <: AbstractDataContainer | ||
fid::HDF5.File | ||
paths::T | ||
shapes::Vector{Tuple} | ||
|
||
function HDF5Dataset(fid::HDF5.File, paths::T, shapes::Vector) where T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}} | ||
_check_hdf5_shapes(shapes) || | ||
throw(ArgumentError("Cannot create HDF5Dataset for datasets with mismatched number of observations.")) | ||
|
||
new{T}(fid, paths, shapes) | ||
end | ||
end | ||
|
||
HDF5Dataset(fid::HDF5.File, path::HDF5.Dataset) = HDF5Dataset(fid, path, [size(path)]) | ||
HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}) = | ||
HDF5Dataset(fid, paths, map(size, paths)) | ||
HDF5Dataset(fid::HDF5.File, path::AbstractString) = HDF5Dataset(fid, fid[path]) | ||
HDF5Dataset(fid::HDF5.File, paths::Vector{<:AbstractString}) = | ||
HDF5Dataset(fid, map(p -> fid[p], paths)) | ||
HDF5Dataset(file::AbstractString, paths) = HDF5Dataset(h5open(file, "r"), paths) | ||
|
||
_getobs_hdf5(dataset::HDF5.Dataset, ::Tuple{}, i) = read(dataset) | ||
function _getobs_hdf5(dataset::HDF5.Dataset, shape, i) | ||
I = map(s -> 1:s, shape[1:(end - 1)]) | ||
|
||
return dataset[I..., i] | ||
end | ||
Base.getindex(dataset::HDF5Dataset{HDF5.Dataset}, i) = | ||
_getobs_hdf5(dataset.paths, only(dataset.shapes), i) | ||
Base.getindex(dataset::HDF5Dataset{<:Vector}, i) = | ||
Tuple(map((p, s) -> _getobs_hdf5(p, s, i), dataset.paths, dataset.shapes)) | ||
Base.length(dataset::HDF5Dataset) = last(first(filter(!isempty, dataset.shapes))) | ||
|
||
""" | ||
close(dataset::HDF5Dataset) | ||
|
||
Close the underlying HDF5 file pointer for `dataset`. | ||
""" | ||
Base.close(dataset::HDF5Dataset) = close(dataset.fid) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
_check_jld2_nobs(nobs) = all(==(first(nobs)), nobs[2:end]) | ||
|
||
""" | ||
JLD2Dataset(file::AbstractString, paths) | ||
JLD2Dataset(fid::JLD2.JLDFile, paths::Union{String, Vector{String}}) | ||
|
||
Wrap several JLD2 datasets (`paths`) as a single dataset container. | ||
Each dataset `p` in `paths` should be accessible as `fid[p]`. | ||
Calling `getobs` on a `JLD2Dataset` is equivalent to mapping `getobs` on | ||
each dataset in `paths`. | ||
See [`close(::JLD2Dataset)`](@ref) for closing the underlying JLD2 file pointer. | ||
""" | ||
struct JLD2Dataset{T<:JLD2.JLDFile, S<:Tuple} <: AbstractDataContainer | ||
fid::T | ||
paths::S | ||
|
||
function JLD2Dataset(fid::JLD2.JLDFile, paths) | ||
_paths = Tuple(map(p -> fid[p], paths)) | ||
nobs = map(numobs, _paths) | ||
_check_jld2_nobs(nobs) || | ||
throw(ArgumentError("Cannot create JLD2Dataset for datasets with mismatched number of observations (got $nobs).")) | ||
|
||
new{typeof(fid), typeof(_paths)}(fid, _paths) | ||
end | ||
end | ||
|
||
JLD2Dataset(file::JLD2.JLDFile, path::String) = JLD2Dataset(file, (path,)) | ||
JLD2Dataset(file::AbstractString, paths) = JLD2Dataset(jldopen(file, "r"), paths) | ||
|
||
Base.getindex(dataset::JLD2Dataset{<:JLD2.JLDFile, <:NTuple{1}}, i) = getobs(only(dataset.paths), i) | ||
Base.getindex(dataset::JLD2Dataset, i) = map(Base.Fix2(getobs, i), dataset.paths) | ||
Base.length(dataset::JLD2Dataset) = numobs(dataset.paths[1]) | ||
|
||
""" | ||
close(dataset::JLD2Dataset) | ||
|
||
Close the underlying JLD2 file pointer for `dataset`. | ||
""" | ||
Base.close(dataset::JLD2Dataset) = close(dataset.fid) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
""" | ||
TableDataset(table) | ||
TableDataset(path::AbstractString) | ||
|
||
Wrap a Tables.jl-compatible `table` as a dataset container. | ||
Alternatively, specify the `path` to a CSV file directly | ||
to load it with CSV.jl + DataFrames.jl. | ||
""" | ||
struct TableDataset{T} <: AbstractDataContainer | ||
table::T | ||
|
||
# TableDatasets must implement the Tables.jl interface | ||
function TableDataset{T}(table::T) where {T} | ||
Tables.istable(table) || | ||
throw(ArgumentError("TableDatasets must implement the Tabels.jl interface")) | ||
|
||
new{T}(table) | ||
end | ||
end | ||
|
||
TableDataset(table::T) where {T} = TableDataset{T}(table) | ||
TableDataset(path::AbstractString) = TableDataset(DataFrame(CSV.File(path))) | ||
|
||
# slow accesses based on Tables.jl | ||
_getobs_row(x, i) = first(Iterators.peel(Iterators.drop(x, i - 1))) | ||
function _getobs_column(x, i) | ||
colnames = Tuple(Tables.columnnames(x)) | ||
rowvals = ntuple(j -> Tables.getcolumn(x, j)[i], length(colnames)) | ||
|
||
return NamedTuple{colnames}(rowvals) | ||
end | ||
function Base.getindex(dataset::TableDataset, i) | ||
if Tables.rowaccess(dataset.table) | ||
return _getobs_row(Tables.rows(dataset.table), i) | ||
elseif Tables.columnaccess(dataset.table) | ||
return _getobs_column(dataset.table, i) | ||
else | ||
error("The Tables.jl implementation used should have either rowaccess or columnaccess.") | ||
end | ||
end | ||
function Base.length(dataset::TableDataset) | ||
if Tables.columnaccess(dataset.table) | ||
return length(Tables.getcolumn(dataset.table, 1)) | ||
elseif Tables.rowaccess(dataset.table) | ||
# length might not be defined, but has to be for this to work. | ||
return length(Tables.rows(dataset.table)) | ||
else | ||
error("The Tables.jl implementation used should have either rowaccess or columnaccess.") | ||
end | ||
end | ||
|
||
# fast access for DataFrame | ||
Base.getindex(dataset::TableDataset{<:DataFrame}, i) = dataset.table[i, :] | ||
Base.length(dataset::TableDataset{<:DataFrame}) = nrow(dataset.table) | ||
|
||
# fast access for CSV.File | ||
Base.getindex(dataset::TableDataset{<:CSV.File}, i) = dataset.table[i] | ||
Base.length(dataset::TableDataset{<:CSV.File}) = length(dataset.table) | ||
|
||
## Tables.jl interface | ||
|
||
Tables.istable(::TableDataset) = true | ||
for fn in (:rowaccess, :rows, :columnaccess, :columns, :schema, :materializer) | ||
@eval Tables.$fn(dataset::TableDataset) = Tables.$fn(dataset.table) | ||
end |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this go in MLUtils.jl?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think so
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can make a PR though this is only useful for data that isn't already in memory. I had trouble thinking of cases where that's true but the data isn't a dataset.