Skip to content

Commit

Permalink
Switch Connector fields from Vector{Vector{...}} to Vector{...} (
Browse files Browse the repository at this point in the history
…#504)

* use `AdjacencyMatrix` instead of `get_adjacency`

* store `Vector{Connector}` in composite bloxs without reducing

* change `Vector{Vector}` to `Vector` for all `Connector` fields

* adapt system construction to new `Connector` fields

* `get_connector` always returns `Vector{Connector}`

* accumulate connection eqs per blox like `BloxConnector` used to

* adapt RL to new `Connector`

* fix show Connector

* fix RL test

* change adjacency colormap to grays

* add `get_connector`(s) util functions

* use correct connectors getter

* add token weight for Striatum=>Striatum connection

* use correct connectors getter
  • Loading branch information
harisorgn authored and david-hofmann committed Jan 2, 2025
1 parent b8778b4 commit c00d823
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 84 deletions.
8 changes: 4 additions & 4 deletions ext/MakieExtension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ isdefined(Base, :get_extension) ? using Makie : using ..Makie

using Neuroblox
using Neuroblox: AbstractBlox, AbstractNeuronBlox, CompositeBlox, VLState, VLSetup
using Neuroblox: meanfield_timeseries, voltage_timeseries, detect_spikes, firing_rate, get_neurons, get_adjacency
using Neuroblox: meanfield_timeseries, voltage_timeseries, detect_spikes, firing_rate, get_neurons
using Neuroblox: powerspectrum
using SciMLBase: AbstractSolution, EnsembleSolution
using LinearAlgebra: diag
Expand All @@ -20,15 +20,15 @@ import Neuroblox: powerspectrumplot, powerspectrumplot!

@recipe(Adjacency, blox_or_graph) do scene
Theme(
colormap = :vanimo
colormap = :grays
)
end

argument_names(::Type{<: Adjacency}) = (:blox_or_graph)

function Makie.plot!(p::Adjacency)
blox_or_graph = p.blox_or_graph[]
adj = get_adjacency(blox_or_graph)
adj = AdjacencyMatrix(blox_or_graph)

N = length(adj.names)

Expand All @@ -42,7 +42,7 @@ function Makie.plot!(p::Adjacency)

X, Y, D = findnz(adj.matrix)

heatmap!(p, X, Y, D; colormap = p.colormap[], colorrange = (minimum(D), maximum(D)))
heatmap!(p, Y, X, D; colormap = p.colormap[])

return p
end
Expand Down
17 changes: 10 additions & 7 deletions src/Neurographs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ get_dynamics_bloxs(blox::CompositeBlox) = get_parts(blox)
flatten_graph(g::MetaDiGraph) = mapreduce(get_dynamics_bloxs, vcat, get_bloxs(g))

function connectors_from_graph(g::MetaDiGraph)
conns = get_connector.(get_bloxs(g))
conns = reduce(vcat, get_connectors.(get_bloxs(g)))
for edge in edges(g)

blox_src = get_prop(g, edge.src, :blox)
Expand Down Expand Up @@ -188,22 +188,25 @@ function system_from_graph(g::MetaDiGraph, p::Vector{Num}=Num[]; name=nothing, t
throw(UndefKeywordError(:name))
end

bc = connector_from_graph(g)
conns = connectors_from_graph(g)

return system_from_graph(g, bc, p; name, t_block, simplify, kwargs...)
return system_from_graph(g, conns, p; name, t_block, simplify, kwargs...)
end
end

function system_from_graph(g::MetaDiGraph, bc::Connector, p::Vector{Num}=Num[]; name=nothing, t_block=missing, simplify=true, graphdynamics=false, kwargs...)
function system_from_graph(g::MetaDiGraph, conns::AbstractVector{<:Connector}, p::Vector{Num}=Num[]; name=nothing, t_block=missing, simplify=true, graphdynamics=false, kwargs...)
bloxs = get_bloxs(g)
blox_syss = get_system.(bloxs)

bc = isempty(conns) ? Connector(name, name) : reduce(merge!, conns)

eqs = equations(bc)
accumulate_equations!(eqs, bloxs)
eqs_init = mapreduce(get_input_equations, vcat, bloxs)
accumulate_equations!(eqs_init, eqs)

connection_eqs = get_equations_with_state_lhs(eqs)
connection_eqs = get_equations_with_state_lhs(eqs_init)

discrete_cbs = identity.(generate_discrete_callbacks(g, bc, eqs; t_block))
discrete_cbs = identity.(generate_discrete_callbacks(g, bc, eqs_init; t_block))

sys = compose(System(connection_eqs, t, [], vcat(params(bc), p); name, discrete_events = discrete_cbs), blox_syss)
if simplify
Expand Down
8 changes: 4 additions & 4 deletions src/blox/DBS_Model_Blox_Adam_Brown.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ struct Striatum_MSN_Adam <: CompositeBlox
end
parts = n_inh

bc = connector_from_graph(g)
bc = connectors_from_graph(g)

sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name)

Expand Down Expand Up @@ -185,7 +185,7 @@ struct Striatum_FSI_Adam <: CompositeBlox

parts = n_inh

bc = connector_from_graph(g)
bc = connectors_from_graph(g)

sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name)

