Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New explorers and tripartite graph plot utils #239

Merged
merged 145 commits into from
Jul 8, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
145 commits
Select commit Hold shift + click to select a range
2acaa09
Create heterogeneous file
gostreap May 9, 2022
5eaa1c0
Add check_dimensions for hetereogeneous graph
gostreap May 9, 2022
ec06f83
First version of HeterogeneousGraphConv
louis-gautier May 9, 2022
b3ba22f
Add heterogeneous batched featured graph
gostreap May 9, 2022
271e5fe
Add HeterogeneousTrajectoryState
gostreap May 9, 2022
e14f1d3
Created HeterogeneousStateRepresentation
louis-gautier May 9, 2022
19b4761
Merge branch 'heterogeneous' of https://github.com/corail-research/Se…
louis-gautier May 9, 2022
0de9c50
Consistent order of types of nodes + naming adjustments
louis-gautier May 9, 2022
31eaa7b
Vertex ordering consistency
louis-gautier May 9, 2022
fdcaa50
Adapted the conversion of DefaultTrajectoryState
louis-gautier May 10, 2022
c1bd8aa
Add HeterogeneousGraphConvInit + Fix HeterogeneousGraphConv
gostreap May 10, 2022
5911621
Merge branch 'heterogeneous' of github.com:corail-research/SeaPearl.j…
gostreap May 10, 2022
ab000e6
Fixed syntax errors
louis-gautier May 10, 2022
605ef98
include heterogeneous file
gostreap May 10, 2022
9fcc38f
Fix HeterogeneousStateRepresentation
gostreap May 10, 2022
a719af4
Fix adjacency_matrices
gostreap May 10, 2022
0b91a56
fix HeterogeneousTrajectoryState from HeterogeneousStateRepresentation
gostreap May 10, 2022
55801bd
Fix compilation warning and multiple include of same file
gostreap May 10, 2022
e6c3d8f
New management of VarViews in tripartite graph
louis-gautier May 10, 2022
cbe0fce
Merge branch 'heterogeneous' of https://github.com/corail-research/Se…
louis-gautier May 10, 2022
5ba7f63
Add HeterogeneousCPNN
gostreap May 10, 2022
ad84934
Merge branch 'heterogeneous' of github.com:corail-research/SeaPearl.j…
gostreap May 10, 2022
a20e754
New management of VarViews for heterogeneous graphs
louis-gautier May 10, 2022
18af577
Merge branch 'heterogeneous' of https://github.com/corail-research/Se…
louis-gautier May 10, 2022
8375617
Fix compilation error
gostreap May 10, 2022
ad69d01
Fixed error for registering VarView constraints
louis-gautier May 10, 2022
b1fa79a
Merge branch 'heterogeneous' of https://github.com/corail-research/Se…
louis-gautier May 10, 2022
69b1832
Fixed indexing error
louis-gautier May 10, 2022
7af8721
Move totalLength
gostreap May 10, 2022
0f9c82a
Fixed ViewConstraint
louis-gautier May 10, 2022
1b9096f
Merge branch 'heterogeneous' of https://github.com/corail-research/Se…
louis-gautier May 10, 2022
78365a9
Adapted tests to new management of VarViews
louis-gautier May 10, 2022
d4fa24c
Add update_with_cp_models for heterogeneous + fix featurize for heter…
gostreap May 10, 2022
392c5fb
Merge branch 'heterogeneous' of github.com:corail-research/SeaPearl.j…
gostreap May 10, 2022
c65fad0
Redefine constraint_activity for ViewConstraint + fix in featurize
gostreap May 10, 2022
a1c9432
Fixed constraint type onehot encoding
louis-gautier May 10, 2022
1a93cd7
Adapted tests to new management of VarViews
louis-gautier May 10, 2022
53cbbe6
Fixing some issues in HeterogeneousCPNN
louis-gautier May 10, 2022
a540db1
Fixed HeterogeneousCPNN-related bugs
louis-gautier May 11, 2022
6a711a6
Fixed bugs in HeterogeneousFeaturedGraph
louis-gautier May 11, 2022
d29c743
Fixed bug in HeterogeneousStateRepresentation
louis-gautier May 11, 2022
7ab5c60
First test about heterogeneousSR
gostreap May 11, 2022
24f421f
Add ! to initChosenFeatures
gostreap May 11, 2022
ce86249
Add test for HeterogeneousSR with chosen_features
gostreap May 11, 2022
edf67aa
test HeterogeneousTrajectoryState constructor + heterogeneous on squa…
gostreap May 11, 2022
1d07904
test contovar and val to var in square graph
gostreap May 11, 2022
0363920
Create heterogeneous file
gostreap May 9, 2022
ca65213
Add check_dimensions for hetereogeneous graph
gostreap May 9, 2022
a996560
First version of HeterogeneousGraphConv
louis-gautier May 9, 2022
2959429
Created HeterogeneousStateRepresentation
louis-gautier May 9, 2022
24290d8
Add heterogeneous batched featured graph
gostreap May 9, 2022
f87b405
Add HeterogeneousTrajectoryState
gostreap May 9, 2022
e4b09c0
Consistent order of types of nodes + naming adjustments
louis-gautier May 9, 2022
75e4927
Vertex ordering consistency
louis-gautier May 9, 2022
f165a44
Add HeterogeneousGraphConvInit + Fix HeterogeneousGraphConv
gostreap May 10, 2022
bbe7e74
Adapted the conversion of DefaultTrajectoryState
louis-gautier May 10, 2022
9419cb3
Fixed syntax errors
louis-gautier May 10, 2022
cdd08ef
include heterogeneous file
gostreap May 10, 2022
adba191
Fix HeterogeneousStateRepresentation
gostreap May 10, 2022
0c26667
Fix adjacency_matrices
gostreap May 10, 2022
cb17e9c
fix HeterogeneousTrajectoryState from HeterogeneousStateRepresentation
gostreap May 10, 2022
e27fbc9
New management of VarViews in tripartite graph
louis-gautier May 10, 2022
84ac6b0
Fix compilation warning and multiple include of same file
gostreap May 10, 2022
d752080
New management of VarViews for heterogeneous graphs
louis-gautier May 10, 2022
3e4234d
Add HeterogeneousCPNN
gostreap May 10, 2022
a2b16b7
Fixed error for registering VarView constraints
louis-gautier May 10, 2022
b102ebc
Fix compilation error
gostreap May 10, 2022
cfb901c
Fixed indexing error
louis-gautier May 10, 2022
81496b3
Fixed ViewConstraint
louis-gautier May 10, 2022
289e2ef
Move totalLength
gostreap May 10, 2022
8579519
Add update_with_cp_models for heterogeneous + fix featurize for heter…
gostreap May 10, 2022
9525b7e
Adapted tests to new management of VarViews
louis-gautier May 10, 2022
9c1da59
Redefine constraint_activity for ViewConstraint + fix in featurize
gostreap May 10, 2022
d5b2742
Fixed constraint type onehot encoding
louis-gautier May 10, 2022
0f5395c
Adapted tests to new management of VarViews
louis-gautier May 10, 2022
48195fe
Fixing some issues in HeterogeneousCPNN
louis-gautier May 10, 2022
34281a7
Fixed HeterogeneousCPNN-related bugs
louis-gautier May 11, 2022
7e32856
Fixed bugs in HeterogeneousFeaturedGraph
louis-gautier May 11, 2022
a7bed69
Fixed bug in HeterogeneousStateRepresentation
louis-gautier May 11, 2022
d7b850d
First test about heterogeneousSR
gostreap May 11, 2022
e4311ae
Add ! to initChosenFeatures
gostreap May 11, 2022
61ba319
Add test for HeterogeneousSR with chosen_features
gostreap May 11, 2022
3c61e0d
test HeterogeneousTrajectoryState constructor + heterogeneous on squa…
gostreap May 11, 2022
ed78e3a
test contovar and val to var in square graph
gostreap May 11, 2022
7c81bc0
include test for heterogeneousSR
gostreap May 12, 2022
0f96cfa
Merge branch 'heterogeneous' of https://github.com/corail-research/Se…
louis-gautier May 12, 2022
a05c3ca
Fixed error in DefaultStateRepresentation
louis-gautier May 13, 2022
a4b1c6c
added possibleValuesIdx features
May 13, 2022
a2d7baf
added HeterogeneousVariableOutputCPNN
May 13, 2022
b5b14c9
added possibleValuesIdx field for DefaultTrajectoryState
May 15, 2022
d9deba1
fixed associated testsets
May 15, 2022
ebe5193
Merge branch 'heterogeneous' of github.com:corail-research/SeaPearl.j…
May 15, 2022
1b677b8
Add chosen_features field to SupervisedLearnedHeuristic
ziadelassal May 17, 2022
62df632
Fixes for using VariableOutputCPNN with HeterogeneousStateRepresentation
louis-gautier May 17, 2022
248ad4b
WIP HeterogeneousFFCPNN
May 17, 2022
e3c3689
Merge branch 'heterogeneous' of github.com:corail-research/SeaPearl.j…
May 17, 2022
dac0d2d
Fixed mistakes in BatchedHeterogeneousTrajectoryState
louis-gautier May 17, 2022
7a4c40f
fixed bug for defaulstaterepresentation
May 18, 2022
b31c696
action_to_value testset
May 18, 2022
8e9a3f3
Implementing HeterogeneousFullFeaturedCPNN
louis-gautier May 18, 2022
6afa930
Merge branch 'heterogeneous' of https://github.com/corail-research/Se…
louis-gautier May 18, 2022
a9a8721
added doc
May 18, 2022
a529628
fixed from_order_to_id for DefaultTrajectoryState
May 18, 2022
a048f6f
added advanced test for default trajectory state
May 18, 2022
3623ec0
Simplifying HeterogeneousCPNN
louis-gautier May 18, 2022
acb0e56
added accessors bor batchedheterogeneousfeaturedgraph
May 18, 2022
9fb99e6
fixed edge cases and bugs
May 18, 2022
3dced34
added basic testsets for heterogeneousfeaturedgraph
May 18, 2022
c37cc0b
Fix rb generator
gostreap May 19, 2022
d9bff62
Merge branch 'heterogeneous' of github.com:corail-research/SeaPearl.j…
gostreap May 19, 2022
0247aea
fix bug on varying featured graph stored in the trajectory
May 19, 2022
20ad372
added testset for trajectoryState copy
May 19, 2022
5213af6
Merge branch 'heterogeneous' of github.com:corail-research/SeaPearl.j…
May 19, 2022
4b249a7
test for HeterogeneousFFCPNN
May 19, 2022
1bbe077
added test dependency
May 19, 2022
05edfa7
fixed testsets
May 19, 2022
a26d867
Fix kep generator
gostreap May 19, 2022
11b1ff0
Merge branch 'heterogeneous' of github.com:corail-research/SeaPearl.j…
gostreap May 19, 2022
344247e
Fixed bug causing error when running KEP
louis-gautier May 19, 2022
73f769c
Cleaned model.jl
louis-gautier May 19, 2022
1cd6eae
Fixed bug causing duplicate constraints in graphcoloring
louis-gautier May 20, 2022
8da47e2
Added new features is_branchable and is_objective
louis-gautier May 20, 2022
723fd58
fix typo
May 20, 2022
e381b66
more testset for heterogeneousFFcpnn
May 20, 2022
04665e4
added heterogeneoustrajectorystate testsets
May 20, 2022
d79307c
Added unit tests for new features
louis-gautier May 20, 2022
a034981
temporary fix of the function state |> "device"
May 23, 2022
10c40b0
fix chosen_feature's attribute : variable_is_branchable
3rdCore May 23, 2022
be6f294
fix graph coloring datagen
May 23, 2022
12f441f
added type specific fc layer for HFFCPNN
May 23, 2022
dd1fbc2
fix coloring.jl for v1.7
May 23, 2022
0991b4b
Fixed mistakes for new features
louis-gautier May 24, 2022
f77933d
Fixed error for is_branchable feature on non-branchable variables
louis-gautier May 24, 2022
570ecc8
Added testset for new variable featured
louis-gautier May 24, 2022
de8daf9
added testsets
May 24, 2022
82cb106
Merge branch 'heterogeneous' of github.com:corail-research/SeaPearl.j…
May 24, 2022
0ba9158
Added softmax explorer with temperature decay
louis-gautier May 25, 2022
4bbaba1
Add UCBExplorer with mask
gostreap May 25, 2022
1e5c2bc
Added util to get insights about input graphs
louis-gautier May 25, 2022
90fbbe2
Merge branch 'new_explorers' of https://github.com/corail-research/Se…
louis-gautier May 25, 2022
fb5a2b2
Correction on list of dependencies
louis-gautier May 25, 2022
279d37b
Fixed softmax explorer
louis-gautier May 30, 2022
09d0b87
Merge remote-tracking branch 'origin/master' into new_explorers
louis-gautier Jun 15, 2022
9cb1863
Add GraphPlot in dependencies
louis-gautier Jun 15, 2022
24cfe90
Update src/RL/representation/graphplotutils.jl
louis-gautier Jul 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge remote-tracking branch 'origin/master' into new_explorers
Conflicts:
	Project.toml
	src/RL/nn_structures/heterogeneouscpnn.jl
	src/RL/nn_structures/heterogeneousvariableoutputcpnn.jl
	src/RL/representation/default/cp_layer/accessors.jl
	src/RL/representation/default/defaultstaterepresentation.jl
	src/RL/representation/default/heterogeneousstaterepresentation.jl
	src/RL/utils/geometricflux/heterogeneousgraphconv.jl
	test/CP/valueselection/learning/environment.jl
	test/RL/nn_structures/heterogeneousfullfeaturedcpnn.jl
	test/RL/representation/default/defaultstaterepresentation.jl
	test/RL/representation/default/defaulttrajectorystate.jl
	test/RL/representation/default/heterogeneousstaterepresentation.jl
	test/datagen/coloring.jl
  • Loading branch information
