Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Redesign package to be built on top of reusable dataset containers #96

Merged
merged 19 commits into from
Mar 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/UnitTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ jobs:
strategy:
fail-fast: false
matrix:
julia-version: ['1.3', '1', 'nightly']
julia-version: ['1.6', '1', 'nightly']
os: [ubuntu-latest, windows-latest, macOS-latest, macos-11]
env:
PYTHON: ""
steps:
- uses: actions/checkout@v1.0.0
- uses: actions/checkout@v2
- name: "Set up Julia"
uses: julia-actions/setup-julia@v1
with:
Expand Down
18 changes: 17 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,45 @@ version = "0.5.15"

[deps]
BinDeps = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
GZip = "92fee26a-97fe-5a0c-ad85-20a5f3185b63"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
BinDeps = "1"
CSV = "0.10.2"
ColorTypes = "0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 0.11"
DataDeps = "0.3, 0.4, 0.5, 0.6, 0.7"
DataFrames = "1.3"
FileIO = "1.13"
FixedPointNumbers = "0.3, 0.4, 0.5, 0.6, 0.7, 0.8"
GZip = "0.5"
Glob = "1.3"
HDF5 = "0.16.2"
ImageCore = "0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8"
JLD2 = "0.4.21"
JSON3 = "1"
MAT = "0.7, 0.8, 0.9, 0.10"
MLUtils = "0.2.0"
Pickle = "0.2, 0.3"
Requires = "1"
julia = "1.3"
Tables = "1.6"
julia = "1.6"

[extras]
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
Expand Down
5 changes: 3 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ makedocs(
"Mutagenesis" => "datasets/Mutagenesis.md",
"Titanic" => "datasets/Titanic.md",
],

"Text" => Any[
"PTBLM" => "datasets/PTBLM.md",
"UD_English" => "datasets/UD_English.md",
Expand All @@ -52,9 +51,11 @@ makedocs(

],
"Utils" => "utils.md",
"Data Containers" => "containers/overview.md",
"LICENSE.md",
],
strict = true
strict = true,
checkdocs = :exports
)


Expand Down
14 changes: 14 additions & 0 deletions docs/src/containers/overview.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Dataset Containers

MLDatasets.jl contains several reusable data containers for accessing datasets in common storage formats. This feature is a work-in-progress and subject to change.

```@docs
FileDataset
TableDataset
HDF5Dataset
Base.close(::HDF5Dataset)
JLD2Dataset
Base.close(::JLD2Dataset)
CachedDataset
MLDatasets.make_cache
```
19 changes: 19 additions & 0 deletions src/MLDatasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ using DelimitedFiles: readdlm
using FixedPointNumbers, ColorTypes
using Pickle
using SparseArrays
using FileIO
using DataFrames, CSV, Tables
using Glob
using HDF5
using JLD2

import MLUtils
using MLUtils: getobs, numobs, AbstractDataContainer
export getobs, numobs

# Julia 1.0 compatibility
if !isdefined(Base, :isnothing)
Expand Down Expand Up @@ -36,6 +45,16 @@ end

include("download.jl")

include("containers/filedataset.jl")
export FileDataset
include("containers/tabledataset.jl")
export TableDataset
include("containers/hdf5dataset.jl")
export HDF5Dataset
include("containers/jld2dataset.jl")
export JLD2Dataset
include("containers/cacheddataset.jl")
export CachedDataset

# Misc.
include("BostonHousing/BostonHousing.jl")
Expand Down
38 changes: 38 additions & 0 deletions src/containers/cacheddataset.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
make_cache(source, cacheidx)

Return a in-memory copy of `source` at observation indices `cacheidx`.
Defaults to `getobs(source, cacheidx)`.
"""
make_cache(source, cacheidx) = getobs(source, cacheidx)

"""
CachedDataset(source, cachesize = numbobs(source))
CachedDataset(source, cacheidx = 1:numbobs(source))
CachedDataset(source, cacheidx, cache)

Wrap a `source` data container and cache `cachesize` samples in memory.
This can be useful for improving read speeds when `source` is a lazy data container,
but your system memory is large enough to store a sizeable chunk of it.

By default the observation indices `1:cachesize` are cached.
You can manually pass in a set of `cacheidx` as well.