Expand Down Expand Up @@ -255,7 +255,7 @@ struct GPe_Adam <: CompositeBlox
end
parts = n_inh

bc = connector_from_graph(g)
bc = connectors_from_graph(g)

sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name)

Expand Down Expand Up @@ -324,7 +324,7 @@ struct STN_Adam <: CompositeBlox
end
parts = n_exci

bc = connector_from_graph(g)
bc = connectors_from_graph(g)

sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name)

Expand Down
5 changes: 4 additions & 1 deletion src/blox/blox_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ end

get_input_equations(blox) = []

get_connector(blox::Union{CompositeBlox, Agent}) = blox.connector
get_connectors(blox::Union{CompositeBlox, Agent}) = blox.connector
get_connectors(blox) = [Connector(namespaced_nameof(blox), namespaced_nameof(blox))]

get_connector(blox::Union{CompositeBlox, Agent}) = reduce(merge!, get_connectors(blox))
get_connector(blox) = Connector(namespaced_nameof(blox), namespaced_nameof(blox))

function get_weight(kwargs, name_blox1, name_blox2)
Expand Down
2 changes: 1 addition & 1 deletion src/blox/canonicalmicrocircuit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ mutable struct CanonicalMicroCircuitBlox <: CompositeBlox
add_edge!(g, ii => dp; :weight => -400.0)
add_edge!(g, dp => dp; :weight => -200.0)