louis-gautier committed Jun 15, 2022
commit 09d0b878c251e086b4098e0b38315e410f86ec4f
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GraphPlot = "a2cc645c-3eea-5389-862e-a155d0052231"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
Expand Down
8 changes: 2 additions & 6 deletions src/RL/nn_structures/heterogeneouscpnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ function (nn::HeterogeneousCPNN)(states::BatchedHeterogeneousTrajectoryState)

# chain working on the graph(s)
fg = nn.graphChain(states.fg)
variableFeatures = fg.varnf
valueFeatures = fg.valnf

globalFeatures = fg.gf
# extract the feature(s) of the variable(s) we're working on
Expand All @@ -36,11 +34,9 @@ function (nn::HeterogeneousCPNN)(states::BatchedHeterogeneousTrajectoryState)
# Double check that we are extracting the right variable
indices = CartesianIndex.(zip(variableIdx, 1:batchSize))
end
variableFeature = variableFeatures[:, indices]
#graphVariableEmbedding = Base.maximum(variableFeatures,dims=2)
#graphValueEmbedding = Base.maximum(valueFeatures,dims=2)
variableFeature = fg.varnf[:, indices]

# chain working on the node(s) feature(s)
#chainNodeOutput = nn.nodeChain(vcat(variableFeature, graphVariableEmbedding[:,1,:], graphValueEmbedding[:,1,:]))
chainNodeOutput = nn.nodeChain(variableFeature)
if isempty(globalFeatures)
# output layers
Expand Down
115 changes: 106 additions & 9 deletions src/RL/nn_structures/heterogeneousvariableoutputcpnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,121 @@ end
# Enable the `|> gpu` syntax from Flux
Flux.@functor HeterogeneousVariableOutputCPNN

