From 834d9b83597a25f0d08598d32c02663b32f1cdca Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 19 Feb 2022 12:12:18 -0600 Subject: [PATCH 01/19] Initial port of FastAI dataset containers --- Project.toml | 6 ++++ src/MLDatasets.jl | 11 +++++++ src/containers/filedataset.jl | 38 +++++++++++++++++++++++ src/containers/tabledataset.jl | 47 +++++++++++++++++++++++++++++ test/containers/tabledataset.jl | 53 +++++++++++++++++++++++++++++++++ test/runtests.jl | 15 ++++++++-- 6 files changed, 167 insertions(+), 3 deletions(-) create mode 100644 src/containers/filedataset.jl create mode 100644 src/containers/tabledataset.jl create mode 100644 test/containers/tabledataset.jl diff --git a/Project.toml b/Project.toml index c6b684a5..ffa89571 100644 --- a/Project.toml +++ b/Project.toml @@ -4,16 +4,22 @@ version = "0.5.15" [deps] BinDeps = "9e28174c-4ba2-5203-b857-d8d62c4213ee" +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" +FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +FilePathsBase = "48062228-2e41-5def-b9a4-89aafe57970f" FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" GZip = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" MAT = "23992714-dd62-5051-b70f-ba57cb901cac" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] BinDeps = "1" diff --git a/src/MLDatasets.jl b/src/MLDatasets.jl index 25c7f19b..8f30fb75 100644 --- a/src/MLDatasets.jl +++ b/src/MLDatasets.jl @@ -7,6 +7,15 @@ using DelimitedFiles: readdlm using FixedPointNumbers, ColorTypes using Pickle using SparseArrays +using DataFrames, CSV, Tables +using FilePathsBase +using FilePathsBase: AbstractPath + +import MLUtils +using MLUtils: getobs, numobs, AbstractDataContainer + +export FileDataset, TableDataset +export getobs, numobs # Julia 1.0 compatibility if !isdefined(Base, :isnothing) @@ -36,6 +45,8 @@ end include("download.jl") +include("containers/filedataset.jl") +include("containers/tabledataset.jl") # Misc. include("BostonHousing/BostonHousing.jl") diff --git a/src/containers/filedataset.jl b/src/containers/filedataset.jl new file mode 100644 index 00000000..af4a1bec --- /dev/null +++ b/src/containers/filedataset.jl @@ -0,0 +1,38 @@ +# FileDataset + +""" + rglob(filepattern, dir = pwd(), depth = 4) + +Recursive glob up to `depth` layers deep within `dir`. +""" +function rglob(filepattern = "*", dir = pwd(), depth = 4) + patterns = [repeat("*/", i) * filepattern for i in 0:(depth - 1)] + + return vcat([glob(pattern, dir) for pattern in patterns]...) +end + +""" + loadfile(file) + +Load a file from disk into the appropriate format. +""" +function loadfile(file::String) + if isimagefile(file) + # faster image loading + return FileIO.load(file, view = true) + elseif endswith(file, ".csv") + return DataFrame(CSV.File(file)) + else + return FileIO.load(file) + end +end +loadfile(file::AbstractPath) = loadfile(string(file)) + +struct FileDataset{T} <: AbstractDataContainer + paths::T +end + +FileDataset(dir, pattern = "*", depth = 4) = rglob(pattern, string(dir), depth) + +MLUtils.getobs(dataset::FileDataset, i) = loadfile(dataset.paths[i]) +MLUtils.numobs(dataset::FileDataset) = length(dataset.paths) diff --git a/src/containers/tabledataset.jl b/src/containers/tabledataset.jl new file mode 100644 index 00000000..c1c75bfc --- /dev/null +++ b/src/containers/tabledataset.jl @@ -0,0 +1,47 @@ +struct TableDataset{T} <: AbstractDataContainer + table::T + + # TableDatasets must implement the Tables.jl interface + function TableDataset{T}(table::T) where {T} + Tables.istable(table) || + throw(ArgumentError("TableDatasets must implement the Tabels.jl interface")) + + new{T}(table) + end +end + +TableDataset(table::T) where {T} = TableDataset{T}(table) +TableDataset(path::Union{AbstractPath, AbstractString}) = + TableDataset(DataFrame(CSV.File(path))) + +# slow accesses based on Tables.jl +function MLUtils.getobs(dataset::TableDataset, i) + if Tables.rowaccess(dataset.table) + row, _ = Iterators.peel(Iterators.drop(Tables.rows(dataset.table), i - 1)) + return row + elseif Tables.columnaccess(dataset.table) + colnames = Tables.columnnames(dataset.table) + rowvals = [Tables.getcolumn(dataset.table, j)[i] for j in 1:length(colnames)] + return (; zip(colnames, rowvals)...) + else + error("The Tables.jl implementation used should have either rowaccess or columnaccess.") + end +end +function MLUtils.numobs(dataset::TableDataset) + if Tables.columnaccess(dataset.table) + return length(Tables.getcolumn(dataset.table, 1)) + elseif Tables.rowaccess(dataset.table) + # length might not be defined, but has to be for this to work. + return length(Tables.rows(dataset.table)) + else + error("The Tables.jl implementation used should have either rowaccess or columnaccess.") + end +end + +# fast access for DataFrame +MLUtils.getobs(dataset::TableDataset{<:DataFrame}, i) = dataset.table[i, :] +MLUtils.numobs(dataset::TableDataset{<:DataFrame}) = nrow(dataset.table) + +# fast access for CSV.File +MLUtils.getobs(dataset::TableDataset{<:CSV.File}, i) = dataset.table[i] +MLUtils.numobs(dataset::TableDataset{<:CSV.File}) = length(dataset.table) diff --git a/test/containers/tabledataset.jl b/test/containers/tabledataset.jl new file mode 100644 index 00000000..92d76847 --- /dev/null +++ b/test/containers/tabledataset.jl @@ -0,0 +1,53 @@ +@testset "TableDataset" begin + @testset "TableDataset from rowaccess table" begin + Tables.columnaccess(::Type{<:Tables.MatrixTable}) = false + Tables.rowaccess(::Type{<:Tables.MatrixTable}) = true + + testtable = Tables.table([1 4.0 "7"; 2 5.0 "8"; 3 6.0 "9"]) + td = TableDataset(testtable) + + @test all(getobs(td, 1) .== [1, 4.0, "7"]) + @test numobs(td) == 3 + end + + @testset "TableDataset from columnaccess table" begin + Tables.columnaccess(::Type{<:Tables.MatrixTable}) = true + Tables.rowaccess(::Type{<:Tables.MatrixTable}) = false + + testtable = Tables.table([1 4.0 "7"; 2 5.0 "8"; 3 6.0 "9"]) + td = TableDataset(testtable) + + @test [data for data in getobs(td, 2)] == [2, 5.0, "8"] + @test numobs(td) == 3 + + @test getobs(td, 1) isa NamedTuple + end + + @testset "TableDataset from DataFrames" begin + testtable = DataFrame( + col1 = [1, 2, 3, 4, 5], + col2 = ["a", "b", "c", "d", "e"], + col3 = [10, 20, 30, 40, 50], + col4 = ["A", "B", "C", "D", "E"], + col5 = [100.0, 200.0, 300.0, 400.0, 500.0], + split = ["train", "train", "train", "valid", "valid"], + ) + td = TableDataset(testtable) + @test td isa TableDataset{<:DataFrame} + + @test [data for data in getobs(td, 1)] == [1, "a", 10, "A", 100.0, "train"] + @test numobs(td) == 5 + end + + @testset "TableDataset from CSV" begin + open("test.csv", "w") do io + write(io, "col1,col2,col3,col4,col5, split\n1,a,10,A,100.,train") + end + testtable = CSV.File("test.csv") + td = TableDataset(testtable) + @test td isa TableDataset{<:CSV.File} + @test [data for data in getobs(td, 1)] == [1, "a", 10, "A", 100.0, "train"] + @test numobs(td) == 1 + rm("test.csv") + end +end diff --git a/test/runtests.jl b/test/runtests.jl index d11559db..0e0e8fc9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,11 +2,12 @@ using Test using MLDatasets using ImageCore using DataDeps - +using MLUtils: getobs, numobs +using DataFrames, CSV, Tables ENV["DATADEPS_ALWAYS_ACCEPT"] = true -tests = [ +dataset_tests = [ # misc "tst_iris.jl", "tst_boston_housing.jl", @@ -26,12 +27,20 @@ tests = [ "tst_tudataset.jl", ] -for t in tests +container_tests = [ + "containers/tabledataset.jl", +] + +@testset "Datasets" for t in dataset_tests @testset "$t" begin include(t) end end +@testset "Containers" for t in container_tests + include(t) +end + #temporary to not stress CI if !parse(Bool, get(ENV, "CI", "false")) @testset "other tests" begin From 83452ed09b2d99ff2f0343cb5bbd2fb9b511f25e Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 19 Feb 2022 12:37:51 -0600 Subject: [PATCH 02/19] Drop Julia < 1.6 support and add bounds --- .github/workflows/UnitTest.yml | 4 ++-- Project.toml | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/UnitTest.yml b/.github/workflows/UnitTest.yml index 6e920762..be18c8fa 100644 --- a/.github/workflows/UnitTest.yml +++ b/.github/workflows/UnitTest.yml @@ -16,12 +16,12 @@ jobs: strategy: fail-fast: false matrix: - julia-version: ['1.3', '1', 'nightly'] + julia-version: ['1.6', '1', 'nightly'] os: [ubuntu-latest, windows-latest, macOS-latest, macos-11] env: PYTHON: "" steps: - - uses: actions/checkout@v1.0.0 + - uses: actions/checkout@v2 - name: "Set up Julia" uses: julia-actions/setup-julia@v1 with: diff --git a/Project.toml b/Project.toml index ffa89571..ed4b6d96 100644 --- a/Project.toml +++ b/Project.toml @@ -24,15 +24,18 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] BinDeps = "1" ColorTypes = "0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 0.11" +CSV = "0.10.2" DataDeps = "0.3, 0.4, 0.5, 0.6, 0.7" +DataFrames = "1.3" FixedPointNumbers = "0.3, 0.4, 0.5, 0.6, 0.7, 0.8" GZip = "0.5" ImageCore = "0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8" JSON3 = "1" MAT = "0.7, 0.8, 0.9, 0.10" +MLUtils = "0.1.4" Pickle = "0.2, 0.3" Requires = "1" -julia = "1.3" +julia = "1.6" [extras] ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" From 757b706bcd7a0e0441f653d8d92be39382efb233 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 19 Feb 2022 13:13:01 -0600 Subject: [PATCH 03/19] Add some docstrings and test for FileDataset --- Project.toml | 4 +++- src/MLDatasets.jl | 1 + src/containers/filedataset.jl | 15 ++++++++++++-- src/containers/tabledataset.jl | 8 ++++++++ test/containers/filedataset.jl | 36 ++++++++++++++++++++++++++++++++++ test/runtests.jl | 7 +++++-- 6 files changed, 66 insertions(+), 5 deletions(-) create mode 100644 test/containers/filedataset.jl diff --git a/Project.toml b/Project.toml index ed4b6d96..89bc8f68 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" FilePathsBase = "48062228-2e41-5def-b9a4-89aafe57970f" FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" GZip = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" +Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" MAT = "23992714-dd62-5051-b70f-ba57cb901cac" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" @@ -23,11 +24,12 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] BinDeps = "1" -ColorTypes = "0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 0.11" CSV = "0.10.2" +ColorTypes = "0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 0.11" DataDeps = "0.3, 0.4, 0.5, 0.6, 0.7" DataFrames = "1.3" FixedPointNumbers = "0.3, 0.4, 0.5, 0.6, 0.7, 0.8" +Glob = "1.3" GZip = "0.5" ImageCore = "0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8" JSON3 = "1" diff --git a/src/MLDatasets.jl b/src/MLDatasets.jl index 8f30fb75..e34b2839 100644 --- a/src/MLDatasets.jl +++ b/src/MLDatasets.jl @@ -10,6 +10,7 @@ using SparseArrays using DataFrames, CSV, Tables using FilePathsBase using FilePathsBase: AbstractPath +using Glob import MLUtils using MLUtils: getobs, numobs, AbstractDataContainer diff --git a/src/containers/filedataset.jl b/src/containers/filedataset.jl index af4a1bec..b0c677b7 100644 --- a/src/containers/filedataset.jl +++ b/src/containers/filedataset.jl @@ -1,4 +1,7 @@ -# FileDataset +matches(re::Regex) = f -> matches(re, f) +matches(re::Regex, f) = !isnothing(match(re, f)) +const RE_IMAGEFILE = r".*\.(gif|jpe?g|tiff?|png|webp|bmp)$"i +isimagefile(f) = matches(RE_IMAGEFILE, f) """ rglob(filepattern, dir = pwd(), depth = 4) @@ -28,11 +31,19 @@ function loadfile(file::String) end loadfile(file::AbstractPath) = loadfile(string(file)) +""" + FileDataset(paths) + FileDataset(dir, pattern = "*", depth = 4) + +Wrap a set of file `paths` as a dataset (traversed in the same order as `paths`). +Alternatively, specify a `dir` and collect all paths that match a glob `pattern` +(recursively globbing by `depth`). The glob order determines the traversal order. +""" struct FileDataset{T} <: AbstractDataContainer paths::T end -FileDataset(dir, pattern = "*", depth = 4) = rglob(pattern, string(dir), depth) +FileDataset(dir, pattern = "*", depth = 4) = FileDataset(rglob(pattern, string(dir), depth)) MLUtils.getobs(dataset::FileDataset, i) = loadfile(dataset.paths[i]) MLUtils.numobs(dataset::FileDataset) = length(dataset.paths) diff --git a/src/containers/tabledataset.jl b/src/containers/tabledataset.jl index c1c75bfc..56cbcfed 100644 --- a/src/containers/tabledataset.jl +++ b/src/containers/tabledataset.jl @@ -1,3 +1,11 @@ +""" + TableDataset(table) + TableDataset(path::Union{AbstractPath, AbstractString}) + +Wrap a Tables.jl-compatible `table` as a dataset container. +Alternatively, specify the `path` to a CSV file directly +to load it with CSV.jl + DataFrames.jl. +""" struct TableDataset{T} <: AbstractDataContainer table::T diff --git a/test/containers/filedataset.jl b/test/containers/filedataset.jl new file mode 100644 index 00000000..3a07fb88 --- /dev/null +++ b/test/containers/filedataset.jl @@ -0,0 +1,36 @@ +function setup_filedataset_test() + files = [ + "root/p1/f1.csv", + "root/p2/f2.csv", + "root/p2/p2p1/f2.csv", + "root/p3/p3p1/f1.csv" + ] + + for (i, file) in enumerate(files) + paths = splitpath(file)[1:(end - 1)] + root = "" + for p in paths + fullp = joinpath(root, p) + isdir(fullp) || mkdir(fullp) + root = fullp + end + + open(file, "w") do io + write(io, "a,b,c\n") + write(io, join(i .* [1, 2, 3], ",")) + end + end + + return files +end + +@testset "FileDataset" begin + files = setup_filedataset_test() + dataset = FileDataset("root", "*.csv") + @test numobs(dataset) == length(files) + for (i, file) in enumerate(files) + true_obs = MLDatasets.loadfile(file) + @test getobs(dataset, i) == true_obs + end + rm("root"; recursive = true) +end diff --git a/test/runtests.jl b/test/runtests.jl index 0e0e8fc9..0791d8b6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,6 +28,7 @@ dataset_tests = [ ] container_tests = [ + "containers/filedataset.jl", "containers/tabledataset.jl", ] @@ -37,8 +38,10 @@ container_tests = [ end end -@testset "Containers" for t in container_tests - include(t) +@testset "Containers" begin + for t in container_tests + include(t) + end end #temporary to not stress CI From 4e6807fea054ad6e5f600da94cfb974b0c532337 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sun, 20 Feb 2022 12:18:26 -0600 Subject: [PATCH 04/19] Add HDF5 dataset --- Project.toml | 3 +- src/MLDatasets.jl | 4 ++- src/containers/filedataset.jl | 4 +-- src/containers/hdf5dataset.jl | 50 ++++++++++++++++++++++++++++++++++ test/containers/hdf5dataset.jl | 28 +++++++++++++++++++ test/runtests.jl | 2 ++ 6 files changed, 87 insertions(+), 4 deletions(-) create mode 100644 src/containers/hdf5dataset.jl create mode 100644 test/containers/hdf5dataset.jl diff --git a/Project.toml b/Project.toml index 89bc8f68..be523682 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ FilePathsBase = "48062228-2e41-5def-b9a4-89aafe57970f" FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" GZip = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" +HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" MAT = "23992714-dd62-5051-b70f-ba57cb901cac" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" @@ -29,8 +30,8 @@ ColorTypes = "0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 0.11" DataDeps = "0.3, 0.4, 0.5, 0.6, 0.7" DataFrames = "1.3" FixedPointNumbers = "0.3, 0.4, 0.5, 0.6, 0.7, 0.8" -Glob = "1.3" GZip = "0.5" +Glob = "1.3" ImageCore = "0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8" JSON3 = "1" MAT = "0.7, 0.8, 0.9, 0.10" diff --git a/src/MLDatasets.jl b/src/MLDatasets.jl index e34b2839..5f83f9f6 100644 --- a/src/MLDatasets.jl +++ b/src/MLDatasets.jl @@ -11,11 +11,12 @@ using DataFrames, CSV, Tables using FilePathsBase using FilePathsBase: AbstractPath using Glob +using HDF5 import MLUtils using MLUtils: getobs, numobs, AbstractDataContainer -export FileDataset, TableDataset +export FileDataset, TableDataset, HDF5Dataset export getobs, numobs # Julia 1.0 compatibility @@ -48,6 +49,7 @@ include("download.jl") include("containers/filedataset.jl") include("containers/tabledataset.jl") +include("containers/hdf5dataset.jl") # Misc. include("BostonHousing/BostonHousing.jl") diff --git a/src/containers/filedataset.jl b/src/containers/filedataset.jl index b0c677b7..0d0c463b 100644 --- a/src/containers/filedataset.jl +++ b/src/containers/filedataset.jl @@ -39,8 +39,8 @@ Wrap a set of file `paths` as a dataset (traversed in the same order as `paths`) Alternatively, specify a `dir` and collect all paths that match a glob `pattern` (recursively globbing by `depth`). The glob order determines the traversal order. """ -struct FileDataset{T} <: AbstractDataContainer - paths::T +struct FileDataset{T<:Union{AbstractPath, AbstractString}} <: AbstractDataContainer + paths::Vector{T} end FileDataset(dir, pattern = "*", depth = 4) = FileDataset(rglob(pattern, string(dir), depth)) diff --git a/src/containers/hdf5dataset.jl b/src/containers/hdf5dataset.jl new file mode 100644 index 00000000..e80160c8 --- /dev/null +++ b/src/containers/hdf5dataset.jl @@ -0,0 +1,50 @@ +function _check_hdf5_shapes(shapes) + nobs = map(last, filter(!isempty, shapes)) + + return all(==(first(nobs)), nobs[2:end]) +end + +""" + HDF5Dataset(file::Union{AbstractString, AbstractPath}, paths) + HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}) + HDF5Dataset(fid::HDF5.File, paths::Vector{<:AbstractString}) + HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}, shapes) + +Wrap several HDF5 datasets (`paths`) as a single dataset container. +Each dataset `p` in `paths` should be accessible as `fid[p]`. +See [`close(::HDF5Dataset)`](@ref) for closing the underlying HDF5 file pointer. + +For array datasets, the last dimension is assumed to be the observation dimension. +For scalar datasets, the stored value is returned by `getobs` for any index. +""" +struct HDF5Dataset + fid::HDF5.File + paths::Vector{HDF5.Dataset} + shapes::Vector{Tuple} + + function HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}, shapes::Vector) + _check_hdf5_shapes(shapes) || + throw(ArgumentError("Cannot create HDF5Dataset for datasets with mismatch number of observations.")) + + new(fid, paths, shapes) + end +end + +HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}) = + HDF5Dataset(fid, paths, map(size, paths)) +HDF5Dataset(fid::HDF5.File, paths::Vector{<:AbstractString}) = + HDF5Dataset(fid, map(p -> fid[p], paths)) +HDF5Dataset(file::Union{AbstractString, AbstractPath}, paths) = + HDF5Dataset(h5open(file, "r"), paths) + +MLUtils.getobs(dataset::HDF5Dataset, i) = Tuple(map(dataset.paths, dataset.shapes) do path, shape + if isempty(shape) + return read(path) + else + I = map(s -> 1:s, shape[1:(end - 1)]) + return path[I..., i] + end +end) +MLUtils.numobs(dataset::HDF5Dataset) = last(first(filter(!isempty, dataset.shapes))) + +Base.close(dataset::HDF5Dataset) = close(dataset.fid) diff --git a/test/containers/hdf5dataset.jl b/test/containers/hdf5dataset.jl new file mode 100644 index 00000000..5398b6c1 --- /dev/null +++ b/test/containers/hdf5dataset.jl @@ -0,0 +1,28 @@ +function setup_hdf5dataset_test() + datasets = [ + ("d1", rand(2, 10)), + ("g1/d1", rand(10)), + # these are broken at HDF5 level + # ("g1/d2", string.('a':'j')), + # ("g2/g1/d1", "test") + ] + + fid = h5open("test.h5", "w") + for (path, data) in datasets + fid[path] = data + end + close(fid) + + return first.(datasets), last.(datasets) +end + +@testset "HDF5Dataset" begin + paths, datas = setup_hdf5dataset_test() + dataset = HDF5Dataset("test.h5", paths) + for i in 1:10 + data = Tuple(map(x -> getobs(x, i), datas)) + @test getobs(dataset, i) == data + end + @test numobs(dataset) == 10 + rm("test.h5") +end diff --git a/test/runtests.jl b/test/runtests.jl index 0791d8b6..5f34f58d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using ImageCore using DataDeps using MLUtils: getobs, numobs using DataFrames, CSV, Tables +using HDF5 ENV["DATADEPS_ALWAYS_ACCEPT"] = true @@ -30,6 +31,7 @@ dataset_tests = [ container_tests = [ "containers/filedataset.jl", "containers/tabledataset.jl", + "containers/hdf5dataset.jl", ] @testset "Datasets" for t in dataset_tests From 9a4531202ef3c837a3bd466fe3b502c1cbd0528a Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sun, 20 Feb 2022 12:36:01 -0600 Subject: [PATCH 05/19] Close HDF5 file before deleting in tests --- test/containers/hdf5dataset.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/containers/hdf5dataset.jl b/test/containers/hdf5dataset.jl index 5398b6c1..c1def2c2 100644 --- a/test/containers/hdf5dataset.jl +++ b/test/containers/hdf5dataset.jl @@ -24,5 +24,6 @@ end @test getobs(dataset, i) == data end @test numobs(dataset) == 10 + close(dataset) rm("test.h5") end From 3ecc35b469eb64ae3c1d74bcf8e977a8943bd78e Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sun, 20 Feb 2022 12:38:35 -0600 Subject: [PATCH 06/19] Update HDF5 docstrings --- src/containers/hdf5dataset.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/containers/hdf5dataset.jl b/src/containers/hdf5dataset.jl index e80160c8..94a6ae40 100644 --- a/src/containers/hdf5dataset.jl +++ b/src/containers/hdf5dataset.jl @@ -12,6 +12,8 @@ end Wrap several HDF5 datasets (`paths`) as a single dataset container. Each dataset `p` in `paths` should be accessible as `fid[p]`. +Calling `getobs` on a `HDF5Dataset` returns a tuple with each element corresponding +to the observation from each dataset in `paths`. See [`close(::HDF5Dataset)`](@ref) for closing the underlying HDF5 file pointer. For array datasets, the last dimension is assumed to be the observation dimension. @@ -47,4 +49,9 @@ MLUtils.getobs(dataset::HDF5Dataset, i) = Tuple(map(dataset.paths, dataset.shape end) MLUtils.numobs(dataset::HDF5Dataset) = last(first(filter(!isempty, dataset.shapes))) +""" + close(dataset::HDF5Dataset) + +Close the underlying HDF5 file pointer for `dataset`. +""" Base.close(dataset::HDF5Dataset) = close(dataset.fid) From 7bba084edf1fe321f346e39c32fc9fb492ad1481 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sun, 20 Feb 2022 12:55:54 -0600 Subject: [PATCH 07/19] Fix broken HDF5 string tests --- test/containers/hdf5dataset.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/containers/hdf5dataset.jl b/test/containers/hdf5dataset.jl index c1def2c2..b53454d2 100644 --- a/test/containers/hdf5dataset.jl +++ b/test/containers/hdf5dataset.jl @@ -2,9 +2,8 @@ function setup_hdf5dataset_test() datasets = [ ("d1", rand(2, 10)), ("g1/d1", rand(10)), - # these are broken at HDF5 level - # ("g1/d2", string.('a':'j')), - # ("g2/g1/d1", "test") + ("g1/d2", string.('a':'j')), + ("g2/g1/d1", "test") ] fid = h5open("test.h5", "w") @@ -20,7 +19,7 @@ end paths, datas = setup_hdf5dataset_test() dataset = HDF5Dataset("test.h5", paths) for i in 1:10 - data = Tuple(map(x -> getobs(x, i), datas)) + data = Tuple(map(x -> (x isa String) ? x : getobs(x, i), datas)) @test getobs(dataset, i) == data end @test numobs(dataset) == 10 From 92e0d066e6d83b465932883f18ed09cd56a7837e Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sun, 20 Feb 2022 13:42:38 -0600 Subject: [PATCH 08/19] Add JLD2Dataset --- Project.toml | 1 + src/MLDatasets.jl | 4 +++- src/containers/hdf5dataset.jl | 4 ++-- src/containers/jld2dataset.jl | 37 ++++++++++++++++++++++++++++++++++ test/containers/jld2dataset.jl | 27 +++++++++++++++++++++++++ test/runtests.jl | 2 ++ 6 files changed, 72 insertions(+), 3 deletions(-) create mode 100644 src/containers/jld2dataset.jl create mode 100644 test/containers/jld2dataset.jl diff --git a/Project.toml b/Project.toml index be523682..d89e92a4 100644 --- a/Project.toml +++ b/Project.toml @@ -15,6 +15,7 @@ FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" GZip = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" MAT = "23992714-dd62-5051-b70f-ba57cb901cac" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" diff --git a/src/MLDatasets.jl b/src/MLDatasets.jl index 5f83f9f6..126396f3 100644 --- a/src/MLDatasets.jl +++ b/src/MLDatasets.jl @@ -12,11 +12,12 @@ using FilePathsBase using FilePathsBase: AbstractPath using Glob using HDF5 +using JLD2 import MLUtils using MLUtils: getobs, numobs, AbstractDataContainer -export FileDataset, TableDataset, HDF5Dataset +export FileDataset, TableDataset, HDF5Dataset, JLD2Dataset export getobs, numobs # Julia 1.0 compatibility @@ -50,6 +51,7 @@ include("download.jl") include("containers/filedataset.jl") include("containers/tabledataset.jl") include("containers/hdf5dataset.jl") +include("containers/jld2dataset.jl") # Misc. include("BostonHousing/BostonHousing.jl") diff --git a/src/containers/hdf5dataset.jl b/src/containers/hdf5dataset.jl index 94a6ae40..e604dbe6 100644 --- a/src/containers/hdf5dataset.jl +++ b/src/containers/hdf5dataset.jl @@ -19,14 +19,14 @@ See [`close(::HDF5Dataset)`](@ref) for closing the underlying HDF5 file pointer. For array datasets, the last dimension is assumed to be the observation dimension. For scalar datasets, the stored value is returned by `getobs` for any index. """ -struct HDF5Dataset +struct HDF5Dataset <: AbstractDataContainer fid::HDF5.File paths::Vector{HDF5.Dataset} shapes::Vector{Tuple} function HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}, shapes::Vector) _check_hdf5_shapes(shapes) || - throw(ArgumentError("Cannot create HDF5Dataset for datasets with mismatch number of observations.")) + throw(ArgumentError("Cannot create HDF5Dataset for datasets with mismatched number of observations.")) new(fid, paths, shapes) end diff --git a/src/containers/jld2dataset.jl b/src/containers/jld2dataset.jl new file mode 100644 index 00000000..edbd9384 --- /dev/null +++ b/src/containers/jld2dataset.jl @@ -0,0 +1,37 @@ +_check_jld2_nobs(nobs) = all(==(first(nobs)), nobs[2:end]) + +""" + JLD2Dataset(file::Union{AbstractString, AbstractPath}, paths) + JLD2Dataset(fid::JLD2.JLDFile, paths) + +Wrap several JLD2 datasets (`paths`) as a single dataset container. +Each dataset `p` in `paths` should be accessible as `fid[p]`. +Calling `getobs` on a `JLD2Dataset` is equivalent to mapping `getobs` on +each dataset in `paths`. +See [`close(::JLD2Dataset)`](@ref) for closing the underlying JLD2 file pointer. +""" +struct JLD2Dataset{T<:JLD2.JLDFile} <: AbstractDataContainer + fid::T + paths::Vector{String} + + function JLD2Dataset(fid::JLD2.JLDFile, paths::Vector{String}) + nobs = map(p -> numobs(fid[p]), paths) + _check_jld2_nobs(nobs) || + throw(ArgumentError("Cannot create JLD2Dataset for datasets with mismatched number of observations.")) + + new{typeof(fid)}(fid, paths) + end +end + +JLD2Dataset(file::Union{AbstractString, AbstractPath}, paths) = + JLD2Dataset(jldopen(file, "r"), paths) + +MLUtils.getobs(dataset::JLD2Dataset, i) = Tuple(map(p -> getobs(dataset.fid[p], i), dataset.paths)) +MLUtils.numobs(dataset::JLD2Dataset) = numobs(dataset.fid[dataset.paths[1]]) + +""" + close(dataset::JLD2Dataset) + +Close the underlying JLD2 file pointer for `dataset`. +""" +Base.close(dataset::JLD2Dataset) = close(dataset.fid) diff --git a/test/containers/jld2dataset.jl b/test/containers/jld2dataset.jl new file mode 100644 index 00000000..2f91c730 --- /dev/null +++ b/test/containers/jld2dataset.jl @@ -0,0 +1,27 @@ +function setup_jld2dataset_test() + datasets = [ + ("d1", rand(2, 10)), + ("g1/d1", rand(10)), + ("g1/d2", string.('a':'j')), + ("g2/g1/d1", rand(Bool, 3, 3, 10)) + ] + + fid = jldopen("test.jld2", "w") + for (path, data) in datasets + fid[path] = data + end + close(fid) + + return first.(datasets), Tuple(last.(datasets)) +end + +@testset "JLD2Dataset" begin + paths, datas = setup_jld2dataset_test() + dataset = JLD2Dataset("test.jld2", paths) + for i in 1:10 + @test getobs(dataset, i) == getobs(datas, i) + end + @test numobs(dataset) == 10 + close(dataset) + rm("test.jld2") +end diff --git a/test/runtests.jl b/test/runtests.jl index 5f34f58d..ca0c401e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,7 @@ using DataDeps using MLUtils: getobs, numobs using DataFrames, CSV, Tables using HDF5 +using JLD2 ENV["DATADEPS_ALWAYS_ACCEPT"] = true @@ -32,6 +33,7 @@ container_tests = [ "containers/filedataset.jl", "containers/tabledataset.jl", "containers/hdf5dataset.jl", + "containers/jld2dataset.jl", ] @testset "Datasets" for t in dataset_tests From 3d48893edd7a0120f8698421638556681139989c Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Mon, 21 Feb 2022 13:30:54 -0600 Subject: [PATCH 09/19] Support custom loading function in FileDataset --- src/containers/filedataset.jl | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/containers/filedataset.jl b/src/containers/filedataset.jl index 0d0c463b..c126bc49 100644 --- a/src/containers/filedataset.jl +++ b/src/containers/filedataset.jl @@ -32,18 +32,25 @@ end loadfile(file::AbstractPath) = loadfile(string(file)) """ - FileDataset(paths) - FileDataset(dir, pattern = "*", depth = 4) + FileDataset([loadfn = loadfile,] paths) + FileDataset([loadfn = loadfile,] dir, pattern = "*", depth = 4) Wrap a set of file `paths` as a dataset (traversed in the same order as `paths`). Alternatively, specify a `dir` and collect all paths that match a glob `pattern` (recursively globbing by `depth`). The glob order determines the traversal order. """ -struct FileDataset{T<:Union{AbstractPath, AbstractString}} <: AbstractDataContainer +struct FileDataset{F, T<:Union{AbstractPath, AbstractString}} <: AbstractDataContainer + loadfn::F paths::Vector{T} end -FileDataset(dir, pattern = "*", depth = 4) = FileDataset(rglob(pattern, string(dir), depth)) +FileDataset(paths) = FileDataset(loadfile, paths) +FileDataset(loadfn, + dir::Union{AbstractPath, AbstractString}, + pattern::AbstractString = "*", + depth = 4) = FileDataset(loadfn, rglob(pattern, string(dir), depth)) +FileDataset(dir::Union{AbstractPath, AbstractString}, pattern::AbstractString = "*", depth = 4) = + FileDataset(loadfile, dir, pattern, depth) MLUtils.getobs(dataset::FileDataset, i) = loadfile(dataset.paths[i]) MLUtils.numobs(dataset::FileDataset) = length(dataset.paths) From a30bfabe9f0e6c7f58d03c595914d4fc70e43552 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 22 Feb 2022 14:05:07 -0600 Subject: [PATCH 10/19] Add CachedDataset --- src/MLDatasets.jl | 8 +++++-- src/containers/cacheddataset.jl | 39 ++++++++++++++++++++++++++++++++ src/containers/filedataset.jl | 3 ++- test/containers/cacheddataset.jl | 25 ++++++++++++++++++++ test/containers/filedataset.jl | 3 ++- test/containers/hdf5dataset.jl | 3 ++- test/containers/jld2dataset.jl | 3 ++- test/runtests.jl | 2 +- 8 files changed, 79 insertions(+), 7 deletions(-) create mode 100644 src/containers/cacheddataset.jl create mode 100644 test/containers/cacheddataset.jl diff --git a/src/MLDatasets.jl b/src/MLDatasets.jl index 126396f3..384a9679 100644 --- a/src/MLDatasets.jl +++ b/src/MLDatasets.jl @@ -16,8 +16,6 @@ using JLD2 import MLUtils using MLUtils: getobs, numobs, AbstractDataContainer - -export FileDataset, TableDataset, HDF5Dataset, JLD2Dataset export getobs, numobs # Julia 1.0 compatibility @@ -49,9 +47,15 @@ end include("download.jl") include("containers/filedataset.jl") +export FileDataset include("containers/tabledataset.jl") +export TableDataset include("containers/hdf5dataset.jl") +export HDF5Dataset include("containers/jld2dataset.jl") +export JLD2Dataset +include("containers/cacheddataset.jl") +export CachedDataset # Misc. include("BostonHousing/BostonHousing.jl") diff --git a/src/containers/cacheddataset.jl b/src/containers/cacheddataset.jl new file mode 100644 index 00000000..aa64ed5d --- /dev/null +++ b/src/containers/cacheddataset.jl @@ -0,0 +1,39 @@ +""" + make_cache(source, cacheidx) + +Return a in-memory copy of `source` at observation indices `cacheidx`. +Defaults to `getobs(source, cacheidx)`. +""" +make_cache(source, cacheidx) = getobs(source, cacheidx) + +""" + CachedDataset(source, cachesize = numbobs(source)) + CachedDataset(source, cacheidx, cache) + +Wrap a `source` data container and cache `cachesize` samples in memory. +This can be useful for improving read speeds when `source` is a lazy data container, +but your system memory is large enough to store a sizeable chunk of it. + +By default the observation indices `1:cachesize` are cached. +You can manually pass in a `cache` and set of `cacheidx` as well. + +See also [`make_cache`](@ref) for customizing the default cache creation for `source`. +""" +struct CachedDataset{T, S} + source::T + cacheidx::Vector{Int} + cache::S +end + +function CachedDataset(source, cachesize::Int = numobs(source)) + cacheidx = 1:cachesize + + CachedDataset(source, collect(cacheidx), make_cache(source, cacheidx)) +end + +function MLUtils.getobs(dataset::CachedDataset, i::Integer) + _i = findfirst(==(i), dataset.cacheidx) + + return isnothing(_i) ? getobs(dataset.source, i) : getobs(dataset.cache, _i) +end +MLUtils.numobs(dataset::CachedDataset) = numobs(dataset.source) diff --git a/src/containers/filedataset.jl b/src/containers/filedataset.jl index c126bc49..ae6434ad 100644 --- a/src/containers/filedataset.jl +++ b/src/containers/filedataset.jl @@ -52,5 +52,6 @@ FileDataset(loadfn, FileDataset(dir::Union{AbstractPath, AbstractString}, pattern::AbstractString = "*", depth = 4) = FileDataset(loadfile, dir, pattern, depth) -MLUtils.getobs(dataset::FileDataset, i) = loadfile(dataset.paths[i]) +MLUtils.getobs(dataset::FileDataset, i::Integer) = loadfile(dataset.paths[i]) +MLUtils.getobs(dataset::FileDataset, is::AbstractVector) = map(Base.Fix1(getobs, dataset), is) MLUtils.numobs(dataset::FileDataset) = length(dataset.paths) diff --git a/test/containers/cacheddataset.jl b/test/containers/cacheddataset.jl new file mode 100644 index 00000000..ff70aab4 --- /dev/null +++ b/test/containers/cacheddataset.jl @@ -0,0 +1,25 @@ +@testset "CachedDataset" begin + @testset "CachedDataset(::FileDataset)" begin + files = setup_filedataset_test() + fdataset = FileDataset("root", "*.csv") + cdataset = CachedDataset(fdataset) + + @test numobs(cdataset) == numobs(fdataset) + @test cdataset.cache == getobs(fdataset, 1:numobs(fdataset)) + @test all(getobs(cdataset, i) == getobs(fdataset, i) for i in 1:numobs(fdataset)) + + cleanup_filedataset_test() + end + + @testset "CachedDataset(::HDF5Dataset)" begin + paths, datas = setup_hdf5dataset_test() + hdataset = HDF5Dataset("test.h5", ["d1"]) + cdataset = CachedDataset(hdataset, 5) + + @test numobs(cdataset) == numobs(hdataset) + @test cdataset.cache == getobs(hdataset, 1:5) + @test all(getobs(cdataset, i) == getobs(hdataset, i) for i in 1:10) + + cleanup_hdf5dataset_test() + end +end diff --git a/test/containers/filedataset.jl b/test/containers/filedataset.jl index 3a07fb88..f143aae3 100644 --- a/test/containers/filedataset.jl +++ b/test/containers/filedataset.jl @@ -23,6 +23,7 @@ function setup_filedataset_test() return files end +cleanup_filedataset_test() = rm("root"; recursive = true) @testset "FileDataset" begin files = setup_filedataset_test() @@ -32,5 +33,5 @@ end true_obs = MLDatasets.loadfile(file) @test getobs(dataset, i) == true_obs end - rm("root"; recursive = true) + cleanup_filedataset_test() end diff --git a/test/containers/hdf5dataset.jl b/test/containers/hdf5dataset.jl index b53454d2..878cc262 100644 --- a/test/containers/hdf5dataset.jl +++ b/test/containers/hdf5dataset.jl @@ -14,6 +14,7 @@ function setup_hdf5dataset_test() return first.(datasets), last.(datasets) end +cleanup_hdf5dataset_test() = rm("test.h5") @testset "HDF5Dataset" begin paths, datas = setup_hdf5dataset_test() @@ -24,5 +25,5 @@ end end @test numobs(dataset) == 10 close(dataset) - rm("test.h5") + cleanup_hdf5dataset_test() end diff --git a/test/containers/jld2dataset.jl b/test/containers/jld2dataset.jl index 2f91c730..d849d8fa 100644 --- a/test/containers/jld2dataset.jl +++ b/test/containers/jld2dataset.jl @@ -14,6 +14,7 @@ function setup_jld2dataset_test() return first.(datasets), Tuple(last.(datasets)) end +cleanup_jld2dataset_test() = rm("test.jld2") @testset "JLD2Dataset" begin paths, datas = setup_jld2dataset_test() @@ -23,5 +24,5 @@ end end @test numobs(dataset) == 10 close(dataset) - rm("test.jld2") + cleanup_jld2dataset_test() end diff --git a/test/runtests.jl b/test/runtests.jl index ca0c401e..2844d26a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,6 @@ using Test using MLDatasets using ImageCore using DataDeps -using MLUtils: getobs, numobs using DataFrames, CSV, Tables using HDF5 using JLD2 @@ -34,6 +33,7 @@ container_tests = [ "containers/tabledataset.jl", "containers/hdf5dataset.jl", "containers/jld2dataset.jl", + "containers/cacheddataset.jl", ] @testset "Datasets" for t in dataset_tests From b7ba9c4500e57bbd0bcc6141c32e1818613eeb9a Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 22 Feb 2022 16:19:46 -0600 Subject: [PATCH 11/19] Special case single path HDF5 and JLD2 datasets and add `@inferred` tests --- src/containers/hdf5dataset.jl | 34 ++++++++++++++++++-------------- src/containers/jld2dataset.jl | 21 +++++++++++--------- src/containers/tabledataset.jl | 14 ++++++++----- test/containers/cacheddataset.jl | 18 ++++++++++++++++- test/containers/hdf5dataset.jl | 22 +++++++++++++++------ test/containers/jld2dataset.jl | 20 ++++++++++++++----- test/containers/tabledataset.jl | 8 ++++---- 7 files changed, 92 insertions(+), 45 deletions(-) diff --git a/src/containers/hdf5dataset.jl b/src/containers/hdf5dataset.jl index e604dbe6..2bdc6ac9 100644 --- a/src/containers/hdf5dataset.jl +++ b/src/containers/hdf5dataset.jl @@ -6,9 +6,9 @@ end """ HDF5Dataset(file::Union{AbstractString, AbstractPath}, paths) - HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}) - HDF5Dataset(fid::HDF5.File, paths::Vector{<:AbstractString}) - HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}, shapes) + HDF5Dataset(fid::HDF5.File, paths::Union{HDF5.Dataset, Vector{HDF5.Dataset}}) + HDF5Dataset(fid::HDF5.File, paths::Union{AbstractString, Vector{<:AbstractString}}) + HDF5Dataset(fid::HDF5.File, paths::Union{HDF5.Dataset, Vector{HDF5.Dataset}}, shapes) Wrap several HDF5 datasets (`paths`) as a single dataset container. Each dataset `p` in `paths` should be accessible as `fid[p]`. @@ -19,34 +19,38 @@ See [`close(::HDF5Dataset)`](@ref) for closing the underlying HDF5 file pointer. For array datasets, the last dimension is assumed to be the observation dimension. For scalar datasets, the stored value is returned by `getobs` for any index. """ -struct HDF5Dataset <: AbstractDataContainer +struct HDF5Dataset{T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}}} <: AbstractDataContainer fid::HDF5.File - paths::Vector{HDF5.Dataset} + paths::T shapes::Vector{Tuple} - function HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}, shapes::Vector) + function HDF5Dataset(fid::HDF5.File, paths::T, shapes::Vector) where T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}} _check_hdf5_shapes(shapes) || throw(ArgumentError("Cannot create HDF5Dataset for datasets with mismatched number of observations.")) - new(fid, paths, shapes) + new{T}(fid, paths, shapes) end end +HDF5Dataset(fid::HDF5.File, path::HDF5.Dataset) = HDF5Dataset(fid, path, [size(path)]) HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}) = HDF5Dataset(fid, paths, map(size, paths)) +HDF5Dataset(fid::HDF5.File, path::AbstractString) = HDF5Dataset(fid, fid[path]) HDF5Dataset(fid::HDF5.File, paths::Vector{<:AbstractString}) = HDF5Dataset(fid, map(p -> fid[p], paths)) HDF5Dataset(file::Union{AbstractString, AbstractPath}, paths) = HDF5Dataset(h5open(file, "r"), paths) -MLUtils.getobs(dataset::HDF5Dataset, i) = Tuple(map(dataset.paths, dataset.shapes) do path, shape - if isempty(shape) - return read(path) - else - I = map(s -> 1:s, shape[1:(end - 1)]) - return path[I..., i] - end -end) +_getobs_hdf5(dataset::HDF5.Dataset, ::Tuple{}, i) = read(dataset) +function _getobs_hdf5(dataset::HDF5.Dataset, shape, i) + I = map(s -> 1:s, shape[1:(end - 1)]) + + return dataset[I..., i] +end +MLUtils.getobs(dataset::HDF5Dataset{HDF5.Dataset}, i) = + _getobs_hdf5(dataset.paths, only(dataset.shapes), i) +MLUtils.getobs(dataset::HDF5Dataset{<:Vector}, i) = + Tuple(map((p, s) -> _getobs_hdf5(p, s, i), dataset.paths, dataset.shapes)) MLUtils.numobs(dataset::HDF5Dataset) = last(first(filter(!isempty, dataset.shapes))) """ diff --git a/src/containers/jld2dataset.jl b/src/containers/jld2dataset.jl index edbd9384..9fb91150 100644 --- a/src/containers/jld2dataset.jl +++ b/src/containers/jld2dataset.jl @@ -2,7 +2,7 @@ _check_jld2_nobs(nobs) = all(==(first(nobs)), nobs[2:end]) """ JLD2Dataset(file::Union{AbstractString, AbstractPath}, paths) - JLD2Dataset(fid::JLD2.JLDFile, paths) + JLD2Dataset(fid::JLD2.JLDFile, paths::Union{String, Vector{String}}) Wrap several JLD2 datasets (`paths`) as a single dataset container. Each dataset `p` in `paths` should be accessible as `fid[p]`. @@ -10,24 +10,27 @@ Calling `getobs` on a `JLD2Dataset` is equivalent to mapping `getobs` on each dataset in `paths`. See [`close(::JLD2Dataset)`](@ref) for closing the underlying JLD2 file pointer. """ -struct JLD2Dataset{T<:JLD2.JLDFile} <: AbstractDataContainer +struct JLD2Dataset{T<:JLD2.JLDFile, S<:Tuple} <: AbstractDataContainer fid::T - paths::Vector{String} + paths::S - function JLD2Dataset(fid::JLD2.JLDFile, paths::Vector{String}) - nobs = map(p -> numobs(fid[p]), paths) + function JLD2Dataset(fid::JLD2.JLDFile, paths) + _paths = Tuple(map(p -> fid[p], paths)) + nobs = map(numobs, _paths) _check_jld2_nobs(nobs) || - throw(ArgumentError("Cannot create JLD2Dataset for datasets with mismatched number of observations.")) + throw(ArgumentError("Cannot create JLD2Dataset for datasets with mismatched number of observations (got $nobs).")) - new{typeof(fid)}(fid, paths) + new{typeof(fid), typeof(_paths)}(fid, _paths) end end +JLD2Dataset(file::JLD2.JLDFile, path::String) = JLD2Dataset(file, (path,)) JLD2Dataset(file::Union{AbstractString, AbstractPath}, paths) = JLD2Dataset(jldopen(file, "r"), paths) -MLUtils.getobs(dataset::JLD2Dataset, i) = Tuple(map(p -> getobs(dataset.fid[p], i), dataset.paths)) -MLUtils.numobs(dataset::JLD2Dataset) = numobs(dataset.fid[dataset.paths[1]]) +MLUtils.getobs(dataset::JLD2Dataset{<:JLD2.JLDFile, <:NTuple{1}}, i) = getobs(only(dataset.paths), i) +MLUtils.getobs(dataset::JLD2Dataset, i) = map(Base.Fix2(getobs, i), dataset.paths) +MLUtils.numobs(dataset::JLD2Dataset) = numobs(dataset.paths[1]) """ close(dataset::JLD2Dataset) diff --git a/src/containers/tabledataset.jl b/src/containers/tabledataset.jl index 56cbcfed..f2756ae7 100644 --- a/src/containers/tabledataset.jl +++ b/src/containers/tabledataset.jl @@ -23,14 +23,18 @@ TableDataset(path::Union{AbstractPath, AbstractString}) = TableDataset(DataFrame(CSV.File(path))) # slow accesses based on Tables.jl +_getobs_row(x, i) = first(Iterators.peel(Iterators.drop(x, i - 1))) +function _getobs_column(x, i) + colnames = Tuple(Tables.columnnames(x)) + rowvals = ntuple(j -> Tables.getcolumn(x, j)[i], length(colnames)) + + return NamedTuple{colnames}(rowvals) +end function MLUtils.getobs(dataset::TableDataset, i) if Tables.rowaccess(dataset.table) - row, _ = Iterators.peel(Iterators.drop(Tables.rows(dataset.table), i - 1)) - return row + return _getobs_row(Tables.rows(dataset.table), i) elseif Tables.columnaccess(dataset.table) - colnames = Tables.columnnames(dataset.table) - rowvals = [Tables.getcolumn(dataset.table, j)[i] for j in 1:length(colnames)] - return (; zip(colnames, rowvals)...) + return _getobs_column(dataset.table, i) else error("The Tables.jl implementation used should have either rowaccess or columnaccess.") end diff --git a/test/containers/cacheddataset.jl b/test/containers/cacheddataset.jl index ff70aab4..c9726133 100644 --- a/test/containers/cacheddataset.jl +++ b/test/containers/cacheddataset.jl @@ -13,13 +13,29 @@ @testset "CachedDataset(::HDF5Dataset)" begin paths, datas = setup_hdf5dataset_test() - hdataset = HDF5Dataset("test.h5", ["d1"]) + hdataset = HDF5Dataset("test.h5", "d1") cdataset = CachedDataset(hdataset, 5) @test numobs(cdataset) == numobs(hdataset) + @test cdataset.cache isa Matrix{Float64} @test cdataset.cache == getobs(hdataset, 1:5) @test all(getobs(cdataset, i) == getobs(hdataset, i) for i in 1:10) + close(hdataset) cleanup_hdf5dataset_test() end + + @testset "CachedDataset(::JLD2Dataset)" begin + paths, datas = setup_jld2dataset_test() + jdataset = JLD2Dataset("test.jld2", "d1") + cdataset = CachedDataset(jdataset, 5) + + @test numobs(cdataset) == numobs(jdataset) + @test cdataset.cache isa Matrix{Float64} + @test cdataset.cache == getobs(jdataset, 1:5) + @test all(@inferred(getobs(cdataset, i)) == getobs(jdataset, i) for i in 1:10) + + close(jdataset) + cleanup_jld2dataset_test() + end end diff --git a/test/containers/hdf5dataset.jl b/test/containers/hdf5dataset.jl index 878cc262..8d79e033 100644 --- a/test/containers/hdf5dataset.jl +++ b/test/containers/hdf5dataset.jl @@ -18,12 +18,22 @@ cleanup_hdf5dataset_test() = rm("test.h5") @testset "HDF5Dataset" begin paths, datas = setup_hdf5dataset_test() - dataset = HDF5Dataset("test.h5", paths) - for i in 1:10 - data = Tuple(map(x -> (x isa String) ? x : getobs(x, i), datas)) - @test getobs(dataset, i) == data + @testset "Single path" begin + dataset = HDF5Dataset("test.h5", "d1") + for i in 1:10 + @test getobs(dataset, i) == getobs(datas[1], i) + end + @test numobs(dataset) == 10 + close(dataset) + end + @testset "Multiple paths" begin + dataset = HDF5Dataset("test.h5", paths) + for i in 1:10 + data = Tuple(map(x -> (x isa String) ? x : getobs(x, i), datas)) + @test @inferred(Tuple, getobs(dataset, i)) == data + end + @test numobs(dataset) == 10 + close(dataset) end - @test numobs(dataset) == 10 - close(dataset) cleanup_hdf5dataset_test() end diff --git a/test/containers/jld2dataset.jl b/test/containers/jld2dataset.jl index d849d8fa..d1656b91 100644 --- a/test/containers/jld2dataset.jl +++ b/test/containers/jld2dataset.jl @@ -18,11 +18,21 @@ cleanup_jld2dataset_test() = rm("test.jld2") @testset "JLD2Dataset" begin paths, datas = setup_jld2dataset_test() - dataset = JLD2Dataset("test.jld2", paths) - for i in 1:10 - @test getobs(dataset, i) == getobs(datas, i) + @testset "Single path" begin + dataset = JLD2Dataset("test.jld2", "d1") + for i in 1:10 + @test @inferred(getobs(dataset, i)) == getobs(datas[1], i) + end + @test numobs(dataset) == 10 + close(dataset) + end + @testset "Multiple paths" begin + dataset = JLD2Dataset("test.jld2", paths) + for i in 1:10 + @test @inferred(getobs(dataset, i)) == getobs(datas, i) + end + @test numobs(dataset) == 10 + close(dataset) end - @test numobs(dataset) == 10 - close(dataset) cleanup_jld2dataset_test() end diff --git a/test/containers/tabledataset.jl b/test/containers/tabledataset.jl index 92d76847..9154484c 100644 --- a/test/containers/tabledataset.jl +++ b/test/containers/tabledataset.jl @@ -6,7 +6,7 @@ testtable = Tables.table([1 4.0 "7"; 2 5.0 "8"; 3 6.0 "9"]) td = TableDataset(testtable) - @test all(getobs(td, 1) .== [1, 4.0, "7"]) + @test collect(@inferred(getobs(td, 1))) == [1, 4.0, "7"] @test numobs(td) == 3 end @@ -17,7 +17,7 @@ testtable = Tables.table([1 4.0 "7"; 2 5.0 "8"; 3 6.0 "9"]) td = TableDataset(testtable) - @test [data for data in getobs(td, 2)] == [2, 5.0, "8"] + @test collect(@inferred(NamedTuple, getobs(td, 2))) == [2, 5.0, "8"] @test numobs(td) == 3 @test getobs(td, 1) isa NamedTuple @@ -35,7 +35,7 @@ td = TableDataset(testtable) @test td isa TableDataset{<:DataFrame} - @test [data for data in getobs(td, 1)] == [1, "a", 10, "A", 100.0, "train"] + @test collect(@inferred(getobs(td, 1))) == [1, "a", 10, "A", 100.0, "train"] @test numobs(td) == 5 end @@ -46,7 +46,7 @@ testtable = CSV.File("test.csv") td = TableDataset(testtable) @test td isa TableDataset{<:CSV.File} - @test [data for data in getobs(td, 1)] == [1, "a", 10, "A", 100.0, "train"] + @test collect(@inferred(getobs(td, 1))) == [1, "a", 10, "A", 100.0, "train"] @test numobs(td) == 1 rm("test.csv") end From 7db946880cb0f72d073904f93e200e8fe6900d54 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 22 Feb 2022 17:02:22 -0600 Subject: [PATCH 12/19] Add more compat entries --- Project.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/Project.toml b/Project.toml index d89e92a4..dcfaf8eb 100644 --- a/Project.toml +++ b/Project.toml @@ -30,15 +30,20 @@ CSV = "0.10.2" ColorTypes = "0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 0.11" DataDeps = "0.3, 0.4, 0.5, 0.6, 0.7" DataFrames = "1.3" +FileIO = "1.13" +FilePathsBase = "0.9.17" FixedPointNumbers = "0.3, 0.4, 0.5, 0.6, 0.7, 0.8" GZip = "0.5" Glob = "1.3" +HDF5 = "0.16.2" ImageCore = "0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8" +JLD2 = "0.4.21" JSON3 = "1" MAT = "0.7, 0.8, 0.9, 0.10" MLUtils = "0.1.4" Pickle = "0.2, 0.3" Requires = "1" +Tables = "1.6" julia = "1.6" [extras] From 950e111d9a8a1b6844e05c57c437d50617382ae9 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Wed, 23 Feb 2022 08:13:39 -0600 Subject: [PATCH 13/19] Switch to MLUtils v0.2 --- Project.toml | 2 +- src/containers/cacheddataset.jl | 4 ++-- src/containers/filedataset.jl | 6 +++--- src/containers/hdf5dataset.jl | 6 +++--- src/containers/jld2dataset.jl | 6 +++--- src/containers/tabledataset.jl | 12 ++++++------ 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/Project.toml b/Project.toml index dcfaf8eb..439ac4f8 100644 --- a/Project.toml +++ b/Project.toml @@ -40,7 +40,7 @@ ImageCore = "0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8" JLD2 = "0.4.21" JSON3 = "1" MAT = "0.7, 0.8, 0.9, 0.10" -MLUtils = "0.1.4" +MLUtils = "0.2.0" Pickle = "0.2, 0.3" Requires = "1" Tables = "1.6" diff --git a/src/containers/cacheddataset.jl b/src/containers/cacheddataset.jl index aa64ed5d..825a1d6e 100644 --- a/src/containers/cacheddataset.jl +++ b/src/containers/cacheddataset.jl @@ -31,9 +31,9 @@ function CachedDataset(source, cachesize::Int = numobs(source)) CachedDataset(source, collect(cacheidx), make_cache(source, cacheidx)) end -function MLUtils.getobs(dataset::CachedDataset, i::Integer) +function Base.getindex(dataset::CachedDataset, i::Integer) _i = findfirst(==(i), dataset.cacheidx) return isnothing(_i) ? getobs(dataset.source, i) : getobs(dataset.cache, _i) end -MLUtils.numobs(dataset::CachedDataset) = numobs(dataset.source) +Base.length(dataset::CachedDataset) = numobs(dataset.source) diff --git a/src/containers/filedataset.jl b/src/containers/filedataset.jl index ae6434ad..539759a7 100644 --- a/src/containers/filedataset.jl +++ b/src/containers/filedataset.jl @@ -52,6 +52,6 @@ FileDataset(loadfn, FileDataset(dir::Union{AbstractPath, AbstractString}, pattern::AbstractString = "*", depth = 4) = FileDataset(loadfile, dir, pattern, depth) -MLUtils.getobs(dataset::FileDataset, i::Integer) = loadfile(dataset.paths[i]) -MLUtils.getobs(dataset::FileDataset, is::AbstractVector) = map(Base.Fix1(getobs, dataset), is) -MLUtils.numobs(dataset::FileDataset) = length(dataset.paths) +Base.getindex(dataset::FileDataset, i::Integer) = loadfile(dataset.paths[i]) +Base.getindex(dataset::FileDataset, is::AbstractVector) = map(Base.Fix1(getobs, dataset), is) +Base.length(dataset::FileDataset) = length(dataset.paths) diff --git a/src/containers/hdf5dataset.jl b/src/containers/hdf5dataset.jl index 2bdc6ac9..0e06e76c 100644 --- a/src/containers/hdf5dataset.jl +++ b/src/containers/hdf5dataset.jl @@ -47,11 +47,11 @@ function _getobs_hdf5(dataset::HDF5.Dataset, shape, i) return dataset[I..., i] end -MLUtils.getobs(dataset::HDF5Dataset{HDF5.Dataset}, i) = +Base.getindex(dataset::HDF5Dataset{HDF5.Dataset}, i) = _getobs_hdf5(dataset.paths, only(dataset.shapes), i) -MLUtils.getobs(dataset::HDF5Dataset{<:Vector}, i) = +Base.getindex(dataset::HDF5Dataset{<:Vector}, i) = Tuple(map((p, s) -> _getobs_hdf5(p, s, i), dataset.paths, dataset.shapes)) -MLUtils.numobs(dataset::HDF5Dataset) = last(first(filter(!isempty, dataset.shapes))) +Base.length(dataset::HDF5Dataset) = last(first(filter(!isempty, dataset.shapes))) """ close(dataset::HDF5Dataset) diff --git a/src/containers/jld2dataset.jl b/src/containers/jld2dataset.jl index 9fb91150..c624df89 100644 --- a/src/containers/jld2dataset.jl +++ b/src/containers/jld2dataset.jl @@ -28,9 +28,9 @@ JLD2Dataset(file::JLD2.JLDFile, path::String) = JLD2Dataset(file, (path,)) JLD2Dataset(file::Union{AbstractString, AbstractPath}, paths) = JLD2Dataset(jldopen(file, "r"), paths) -MLUtils.getobs(dataset::JLD2Dataset{<:JLD2.JLDFile, <:NTuple{1}}, i) = getobs(only(dataset.paths), i) -MLUtils.getobs(dataset::JLD2Dataset, i) = map(Base.Fix2(getobs, i), dataset.paths) -MLUtils.numobs(dataset::JLD2Dataset) = numobs(dataset.paths[1]) +Base.getindex(dataset::JLD2Dataset{<:JLD2.JLDFile, <:NTuple{1}}, i) = getobs(only(dataset.paths), i) +Base.getindex(dataset::JLD2Dataset, i) = map(Base.Fix2(getobs, i), dataset.paths) +Base.length(dataset::JLD2Dataset) = numobs(dataset.paths[1]) """ close(dataset::JLD2Dataset) diff --git a/src/containers/tabledataset.jl b/src/containers/tabledataset.jl index f2756ae7..4046680c 100644 --- a/src/containers/tabledataset.jl +++ b/src/containers/tabledataset.jl @@ -30,7 +30,7 @@ function _getobs_column(x, i) return NamedTuple{colnames}(rowvals) end -function MLUtils.getobs(dataset::TableDataset, i) +function Base.getindex(dataset::TableDataset, i) if Tables.rowaccess(dataset.table) return _getobs_row(Tables.rows(dataset.table), i) elseif Tables.columnaccess(dataset.table) @@ -39,7 +39,7 @@ function MLUtils.getobs(dataset::TableDataset, i) error("The Tables.jl implementation used should have either rowaccess or columnaccess.") end end -function MLUtils.numobs(dataset::TableDataset) +function Base.length(dataset::TableDataset) if Tables.columnaccess(dataset.table) return length(Tables.getcolumn(dataset.table, 1)) elseif Tables.rowaccess(dataset.table) @@ -51,9 +51,9 @@ function MLUtils.numobs(dataset::TableDataset) end # fast access for DataFrame -MLUtils.getobs(dataset::TableDataset{<:DataFrame}, i) = dataset.table[i, :] -MLUtils.numobs(dataset::TableDataset{<:DataFrame}) = nrow(dataset.table) +Base.getindex(dataset::TableDataset{<:DataFrame}, i) = dataset.table[i, :] +Base.length(dataset::TableDataset{<:DataFrame}) = nrow(dataset.table) # fast access for CSV.File -MLUtils.getobs(dataset::TableDataset{<:CSV.File}, i) = dataset.table[i] -MLUtils.numobs(dataset::TableDataset{<:CSV.File}) = length(dataset.table) +Base.getindex(dataset::TableDataset{<:CSV.File}, i) = dataset.table[i] +Base.length(dataset::TableDataset{<:CSV.File}) = length(dataset.table) From ebc6938195e061ddd0d2b5bbac34d98ac45cde5d Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Wed, 23 Feb 2022 08:21:43 -0600 Subject: [PATCH 14/19] Get rid of cruft from FastAI.jl --- src/containers/cacheddataset.jl | 9 ++++---- src/containers/filedataset.jl | 38 +++++++-------------------------- src/containers/hdf5dataset.jl | 5 ++--- src/containers/jld2dataset.jl | 5 ++--- src/containers/tabledataset.jl | 5 ++--- 5 files changed, 18 insertions(+), 44 deletions(-) diff --git a/src/containers/cacheddataset.jl b/src/containers/cacheddataset.jl index 825a1d6e..70451a9f 100644 --- a/src/containers/cacheddataset.jl +++ b/src/containers/cacheddataset.jl @@ -8,6 +8,7 @@ make_cache(source, cacheidx) = getobs(source, cacheidx) """ CachedDataset(source, cachesize = numbobs(source)) + CachedDataset(source, cacheidx = 1:numbobs(source)) CachedDataset(source, cacheidx, cache) Wrap a `source` data container and cache `cachesize` samples in memory. @@ -15,7 +16,7 @@ This can be useful for improving read speeds when `source` is a lazy data contai but your system memory is large enough to store a sizeable chunk of it. By default the observation indices `1:cachesize` are cached. -You can manually pass in a `cache` and set of `cacheidx` as well. +You can manually pass in a set of `cacheidx` as well. See also [`make_cache`](@ref) for customizing the default cache creation for `source`. """ @@ -25,11 +26,9 @@ struct CachedDataset{T, S} cache::S end -function CachedDataset(source, cachesize::Int = numobs(source)) - cacheidx = 1:cachesize - +CachedDataset(source, cacheidx::AbstractVector{<:Integer} = 1:numobs(source)) = CachedDataset(source, collect(cacheidx), make_cache(source, cacheidx)) -end +CachedDataset(source, cachesize::Int = numobs(source)) = CachedDataset(source, 1:cachesize) function Base.getindex(dataset::CachedDataset, i::Integer) _i = findfirst(==(i), dataset.cacheidx) diff --git a/src/containers/filedataset.jl b/src/containers/filedataset.jl index 539759a7..7b5f4599 100644 --- a/src/containers/filedataset.jl +++ b/src/containers/filedataset.jl @@ -1,8 +1,3 @@ -matches(re::Regex) = f -> matches(re, f) -matches(re::Regex, f) = !isnothing(match(re, f)) -const RE_IMAGEFILE = r".*\.(gif|jpe?g|tiff?|png|webp|bmp)$"i -isimagefile(f) = matches(RE_IMAGEFILE, f) - """ rglob(filepattern, dir = pwd(), depth = 4) @@ -15,43 +10,26 @@ function rglob(filepattern = "*", dir = pwd(), depth = 4) end """ - loadfile(file) - -Load a file from disk into the appropriate format. -""" -function loadfile(file::String) - if isimagefile(file) - # faster image loading - return FileIO.load(file, view = true) - elseif endswith(file, ".csv") - return DataFrame(CSV.File(file)) - else - return FileIO.load(file) - end -end -loadfile(file::AbstractPath) = loadfile(string(file)) - -""" - FileDataset([loadfn = loadfile,] paths) - FileDataset([loadfn = loadfile,] dir, pattern = "*", depth = 4) + FileDataset([loadfn = FileIO.load,] paths) + FileDataset([loadfn = FileIO.load,] dir, pattern = "*", depth = 4) Wrap a set of file `paths` as a dataset (traversed in the same order as `paths`). Alternatively, specify a `dir` and collect all paths that match a glob `pattern` (recursively globbing by `depth`). The glob order determines the traversal order. """ -struct FileDataset{F, T<:Union{AbstractPath, AbstractString}} <: AbstractDataContainer +struct FileDataset{F, T<:AbstractString} <: AbstractDataContainer loadfn::F paths::Vector{T} end -FileDataset(paths) = FileDataset(loadfile, paths) +FileDataset(paths) = FileDataset(FileIO.load, paths) FileDataset(loadfn, - dir::Union{AbstractPath, AbstractString}, + dir::AbstractString, pattern::AbstractString = "*", depth = 4) = FileDataset(loadfn, rglob(pattern, string(dir), depth)) -FileDataset(dir::Union{AbstractPath, AbstractString}, pattern::AbstractString = "*", depth = 4) = - FileDataset(loadfile, dir, pattern, depth) +FileDataset(dir::AbstractString, pattern::AbstractString = "*", depth = 4) = + FileDataset(FileIO.load, dir, pattern, depth) -Base.getindex(dataset::FileDataset, i::Integer) = loadfile(dataset.paths[i]) +Base.getindex(dataset::FileDataset, i::Integer) = dataset.loadfn(dataset.paths[i]) Base.getindex(dataset::FileDataset, is::AbstractVector) = map(Base.Fix1(getobs, dataset), is) Base.length(dataset::FileDataset) = length(dataset.paths) diff --git a/src/containers/hdf5dataset.jl b/src/containers/hdf5dataset.jl index 0e06e76c..77961090 100644 --- a/src/containers/hdf5dataset.jl +++ b/src/containers/hdf5dataset.jl @@ -5,7 +5,7 @@ function _check_hdf5_shapes(shapes) end """ - HDF5Dataset(file::Union{AbstractString, AbstractPath}, paths) + HDF5Dataset(file::AbstractString, paths) HDF5Dataset(fid::HDF5.File, paths::Union{HDF5.Dataset, Vector{HDF5.Dataset}}) HDF5Dataset(fid::HDF5.File, paths::Union{AbstractString, Vector{<:AbstractString}}) HDF5Dataset(fid::HDF5.File, paths::Union{HDF5.Dataset, Vector{HDF5.Dataset}}, shapes) @@ -38,8 +38,7 @@ HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}) = HDF5Dataset(fid::HDF5.File, path::AbstractString) = HDF5Dataset(fid, fid[path]) HDF5Dataset(fid::HDF5.File, paths::Vector{<:AbstractString}) = HDF5Dataset(fid, map(p -> fid[p], paths)) -HDF5Dataset(file::Union{AbstractString, AbstractPath}, paths) = - HDF5Dataset(h5open(file, "r"), paths) +HDF5Dataset(file::AbstractString, paths) = HDF5Dataset(h5open(file, "r"), paths) _getobs_hdf5(dataset::HDF5.Dataset, ::Tuple{}, i) = read(dataset) function _getobs_hdf5(dataset::HDF5.Dataset, shape, i) diff --git a/src/containers/jld2dataset.jl b/src/containers/jld2dataset.jl index c624df89..e84f4479 100644 --- a/src/containers/jld2dataset.jl +++ b/src/containers/jld2dataset.jl @@ -1,7 +1,7 @@ _check_jld2_nobs(nobs) = all(==(first(nobs)), nobs[2:end]) """ - JLD2Dataset(file::Union{AbstractString, AbstractPath}, paths) + JLD2Dataset(file::AbstractString, paths) JLD2Dataset(fid::JLD2.JLDFile, paths::Union{String, Vector{String}}) Wrap several JLD2 datasets (`paths`) as a single dataset container. @@ -25,8 +25,7 @@ struct JLD2Dataset{T<:JLD2.JLDFile, S<:Tuple} <: AbstractDataContainer end JLD2Dataset(file::JLD2.JLDFile, path::String) = JLD2Dataset(file, (path,)) -JLD2Dataset(file::Union{AbstractString, AbstractPath}, paths) = - JLD2Dataset(jldopen(file, "r"), paths) +JLD2Dataset(file::AbstractString, paths) = JLD2Dataset(jldopen(file, "r"), paths) Base.getindex(dataset::JLD2Dataset{<:JLD2.JLDFile, <:NTuple{1}}, i) = getobs(only(dataset.paths), i) Base.getindex(dataset::JLD2Dataset, i) = map(Base.Fix2(getobs, i), dataset.paths) diff --git a/src/containers/tabledataset.jl b/src/containers/tabledataset.jl index 4046680c..0684ed02 100644 --- a/src/containers/tabledataset.jl +++ b/src/containers/tabledataset.jl @@ -1,6 +1,6 @@ """ TableDataset(table) - TableDataset(path::Union{AbstractPath, AbstractString}) + TableDataset(path::AbstractString) Wrap a Tables.jl-compatible `table` as a dataset container. Alternatively, specify the `path` to a CSV file directly @@ -19,8 +19,7 @@ struct TableDataset{T} <: AbstractDataContainer end TableDataset(table::T) where {T} = TableDataset{T}(table) -TableDataset(path::Union{AbstractPath, AbstractString}) = - TableDataset(DataFrame(CSV.File(path))) +TableDataset(path::AbstractPath) = TableDataset(DataFrame(CSV.File(path))) # slow accesses based on Tables.jl _getobs_row(x, i) = first(Iterators.peel(Iterators.drop(x, i - 1))) From 15fa280e438591b8a52cf03b5e254c72e244c317 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Wed, 23 Feb 2022 08:24:18 -0600 Subject: [PATCH 15/19] Remove some outdated deps --- Project.toml | 2 -- src/MLDatasets.jl | 2 -- 2 files changed, 4 deletions(-) diff --git a/Project.toml b/Project.toml index 439ac4f8..4ab36295 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,6 @@ DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -FilePathsBase = "48062228-2e41-5def-b9a4-89aafe57970f" FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" GZip = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" @@ -31,7 +30,6 @@ ColorTypes = "0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 0.11" DataDeps = "0.3, 0.4, 0.5, 0.6, 0.7" DataFrames = "1.3" FileIO = "1.13" -FilePathsBase = "0.9.17" FixedPointNumbers = "0.3, 0.4, 0.5, 0.6, 0.7, 0.8" GZip = "0.5" Glob = "1.3" diff --git a/src/MLDatasets.jl b/src/MLDatasets.jl index 384a9679..c737b243 100644 --- a/src/MLDatasets.jl +++ b/src/MLDatasets.jl @@ -8,8 +8,6 @@ using FixedPointNumbers, ColorTypes using Pickle using SparseArrays using DataFrames, CSV, Tables -using FilePathsBase -using FilePathsBase: AbstractPath using Glob using HDF5 using JLD2 From 860fc143473b4d5a8a2cc46f15307f3b4d0c943a Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Wed, 2 Mar 2022 11:24:29 -0600 Subject: [PATCH 16/19] Remove reference to AbstractPath Co-authored-by: lorenzoh --- src/containers/tabledataset.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/containers/tabledataset.jl b/src/containers/tabledataset.jl index 0684ed02..91fe3de3 100644 --- a/src/containers/tabledataset.jl +++ b/src/containers/tabledataset.jl @@ -19,7 +19,7 @@ struct TableDataset{T} <: AbstractDataContainer end TableDataset(table::T) where {T} = TableDataset{T}(table) -TableDataset(path::AbstractPath) = TableDataset(DataFrame(CSV.File(path))) +TableDataset(path::AbstractString) = TableDataset(DataFrame(CSV.File(path))) # slow accesses based on Tables.jl _getobs_row(x, i) = first(Iterators.peel(Iterators.drop(x, i - 1))) From 5a882bd0e64ff4c1c440812dec719b40ab1d5c51 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 5 Mar 2022 08:56:42 -0600 Subject: [PATCH 17/19] Add FileIO into usings --- src/MLDatasets.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/MLDatasets.jl b/src/MLDatasets.jl index c737b243..a8abe351 100644 --- a/src/MLDatasets.jl +++ b/src/MLDatasets.jl @@ -7,6 +7,7 @@ using DelimitedFiles: readdlm using FixedPointNumbers, ColorTypes using Pickle using SparseArrays +using FileIO using DataFrames, CSV, Tables using Glob using HDF5 From df48e5ea64bb1837d25e88ab57ff657ba0eb4c94 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 5 Mar 2022 09:15:55 -0600 Subject: [PATCH 18/19] Fix tests and add Tables.jl interface --- src/containers/tabledataset.jl | 7 +++++++ test/containers/cacheddataset.jl | 2 +- test/containers/filedataset.jl | 4 ++-- test/containers/tabledataset.jl | 18 ++++++++++++++++++ test/runtests.jl | 1 + 5 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/containers/tabledataset.jl b/src/containers/tabledataset.jl index 91fe3de3..5ef88e9f 100644 --- a/src/containers/tabledataset.jl +++ b/src/containers/tabledataset.jl @@ -56,3 +56,10 @@ Base.length(dataset::TableDataset{<:DataFrame}) = nrow(dataset.table) # fast access for CSV.File Base.getindex(dataset::TableDataset{<:CSV.File}, i) = dataset.table[i] Base.length(dataset::TableDataset{<:CSV.File}) = length(dataset.table) + +## Tables.jl interface + +Tables.istable(::TableDataset) = true +for fn in (:rowaccess, :rows, :columnaccess, :columns, :schema, :materializer) + @eval Tables.$fn(dataset::TableDataset) = Tables.$fn(dataset.table) +end diff --git a/test/containers/cacheddataset.jl b/test/containers/cacheddataset.jl index c9726133..67de4aac 100644 --- a/test/containers/cacheddataset.jl +++ b/test/containers/cacheddataset.jl @@ -1,7 +1,7 @@ @testset "CachedDataset" begin @testset "CachedDataset(::FileDataset)" begin files = setup_filedataset_test() - fdataset = FileDataset("root", "*.csv") + fdataset = FileDataset(f -> CSV.read(f, DataFrame), "root", "*.csv") cdataset = CachedDataset(fdataset) @test numobs(cdataset) == numobs(fdataset) diff --git a/test/containers/filedataset.jl b/test/containers/filedataset.jl index f143aae3..a3e10e00 100644 --- a/test/containers/filedataset.jl +++ b/test/containers/filedataset.jl @@ -27,10 +27,10 @@ cleanup_filedataset_test() = rm("root"; recursive = true) @testset "FileDataset" begin files = setup_filedataset_test() - dataset = FileDataset("root", "*.csv") + dataset = FileDataset(f -> CSV.read(f, DataFrame), "root", "*.csv") @test numobs(dataset) == length(files) for (i, file) in enumerate(files) - true_obs = MLDatasets.loadfile(file) + true_obs = CSV.read(file, DataFrame) @test getobs(dataset, i) == true_obs end cleanup_filedataset_test() diff --git a/test/containers/tabledataset.jl b/test/containers/tabledataset.jl index 9154484c..6d260ea8 100644 --- a/test/containers/tabledataset.jl +++ b/test/containers/tabledataset.jl @@ -50,4 +50,22 @@ @test numobs(td) == 1 rm("test.csv") end + + @testset "TableDataset is a table" begin + testtable = DataFrame( + col1 = [1, 2, 3, 4, 5], + col2 = ["a", "b", "c", "d", "e"], + col3 = [10, 20, 30, 40, 50], + col4 = ["A", "B", "C", "D", "E"], + col5 = [100.0, 200.0, 300.0, 400.0, 500.0], + split = ["train", "train", "train", "valid", "valid"], + ) + td = TableDataset(testtable) + @testset for fn in (Tables.istable, + Tables.rowaccess, Tables.rows, + Tables.columnaccess, Tables.columns, + Tables.schema, Tables.materializer) + @test fn(td) == fn(testtable) + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index 2844d26a..51698f86 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using Test using MLDatasets +using FileIO using ImageCore using DataDeps using DataFrames, CSV, Tables From 2ff0d299bfde22196170ab688ed6937f79bfab0b Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 5 Mar 2022 09:48:38 -0600 Subject: [PATCH 19/19] Fix doc errors --- docs/make.jl | 5 +++-- docs/src/containers/overview.md | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) create mode 100644 docs/src/containers/overview.md diff --git a/docs/make.jl b/docs/make.jl index 157df9a7..df38064e 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -37,7 +37,6 @@ makedocs( "Mutagenesis" => "datasets/Mutagenesis.md", "Titanic" => "datasets/Titanic.md", ], - "Text" => Any[ "PTBLM" => "datasets/PTBLM.md", "UD_English" => "datasets/UD_English.md", @@ -52,9 +51,11 @@ makedocs( ], "Utils" => "utils.md", + "Data Containers" => "containers/overview.md", "LICENSE.md", ], - strict = true + strict = true, + checkdocs = :exports ) diff --git a/docs/src/containers/overview.md b/docs/src/containers/overview.md new file mode 100644 index 00000000..c6a7528f --- /dev/null +++ b/docs/src/containers/overview.md @@ -0,0 +1,14 @@ +# Dataset Containers + +MLDatasets.jl contains several reusable data containers for accessing datasets in common storage formats. This feature is a work-in-progress and subject to change. + +```@docs +FileDataset +TableDataset +HDF5Dataset +Base.close(::HDF5Dataset) +JLD2Dataset +Base.close(::JLD2Dataset) +CachedDataset +MLDatasets.make_cache +```