From c00d823a1dc2da0f60b10d98156d4a7f941a03ad Mon Sep 17 00:00:00 2001 From: haris organtzidis Date: Fri, 20 Dec 2024 16:42:15 +0200 Subject: [PATCH] Switch `Connector` fields from `Vector{Vector{...}}` to `Vector{...}` (#504) * use `AdjacencyMatrix` instead of `get_adjacency` * store `Vector{Connector}` in composite bloxs without reducing * change `Vector{Vector}` to `Vector` for all `Connector` fields * adapt system construction to new `Connector` fields * `get_connector` always returns `Vector{Connector}` * accumulate connection eqs per blox like `BloxConnector` used to * adapt RL to new `Connector` * fix show Connector * fix RL test * change adjacency colormap to grays * add `get_connector`(s) util functions * use correct connectors getter * add token weight for Striatum=>Striatum connection * use correct connectors getter --- ext/MakieExtension.jl | 8 +- src/Neurographs.jl | 17 ++-- src/blox/DBS_Model_Blox_Adam_Brown.jl | 8 +- src/blox/blox_utilities.jl | 5 +- src/blox/canonicalmicrocircuit.jl | 2 +- src/blox/connections.jl | 136 ++++++++++++++++---------- src/blox/cortical.jl | 6 +- src/blox/reinforcement_learning.jl | 8 +- src/blox/subcortical_blox.jl | 12 +-- src/blox/winnertakeall.jl | 2 +- 10 files changed, 120 insertions(+), 84 deletions(-) diff --git a/ext/MakieExtension.jl b/ext/MakieExtension.jl index a7feb5ce..1e898a62 100644 --- a/ext/MakieExtension.jl +++ b/ext/MakieExtension.jl @@ -4,7 +4,7 @@ isdefined(Base, :get_extension) ? using Makie : using ..Makie using Neuroblox using Neuroblox: AbstractBlox, AbstractNeuronBlox, CompositeBlox, VLState, VLSetup -using Neuroblox: meanfield_timeseries, voltage_timeseries, detect_spikes, firing_rate, get_neurons, get_adjacency +using Neuroblox: meanfield_timeseries, voltage_timeseries, detect_spikes, firing_rate, get_neurons using Neuroblox: powerspectrum using SciMLBase: AbstractSolution, EnsembleSolution using LinearAlgebra: diag @@ -20,7 +20,7 @@ import Neuroblox: powerspectrumplot, powerspectrumplot! @recipe(Adjacency, blox_or_graph) do scene Theme( - colormap = :vanimo + colormap = :grays ) end @@ -28,7 +28,7 @@ argument_names(::Type{<: Adjacency}) = (:blox_or_graph) function Makie.plot!(p::Adjacency) blox_or_graph = p.blox_or_graph[] - adj = get_adjacency(blox_or_graph) + adj = AdjacencyMatrix(blox_or_graph) N = length(adj.names) @@ -42,7 +42,7 @@ function Makie.plot!(p::Adjacency) X, Y, D = findnz(adj.matrix) - heatmap!(p, X, Y, D; colormap = p.colormap[], colorrange = (minimum(D), maximum(D))) + heatmap!(p, Y, X, D; colormap = p.colormap[]) return p end diff --git a/src/Neurographs.jl b/src/Neurographs.jl index 88bbff61..0da9391b 100644 --- a/src/Neurographs.jl +++ b/src/Neurographs.jl @@ -53,7 +53,7 @@ get_dynamics_bloxs(blox::CompositeBlox) = get_parts(blox) flatten_graph(g::MetaDiGraph) = mapreduce(get_dynamics_bloxs, vcat, get_bloxs(g)) function connectors_from_graph(g::MetaDiGraph) - conns = get_connector.(get_bloxs(g)) + conns = reduce(vcat, get_connectors.(get_bloxs(g))) for edge in edges(g) blox_src = get_prop(g, edge.src, :blox) @@ -188,22 +188,25 @@ function system_from_graph(g::MetaDiGraph, p::Vector{Num}=Num[]; name=nothing, t throw(UndefKeywordError(:name)) end - bc = connector_from_graph(g) + conns = connectors_from_graph(g) - return system_from_graph(g, bc, p; name, t_block, simplify, kwargs...) + return system_from_graph(g, conns, p; name, t_block, simplify, kwargs...) end end -function system_from_graph(g::MetaDiGraph, bc::Connector, p::Vector{Num}=Num[]; name=nothing, t_block=missing, simplify=true, graphdynamics=false, kwargs...) +function system_from_graph(g::MetaDiGraph, conns::AbstractVector{<:Connector}, p::Vector{Num}=Num[]; name=nothing, t_block=missing, simplify=true, graphdynamics=false, kwargs...) bloxs = get_bloxs(g) blox_syss = get_system.(bloxs) + bc = isempty(conns) ? Connector(name, name) : reduce(merge!, conns) + eqs = equations(bc) - accumulate_equations!(eqs, bloxs) + eqs_init = mapreduce(get_input_equations, vcat, bloxs) + accumulate_equations!(eqs_init, eqs) - connection_eqs = get_equations_with_state_lhs(eqs) + connection_eqs = get_equations_with_state_lhs(eqs_init) - discrete_cbs = identity.(generate_discrete_callbacks(g, bc, eqs; t_block)) + discrete_cbs = identity.(generate_discrete_callbacks(g, bc, eqs_init; t_block)) sys = compose(System(connection_eqs, t, [], vcat(params(bc), p); name, discrete_events = discrete_cbs), blox_syss) if simplify diff --git a/src/blox/DBS_Model_Blox_Adam_Brown.jl b/src/blox/DBS_Model_Blox_Adam_Brown.jl index e481e46d..a8cf658e 100644 --- a/src/blox/DBS_Model_Blox_Adam_Brown.jl +++ b/src/blox/DBS_Model_Blox_Adam_Brown.jl @@ -106,7 +106,7 @@ struct Striatum_MSN_Adam <: CompositeBlox end parts = n_inh - bc = connector_from_graph(g) + bc = connectors_from_graph(g) sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name) @@ -185,7 +185,7 @@ struct Striatum_FSI_Adam <: CompositeBlox parts = n_inh - bc = connector_from_graph(g) + bc = connectors_from_graph(g) sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name) @@ -255,7 +255,7 @@ struct GPe_Adam <: CompositeBlox end parts = n_inh - bc = connector_from_graph(g) + bc = connectors_from_graph(g) sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name) @@ -324,7 +324,7 @@ struct STN_Adam <: CompositeBlox end parts = n_exci - bc = connector_from_graph(g) + bc = connectors_from_graph(g) sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name) diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index fbce5c56..168c4a1c 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -165,7 +165,10 @@ end get_input_equations(blox) = [] -get_connector(blox::Union{CompositeBlox, Agent}) = blox.connector +get_connectors(blox::Union{CompositeBlox, Agent}) = blox.connector +get_connectors(blox) = [Connector(namespaced_nameof(blox), namespaced_nameof(blox))] + +get_connector(blox::Union{CompositeBlox, Agent}) = reduce(merge!, get_connectors(blox)) get_connector(blox) = Connector(namespaced_nameof(blox), namespaced_nameof(blox)) function get_weight(kwargs, name_blox1, name_blox2) diff --git a/src/blox/canonicalmicrocircuit.jl b/src/blox/canonicalmicrocircuit.jl index 16cf305e..0b090991 100644 --- a/src/blox/canonicalmicrocircuit.jl +++ b/src/blox/canonicalmicrocircuit.jl @@ -48,7 +48,7 @@ mutable struct CanonicalMicroCircuitBlox <: CompositeBlox add_edge!(g, ii => dp; :weight => -400.0) add_edge!(g, dp => dp; :weight => -200.0) - bc = connector_from_graph(g) + bc = connectors_from_graph(g) # If a namespace is not provided, assume that this is the highest level # and construct the ODEsystem from the graph. sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(sblox_parts; name) diff --git a/src/blox/connections.jl b/src/blox/connections.jl index 8ab16fdf..f92514bc 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -1,9 +1,9 @@ struct Connector - source::Vector{Vector{Symbol}} - destination::Vector{Vector{Symbol}} - equation::Vector{Vector{Equation}} - weight::Vector{Vector{Num}} - delay::Vector{Vector{Num}} + source::Vector{Symbol} + destination::Vector{Symbol} + equation::Vector{Equation} + weight::Vector{Num} + delay::Vector{Num} discrete_callbacks spike_affects::Dict{Symbol, Tuple{Vector{Num}, Vector{Num}}} learning_rule::Dict{Num, AbstractLearningRule} @@ -25,65 +25,69 @@ function Connector( learning_rule = U <: NoLearningRule ? Dict{Num, NoLearningRule}() : learning_rule Connector( - to_double_vector(src), - to_double_vector(dest), - to_double_vector(equation), - to_double_vector(weight), - to_double_vector(delay), - to_double_vector(discrete_callbacks), + to_vector(src), + to_vector(dest), + to_vector(equation), + to_vector(weight), + to_vector(delay), + to_vector(discrete_callbacks), spike_affects, learning_rule ) end function Base.isempty(conn::Connector) - return all(isempty.(conn.equation)) && all(isempty.(conn.weight)) && all(isempty.(conn.delay)) && all(isempty.(conn.discrete_callbacks)) && isempty(conn.spike_affects) && isempty(conn.learning_rule) + return isempty(conn.equation) && isempty(conn.weight) && isempty(conn.delay) && isempty(conn.discrete_callbacks) && isempty(conn.spike_affects) && isempty(conn.learning_rule) end -connection_rule(blox_src, blox_dest; kwargs...) = Connector(blox_src, blox_dest; kwargs...) - Base.show(io::IO, c::Connector) = print(io, "$(c.source) => $(c.destination) with ", c.equation) -function string_to_show(v, title) - s = string.(v) - - return string("\t $(title): ", "[", join(s, " , "), "]") +function show_field(v::AbstractVector, title) + if !isempty(v) + println(title, " :") + for val in v + println("\t $(val)") + end + end end -function string_to_show(d::Dict, title) - s = [string(k, " => ", v) for (k,v) in d] - - return string("\t $(title): ", "[", join(s, " , "), "]") +function show_field(d::Dict, title) + if !isempty(d) + println(title, " :") + for (k, v) in d + println("\t ", k, " => ", v) + end + end end function Base.show(io::IO, ::MIME"text/plain", c::Connector) - N_conns = length(c.source) - for i in Base.OneTo(N_conns) - println("Connection $(c.source[i]) => $(c.destination[i])") + println("Connections :") + for (s, d) in zip(c.source, c.destination) + println("\t $(s) => $(d)") + end - !isempty(c.equation[i]) && println(string_to_show(c.equation[i], "Equation")) - !isempty(c.weight[i]) && println(string_to_show(c.weight[i], "Weight")) - !isempty(c.delay[i]) && println(string_to_show(c.delay[i], "Delay")) + show_field(c.equation, "Equations") + show_field(c.weight, "Weights") + show_field(c.delay, "Delays") - d = Dict() - for w in c.weight[i] - if haskey(c.learning_rule, w) - d[w] = c.learning_rule[w] - end + d = Dict() + for w in c.weight + if haskey(c.learning_rule, w) + d[w] = c.learning_rule[w] end - !isempty(d) && println(string_to_show(d, "Plasticity model")) - - for s in c.source[i] - if haskey(c.spike_affects, s) - println("\t $(s) spikes affect :") - vars, vals = c.spike_affects[s] - for (var, val) in zip(vars, vals) - println("\t \t $(var) += $(val)") - end + end + show_field(d, "Plasticity model") + + for s in c.source + if haskey(c.spike_affects, s) + println("$(s) spikes affect :") + vars, vals = c.spike_affects[s] + for (var, val) in zip(vars, vals) + println("\t $(var) += $(val)") end end - end + end end function accumulate_equations!(eqs::AbstractVector{<:Equation}, bloxs) @@ -108,7 +112,21 @@ function accumulate_equations!(eqs1::Vector{<:Equation}, eqs2::Vector{<:Equation return eqs1 end -ModelingToolkit.equations(c::Connector) = reduce(accumulate_equations!, c.equation) +function accumulate_equations(eqs1::Vector{<:Equation}, eqs2::Vector{<:Equation}) + eqs = copy(eqs1) + for eq in eqs2 + lhs = eq.lhs + idx = find_eq(eqs1, lhs) + + if isnothing(idx) + push!(eqs, eq) + else + eqs[idx] = eqs[idx].lhs ~ eqs[idx].rhs + eq.rhs + end + end + + return eqs +end function tuple_append!(t1::Tuple, t2::Tuple) append!(first(t1), first(t2)) @@ -117,20 +135,24 @@ function tuple_append!(t1::Tuple, t2::Tuple) return t1 end -discrete_callbacks(c::Connector) = reduce(append!, c.discrete_callbacks) +ModelingToolkit.equations(c::Connector) = c.equation -sources(c::Connector) = reduce(append!, c.source) +discrete_callbacks(c::Connector) = c.discrete_callbacks -destinations(c::Connector) = reduce(append!, c.destination) +sources(c::Connector) = c.source -weights(c::Connector) = reduce(append!, c.weight) +destinations(c::Connector) = c.destination -delays(c::Connector) = reduce(append!, c.delay) +weights(c::Connector) = c.weight + +delays(c::Connector) = c.delay spike_affects(c::Connector) = c.spike_affects learning_rules(c::Connector) = c.learning_rule +learning_rules(conns::AbstractVector{<:Connector}) = mapreduce(c -> c.learning_rule, merge!, conns) + get_equations_with_parameter_lhs(eqs::AbstractVector{<:Equation}) = filter(eq -> isparameter(eq.lhs), eqs) get_equations_with_state_lhs(eqs::AbstractVector{<:Equation}) = filter(eq -> !isparameter(eq.lhs), eqs) @@ -183,7 +205,7 @@ end function Base.merge!(c1::Connector, c2::Connector) append!(c1.source, c2.source) append!(c1.destination, c2.destination) - append!(c1.equation, c2.equation) + accumulate_equations!(c1.equation, c2.equation) append!(c1.weight, c2.weight) append!(c1.delay, c2.delay) append!(c1.discrete_callbacks, c2.discrete_callbacks) @@ -256,13 +278,19 @@ function indegree_constrained_connections(neurons_src, neurons_dst, name_src, na return reduce(merge!, C) end -function Connector(blox_src::AbstractBlox, blox_dest::AbstractBlox; kwargs...) +connection_rule(blox_src, blox_dest; kwargs...) = Connector(blox_src, blox_dest; kwargs...) + +connection_equation(blox_src, blox_dest; kwargs...) = Connector(blox_src, blox_dest; kwargs...).equation + +function connection_equation(blox_src, blox_dest, w) end + +function Connector(blox_src, blox_dest::AbstractBlox; kwargs...) sys_src = get_namespaced_sys(blox_src) sys_dest = get_namespaced_sys(blox_dest) w = generate_weight_param(blox_src, blox_dest; kwargs...) - eq = sys_dest.jcn ~ w*sys_src.v + eq = connection_equation(blox_src, blox_dest, w) return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) end @@ -767,7 +795,9 @@ function Connector( push!(dc, cb_neuron_init) end - return Connector(namespaced_nameof(blox_src), namespaced_nameof(blox_dest); discrete_callbacks=dc) + w = generate_weight_param(blox_src, blox_dest; weight=1) + + return Connector(namespaced_nameof(blox_src), namespaced_nameof(blox_dest); discrete_callbacks=dc, weight=w) end function Connector( diff --git a/src/blox/cortical.jl b/src/blox/cortical.jl index a6a97f3d..bd12d02f 100644 --- a/src/blox/cortical.jl +++ b/src/blox/cortical.jl @@ -71,7 +71,7 @@ struct CorticalBlox <: CompositeBlox add_edge!(g, N_wta+1, i, Dict(:weight => 1)) end - bc = connector_from_graph(g) + bc = connectors_from_graph(g) # If a namespace is not provided, assume that this is the highest level # and construct the ODEsystem from the graph. # If there is a higher namespace, construct only a subsystem containing the parts of this level @@ -154,7 +154,7 @@ struct LIFExciCircuitBlox <: CompositeBlox end end - bc = connector_from_graph(g) + bc = connectors_from_graph(g) if skip_system_creation sys = nothing @@ -232,7 +232,7 @@ struct LIFInhCircuitBlox <: CompositeBlox end end - bc = connector_from_graph(g) + bc = connectors_from_graph(g) if skip_system_creation sys = nothing diff --git a/src/blox/reinforcement_learning.jl b/src/blox/reinforcement_learning.jl index 206245ca..f61ea1c2 100644 --- a/src/blox/reinforcement_learning.jl +++ b/src/blox/reinforcement_learning.jl @@ -190,11 +190,11 @@ mutable struct Agent{S,P,A,LR,C} connector::C function Agent(g::MetaDiGraph; name, kwargs...) - bc = connector_from_graph(g) + conns = connectors_from_graph(g) t_block = haskey(kwargs, :t_block) ? kwargs[:t_block] : missing # TODO: add another version that uses system_from_graph(g,bc,params;) - sys = system_from_graph(g, bc; name, t_block, allow_parameter=false) + sys = system_from_graph(g, conns; name, t_block, allow_parameter=false) u0 = haskey(kwargs, :u0) ? kwargs[:u0] : [] p = haskey(kwargs, :p) ? kwargs[:p] : [] @@ -202,9 +202,9 @@ mutable struct Agent{S,P,A,LR,C} prob = ODEProblem(sys, u0, (0.,1.), p) policy = action_selection_from_graph(g) - lr = narrowtype(learning_rules(bc)) + lr = narrowtype(learning_rules(conns)) - new{typeof(sys), typeof(prob), typeof(policy), typeof(lr), typeof(bc)}(sys, prob, policy, lr, bc) + new{typeof(sys), typeof(prob), typeof(policy), typeof(lr), typeof(conns)}(sys, prob, policy, lr, conns) end end diff --git a/src/blox/subcortical_blox.jl b/src/blox/subcortical_blox.jl index de105d08..fb7de5b6 100644 --- a/src/blox/subcortical_blox.jl +++ b/src/blox/subcortical_blox.jl @@ -51,7 +51,7 @@ struct Striatum <: CompositeBlox if !isnothing(namespace) add_blox!(g, matrisome) add_blox!(g, striosome) - bc = connector_from_graph(g) + bc = connectors_from_graph(g) sys = system_from_parts(parts; name) @variables t @@ -60,7 +60,7 @@ struct Striatum <: CompositeBlox new(namespace, parts, sys, bc, m) else - bc = connector_from_graph(g) + bc = connectors_from_graph(g) sys = system_from_graph(g, bc; name, simplify=false) m = [s for s in unknowns.((sys,), unknowns(sys)) if contains(string(s), "V(t)")] @@ -119,7 +119,7 @@ struct GPi <: CompositeBlox parts = n_inh - bc = connector_from_graph(g) + bc = connectors_from_graph(g) sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name) m = if isnothing(namespace) @@ -175,7 +175,7 @@ struct GPe <: CompositeBlox parts = n_inh - bc = connector_from_graph(g) + bc = connectors_from_graph(g) sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name) m = if isnothing(namespace) @@ -231,7 +231,7 @@ struct Thalamus <: CompositeBlox parts = n_exci - bc = connector_from_graph(g) + bc = connectors_from_graph(g) sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name) m = if isnothing(namespace) @@ -286,7 +286,7 @@ struct STN <: CompositeBlox parts = n_exci - bc = connector_from_graph(g) + bc = connectors_from_graph(g) sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name) # TO DO : m is a subset of unknowns to be plotted in the GUI. diff --git a/src/blox/winnertakeall.jl b/src/blox/winnertakeall.jl index ea260ec1..0b798c79 100644 --- a/src/blox/winnertakeall.jl +++ b/src/blox/winnertakeall.jl @@ -56,7 +56,7 @@ struct WinnerTakeAllBlox{P} <: CompositeBlox parts = vcat(n_inh, n_excis) - bc = connector_from_graph(g) + bc = connectors_from_graph(g) # If a namespace is not provided, assume that this is the highest level # and construct the ODEsystem from the graph. sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name)