wears_mask(s::HeterogeneousVariableOutputCPNN) = false

function (nn::HeterogeneousVariableOutputCPNN)(state::GraphTrajectoryState)
wears_mask(s::HeterogeneousVariableOutputCPNN) = true

function (nn::HeterogeneousVariableOutputCPNN)(state::HeterogeneousTrajectoryState)
variableIdx = state.variableIdx
possibleValuesIdx = state.possibleValuesIdx
# chain working on the graph(s)
fg = nn.graphChain(state.fg)
# chain working on the node(s) feature(s)
variableFeature = fg.varnf[:, variableIdx]
variableFeatures = fg.varnf[:, variableIdx]
valueFeatures = fg.valnf[:, possibleValuesIdx]
chainOutput = nn.nodeChain(hcat(variableFeature, valueFeatures))
chainOutput = nn.nodeChain(hcat(variableFeatures, valueFeatures))
variableOutput = chainOutput[:, 1]
valueOutput = chainOutput[:, 2:end]


finalInput = vcat(repeat(variableOutput, 1, length(possibleValuesIdx)), valueOutput)
finalOutput = nn.outputChain(finalInput)
output = dropdims(finalOutput; dims = 1)
return output
output = dropdims(nn.outputChain(finalInput); dims = 1)
finalOutput = zeros(Float32, size(fg.valnf, 2))
for i in 1:length(possibleValuesIdx)
finalOutput[possibleValuesIdx[i]] = output[i]
end
return finalOutput