See also [`make_cache`](@ref) for customizing the default cache creation for `source`.
"""
struct CachedDataset{T, S}
source::T
cacheidx::Vector{Int}
cache::S
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this go in MLUtils.jl?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think so

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can make a PR though this is only useful for data that isn't already in memory. I had trouble thinking of cases where that's true but the data isn't a dataset.


CachedDataset(source, cacheidx::AbstractVector{<:Integer} = 1:numobs(source)) =
CachedDataset(source, collect(cacheidx), make_cache(source, cacheidx))
CachedDataset(source, cachesize::Int = numobs(source)) = CachedDataset(source, 1:cachesize)

function Base.getindex(dataset::CachedDataset, i::Integer)
_i = findfirst(==(i), dataset.cacheidx)

return isnothing(_i) ? getobs(dataset.source, i) : getobs(dataset.cache, _i)
end
Base.length(dataset::CachedDataset) = numobs(dataset.source)
35 changes: 35 additions & 0 deletions src/containers/filedataset.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
rglob(filepattern, dir = pwd(), depth = 4)

Recursive glob up to `depth` layers deep within `dir`.
"""
function rglob(filepattern = "*", dir = pwd(), depth = 4)
patterns = [repeat("*/", i) * filepattern for i in 0:(depth - 1)]
darsnack marked this conversation as resolved.
Show resolved Hide resolved

return vcat([glob(pattern, dir) for pattern in patterns]...)
end

"""
FileDataset([loadfn = FileIO.load,] paths)
FileDataset([loadfn = FileIO.load,] dir, pattern = "*", depth = 4)

Wrap a set of file `paths` as a dataset (traversed in the same order as `paths`).
Alternatively, specify a `dir` and collect all paths that match a glob `pattern`
(recursively globbing by `depth`). The glob order determines the traversal order.
"""
struct FileDataset{F, T<:AbstractString} <: AbstractDataContainer
loadfn::F
paths::Vector{T}
end

FileDataset(paths) = FileDataset(FileIO.load, paths)
FileDataset(loadfn,
dir::AbstractString,
pattern::AbstractString = "*",
depth = 4) = FileDataset(loadfn, rglob(pattern, string(dir), depth))
FileDataset(dir::AbstractString, pattern::AbstractString = "*", depth = 4) =
FileDataset(FileIO.load, dir, pattern, depth)

Base.getindex(dataset::FileDataset, i::Integer) = dataset.loadfn(dataset.paths[i])
Base.getindex(dataset::FileDataset, is::AbstractVector) = map(Base.Fix1(getobs, dataset), is)
Base.length(dataset::FileDataset) = length(dataset.paths)
60 changes: 60 additions & 0 deletions src/containers/hdf5dataset.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
function _check_hdf5_shapes(shapes)
nobs = map(last, filter(!isempty, shapes))

return all(==(first(nobs)), nobs[2:end])
end

"""
HDF5Dataset(file::AbstractString, paths)
HDF5Dataset(fid::HDF5.File, paths::Union{HDF5.Dataset, Vector{HDF5.Dataset}})
HDF5Dataset(fid::HDF5.File, paths::Union{AbstractString, Vector{<:AbstractString}})
HDF5Dataset(fid::HDF5.File, paths::Union{HDF5.Dataset, Vector{HDF5.Dataset}}, shapes)

Wrap several HDF5 datasets (`paths`) as a single dataset container.
Each dataset `p` in `paths` should be accessible as `fid[p]`.
Calling `getobs` on a `HDF5Dataset` returns a tuple with each element corresponding
to the observation from each dataset in `paths`.
See [`close(::HDF5Dataset)`](@ref) for closing the underlying HDF5 file pointer.

For array datasets, the last dimension is assumed to be the observation dimension.
For scalar datasets, the stored value is returned by `getobs` for any index.
"""
struct HDF5Dataset{T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}}} <: AbstractDataContainer
fid::HDF5.File
paths::T
shapes::Vector{Tuple}

function HDF5Dataset(fid::HDF5.File, paths::T, shapes::Vector) where T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}}
_check_hdf5_shapes(shapes) ||
throw(ArgumentError("Cannot create HDF5Dataset for datasets with mismatched number of observations."))

new{T}(fid, paths, shapes)
end
end

HDF5Dataset(fid::HDF5.File, path::HDF5.Dataset) = HDF5Dataset(fid, path, [size(path)])
HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}) =
HDF5Dataset(fid, paths, map(size, paths))
HDF5Dataset(fid::HDF5.File, path::AbstractString) = HDF5Dataset(fid, fid[path])
HDF5Dataset(fid::HDF5.File, paths::Vector{<:AbstractString}) =
HDF5Dataset(fid, map(p -> fid[p], paths))
HDF5Dataset(file::AbstractString, paths) = HDF5Dataset(h5open(file, "r"), paths)

_getobs_hdf5(dataset::HDF5.Dataset, ::Tuple{}, i) = read(dataset)
function _getobs_hdf5(dataset::HDF5.Dataset, shape, i)
I = map(s -> 1:s, shape[1:(end - 1)])

return dataset[I..., i]
end
Base.getindex(dataset::HDF5Dataset{HDF5.Dataset}, i) =
_getobs_hdf5(dataset.paths, only(dataset.shapes), i)
Base.getindex(dataset::HDF5Dataset{<:Vector}, i) =
Tuple(map((p, s) -> _getobs_hdf5(p, s, i), dataset.paths, dataset.shapes))
Base.length(dataset::HDF5Dataset) = last(first(filter(!isempty, dataset.shapes)))

