diff --git a/.github/workflows/UnitTest.yml b/.github/workflows/UnitTest.yml index 6e920762..be18c8fa 100644 --- a/.github/workflows/UnitTest.yml +++ b/.github/workflows/UnitTest.yml @@ -16,12 +16,12 @@ jobs: strategy: fail-fast: false matrix: - julia-version: ['1.3', '1', 'nightly'] + julia-version: ['1.6', '1', 'nightly'] os: [ubuntu-latest, windows-latest, macOS-latest, macos-11] env: PYTHON: "" steps: - - uses: actions/checkout@v1.0.0 + - uses: actions/checkout@v2 - name: "Set up Julia" uses: julia-actions/setup-julia@v1 with: diff --git a/Project.toml b/Project.toml index c6b684a5..4ab36295 100644 --- a/Project.toml +++ b/Project.toml @@ -4,29 +4,45 @@ version = "0.5.15" [deps] BinDeps = "9e28174c-4ba2-5203-b857-d8d62c4213ee" +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" ColorTypes = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" +FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" FixedPointNumbers = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" GZip = "92fee26a-97fe-5a0c-ad85-20a5f3185b63" +Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" +HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" MAT = "23992714-dd62-5051-b70f-ba57cb901cac" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] BinDeps = "1" +CSV = "0.10.2" ColorTypes = "0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.10, 0.11" DataDeps = "0.3, 0.4, 0.5, 0.6, 0.7" +DataFrames = "1.3" +FileIO = "1.13" FixedPointNumbers = "0.3, 0.4, 0.5, 0.6, 0.7, 0.8" GZip = "0.5" +Glob = "1.3" +HDF5 = "0.16.2" ImageCore = "0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8" +JLD2 = "0.4.21" JSON3 = "1" MAT = "0.7, 0.8, 0.9, 0.10" +MLUtils = "0.2.0" Pickle = "0.2, 0.3" Requires = "1" -julia = "1.3" +Tables = "1.6" +julia = "1.6" [extras] ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534" diff --git a/docs/make.jl b/docs/make.jl index 157df9a7..df38064e 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -37,7 +37,6 @@ makedocs( "Mutagenesis" => "datasets/Mutagenesis.md", "Titanic" => "datasets/Titanic.md", ], - "Text" => Any[ "PTBLM" => "datasets/PTBLM.md", "UD_English" => "datasets/UD_English.md", @@ -52,9 +51,11 @@ makedocs( ], "Utils" => "utils.md", + "Data Containers" => "containers/overview.md", "LICENSE.md", ], - strict = true + strict = true, + checkdocs = :exports ) diff --git a/docs/src/containers/overview.md b/docs/src/containers/overview.md new file mode 100644 index 00000000..c6a7528f --- /dev/null +++ b/docs/src/containers/overview.md @@ -0,0 +1,14 @@ +# Dataset Containers + +MLDatasets.jl contains several reusable data containers for accessing datasets in common storage formats. This feature is a work-in-progress and subject to change. + +```@docs +FileDataset +TableDataset +HDF5Dataset +Base.close(::HDF5Dataset) +JLD2Dataset +Base.close(::JLD2Dataset) +CachedDataset +MLDatasets.make_cache +``` diff --git a/src/MLDatasets.jl b/src/MLDatasets.jl index 25c7f19b..a8abe351 100644 --- a/src/MLDatasets.jl +++ b/src/MLDatasets.jl @@ -7,6 +7,15 @@ using DelimitedFiles: readdlm using FixedPointNumbers, ColorTypes using Pickle using SparseArrays +using FileIO +using DataFrames, CSV, Tables +using Glob +using HDF5 +using JLD2 + +import MLUtils +using MLUtils: getobs, numobs, AbstractDataContainer +export getobs, numobs # Julia 1.0 compatibility if !isdefined(Base, :isnothing) @@ -36,6 +45,16 @@ end include("download.jl") +include("containers/filedataset.jl") +export FileDataset +include("containers/tabledataset.jl") +export TableDataset +include("containers/hdf5dataset.jl") +export HDF5Dataset +include("containers/jld2dataset.jl") +export JLD2Dataset +include("containers/cacheddataset.jl") +export CachedDataset # Misc. include("BostonHousing/BostonHousing.jl") diff --git a/src/containers/cacheddataset.jl b/src/containers/cacheddataset.jl new file mode 100644 index 00000000..70451a9f --- /dev/null +++ b/src/containers/cacheddataset.jl @@ -0,0 +1,38 @@ +""" + make_cache(source, cacheidx) + +Return a in-memory copy of `source` at observation indices `cacheidx`. +Defaults to `getobs(source, cacheidx)`. +""" +make_cache(source, cacheidx) = getobs(source, cacheidx) + +""" + CachedDataset(source, cachesize = numbobs(source)) + CachedDataset(source, cacheidx = 1:numbobs(source)) + CachedDataset(source, cacheidx, cache) + +Wrap a `source` data container and cache `cachesize` samples in memory. +This can be useful for improving read speeds when `source` is a lazy data container, +but your system memory is large enough to store a sizeable chunk of it. + +By default the observation indices `1:cachesize` are cached. +You can manually pass in a set of `cacheidx` as well. + +See also [`make_cache`](@ref) for customizing the default cache creation for `source`. +""" +struct CachedDataset{T, S} + source::T + cacheidx::Vector{Int} + cache::S +end + +CachedDataset(source, cacheidx::AbstractVector{<:Integer} = 1:numobs(source)) = + CachedDataset(source, collect(cacheidx), make_cache(source, cacheidx)) +CachedDataset(source, cachesize::Int = numobs(source)) = CachedDataset(source, 1:cachesize) + +function Base.getindex(dataset::CachedDataset, i::Integer) + _i = findfirst(==(i), dataset.cacheidx) + + return isnothing(_i) ? getobs(dataset.source, i) : getobs(dataset.cache, _i) +end +Base.length(dataset::CachedDataset) = numobs(dataset.source) diff --git a/src/containers/filedataset.jl b/src/containers/filedataset.jl new file mode 100644 index 00000000..7b5f4599 --- /dev/null +++ b/src/containers/filedataset.jl @@ -0,0 +1,35 @@ +""" + rglob(filepattern, dir = pwd(), depth = 4) + +Recursive glob up to `depth` layers deep within `dir`. +""" +function rglob(filepattern = "*", dir = pwd(), depth = 4) + patterns = [repeat("*/", i) * filepattern for i in 0:(depth - 1)] + + return vcat([glob(pattern, dir) for pattern in patterns]...) +end + +""" + FileDataset([loadfn = FileIO.load,] paths) + FileDataset([loadfn = FileIO.load,] dir, pattern = "*", depth = 4) + +Wrap a set of file `paths` as a dataset (traversed in the same order as `paths`). +Alternatively, specify a `dir` and collect all paths that match a glob `pattern` +(recursively globbing by `depth`). The glob order determines the traversal order. +""" +struct FileDataset{F, T<:AbstractString} <: AbstractDataContainer + loadfn::F + paths::Vector{T} +end + +FileDataset(paths) = FileDataset(FileIO.load, paths) +FileDataset(loadfn, + dir::AbstractString, + pattern::AbstractString = "*", + depth = 4) = FileDataset(loadfn, rglob(pattern, string(dir), depth)) +FileDataset(dir::AbstractString, pattern::AbstractString = "*", depth = 4) = + FileDataset(FileIO.load, dir, pattern, depth) + +Base.getindex(dataset::FileDataset, i::Integer) = dataset.loadfn(dataset.paths[i]) +Base.getindex(dataset::FileDataset, is::AbstractVector) = map(Base.Fix1(getobs, dataset), is) +Base.length(dataset::FileDataset) = length(dataset.paths) diff --git a/src/containers/hdf5dataset.jl b/src/containers/hdf5dataset.jl new file mode 100644 index 00000000..77961090 --- /dev/null +++ b/src/containers/hdf5dataset.jl @@ -0,0 +1,60 @@ +function _check_hdf5_shapes(shapes) + nobs = map(last, filter(!isempty, shapes)) + + return all(==(first(nobs)), nobs[2:end]) +end + +""" + HDF5Dataset(file::AbstractString, paths) + HDF5Dataset(fid::HDF5.File, paths::Union{HDF5.Dataset, Vector{HDF5.Dataset}}) + HDF5Dataset(fid::HDF5.File, paths::Union{AbstractString, Vector{<:AbstractString}}) + HDF5Dataset(fid::HDF5.File, paths::Union{HDF5.Dataset, Vector{HDF5.Dataset}}, shapes) + +Wrap several HDF5 datasets (`paths`) as a single dataset container. +Each dataset `p` in `paths` should be accessible as `fid[p]`. +Calling `getobs` on a `HDF5Dataset` returns a tuple with each element corresponding +to the observation from each dataset in `paths`. +See [`close(::HDF5Dataset)`](@ref) for closing the underlying HDF5 file pointer. + +For array datasets, the last dimension is assumed to be the observation dimension. +For scalar datasets, the stored value is returned by `getobs` for any index. +""" +struct HDF5Dataset{T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}}} <: AbstractDataContainer + fid::HDF5.File + paths::T + shapes::Vector{Tuple} + + function HDF5Dataset(fid::HDF5.File, paths::T, shapes::Vector) where T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}} + _check_hdf5_shapes(shapes) || + throw(ArgumentError("Cannot create HDF5Dataset for datasets with mismatched number of observations.")) + + new{T}(fid, paths, shapes) + end +end + +HDF5Dataset(fid::HDF5.File, path::HDF5.Dataset) = HDF5Dataset(fid, path, [size(path)]) +HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}) = + HDF5Dataset(fid, paths, map(size, paths)) +HDF5Dataset(fid::HDF5.File, path::AbstractString) = HDF5Dataset(fid, fid[path]) +HDF5Dataset(fid::HDF5.File, paths::Vector{<:AbstractString}) = + HDF5Dataset(fid, map(p -> fid[p], paths)) +HDF5Dataset(file::AbstractString, paths) = HDF5Dataset(h5open(file, "r"), paths) + +_getobs_hdf5(dataset::HDF5.Dataset, ::Tuple{}, i) = read(dataset) +function _getobs_hdf5(dataset::HDF5.Dataset, shape, i) + I = map(s -> 1:s, shape[1:(end - 1)]) + + return dataset[I..., i] +end +Base.getindex(dataset::HDF5Dataset{HDF5.Dataset}, i) = + _getobs_hdf5(dataset.paths, only(dataset.shapes), i) +Base.getindex(dataset::HDF5Dataset{<:Vector}, i) = + Tuple(map((p, s) -> _getobs_hdf5(p, s, i), dataset.paths, dataset.shapes)) +Base.length(dataset::HDF5Dataset) = last(first(filter(!isempty, dataset.shapes))) + +""" + close(dataset::HDF5Dataset) + +Close the underlying HDF5 file pointer for `dataset`. +""" +Base.close(dataset::HDF5Dataset) = close(dataset.fid) diff --git a/src/containers/jld2dataset.jl b/src/containers/jld2dataset.jl new file mode 100644 index 00000000..e84f4479 --- /dev/null +++ b/src/containers/jld2dataset.jl @@ -0,0 +1,39 @@ +_check_jld2_nobs(nobs) = all(==(first(nobs)), nobs[2:end]) + +""" + JLD2Dataset(file::AbstractString, paths) + JLD2Dataset(fid::JLD2.JLDFile, paths::Union{String, Vector{String}}) + +Wrap several JLD2 datasets (`paths`) as a single dataset container. +Each dataset `p` in `paths` should be accessible as `fid[p]`. +Calling `getobs` on a `JLD2Dataset` is equivalent to mapping `getobs` on +each dataset in `paths`. +See [`close(::JLD2Dataset)`](@ref) for closing the underlying JLD2 file pointer. +""" +struct JLD2Dataset{T<:JLD2.JLDFile, S<:Tuple} <: AbstractDataContainer + fid::T + paths::S + + function JLD2Dataset(fid::JLD2.JLDFile, paths) + _paths = Tuple(map(p -> fid[p], paths)) + nobs = map(numobs, _paths) + _check_jld2_nobs(nobs) || + throw(ArgumentError("Cannot create JLD2Dataset for datasets with mismatched number of observations (got $nobs).")) + + new{typeof(fid), typeof(_paths)}(fid, _paths) + end +end + +JLD2Dataset(file::JLD2.JLDFile, path::String) = JLD2Dataset(file, (path,)) +JLD2Dataset(file::AbstractString, paths) = JLD2Dataset(jldopen(file, "r"), paths) + +Base.getindex(dataset::JLD2Dataset{<:JLD2.JLDFile, <:NTuple{1}}, i) = getobs(only(dataset.paths), i) +Base.getindex(dataset::JLD2Dataset, i) = map(Base.Fix2(getobs, i), dataset.paths) +Base.length(dataset::JLD2Dataset) = numobs(dataset.paths[1]) + +""" + close(dataset::JLD2Dataset) + +Close the underlying JLD2 file pointer for `dataset`. +""" +Base.close(dataset::JLD2Dataset) = close(dataset.fid) diff --git a/src/containers/tabledataset.jl b/src/containers/tabledataset.jl new file mode 100644 index 00000000..5ef88e9f --- /dev/null +++ b/src/containers/tabledataset.jl @@ -0,0 +1,65 @@ +""" + TableDataset(table) + TableDataset(path::AbstractString) + +Wrap a Tables.jl-compatible `table` as a dataset container. +Alternatively, specify the `path` to a CSV file directly +to load it with CSV.jl + DataFrames.jl. +""" +struct TableDataset{T} <: AbstractDataContainer + table::T + + # TableDatasets must implement the Tables.jl interface + function TableDataset{T}(table::T) where {T} + Tables.istable(table) || + throw(ArgumentError("TableDatasets must implement the Tabels.jl interface")) + + new{T}(table) + end +end + +TableDataset(table::T) where {T} = TableDataset{T}(table) +TableDataset(path::AbstractString) = TableDataset(DataFrame(CSV.File(path))) + +# slow accesses based on Tables.jl +_getobs_row(x, i) = first(Iterators.peel(Iterators.drop(x, i - 1))) +function _getobs_column(x, i) + colnames = Tuple(Tables.columnnames(x)) + rowvals = ntuple(j -> Tables.getcolumn(x, j)[i], length(colnames)) + + return NamedTuple{colnames}(rowvals) +end +function Base.getindex(dataset::TableDataset, i) + if Tables.rowaccess(dataset.table) + return _getobs_row(Tables.rows(dataset.table), i) + elseif Tables.columnaccess(dataset.table) + return _getobs_column(dataset.table, i) + else + error("The Tables.jl implementation used should have either rowaccess or columnaccess.") + end +end +function Base.length(dataset::TableDataset) + if Tables.columnaccess(dataset.table) + return length(Tables.getcolumn(dataset.table, 1)) + elseif Tables.rowaccess(dataset.table) + # length might not be defined, but has to be for this to work. + return length(Tables.rows(dataset.table)) + else + error("The Tables.jl implementation used should have either rowaccess or columnaccess.") + end +end + +# fast access for DataFrame +Base.getindex(dataset::TableDataset{<:DataFrame}, i) = dataset.table[i, :] +Base.length(dataset::TableDataset{<:DataFrame}) = nrow(dataset.table) + +# fast access for CSV.File +Base.getindex(dataset::TableDataset{<:CSV.File}, i) = dataset.table[i] +Base.length(dataset::TableDataset{<:CSV.File}) = length(dataset.table) + +## Tables.jl interface + +Tables.istable(::TableDataset) = true +for fn in (:rowaccess, :rows, :columnaccess, :columns, :schema, :materializer) + @eval Tables.$fn(dataset::TableDataset) = Tables.$fn(dataset.table) +end diff --git a/test/containers/cacheddataset.jl b/test/containers/cacheddataset.jl new file mode 100644 index 00000000..67de4aac --- /dev/null +++ b/test/containers/cacheddataset.jl @@ -0,0 +1,41 @@ +@testset "CachedDataset" begin + @testset "CachedDataset(::FileDataset)" begin + files = setup_filedataset_test() + fdataset = FileDataset(f -> CSV.read(f, DataFrame), "root", "*.csv") + cdataset = CachedDataset(fdataset) + + @test numobs(cdataset) == numobs(fdataset) + @test cdataset.cache == getobs(fdataset, 1:numobs(fdataset)) + @test all(getobs(cdataset, i) == getobs(fdataset, i) for i in 1:numobs(fdataset)) + + cleanup_filedataset_test() + end + + @testset "CachedDataset(::HDF5Dataset)" begin + paths, datas = setup_hdf5dataset_test() + hdataset = HDF5Dataset("test.h5", "d1") + cdataset = CachedDataset(hdataset, 5) + + @test numobs(cdataset) == numobs(hdataset) + @test cdataset.cache isa Matrix{Float64} + @test cdataset.cache == getobs(hdataset, 1:5) + @test all(getobs(cdataset, i) == getobs(hdataset, i) for i in 1:10) + + close(hdataset) + cleanup_hdf5dataset_test() + end + + @testset "CachedDataset(::JLD2Dataset)" begin + paths, datas = setup_jld2dataset_test() + jdataset = JLD2Dataset("test.jld2", "d1") + cdataset = CachedDataset(jdataset, 5) + + @test numobs(cdataset) == numobs(jdataset) + @test cdataset.cache isa Matrix{Float64} + @test cdataset.cache == getobs(jdataset, 1:5) + @test all(@inferred(getobs(cdataset, i)) == getobs(jdataset, i) for i in 1:10) + + close(jdataset) + cleanup_jld2dataset_test() + end +end diff --git a/test/containers/filedataset.jl b/test/containers/filedataset.jl new file mode 100644 index 00000000..a3e10e00 --- /dev/null +++ b/test/containers/filedataset.jl @@ -0,0 +1,37 @@ +function setup_filedataset_test() + files = [ + "root/p1/f1.csv", + "root/p2/f2.csv", + "root/p2/p2p1/f2.csv", + "root/p3/p3p1/f1.csv" + ] + + for (i, file) in enumerate(files) + paths = splitpath(file)[1:(end - 1)] + root = "" + for p in paths + fullp = joinpath(root, p) + isdir(fullp) || mkdir(fullp) + root = fullp + end + + open(file, "w") do io + write(io, "a,b,c\n") + write(io, join(i .* [1, 2, 3], ",")) + end + end + + return files +end +cleanup_filedataset_test() = rm("root"; recursive = true) + +@testset "FileDataset" begin + files = setup_filedataset_test() + dataset = FileDataset(f -> CSV.read(f, DataFrame), "root", "*.csv") + @test numobs(dataset) == length(files) + for (i, file) in enumerate(files) + true_obs = CSV.read(file, DataFrame) + @test getobs(dataset, i) == true_obs + end + cleanup_filedataset_test() +end diff --git a/test/containers/hdf5dataset.jl b/test/containers/hdf5dataset.jl new file mode 100644 index 00000000..8d79e033 --- /dev/null +++ b/test/containers/hdf5dataset.jl @@ -0,0 +1,39 @@ +function setup_hdf5dataset_test() + datasets = [ + ("d1", rand(2, 10)), + ("g1/d1", rand(10)), + ("g1/d2", string.('a':'j')), + ("g2/g1/d1", "test") + ] + + fid = h5open("test.h5", "w") + for (path, data) in datasets + fid[path] = data + end + close(fid) + + return first.(datasets), last.(datasets) +end +cleanup_hdf5dataset_test() = rm("test.h5") + +@testset "HDF5Dataset" begin + paths, datas = setup_hdf5dataset_test() + @testset "Single path" begin + dataset = HDF5Dataset("test.h5", "d1") + for i in 1:10 + @test getobs(dataset, i) == getobs(datas[1], i) + end + @test numobs(dataset) == 10 + close(dataset) + end + @testset "Multiple paths" begin + dataset = HDF5Dataset("test.h5", paths) + for i in 1:10 + data = Tuple(map(x -> (x isa String) ? x : getobs(x, i), datas)) + @test @inferred(Tuple, getobs(dataset, i)) == data + end + @test numobs(dataset) == 10 + close(dataset) + end + cleanup_hdf5dataset_test() +end diff --git a/test/containers/jld2dataset.jl b/test/containers/jld2dataset.jl new file mode 100644 index 00000000..d1656b91 --- /dev/null +++ b/test/containers/jld2dataset.jl @@ -0,0 +1,38 @@ +function setup_jld2dataset_test() + datasets = [ + ("d1", rand(2, 10)), + ("g1/d1", rand(10)), + ("g1/d2", string.('a':'j')), + ("g2/g1/d1", rand(Bool, 3, 3, 10)) + ] + + fid = jldopen("test.jld2", "w") + for (path, data) in datasets + fid[path] = data + end + close(fid) + + return first.(datasets), Tuple(last.(datasets)) +end +cleanup_jld2dataset_test() = rm("test.jld2") + +@testset "JLD2Dataset" begin + paths, datas = setup_jld2dataset_test() + @testset "Single path" begin + dataset = JLD2Dataset("test.jld2", "d1") + for i in 1:10 + @test @inferred(getobs(dataset, i)) == getobs(datas[1], i) + end + @test numobs(dataset) == 10 + close(dataset) + end + @testset "Multiple paths" begin + dataset = JLD2Dataset("test.jld2", paths) + for i in 1:10 + @test @inferred(getobs(dataset, i)) == getobs(datas, i) + end + @test numobs(dataset) == 10 + close(dataset) + end + cleanup_jld2dataset_test() +end diff --git a/test/containers/tabledataset.jl b/test/containers/tabledataset.jl new file mode 100644 index 00000000..6d260ea8 --- /dev/null +++ b/test/containers/tabledataset.jl @@ -0,0 +1,71 @@ +@testset "TableDataset" begin + @testset "TableDataset from rowaccess table" begin + Tables.columnaccess(::Type{<:Tables.MatrixTable}) = false + Tables.rowaccess(::Type{<:Tables.MatrixTable}) = true + + testtable = Tables.table([1 4.0 "7"; 2 5.0 "8"; 3 6.0 "9"]) + td = TableDataset(testtable) + + @test collect(@inferred(getobs(td, 1))) == [1, 4.0, "7"] + @test numobs(td) == 3 + end + + @testset "TableDataset from columnaccess table" begin + Tables.columnaccess(::Type{<:Tables.MatrixTable}) = true + Tables.rowaccess(::Type{<:Tables.MatrixTable}) = false + + testtable = Tables.table([1 4.0 "7"; 2 5.0 "8"; 3 6.0 "9"]) + td = TableDataset(testtable) + + @test collect(@inferred(NamedTuple, getobs(td, 2))) == [2, 5.0, "8"] + @test numobs(td) == 3 + + @test getobs(td, 1) isa NamedTuple + end + + @testset "TableDataset from DataFrames" begin + testtable = DataFrame( + col1 = [1, 2, 3, 4, 5], + col2 = ["a", "b", "c", "d", "e"], + col3 = [10, 20, 30, 40, 50], + col4 = ["A", "B", "C", "D", "E"], + col5 = [100.0, 200.0, 300.0, 400.0, 500.0], + split = ["train", "train", "train", "valid", "valid"], + ) + td = TableDataset(testtable) + @test td isa TableDataset{<:DataFrame} + + @test collect(@inferred(getobs(td, 1))) == [1, "a", 10, "A", 100.0, "train"] + @test numobs(td) == 5 + end + + @testset "TableDataset from CSV" begin + open("test.csv", "w") do io + write(io, "col1,col2,col3,col4,col5, split\n1,a,10,A,100.,train") + end + testtable = CSV.File("test.csv") + td = TableDataset(testtable) + @test td isa TableDataset{<:CSV.File} + @test collect(@inferred(getobs(td, 1))) == [1, "a", 10, "A", 100.0, "train"] + @test numobs(td) == 1 + rm("test.csv") + end + + @testset "TableDataset is a table" begin + testtable = DataFrame( + col1 = [1, 2, 3, 4, 5], + col2 = ["a", "b", "c", "d", "e"], + col3 = [10, 20, 30, 40, 50], + col4 = ["A", "B", "C", "D", "E"], + col5 = [100.0, 200.0, 300.0, 400.0, 500.0], + split = ["train", "train", "train", "valid", "valid"], + ) + td = TableDataset(testtable) + @testset for fn in (Tables.istable, + Tables.rowaccess, Tables.rows, + Tables.columnaccess, Tables.columns, + Tables.schema, Tables.materializer) + @test fn(td) == fn(testtable) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index d11559db..51698f86 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,12 +1,15 @@ using Test using MLDatasets +using FileIO using ImageCore using DataDeps - +using DataFrames, CSV, Tables +using HDF5 +using JLD2 ENV["DATADEPS_ALWAYS_ACCEPT"] = true -tests = [ +dataset_tests = [ # misc "tst_iris.jl", "tst_boston_housing.jl", @@ -26,12 +29,26 @@ tests = [ "tst_tudataset.jl", ] -for t in tests +container_tests = [ + "containers/filedataset.jl", + "containers/tabledataset.jl", + "containers/hdf5dataset.jl", + "containers/jld2dataset.jl", + "containers/cacheddataset.jl", +] + +@testset "Datasets" for t in dataset_tests @testset "$t" begin include(t) end end +@testset "Containers" begin + for t in container_tests + include(t) + end +end + #temporary to not stress CI if !parse(Bool, get(ENV, "CI", "false")) @testset "other tests" begin