end

function (nn::HeterogeneousVariableOutputCPNN)(state::BatchedHeterogeneousTrajectoryState)
variableIdx = state.variableIdx #Vector of size B
batchSize = length(variableIdx)
possibleValuesIdx = [deepcopy(indexes) for indexes in state.possibleValuesIdx]

@assert batchSize == length(possibleValuesIdx)
actionSpaceSizes = [length(possibleValuesIdxPerVar) for possibleValuesIdxPerVar in possibleValuesIdx]
maxActionSpaceSize = Base.maximum(actionSpaceSizes)
#the largest action space size in the batch

# chain working on the graph(s)
fg = nn.graphChain(state.fg)

# extract the feature(s) of the variable(s) we're working on
indices = nothing
Zygote.ignore() do
indices = CartesianIndex.(zip(variableIdx, 1:batchSize))
end
variableFeatures = nothing
numPadded = nothing
Zygote.ignore() do
variableFeatures = reshape(fg.varnf[:, indices], (:,1,batchSize)) # Fx1xB
# extract the feature(s) of the variable(s) we're working on
numPadded = [maxActionSpaceSize - actionSpaceSizes[i] for i in 1:batchSize] #number of padding zeros needed fo each element of the batch
end

valueIndices = nothing
Zygote.ignore() do
paddedPossibleValuesIdx = [append!(possibleValuesIdx[i], repeat([possibleValuesIdx[i][1]], numPadded[i])) for i in 1:batchSize]
paddedPossibleValuesIdx = mapreduce(identity, hcat, paddedPossibleValuesIdx) #convert from Vector to Matrix
#create a CartesianIndex matrix of size (maxActionSpaceSize x batch_size)
valueIndices = CartesianIndex.(paddedPossibleValuesIdx, repeat(transpose(1:batchSize); outer=(maxActionSpaceSize, 1)))
end

