Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow overloading of container types #2570

Merged
merged 4 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/src/manual/containers.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,24 @@ JuMP.Containers.SparseAxisArray{Tuple{Symbol,Int64},2,Tuple{Int64,Symbol}} with
[3, B] = (:B, 3)
```

## Forcing the container type

Pass `container = T` to use `T` as the container. For example:
```jldoctest; filter=r"\([1-2], [1-2]\) \=\> [2-4]"
julia> Containers.@container([i = 1:2, j = 1:2], i + j, container = Array)
2×2 Array{Int64,2}:
2 3
3 4

julia> Containers.@container([i = 1:2, j = 1:2], i + j, container = Dict)
Dict{Tuple{Int64,Int64},Int64} with 4 entries:
(1, 2) => 3
(1, 1) => 2
(2, 2) => 4
(2, 1) => 3
```
You can also pass `DenseAxisArray` or `SparseAxisArray`.

## How different container types are chosen

If the compiler can prove _at compile time_ that the index sets are rectangular,
Expand Down
15 changes: 15 additions & 0 deletions src/Containers/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,18 @@ function _sparseaxisarray(dict::Dict{Any,Any}, f, indices)
d = _container_dict(_default_eltype(indices), f, _eltype_or_any(indices))
return SparseAxisArray(d)
end

# Don't use length-1 tuples if there is only one index!
_container_key(i::Tuple) = i
_container_key(i::Tuple{T}) where {T} = i[1]

function container(f::Function, indices, D::Type{<:AbstractDict})
return D(_container_key(i) => f(i...) for i in indices)
end

function container(::Function, ::Any, D::Type)
return error(
"Unable to build a container with the provided type $(D). Implement " *
"`Containers.container`.",
)
end
12 changes: 0 additions & 12 deletions src/Containers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,18 +185,6 @@ function container_code(idxvars, indices, code, requested_container)
if isempty(idxvars)
return code
end
if !(
requested_container in
[:Auto, :Array, :DenseAxisArray, :SparseAxisArray]
)
# We do this two-step interpolation, first into the string, and then
# into the expression because interpolating into a string inside an
# expression has scoping issues.
error_message =
"Invalid container type $requested_container. Must be " *
"Auto, Array, DenseAxisArray, or SparseAxisArray."
return :(error($error_message))
end
if requested_container == :DenseAxisArray
requested_container = :(Containers.DenseAxisArray)
elseif requested_container == :SparseAxisArray
Expand Down
41 changes: 41 additions & 0 deletions test/Containers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,45 @@ using Test
throw(err)
end,)
end
@testset "Dict" begin
Containers.@container(v[i = 1:3], sin(i), container = Dict)
@test v isa Dict{Int,Float64}
@test length(v) == 3
@test v[2] ≈ sin(2)
Containers.@container(w[i = 1:3, j = 1:3], i + j, container = Dict)
@test w isa Dict{Tuple{Int,Int},Int}
@test length(w) == 9
@test w[2, 3] == 5
Containers.@container(
x[i = 1:3, j = [:a, :b]],
(j, i),
container = Dict
)
@test x isa Dict{Tuple{Int,Symbol},Tuple{Symbol,Int}}
@test length(x) == 6
@test x[2, :a] == (:a, 2)
Containers.@container(y[i = 1:3, j = 1:i], i + j, container = Dict)
@test y isa Dict{Tuple{Int,Int},Int}
@test length(y) == 6
@test y[2, 1] == 3
Containers.@container(
z[i = 1:3, j = 1:3; isodd(i + j)],
i + j,
container = Dict
)
@test z isa Dict{Tuple{Int,Int},Int}
@test length(z) == 4
@test z[1, 2] == 3
end
@testset "Invalid container" begin
err = ErrorException(
"Unable to build a container with the provided type $(Int). " *
"Implement `Containers.container`.",
)
@test_throws err Containers.@container(
x[i = 1:2, j = 1:2],
i + j,
container = Int
)
end
end
9 changes: 0 additions & 9 deletions test/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -538,15 +538,6 @@ end
@test length(JuMP.object_dictionary(model)) == 0
end

@testset "Invalid container" begin
model = Model()
exception = ErrorException(
"Invalid container type Oops. Must be Auto, Array, " *
"DenseAxisArray, or SparseAxisArray.",
)
@test_throws exception @variable(model, x[1:3], container = Oops)
end

@testset "Adjoints" begin
model = Model()
@variable(model, x[1:2])
Expand Down