bc = connector_from_graph(g)
bc = connectors_from_graph(g)
# If a namespace is not provided, assume that this is the highest level
# and construct the ODEsystem from the graph.
sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(sblox_parts; name)
Expand Down
136 changes: 83 additions & 53 deletions src/blox/connections.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
struct Connector
source::Vector{Vector{Symbol}}
destination::Vector{Vector{Symbol}}
equation::Vector{Vector{Equation}}
weight::Vector{Vector{Num}}
delay::Vector{Vector{Num}}
source::Vector{Symbol}
destination::Vector{Symbol}
equation::Vector{Equation}
weight::Vector{Num}
delay::Vector{Num}
discrete_callbacks
spike_affects::Dict{Symbol, Tuple{Vector{Num}, Vector{Num}}}
learning_rule::Dict{Num, AbstractLearningRule}
Expand All @@ -25,65 +25,69 @@ function Connector(
learning_rule = U <: NoLearningRule ? Dict{Num, NoLearningRule}() : learning_rule

Connector(
to_double_vector(src),
to_double_vector(dest),
to_double_vector(equation),
to_double_vector(weight),
to_double_vector(delay),
to_double_vector(discrete_callbacks),
to_vector(src),
to_vector(dest),
to_vector(equation),
to_vector(weight),
to_vector(delay),
to_vector(discrete_callbacks),
spike_affects,
learning_rule
)
end

function Base.isempty(conn::Connector)
return all(isempty.(conn.equation)) && all(isempty.(conn.weight)) && all(isempty.(conn.delay)) && all(isempty.(conn.discrete_callbacks)) && isempty(conn.spike_affects) && isempty(conn.learning_rule)
return isempty(conn.equation) && isempty(conn.weight) && isempty(conn.delay) && isempty(conn.discrete_callbacks) && isempty(conn.spike_affects) && isempty(conn.learning_rule)
end

connection_rule(blox_src, blox_dest; kwargs...) = Connector(blox_src, blox_dest; kwargs...)

Base.show(io::IO, c::Connector) = print(io, "$(c.source) => $(c.destination) with ", c.equation)

function string_to_show(v, title)
s = string.(v)

return string("\t $(title): ", "[", join(s, " , "), "]")
function show_field(v::AbstractVector, title)
if !isempty(v)
println(title, " :")
for val in v
println("\t $(val)")
end
end
end

function string_to_show(d::Dict, title)
s = [string(k, " => ", v) for (k,v) in d]

return string("\t $(title): ", "[", join(s, " , "), "]")
function show_field(d::Dict, title)
if !isempty(d)
println(title, " :")
for (k, v) in d
println("\t ", k, " => ", v)
end
end
end

function Base.show(io::IO, ::MIME"text/plain", c::Connector)
N_conns = length(c.source)

for i in Base.OneTo(N_conns)
println("Connection $(c.source[i]) => $(c.destination[i])")
println("Connections :")
for (s, d) in zip(c.source, c.destination)
println("\t $(s) => $(d)")
end

!isempty(c.equation[i]) && println(string_to_show(c.equation[i], "Equation"))
!isempty(c.weight[i]) && println(string_to_show(c.weight[i], "Weight"))
!isempty(c.delay[i]) && println(string_to_show(c.delay[i], "Delay"))
show_field(c.equation, "Equations")
show_field(c.weight, "Weights")
show_field(c.delay, "Delays")

d = Dict()
for w in c.weight[i]
if haskey(c.learning_rule, w)
d[w] = c.learning_rule[w]
end
d = Dict()
for w in c.weight
if haskey(c.learning_rule, w)
d[w] = c.learning_rule[w]
end
!isempty(d) && println(string_to_show(d, "Plasticity model"))

for s in c.source[i]
if haskey(c.spike_affects, s)
println("\t $(s) spikes affect :")
vars, vals = c.spike_affects[s]
for (var, val) in zip(vars, vals)
println("\t \t $(var) += $(val)")
end
end
show_field(d, "Plasticity model")

for s in c.source
if haskey(c.spike_affects, s)
println("$(s) spikes affect :")
vars, vals = c.spike_affects[s]
for (var, val) in zip(vars, vals)
println("\t $(var) += $(val)")
end
end
end
end
end

function accumulate_equations!(eqs::AbstractVector{<:Equation}, bloxs)
Expand All @@ -108,7 +112,21 @@ function accumulate_equations!(eqs1::Vector{<:Equation}, eqs2::Vector{<:Equation
return eqs1
end

ModelingToolkit.equations(c::Connector) = reduce(accumulate_equations!, c.equation)
function accumulate_equations(eqs1::Vector{<:Equation}, eqs2::Vector{<:Equation})
eqs = copy(eqs1)
for eq in eqs2
lhs = eq.lhs
idx = find_eq(eqs1, lhs)

if isnothing(idx)
push!(eqs, eq)
else
eqs[idx] = eqs[idx].lhs ~ eqs[idx].rhs + eq.rhs
end
end

return eqs
end

function tuple_append!(t1::Tuple, t2::Tuple)
append!(first(t1), first(t2))
Expand All @@ -117,20 +135,24 @@ function tuple_append!(t1::Tuple, t2::Tuple)
return t1
end

discrete_callbacks(c::Connector) = reduce(append!, c.discrete_callbacks)
ModelingToolkit.equations(c::Connector) = c.equation

sources(c::Connector) = reduce(append!, c.source)
discrete_callbacks(c::Connector) = c.discrete_callbacks

destinations(c::Connector) = reduce(append!, c.destination)
sources(c::Connector) = c.source

weights(c::Connector) = reduce(append!, c.weight)
destinations(c::Connector) = c.destination

delays(c::Connector) = reduce(append!, c.delay)
weights(c::Connector) = c.weight

delays(c::Connector) = c.delay

spike_affects(c::Connector) = c.spike_affects

learning_rules(c::Connector) = c.learning_rule

learning_rules(conns::AbstractVector{<:Connector}) = mapreduce(c -> c.learning_rule, merge!, conns)

get_equations_with_parameter_lhs(eqs::AbstractVector{<:Equation}) = filter(eq -> isparameter(eq.lhs), eqs)

get_equations_with_state_lhs(eqs::AbstractVector{<:Equation}) = filter(eq -> !isparameter(eq.lhs), eqs)
Expand Down Expand Up @@ -183,7 +205,7 @@ end
function Base.merge!(c1::Connector, c2::Connector)
append!(c1.source, c2.source)
append!(c1.destination, c2.destination)
append!(c1.equation, c2.equation)
accumulate_equations!(c1.equation, c2.equation)
append!(c1.weight, c2.weight)
append!(c1.delay, c2.delay)
append!(c1.discrete_callbacks, c2.discrete_callbacks)
Expand Down Expand Up @@ -256,13 +278,19 @@ function indegree_constrained_connections(neurons_src, neurons_dst, name_src, na
return reduce(merge!, C)
end

function Connector(blox_src::AbstractBlox, blox_dest::AbstractBlox; kwargs...)
connection_rule(blox_src, blox_dest; kwargs...) = Connector(blox_src, blox_dest; kwargs...)

connection_equation(blox_src, blox_dest; kwargs...) = Connector(blox_src, blox_dest; kwargs...).equation

function connection_equation(blox_src, blox_dest, w) end

function Connector(blox_src, blox_dest::AbstractBlox; kwargs...)
sys_src = get_namespaced_sys(blox_src)
sys_dest = get_namespaced_sys(blox_dest)

w = generate_weight_param(blox_src, blox_dest; kwargs...)

eq = sys_dest.jcn ~ w*sys_src.v
eq = connection_equation(blox_src, blox_dest, w)

return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w)
end
Expand Down Expand Up @@ -767,7 +795,9 @@ function Connector(
push!(dc, cb_neuron_init)
end

return Connector(namespaced_nameof(blox_src), namespaced_nameof(blox_dest); discrete_callbacks=dc)
w = generate_weight_param(blox_src, blox_dest; weight=1)

return Connector(namespaced_nameof(blox_src), namespaced_nameof(blox_dest); discrete_callbacks=dc, weight=w)
end

function Connector(
Expand Down
6 changes: 3 additions & 3 deletions src/blox/cortical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ struct CorticalBlox <: CompositeBlox
add_edge!(g, N_wta+1, i, Dict(:weight => 1))
end

bc = connector_from_graph(g)
bc = connectors_from_graph(g)
# If a namespace is not provided, assume that this is the highest level
# and construct the ODEsystem from the graph.
# If there is a higher namespace, construct only a subsystem containing the parts of this level
Expand Down Expand Up @@ -154,7 +154,7 @@ struct LIFExciCircuitBlox <: CompositeBlox
end
end

bc = connector_from_graph(g)
bc = connectors_from_graph(g)

if skip_system_creation
sys = nothing
Expand Down Expand Up @@ -232,7 +232,7 @@ struct LIFInhCircuitBlox <: CompositeBlox
end
end

bc = connector_from_graph(g)
bc = connectors_from_graph(g)

if skip_system_creation
sys = nothing
Expand Down
Loading

0 comments on commit c00d823

Please sign in to comment.