Skip to content

Commit

Permalink
Generalize build_variable for Variable(s)ConstrainedOnCreation (#2595)
Browse files Browse the repository at this point in the history
  • Loading branch information
pulsipher authored May 17, 2021
1 parent 9d1089a commit 41ebc5f
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1348,14 +1348,14 @@ end

function build_variable(
_error::Function,
variable::ScalarVariable,
variable::AbstractVariable,
set::MOI.AbstractScalarSet,
)
return VariableConstrainedOnCreation(variable, set)
end
function build_variable(
_error::Function,
variables::Vector{<:ScalarVariable},
variables::Vector{<:AbstractVariable},
set::MOI.AbstractVectorSet,
)
return VariablesConstrainedOnCreation(variables, set)
Expand Down
6 changes: 3 additions & 3 deletions src/sd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ This function is used by the [`@variable`](@ref) macro as follows:
"""
function build_variable(
_error::Function,
variables::Matrix{<:ScalarVariable},
variables::Matrix{<:AbstractVariable},
::SymMatrixSpace,
)
n = _square_side(_error, variables)
Expand All @@ -269,7 +269,7 @@ This function is used by the [`@variable`](@ref) macro as follows:
"""
function build_variable(
_error::Function,
variables::Matrix{<:ScalarVariable},
variables::Matrix{<:AbstractVariable},
::SkewSymmetricMatrixSpace,
)
n = _square_side(_error, variables)
Expand All @@ -295,7 +295,7 @@ This function is used by the [`@variable`](@ref) macro as follows:
"""
function build_variable(
_error::Function,
variables::Matrix{<:ScalarVariable},
variables::Matrix{<:AbstractVariable},
::PSDCone,
)
n = _square_side(_error, variables)
Expand Down
4 changes: 2 additions & 2 deletions src/sets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ See also: [`moi_set`](@ref).
"""
abstract type AbstractVectorSet end

# Used in `@constraint(model, [1:n] in s)`
# Used in `@variable(model, [1:n] in s)`
function build_variable(
_error::Function,
variables::Vector{<:ScalarVariable},
variables::Vector{<:AbstractVariable},
set::AbstractVectorSet,
)
return VariablesConstrainedOnCreation(
Expand Down
46 changes: 46 additions & 0 deletions test/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,52 @@ end
@test constraint_object(c).set == MOI.EqualTo(0.0)
end

struct NewVariable <: JuMP.AbstractVariable
info::JuMP.VariableInfo
end

@testset "Extension variables constrained on creation #2594" begin
function JuMP.build_variable(
_error::Function,
info::JuMP.VariableInfo,
::Type{NewVariable},
)
return NewVariable(info)
end
function JuMP.add_variable(model::Model, v::NewVariable, name::String = "")
return JuMP.add_variable(
model,
ScalarVariable(v.info),
name * "_normal_add",
)
end
function JuMP.add_variable(
model::Model,
v::VariablesConstrainedOnCreation{
MOI.SecondOrderCone,
VectorShape,
NewVariable,
},
names,
)
vs = map(i -> ScalarVariable(i.info), v.scalar_variables)
new_v = VariablesConstrainedOnCreation(vs, v.set, v.shape)
names .*= "_constr_add"
return JuMP.add_variable(model, new_v, names)
end

model = Model()
@variable(model, 0 <= x <= 1, NewVariable, Bin)
@test lower_bound(x) == 0
@test upper_bound(x) == 1
@test is_binary(x)
@test name(x) == "x_normal_add"

@variable(model, y[1:3] in SecondOrderCone(), NewVariable)
@test name.(y) == ["y[$i]_constr_add" for i in 1:3]
@test num_constraints(model, Vector{VariableRef}, MOI.SecondOrderCone) == 1
end

mutable struct MyVariable
test_kw::Int
info::JuMP.VariableInfo
Expand Down

0 comments on commit 41ebc5f

Please sign in to comment.