valueFeatures = fg.valnf[:, valueIndices] #FxAxB

f = size(valueFeatures, 1)
Zygote.ignore() do
for i in 1:batchSize
for j in 1:numPadded[i]
valueFeatures[:,maxActionSpaceSize-j+1,i] = zeros(Float32, f)
end
end
end

# chain working on the node(s) feature(s)
chainOutput = nn.nodeChain(hcat(variableFeatures, valueFeatures)) #F'x(A+1)xB where F' is the output size of nodeChain

variableOutput = nothing
valueOutput = nothing
Zygote.ignore() do
variableOutput = reshape(chainOutput[:,1,:], (:,1,batchSize)) #F'xB
valueOutput = chainOutput[:,2:end,:] #F'xAxB
end

finalInput = nothing
Zygote.ignore() do
finalInput = []
for i in 1:batchSize
singleFinalInput = vcat(repeat(variableOutput[:,:,i], 1, maxActionSpaceSize), valueOutput[:,:,i]) #one element of the batch
finalInput = isempty(finalInput) ? [singleFinalInput] : append!(finalInput, [singleFinalInput])
end
end
#finalInput: vector of matrices of size F'xA (total size BxF'xA)
Zygote.ignore() do
f, a = size(finalInput[1])
finalInput = reshape(collect(Iterators.flatten(finalInput)), (f, maxActionSpaceSize, batchSize)) #!!TO TEST #convert vector of matrices into a 3-dimensional matrix

end

output = dropdims(nn.outputChain(finalInput); dims=1) #AxB

finalOutput = nothing
Zygote.ignore() do
finalOutput = reshape(
Float32[-Inf32 for _ in 1:(size(fg.valnf,2)*size(fg.valnf,3))],
size(fg.valnf,2),
size(fg.valnf,3)
)
end

