Skip to content

Commit

Permalink
Allow overloading of container types (#2570)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored May 11, 2021
1 parent 4f6abf8 commit 8062654
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 21 deletions.
18 changes: 18 additions & 0 deletions docs/src/manual/containers.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,24 @@ JuMP.Containers.SparseAxisArray{Tuple{Symbol,Int64},2,Tuple{Int64,Symbol}} with
[3, B] = (:B, 3)
```

## Forcing the container type

Pass `container = T` to use `T` as the container. For example:
```jldoctest; filter=r"\([1-2], [1-2]\) \=\> [2-4]"
julia> Containers.@container([i = 1:2, j = 1:2], i + j, container = Array)
2×2 Array{Int64,2}:
2 3
3 4
julia> Containers.@container([i = 1:2, j = 1:2], i + j, container = Dict)
Dict{Tuple{Int64,Int64},Int64} with 4 entries:
(1, 2) => 3
(1, 1) => 2
(2, 2) => 4
(2, 1) => 3
```
You can also pass `DenseAxisArray` or `SparseAxisArray`.

## How different container types are chosen

If the compiler can prove _at compile time_ that the index sets are rectangular,
Expand Down
15 changes: 15 additions & 0 deletions src/Containers/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,18 @@ function _sparseaxisarray(dict::Dict{Any,Any}, f, indices)
d = _container_dict(_default_eltype(indices), f, _eltype_or_any(indices))
return SparseAxisArray(d)
end

# Don't use length-1 tuples if there is only one index!
_container_key(i::Tuple) = i
_container_key(i::Tuple{T}) where {T} = i[1]

function container(f::Function, indices, D::Type{<:AbstractDict})
return D(_container_key(i) => f(i...) for i in indices)
end

function container(::Function, ::Any, D::Type)
return error(
"Unable to build a container with the provided type $(D). Implement " *
"`Containers.container`.",
)
end
12 changes: 0 additions & 12 deletions src/Containers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,18 +185,6 @@ function container_code(idxvars, indices, code, requested_container)
if isempty(idxvars)
return code
end
if !(
requested_container in
[:Auto, :Array, :DenseAxisArray, :SparseAxisArray]
)
# We do this two-step interpolation, first into the string, and then
# into the expression because interpolating into a string inside an
# expression has scoping issues.
error_message =
"Invalid container type $requested_container. Must be " *
"Auto, Array, DenseAxisArray, or SparseAxisArray."
return :(error($error_message))
end
if requested_container == :DenseAxisArray
requested_container = :(Containers.DenseAxisArray)
elseif requested_container == :SparseAxisArray
Expand Down
41 changes: 41 additions & 0 deletions test/Containers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,45 @@ using Test
throw(err)
end,)
end
@testset "Dict" begin
Containers.@container(v[i = 1:3], sin(i), container = Dict)
@test v isa Dict{Int,Float64}
@test length(v) == 3
@test v[2] sin(2)
Containers.@container(w[i = 1:3, j = 1:3], i + j, container = Dict)
@test w isa Dict{Tuple{Int,Int},Int}
@test length(w) == 9
@test w[2, 3] == 5
Containers.@container(
x[i = 1:3, j = [:a, :b]],
(j, i),
container = Dict
)
@test x isa Dict{Tuple{Int,Symbol},Tuple{Symbol,Int}}
@test length(x) == 6
@test x[2, :a] == (:a, 2)
Containers.@container(y[i = 1:3, j = 1:i], i + j, container = Dict)
@test y isa Dict{Tuple{Int,Int},Int}
@test length(y) == 6
@test y[2, 1] == 3
Containers.@container(
z[i = 1:3, j = 1:3; isodd(i + j)],
i + j,
container = Dict
)
@test z isa Dict{Tuple{Int,Int},Int}
@test length(z) == 4
@test z[1, 2] == 3
end
@testset "Invalid container" begin
err = ErrorException(
"Unable to build a container with the provided type $(Int). " *
"Implement `Containers.container`.",
)
@test_throws err Containers.@container(
x[i = 1:2, j = 1:2],
i + j,
container = Int
)
end
end
9 changes: 0 additions & 9 deletions test/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -605,15 +605,6 @@ end
@test length(JuMP.object_dictionary(model)) == 0
end

@testset "Invalid container" begin
model = Model()
exception = ErrorException(
"Invalid container type Oops. Must be Auto, Array, " *
"DenseAxisArray, or SparseAxisArray.",
)
@test_throws exception @variable(model, x[1:3], container = Oops)
end

@testset "Adjoints" begin
model = Model()
@variable(model, x[1:2])
Expand Down

0 comments on commit 8062654

Please sign in to comment.