"""
close(dataset::HDF5Dataset)

Close the underlying HDF5 file pointer for `dataset`.
"""
Base.close(dataset::HDF5Dataset) = close(dataset.fid)
39 changes: 39 additions & 0 deletions src/containers/jld2dataset.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
_check_jld2_nobs(nobs) = all(==(first(nobs)), nobs[2:end])

"""
JLD2Dataset(file::AbstractString, paths)
JLD2Dataset(fid::JLD2.JLDFile, paths::Union{String, Vector{String}})

Wrap several JLD2 datasets (`paths`) as a single dataset container.
Each dataset `p` in `paths` should be accessible as `fid[p]`.
Calling `getobs` on a `JLD2Dataset` is equivalent to mapping `getobs` on
each dataset in `paths`.
See [`close(::JLD2Dataset)`](@ref) for closing the underlying JLD2 file pointer.
"""
struct JLD2Dataset{T<:JLD2.JLDFile, S<:Tuple} <: AbstractDataContainer
fid::T
paths::S

function JLD2Dataset(fid::JLD2.JLDFile, paths)
_paths = Tuple(map(p -> fid[p], paths))
nobs = map(numobs, _paths)
_check_jld2_nobs(nobs) ||
throw(ArgumentError("Cannot create JLD2Dataset for datasets with mismatched number of observations (got $nobs)."))

new{typeof(fid), typeof(_paths)}(fid, _paths)
end
end

JLD2Dataset(file::JLD2.JLDFile, path::String) = JLD2Dataset(file, (path,))
JLD2Dataset(file::AbstractString, paths) = JLD2Dataset(jldopen(file, "r"), paths)

Base.getindex(dataset::JLD2Dataset{<:JLD2.JLDFile, <:NTuple{1}}, i) = getobs(only(dataset.paths), i)
Base.getindex(dataset::JLD2Dataset, i) = map(Base.Fix2(getobs, i), dataset.paths)
Base.length(dataset::JLD2Dataset) = numobs(dataset.paths[1])

"""
close(dataset::JLD2Dataset)

Close the underlying JLD2 file pointer for `dataset`.
"""
Base.close(dataset::JLD2Dataset) = close(dataset.fid)
65 changes: 65 additions & 0 deletions src/containers/tabledataset.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
TableDataset(table)
TableDataset(path::AbstractString)

Wrap a Tables.jl-compatible `table` as a dataset container.
Alternatively, specify the `path` to a CSV file directly
to load it with CSV.jl + DataFrames.jl.
"""
struct TableDataset{T} <: AbstractDataContainer
table::T

# TableDatasets must implement the Tables.jl interface
function TableDataset{T}(table::T) where {T}
Tables.istable(table) ||
throw(ArgumentError("TableDatasets must implement the Tabels.jl interface"))

new{T}(table)
end
end

TableDataset(table::T) where {T} = TableDataset{T}(table)
TableDataset(path::AbstractString) = TableDataset(DataFrame(CSV.File(path)))

# slow accesses based on Tables.jl
_getobs_row(x, i) = first(Iterators.peel(Iterators.drop(x, i - 1)))
function _getobs_column(x, i)
colnames = Tuple(Tables.columnnames(x))
rowvals = ntuple(j -> Tables.getcolumn(x, j)[i], length(colnames))

return NamedTuple{colnames}(rowvals)
end
function Base.getindex(dataset::TableDataset, i)
if Tables.rowaccess(dataset.table)
return _getobs_row(Tables.rows(dataset.table), i)
elseif Tables.columnaccess(dataset.table)
return _getobs_column(dataset.table, i)
else
error("The Tables.jl implementation used should have either rowaccess or columnaccess.")
end
end
function Base.length(dataset::TableDataset)
if Tables.columnaccess(dataset.table)
return length(Tables.getcolumn(dataset.table, 1))
elseif Tables.rowaccess(dataset.table)
# length might not be defined, but has to be for this to work.
return length(Tables.rows(dataset.table))
else
error("The Tables.jl implementation used should have either rowaccess or columnaccess.")
end
end

# fast access for DataFrame
Base.getindex(dataset::TableDataset{<:DataFrame}, i) = dataset.table[i, :]
Base.length(dataset::TableDataset{<:DataFrame}) = nrow(dataset.table)

# fast access for CSV.File
Base.getindex(dataset::TableDataset{<:CSV.File}, i) = dataset.table[i]
Base.length(dataset::TableDataset{<:CSV.File}) = length(dataset.table)

## Tables.jl interface

Tables.istable(::TableDataset) = true
for fn in (:rowaccess, :rows, :columnaccess, :columns, :schema, :materializer)
@eval Tables.$fn(dataset::TableDataset) = Tables.$fn(dataset.table)
end
Loading