Zygote.ignore() do
for i in 1:batchSize
for j in 1:actionSpaceSizes[i]
#note that possibleValuesIdx is a Vector{Vector{Int64}} while output and finalOutput are Matrix{Int64}
#thus the indexing can be inverted, e.g. the batches are in dim 1 for a Vector{Vector{Int64}} and in dim 2 for a Matrix{Int64}
finalOutput[possibleValuesIdx[i][j], i] = output[j,i]
end
end
end

return finalOutput

end
6 changes: 3 additions & 3 deletions src/RL/representation/default/cp_layer/accessors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,20 +150,20 @@ function LightGraphs.adjacency_matrix(cplayergraph::CPLayerGraph)
end

function adjacency_matrices(cplayergraph::CPLayerGraph)
g = Graph(cplayergraph) # Update the graph with the new information
g = LightGraphs.Graph(cplayergraph) # Update the graph with the new information
nvar = cplayergraph.numberOfVariables
ncon = cplayergraph.numberOfConstraints
nval = cplayergraph.numberOfValues
contovar = zeros(ncon, nvar)
valtovar = zeros(nval, nvar)
for (i, node) in enumerate(cplayergraph.idToNode)
if isa(node, ConstraintVertex)
neighbors = outneighbors(g, i)
neighbors = LightGraphs.outneighbors(g, i)
for neighbor in neighbors
contovar[i, neighbor - ncon] = 1
end
elseif isa(node, ValueVertex)
neighbors = outneighbors(g, i)
neighbors = LightGraphs.outneighbors(g, i)
for neighbor in neighbors
valtovar[i - ncon - nvar, neighbor - ncon] = 1
end
Expand Down
15 changes: 14 additions & 1 deletion src/RL/representation/default/defaultstaterepresentation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ It is only necessary to specify the options you wish to activate.
"""
function featurize(sr::DefaultStateRepresentation{DefaultFeaturization,TS}; chosen_features::Union{Nothing,Dict{String,Bool}}=nothing) where {TS}
initChosenFeatures!(sr, chosen_features)

g = sr.cplayergraph
features = zeros(Float32, sr.nbFeatures, LightGraphs.nv(g))
for i in 1:LightGraphs.nv(g)
Expand Down Expand Up @@ -191,6 +190,9 @@ function featurize(sr::DefaultStateRepresentation{DefaultFeaturization,TS}; chos
if sr.chosenFeatures["variable_is_objective"][1]
features[sr.chosenFeatures["variable_is_objective"][2], i] = sr.cplayergraph.cpmodel.objective == cp_vertex.variable
end
if sr.chosenFeatures["variable_assigned_value"][1]
features[sr.chosenFeatures["variable_assigned_value"][2], i] = isbound(cp_vertex.variable) ? assignedValue(cp_vertex.variable) : 0
end
end
if isa(cp_vertex, ValueVertex)
features[3, i] = 1.0f0
Expand Down Expand Up @@ -225,6 +227,7 @@ function initChosenFeatures!(sr::DefaultStateRepresentation{DefaultFeaturization
"variable_is_bound" => (false, -1),
"variable_is_branchable" => (false, -1),
"variable_is_objective" => (false, -1),
"variable_assigned_value" => (false, -1),
"values_onehot" => (false, -1),
"values_raw" => (false, -1),
)
Expand Down Expand Up @@ -266,6 +269,16 @@ function initChosenFeatures!(sr::DefaultStateRepresentation{DefaultFeaturization
counter += 1
end

if haskey(chosen_features, "variable_assigned_value") && chosen_features["variable_assigned_value"]
sr.chosenFeatures["variable_assigned_value"] = (true, counter)
counter += 1
end

if haskey(chosen_features, "node_number_of_neighbors") && chosen_features["node_number_of_neighbors"]
sr.chosenFeatures["node_number_of_neighbors"] = (true, counter)
counter += 1
end

if haskey(chosen_features, "nb_not_bounded_variable") && chosen_features["nb_not_bounded_variable"]
sr.chosenFeatures["nb_not_bounded_variable"] = (true, counter)
counter += 1
Expand Down
52 changes: 49 additions & 3 deletions src/RL/representation/default/heterogeneousstaterepresentation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,12 @@ function featurize(sr::HeterogeneousStateRepresentation{DefaultFeaturization,TS}
valueFeatures = zeros(Float32, sr.nbValueFeatures, g.numberOfValues)
ncon = sr.cplayergraph.numberOfConstraints
nvar = sr.cplayergraph.numberOfVariables
for i in 1:nv(g)
for i in 1:LightGraphs.nv(g)
cp_vertex = SeaPearl.cpVertexFromIndex(g, i)
if isa(cp_vertex, VariableVertex)
if sr.chosenFeatures["node_number_of_neighbors"][1]
variableFeatures[sr.chosenFeatures["node_number_of_neighbors"][2], i - ncon] = length(LightGraphs.outneighbors(g, i))
end
if sr.chosenFeatures["variable_initial_domain_size"][1]
variableFeatures[sr.chosenFeatures["variable_initial_domain_size"][2], i - ncon] = length(cp_vertex.variable.domain)
end
Expand All @@ -137,8 +140,14 @@ function featurize(sr::HeterogeneousStateRepresentation{DefaultFeaturization,TS}
if sr.chosenFeatures["variable_is_objective"][1]
variableFeatures[sr.chosenFeatures["variable_is_objective"][2], i - ncon] = sr.cplayergraph.cpmodel.objective == cp_vertex.variable
end
if sr.chosenFeatures["variable_assigned_value"][1]
variableFeatures[sr.chosenFeatures["variable_assigned_value"][2], i - ncon] = isbound(cp_vertex.variable) ? assignedValue(cp_vertex.variable) : 0
end
end
if isa(cp_vertex, ConstraintVertex)
if sr.chosenFeatures["node_number_of_neighbors"][1]
constraintFeatures[sr.chosenFeatures["node_number_of_neighbors"][2], i] = length(LightGraphs.outneighbors(g, i))
end
if sr.chosenFeatures["constraint_activity"][1]
if isa(cp_vertex.constraint, ViewConstraint)
constraintFeatures[sr.chosenFeatures["constraint_activity"][2], i] = isbound(cp_vertex.constraint.parent)
Expand Down Expand Up @@ -171,6 +180,9 @@ function featurize(sr::HeterogeneousStateRepresentation{DefaultFeaturization,TS}
end
end
if isa(cp_vertex, ValueVertex)
if sr.chosenFeatures["node_number_of_neighbors"][1]
valueFeatures[sr.chosenFeatures["node_number_of_neighbors"][2], i - ncon - nvar] = length(LightGraphs.outneighbors(g, i))
end
if sr.chosenFeatures["values_raw"][1]
valueFeatures[sr.chosenFeatures["values_raw"][2], i - ncon - nvar] = cp_vertex.value
end
Expand All @@ -196,11 +208,13 @@ function initChosenFeatures!(sr::HeterogeneousStateRepresentation{DefaultFeaturi
"constraint_type" => (false, -1),
"nb_involved_constraint_propagation" => (false, -1),
"nb_not_bounded_variable" => (false, -1),
"node_number_of_neighbors" => (false, -1),
"variable_domain_size" => (false, -1),
"variable_initial_domain_size" => (false, -1),
"variable_is_bound" => (false, -1),
"variable_is_branchable" => (false, -1),
"variable_is_objective" => (false, -1),
"variable_assigned_value" => (false, -1),
"values_onehot" => (false, -1),
"values_raw" => (false, -1),
)
Expand All @@ -209,6 +223,13 @@ function initChosenFeatures!(sr::HeterogeneousStateRepresentation{DefaultFeaturi
constraint_counter = 1
value_counter = 1
if !isnothing(chosen_features)
if haskey(chosen_features, "node_number_of_neighbors") && chosen_features["node_number_of_neighbors"]
sr.chosenFeatures["node_number_of_neighbors"] = (true, constraint_counter)
constraint_counter += 1
variable_counter += 1
value_counter += 1
end

if haskey(chosen_features, "constraint_activity") && chosen_features["constraint_activity"]
sr.chosenFeatures["constraint_activity"] = (true, constraint_counter)
constraint_counter += 1
Expand Down Expand Up @@ -244,6 +265,11 @@ function initChosenFeatures!(sr::HeterogeneousStateRepresentation{DefaultFeaturi
variable_counter += 1
end

if haskey(chosen_features, "variable_assigned_value") && chosen_features["variable_assigned_value"]
sr.chosenFeatures["variable_assigned_value"] = (true, variable_counter)
variable_counter += 1
end

if haskey(chosen_features, "nb_not_bounded_variable") && chosen_features["nb_not_bounded_variable"]
sr.chosenFeatures["nb_not_bounded_variable"] = (true, constraint_counter)
constraint_counter += 1
Expand Down Expand Up @@ -294,7 +320,7 @@ function update_features!(sr::HeterogeneousStateRepresentation{DefaultFeaturizat
g = sr.cplayergraph
ncon = sr.cplayergraph.numberOfConstraints
nvar = sr.cplayergraph.numberOfVariables
for i in 1:nv(g)
for i in 1:LightGraphs.nv(g)
cp_vertex = SeaPearl.cpVertexFromIndex(g, i)
if isa(cp_vertex, VariableVertex)
if sr.chosenFeatures["variable_domain_size"][1]
Expand All @@ -304,6 +330,14 @@ function update_features!(sr::HeterogeneousStateRepresentation{DefaultFeaturizat
if sr.chosenFeatures["variable_is_bound"][1]
sr.variableNodeFeatures[sr.chosenFeatures["variable_is_bound"][2], i - ncon] = isbound(cp_vertex.variable)
end

if sr.chosenFeatures["node_number_of_neighbors"][1]
sr.variableNodeFeatures[sr.chosenFeatures["node_number_of_neighbors"][2], i - ncon] = length(LightGraphs.outneighbors(g, i))
end

if sr.chosenFeatures["variable_assigned_value"][1]
sr.variableNodeFeatures[sr.chosenFeatures["variable_assigned_value"][2], i - ncon] = isbound(cp_vertex.variable) ? assignedValue(cp_vertex.variable) : 0
end
end
if isa(cp_vertex, ConstraintVertex)
if sr.chosenFeatures["constraint_activity"][1]
Expand All @@ -315,13 +349,21 @@ function update_features!(sr::HeterogeneousStateRepresentation{DefaultFeaturizat
end

if sr.chosenFeatures["nb_involved_constraint_propagation"][1]
sr.constraintNodeFeatures[sr.chosenFeatures["nb_involved_constraint_propagation"][2], i] = sr.cplayergraph.cpmodel.statistics.numberOfTimesInvolvedInPropagation[cp_vertex.constraint]
if isa(cp_vertex.constraint, ViewConstraint)
sr.constraintNodeFeatures[sr.chosenFeatures["nb_involved_constraint_propagation"][2], i] = 0
else
sr.constraintNodeFeatures[sr.chosenFeatures["nb_involved_constraint_propagation"][2], i] = sr.cplayergraph.cpmodel.statistics.numberOfTimesInvolvedInPropagation[cp_vertex.constraint]
end
end

if sr.chosenFeatures["nb_not_bounded_variable"][1]
variables = variablesArray(cp_vertex.constraint)
sr.constraintNodeFeatures[sr.chosenFeatures["nb_not_bounded_variable"][2], i] = count(x -> !isbound(x), variables)
end

if sr.chosenFeatures["node_number_of_neighbors"][1]
sr.constraintNodeFeatures[sr.chosenFeatures["node_number_of_neighbors"][2], i] = length(LightGraphs.outneighbors(g, i))
end
end
if isa(cp_vertex, ValueVertex) # Probably useless, check before removing
if sr.chosenFeatures["values_raw"][1]
Expand All @@ -332,6 +374,10 @@ function update_features!(sr::HeterogeneousStateRepresentation{DefaultFeaturizat
cp_vertex_idx = sr.valueToPos[cp_vertex.value]
sr.valueNodeFeatures[sr.chosenFeatures["values_onehot"][2]+cp_vertex_idx-1, i - ncon - nvar] = 1
end

if sr.chosenFeatures["node_number_of_neighbors"][1]
sr.valueNodeFeatures[sr.chosenFeatures["node_number_of_neighbors"][2], i - ncon - nvar] = length(LightGraphs.outneighbors(g, i))
end
end
end
end
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.