Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow overloading of container types #2570

Merged
merged 4 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/src/manual/containers.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,24 @@ JuMP.Containers.SparseAxisArray{Tuple{Symbol,Int64},2,Tuple{Int64,Symbol}} with
[3, B] = (:B, 3)
```

## Forcing the container type

Pass `container = T` to use `T` as the container. For example:
```jldoctest; filter=r"\([1-2], [1-2]\) \=\> [2-4]"
julia> Containers.@container([i = 1:2, j = 1:2], i + j, container = Array)
2×2 Array{Int64,2}:
2 3
3 4

julia> Containers.@container([i = 1:2, j = 1:2], i + j, container = Dict)
Dict{Tuple{Int64,Int64},Int64} with 4 entries:
(1, 2) => 3
(1, 1) => 2
(2, 2) => 4
(2, 1) => 3
```
You can also pass `DenseAxisArray` or `SparseAxisArray`.

## How different container types are chosen

If the compiler can prove _at compile time_ that the index sets are rectangular,
Expand Down
13 changes: 13 additions & 0 deletions src/Containers/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,16 @@ function _sparseaxisarray(dict::Dict{Any,Any}, f, indices)
d = _container_dict(_default_eltype(indices), f, _eltype_or_any(indices))
return SparseAxisArray(d)
end

function container(f::Function, indices, D::Type{<:AbstractDict})
# Don't use length-1 tuples if there is only one index!
key(i) = length(i) == 1 ? i[1] : i
return D(key(i) => f(i...) for i in indices)
end

function container(::Function, ::Any, D::Type)
return error(
"Unable to build a container with the provided type $(D). Implement " *
"`Containers.container`.",
)
end
12 changes: 0 additions & 12 deletions src/Containers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,18 +185,6 @@ function container_code(idxvars, indices, code, requested_container)
if isempty(idxvars)
return code
end
if !(
requested_container in
[:Auto, :Array, :DenseAxisArray, :SparseAxisArray]
)
# We do this two-step interpolation, first into the string, and then
# into the expression because interpolating into a string inside an
# expression has scoping issues.
error_message =
"Invalid container type $requested_container. Must be " *
"Auto, Array, DenseAxisArray, or SparseAxisArray."
return :(error($error_message))
end
if requested_container == :DenseAxisArray
requested_container = :(Containers.DenseAxisArray)
elseif requested_container == :SparseAxisArray
Expand Down
37 changes: 37 additions & 0 deletions test/Containers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,41 @@ using Test
throw(err)
end,)
end
@testset "Dict" begin
Containers.@container(v[i = 1:3], sin(i), container = Dict)
@test v isa Dict{Int,Float64}
@test length(v) == 3
@test v[2] ≈ sin(2)
Containers.@container(w[i = 1:3, j = 1:3], i + j, container = Dict)
@test w isa Dict{Tuple{Int,Int},Int}
@test length(w) == 9
@test w[2, 3] == 5
Containers.@container(
x[i = 1:3, j = [:a, :b]],
(j, i),
container = Dict
)
@test x isa Dict{Tuple{Int,Symbol},Tuple{Symbol,Int}}
@test length(x) == 6
@test x[2, :a] == (:a, 2)
Containers.@container(y[i = 1:3, j = 1:i], i + j, container = Dict)
@test y isa Dict{Tuple{Int,Int},Int}
@test length(y) == 6
@test y[2, 1] == 3
Containers.@container(
z[i = 1:3, j = 1:3; isodd(i + j)],
i + j,
container = Dict
)
@test z isa Dict{Tuple{Int,Int},Int}
@test length(z) == 4
@test z[1, 2] == 3
end
@testset "Invalid container" begin
@test_throws ErrorException Containers.@container(
x[i = 1:2, j = 1:2],
i + j,
container = Int
)
end
end
9 changes: 0 additions & 9 deletions test/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -538,15 +538,6 @@ end
@test length(JuMP.object_dictionary(model)) == 0
end

@testset "Invalid container" begin
model = Model()
exception = ErrorException(
"Invalid container type Oops. Must be Auto, Array, " *
"DenseAxisArray, or SparseAxisArray.",
)
@test_throws exception @variable(model, x[1:3], container = Oops)
end

@testset "Adjoints" begin
model = Model()
@variable(model, x[1:2])
Expand Down