diff --git a/src/Neuroblox.jl b/src/Neuroblox.jl index 6aa44c1d..c274323c 100644 --- a/src/Neuroblox.jl +++ b/src/Neuroblox.jl @@ -116,11 +116,11 @@ include("blox/winnertakeall.jl") include("blox/subcortical_blox.jl") include("blox/stochastic.jl") include("blox/discrete.jl") +include("blox/ping_neuron_examples.jl") include("blox/reinforcement_learning.jl") include("gui/GUI.jl") include("blox/connections.jl") include("blox/blox_utilities.jl") -include("blox/ping_neuron_examples.jl") include("GraphDynamicsInterop/GraphDynamicsInterop.jl") include("Neurographs.jl") include("adjacency.jl") @@ -256,4 +256,5 @@ export meanfield, meanfield!, rasterplot, rasterplot!, stackplot, stackplot!, fr export powerspectrumplot, powerspectrumplot!, welch_pgram, periodogram, hanning, hamming export detect_spikes, mean_firing_rate, firing_rate export voltage_timeseries, meanfield_timeseries, state_timeseries, get_neurons, get_exci_neurons, get_inh_neurons, get_neuron_color +export AdjacencyMatrix, Connector, connection_rule, connection_equation end diff --git a/src/Neurographs.jl b/src/Neurographs.jl index 807354a1..e6b7b47e 100644 --- a/src/Neurographs.jl +++ b/src/Neurographs.jl @@ -52,30 +52,38 @@ get_dynamics_bloxs(blox::CompositeBlox) = get_parts(blox) flatten_graph(g::MetaDiGraph) = mapreduce(get_dynamics_bloxs, vcat, get_bloxs(g)) -function connector_from_graph(g::MetaDiGraph) - bloxs = get_bloxs(g) - link = BloxConnector(bloxs) +function connectors_from_graph(g::MetaDiGraph) + conns = get_connector.(get_bloxs(g)) + for edge in edges(g) - for v in vertices(g) - b = get_prop(g, v, :blox) - for vn in inneighbors(g, v) - bn = get_prop(g, vn, :blox) - kwargs = props(g, vn, v) - link(bn, b; kwargs...) - end + blox_src = get_prop(g, edge.src, :blox) + blox_dest = get_prop(g, edge.dst, :blox) + + kwargs = props(g, edge) + push!(conns, Connector(blox_src, blox_dest; kwargs...)) end - return link + + filter!(conn -> !isempty(conn), conns) + + return conns +end + +function connector_from_graph(g::MetaDiGraph) + conns = connectors_from_graph(g) + + return isempty(conns) ? Connector(:none, :none) : reduce(merge!, conns) end # Helper function to get delays from a graph function graph_delays(g::MetaDiGraph) - bc = connector_from_graph(g) - return bc.delays + conn = connector_from_graph(g) + + return conn.delay end -generate_discrete_callbacks(blox, ::BloxConnector; t_block = missing) = [] +generate_discrete_callbacks(blox, ::Connector; t_block = missing) = [] -function generate_discrete_callbacks(blox::Union{LIFExciNeuron, LIFInhNeuron}, bc::BloxConnector; t_block = missing) +function generate_discrete_callbacks(blox::Union{LIFExciNeuron, LIFInhNeuron}, bc::Connector; t_block = missing) spike_affects = get_spike_affects(bc) name_blox = namespaced_nameof(blox) sys = get_namespaced_sys(blox) @@ -116,7 +124,7 @@ function generate_discrete_callbacks(blox::Union{LIFExciNeuron, LIFInhNeuron}, b return cb end -function generate_discrete_callbacks(blox::HHNeuronExciBlox, ::BloxConnector; t_block = missing) +function generate_discrete_callbacks(blox::HHNeuronExciBlox, ::Connector; t_block = missing) if !ismissing(t_block) nn = get_namespaced_sys(blox) eq = nn.spikes_window ~ 0 @@ -128,7 +136,7 @@ function generate_discrete_callbacks(blox::HHNeuronExciBlox, ::BloxConnector; t_ end end -function generate_discrete_callbacks(bc::BloxConnector; t_block = missing) +function generate_discrete_callbacks(bc::Connector; t_block = missing) eqs_params = get_equations_with_parameter_lhs(bc) if !ismissing(t_block) && !isempty(eqs_params) @@ -139,7 +147,7 @@ function generate_discrete_callbacks(bc::BloxConnector; t_block = missing) end end -function generate_discrete_callbacks(g::MetaDiGraph, bc::BloxConnector; t_block = missing) +function generate_discrete_callbacks(g::MetaDiGraph, bc::Connector; t_block = missing) bloxs = flatten_graph(g) cbs = mapreduce(vcat, bloxs) do blox @@ -175,30 +183,34 @@ function system_from_graph(g::MetaDiGraph, p::Vector{Num}=Num[]; name=nothing, t isempty(p) || error(ArgumentError("The GraphDynamics.jl backend does yet support extra parameter lists. Got $p.")) GraphDynamicsInterop.graphsystem_from_graph(g; kwargs...) else - bc = connector_from_graph(g) if isnothing(name) throw(UndefKeywordError(:name)) end + + bc = connector_from_graph(g) + return system_from_graph(g, bc, p; name, t_block, simplify, kwargs...) end end -function system_from_graph(g::MetaDiGraph, bc::BloxConnector, p::Vector{Num}=Num[]; - name, t_block=missing, simplify=true, simplify_kwargs...) - blox_syss = get_system(g) +function system_from_graph(g::MetaDiGraph, bc::Connector, p::Vector{Num}=Num[]; name=nothing, t_block=missing, simplify=true, graphdynamics=false, kwargs...) + bloxs = get_bloxs(g) + blox_syss = get_system.(bloxs) + + accumulate_equations!(bc, bloxs) + connection_eqs = get_equations_with_state_lhs(bc) discrete_cbs = identity.(generate_discrete_callbacks(g, bc; t_block)) sys = compose(System(connection_eqs, t, [], vcat(params(bc), p); name, discrete_events = discrete_cbs), blox_syss) if simplify - structural_simplify(sys; simplify_kwargs...) + structural_simplify(sys; kwargs...) else sys end end - function system_from_parts(parts::AbstractVector; name) return compose(System(Equation[], t; name), get_system.(parts)) end diff --git a/src/adjacency.jl b/src/adjacency.jl index 11bfc311..61815120 100644 --- a/src/adjacency.jl +++ b/src/adjacency.jl @@ -3,28 +3,33 @@ struct AdjacencyMatrix names::Vector{Symbol} end -function AdjacencyMatrix(name) - return AdjacencyMatrix(spzeros(1,1), [name]) +function AdjacencyMatrix(names::AbstractVector) + return AdjacencyMatrix(spzeros(1,1), names) end -function Base.merge(adj1::AdjacencyMatrix, adj2::AdjacencyMatrix) - return AdjacencyMatrix( - cat(adj1.matrix, adj2.matrix; dims=(1,2)), - vcat(adj1.names, adj2.names) - ) -end +function AdjacencyMatrix(C::Connector) + weights = C.weight + srcs = C.source + dests = C.destination + names = unique(vcat(srcs, dests)) + sort!(names) -get_adjacency(bc::BloxConnector) = bc.adjacency -get_adjacency(blox::CompositeBlox) = (get_adjacency ∘ get_connector)(blox) -get_adjacency(blox) = AdjacencyMatrix(namespaced_nameof(blox)) + ADJ = AdjacencyMatrix(spzeros(length(names), length(names)), names) + for i in eachindex(srcs) + add_adjacency_edge!(ADJ, srcs[i], dests[i], weights[i]) + end -function get_adjacency(g::MetaDiGraph) - bc = connector_from_graph(g) - return get_adjacency(bc) + return ADJ end -function get_adjacency(bc::BloxConnector, sys::AbstractODESystem, prob::ODEProblem) - A = get_adjacency(bc) +AdjacencyMatrix(blox::CompositeBlox) = AdjacencyMatrix(get_connector(blox)) + +AdjacencyMatrix(blox) = AdjacencyMatrix(namespaced_nameof(blox)) + +AdjacencyMatrix(g::MetaDiGraph) = AdjacencyMatrix(connector_from_graph(g)) + +function AdjacencyMatrix(bc::Connector, sys::AbstractODESystem, prob::ODEProblem) + A = AdjacencyMatrix(bc) names = A.names mat = A.matrix @@ -43,12 +48,29 @@ function get_adjacency(bc::BloxConnector, sys::AbstractODESystem, prob::ODEProbl return AdjacencyMatrix(S, names) end -function get_adjacency(agent::Agent) +function AdjacencyMatrix(agent::Agent) prob = agent.problem sys = get_system(agent) bc = get_connector(agent) - return get_adjacency(bc, sys, prob) + return AdjacencyMatrix(bc, sys, prob) +end + +function add_adjacency_edge!(ADJ::AdjacencyMatrix, name_src, name_dest, weight) + src_idx = findfirst(x -> isequal(name_src, x), ADJ.names) + dest_idx = findfirst(x -> isequal(name_dest, x), ADJ.names) + + weight_def = ModelingToolkit.getdefault(weight) + weight_value = substitute(weight_def, map(x -> x => ModelingToolkit.getdefault(x), Symbolics.get_variables(weight_def))) + + ADJ.matrix[src_idx, dest_idx] = weight_value +end + +function Base.merge(adj1::AdjacencyMatrix, adj2::AdjacencyMatrix) + return AdjacencyMatrix( + cat(adj1.matrix, adj2.matrix; dims=(1,2)), + vcat(adj1.names, adj2.names) + ) end function adjmatrixfromdigraph(g::MetaDiGraph) diff --git a/src/blox/DBS_Model_Blox_Adam_Brown.jl b/src/blox/DBS_Model_Blox_Adam_Brown.jl index 78772bab..e481e46d 100644 --- a/src/blox/DBS_Model_Blox_Adam_Brown.jl +++ b/src/blox/DBS_Model_Blox_Adam_Brown.jl @@ -105,9 +105,9 @@ struct Striatum_MSN_Adam <: CompositeBlox end end parts = n_inh - + bc = connector_from_graph(g) - + sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name) m = if isnothing(namespace) @@ -184,9 +184,9 @@ struct Striatum_FSI_Adam <: CompositeBlox end parts = n_inh - + bc = connector_from_graph(g) - + sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name) m = if isnothing(namespace) @@ -254,9 +254,9 @@ struct GPe_Adam <: CompositeBlox end end parts = n_inh - + bc = connector_from_graph(g) - + sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name) m = if isnothing(namespace) @@ -323,9 +323,9 @@ struct STN_Adam <: CompositeBlox end end parts = n_exci - - bc = connector_from_graph(g) + bc = connector_from_graph(g) + sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name) m = if isnothing(namespace) @@ -334,6 +334,7 @@ struct STN_Adam <: CompositeBlox sys_namespace = System(Equation[], t; name=namespaced_name(namespace, name)) [s for s in unknowns.((sys_namespace,), unknowns(sys)) if contains(string(s), "V(t)")] end + new(namespace, parts, sys, bc, m, connection_matrix) end diff --git a/src/blox/blox_utilities.jl b/src/blox/blox_utilities.jl index 58d7e520..74e0e2de 100644 --- a/src/blox/blox_utilities.jl +++ b/src/blox/blox_utilities.jl @@ -64,6 +64,7 @@ function get_neurons(vn::AbstractVector{<:AbstractBlox}) end get_parts(blox::CompositeBlox) = blox.parts +get_parts(blox::Union{AbstractBlox, ObserverBlox}) = blox get_components(blox::CompositeBlox) = mapreduce(x -> get_components(x), vcat, get_parts(blox)) get_components(blox::Vector{<:AbstractBlox}) = mapreduce(x -> get_components(x), vcat, blox) @@ -97,6 +98,7 @@ end get_namespaced_sys(sys::AbstractODESystem) = sys nameof(blox) = (nameof ∘ get_system)(blox) +nameof(blox::AbstractActionSelection) = blox.name namespaceof(blox) = blox.namespace @@ -104,7 +106,7 @@ namespaced_nameof(blox) = namespaced_name(inner_namespaceof(blox), nameof(blox)) """ Returns the complete namespace EXCLUDING the outermost (highest) level. - This is useful for manually preparing equations (e.g. connections, see BloxConnector), + This is useful for manually preparing equations (e.g. connections, see Connector), that will later be composed and will automatically get the outermost namespace. """ function inner_namespaceof(blox) @@ -119,7 +121,7 @@ end namespaced_name(parent_name, name) = Symbol(parent_name, :₊, name) namespaced_name(::Nothing, name) = Symbol(name) -function find_eq(eqs::AbstractVector{<:Equation}, lhs) +function find_eq(eqs::Union{AbstractVector{<:Equation}, Equation}, lhs) findfirst(eqs) do eq lhs_vars = get_variables(eq.lhs) length(lhs_vars) == 1 && isequal(only(lhs_vars), lhs) @@ -136,7 +138,7 @@ end the higher-level namespaces will be added to them. If blox isa AbstractComponent, it is assumed that it contains a `connector` field, - which holds a `BloxConnector` object with all relevant connections + which holds a `Connector` object with all relevant connections from lower levels and this level. """ function get_input_equations(blox::Union{AbstractBlox, ObserverBlox}) @@ -162,28 +164,28 @@ function get_input_equations(blox::Union{AbstractBlox, ObserverBlox}) end get_connector(blox::Union{CompositeBlox, Agent}) = blox.connector +get_connector(blox) = Connector(namespaced_nameof(blox), namespaced_nameof(blox)) -get_input_equations(bc::BloxConnector) = bc.eqs -get_input_equations(blox::Union{CompositeBlox, AbstractComponent}) = (get_input_equations ∘ get_connector)(blox) +get_input_equations(bc::Connector) = bc.equation get_input_equations(blox) = [] -get_weight_parameters(bc::BloxConnector) = bc.weights +get_weight_parameters(bc::Connector) = bc.weights get_weight_parameters(blox::Union{CompositeBlox, AbstractComponent}) = (get_weight_parameters ∘ get_connector)(blox) get_weight_parameters(blox) = Num[] -get_delay_parameters(bc::BloxConnector) = bc.delays +get_delay_parameters(bc::Connector) = bc.delays get_delay_parameters(blox::Union{CompositeBlox, AbstractComponent}) = (get_delay_parameters ∘ get_connector)(blox) get_delay_parameters(blox) = Num[] -get_discrete_callbacks(bc::BloxConnector) = bc.discrete_callbacks +get_discrete_callbacks(bc::Connector) = bc.discrete_callbacks get_discrete_callbacks(blox::Union{CompositeBlox, AbstractComponent}) = (get_discrete_callbacks ∘ get_connector)(blox) get_discrete_callbacks(blox) = [] -get_spike_affects(bc::BloxConnector) = bc.spike_affects +get_spike_affects(bc::Connector) = bc.spike_affects get_spike_affects(blox::Union{CompositeBlox, AbstractComponent}) = (get_spike_affects ∘ get_connector)(blox) get_spike_affects(blox) = Dict{Symbol, Tuple{Vector{Num}, Vector{Num}}}() -get_weight_learning_rules(bc::BloxConnector) = bc.learning_rules +get_weight_learning_rules(bc::Connector) = bc.learning_rules get_weight_learning_rules(blox::Union{CompositeBlox, AbstractComponent}) = (get_weight_learning_rules ∘ get_connector)(blox) get_weight_learning_rules(blox) = Dict{Num, AbstractLearningRule}() @@ -253,6 +255,14 @@ function get_connection_matrix(kwargs, name_out, name_in, N_out, N_in) connection_matrix end +function get_learning_rule(kwargs, name_src, name_dest) + if haskey(kwargs, :learning_rule) + return deepcopy(kwargs[:learning_rule]) + else + return NoLearningRule() + end +end + function get_weights(agent::Agent, blox_out, blox_in) ps = parameters(agent.odesystem) pv = agent.problem.p @@ -622,3 +632,16 @@ function get_sampling_info(sol::SciMLBase.AbstractSolution; sampling_rate=nothin return nothing, 1000 / sampling_rate end end + +function narrowtype_union(d::Dict) + types = unique(typeof.(values(d))) + U = Union{types...} + + return U +end + +function narrowtype(d::Dict) + U = narrowtype_union(d) + + return Dict{Num, U}(d) +end diff --git a/src/blox/canonicalmicrocircuit.jl b/src/blox/canonicalmicrocircuit.jl index dd72d05e..16cf305e 100644 --- a/src/blox/canonicalmicrocircuit.jl +++ b/src/blox/canonicalmicrocircuit.jl @@ -27,6 +27,7 @@ mutable struct CanonicalMicroCircuitBlox <: CompositeBlox parts odesystem connector + function CanonicalMicroCircuitBlox(;name, namespace=nothing, τ_ss=0.002, τ_sp=0.002, τ_ii=0.016, τ_dp=0.028, r_ss=2.0/3.0, r_sp=2.0/3.0, r_ii=2.0/3.0, r_dp=2.0/3.0) @named ss = JansenRitSPM12(;namespace=namespaced_name(namespace, name), τ=τ_ss, r=r_ss) # spiny stellate @named sp = JansenRitSPM12(;namespace=namespaced_name(namespace, name), τ=τ_sp, r=r_sp) # superficial pyramidal @@ -47,14 +48,9 @@ mutable struct CanonicalMicroCircuitBlox <: CompositeBlox add_edge!(g, ii => dp; :weight => -400.0) add_edge!(g, dp => dp; :weight => -200.0) - # Construct a BloxConnector object from the graph - # containing all connection equations from lower levels and this level. bc = connector_from_graph(g) # If a namespace is not provided, assume that this is the highest level # and construct the ODEsystem from the graph. - # If there is a higher namespace, construct only a subsystem containing the parts of this level - # and propagate the BloxConnector object `bc` to the higher level - # to potentially add more terms to the same connections. sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(sblox_parts; name) new(namespace, sblox_parts, sys, bc) diff --git a/src/blox/connections.jl b/src/blox/connections.jl index a3c7ad12..b47b1fca 100644 --- a/src/blox/connections.jl +++ b/src/blox/connections.jl @@ -1,65 +1,107 @@ -mutable struct BloxConnector - eqs::Vector{Equation} - weights::Vector{Num} - delays::Vector{Num} +struct Connector + source::Vector{Symbol} + destination::Vector{Symbol} + equation::Vector{Equation} + weight::Vector{Num} + delay::Vector{Num} discrete_callbacks spike_affects::Dict{Symbol, Tuple{Vector{Num}, Vector{Num}}} - learning_rules - adjacency - - BloxConnector() = new(Equation[], Num[], Num[], Pair{Any, Vector{Equation}}[], Dict{Symbol, Vector{Num}}(), Dict{Num, AbstractLearningRule}()) - - function BloxConnector(bloxs) - eqs = mapreduce(get_input_equations, vcat, bloxs) - weights = mapreduce(get_weight_parameters, vcat, bloxs) - delays = mapreduce(get_delay_parameters, vcat, bloxs) - discrete_callbacks = mapreduce(get_discrete_callbacks, vcat, bloxs) - # spike_affects holds a Dictionary that maps - # the name of a source Blox to a Tuple of (states, parameters) of a destination Blox. - # The states are affected by a discrete callback of the source Blox - # and the parameters determine the amount of this affect like `states .+= parameters`. - # Typically this is used when a source Blox spikes, so its Voltage state crosses a threshold, - # and this spike affects synaptic parameters of every destination Blox that it connects to. - spike_affects = mapreduce(get_spike_affects, merge, bloxs) - learning_rules = mapreduce(get_weight_learning_rules, merge, bloxs) - adjacency = mapreduce(get_adjacency, merge, bloxs) - - new(eqs, weights, delays, discrete_callbacks, spike_affects, learning_rules, adjacency) - end + learning_rule::Dict{Num, AbstractLearningRule} end -function accumulate_equation!(bc::BloxConnector, eq) - lhs = eq.lhs - idx = find_eq(bc.eqs, lhs) - bc.eqs[idx] = bc.eqs[idx].lhs ~ bc.eqs[idx].rhs + eq.rhs +function Connector( + src::Symbol, + dest::Symbol; + equation=Equation[], + weight=Num[], + delay=Num[], + discrete_callbacks=[], + spike_affects=Dict{Symbol, Tuple{Vector{Num}, Vector{Num}}}(), + learning_rule=Dict{Num, AbstractLearningRule}() + ) + + # Check if all weigths have NoLearningRule and if so don't keep them in the final Dict. + U = narrowtype_union(learning_rule) + learning_rule = U <: NoLearningRule ? Dict{Num, NoLearningRule}() : learning_rule + + Connector( + [src], + [dest], + to_vector(equation), + to_vector(weight), + to_vector(delay), + to_vector(discrete_callbacks), + spike_affects, + learning_rule + ) end -function accumulate_spike_affects!(bc::BloxConnector, name_blox_src, states_affect, params_affect) - if haskey(bc.spike_affects, name_blox_src) - spike_affects = bc.spike_affects[name_blox_src] - append!(spike_affects[1], states_affect) - append!(spike_affects[2], params_affect) - else - bc.spike_affects[name_blox_src] = (states_affect, params_affect) +function Base.isempty(conn::Connector) + return isempty(conn.equation) && isempty(conn.weight) && isempty(conn.delay) && isempty(conn.discrete_callbacks) && isempty(conn.spike_affects) && isempty(conn.learning_rule) +end + +connection_rule(blox_src, blox_dest; kwargs...) = Connector(blox_src, blox_dest; kwargs...) + +connection_equation(blox_src, blox_dest; kwargs...) = get_single_element(Connector(blox_src, blox_dest; kwargs...).equation) + +get_single_element(v::Union{AbstractVector, Dict}) = length(v) == 1 ? only(v) : v + +Base.show(io::IO, c::Connector) = print(io, "$(c.source) connects to $(c.destination) with ", c.equation) + +function Base.show(io::IO, ::MIME"text/plain", c::Connector) + + lines = ["Connection from $(get_single_element(c.source)) to $(get_single_element(c.destination))"] + + !isempty(c.equation) && push!(lines, "Equation : $(get_single_element(c.equation))") + !isempty(c.weight) && push!(lines, "Weight : $(get_single_element(c.weight))") + !isempty(c.delay) && push!(lines, "Delay : $(get_single_element(c.delay))") + !isempty(c.discrete_callbacks) && push!(lines, "Preset time events : $(get_single_element(c.discrete_callbacks))") + + if !isempty(c.spike_affects) + push!(lines, "$(get_single_element(c.source)) spikes affect :") + for (k, v) in c.spike_affects + var, val = get_single_element.(v) + push!(lines, "\t $(var) += $(val)") + end end + + !isempty(c.learning_rule) && push!(lines, "Plasticity rule : $(get_single_element(c.learning_rule))") + + print(io, join(lines, " \n ")) end -function add_adjacency_edge!(bc::BloxConnector, blox_src, blox_dest, weight) +function accumulate_equations!(C::Connector, bloxs) + init_eqs = mapreduce(get_input_equations, vcat, bloxs) + accumulate_equations!(C.equation, init_eqs) - n_src = namespaced_nameof(blox_src) - n_dest = namespaced_nameof(blox_dest) + return C +end - adj = get_adjacency(bc) - src_idx = findfirst(x -> isequal(n_src, x), adj.names) - dest_idx = findfirst(x -> isequal(n_dest, x), adj.names) +function accumulate_equations!(eqs1::Vector{<:Equation}, eqs2::Vector{<:Equation}) + for eq in eqs2 + lhs = eq.lhs + idx = find_eq(eqs1, lhs) + + if isnothing(idx) + push!(eqs1, eq) + else + eqs1[idx] = eqs1[idx].lhs ~ eqs1[idx].rhs + eq.rhs + end + end - weight_value = substitute(weight, map(x -> x => ModelingToolkit.getdefault(x), Symbolics.get_variables(weight))) - adj.matrix[src_idx, dest_idx] = weight_value + return eqs1 end -get_equations_with_parameter_lhs(bc) = filter(eq -> isparameter(eq.lhs), bc.eqs) +function tuple_append!(t1::Tuple, t2::Tuple) + append!(first(t1), first(t2)) + append!(last(t1), last(t2)) + + return t1 +end -get_equations_with_state_lhs(bc) = filter(eq -> !isparameter(eq.lhs), bc.eqs) +get_equations_with_parameter_lhs(bc) = filter(eq -> isparameter(eq.lhs), bc.equation) + +get_equations_with_state_lhs(bc) = filter(eq -> !isparameter(eq.lhs), bc.equation) function generate_weight_param(blox_out, blox_in; kwargs...) name_out = namespaced_nameof(blox_out) @@ -91,29 +133,62 @@ function generate_gap_weight_param(blox_out, blox_in; kwargs...) return gw end -function hypergeometric_connections!(bc, neurons_out, neurons_in, name_out, name_in; kwargs...) +""" + Helper to merge delay and weight into a single vector +""" +function params(bc::Connector) + weight = [] + for w in bc.weight + append!(weight, Symbolics.get_variables(w)) + end + if isempty(weight) + return vcat(weight, bc.delay) + else + return vcat(reduce(vcat, weight), bc.delay) + end +end + +function Base.merge!(c1::Connector, c2::Connector) + append!(c1.source, c2.source) + append!(c1.destination, c2.destination) + accumulate_equations!(c1.equation, c2.equation) + append!(c1.weight, c2.weight) + append!(c1.delay, c2.delay) + append!(c1.discrete_callbacks, c2.discrete_callbacks) + mergewith!(tuple_append!, c1.spike_affects, c2.spike_affects) + merge!(c1.learning_rule, c2.learning_rule) + + return c1 +end + +Base.merge(c1::Connector, c2::Connector) = Base.merge!(deepcopy(c1), c2) + +function hypergeometric_connections(neurons_src, neurons_dest, name_out, name_in; kwargs...) density = get_density(kwargs, name_out, name_in) - N_connects = density * length(neurons_in) * length(neurons_out) - out_degree = Int(ceil(N_connects / length(neurons_out))) - in_degree = Int(ceil(N_connects / length(neurons_in))) + N_connects = density * length(neurons_dest) * length(neurons_src) + out_degree = Int(ceil(N_connects / length(neurons_src))) + in_degree = Int(ceil(N_connects / length(neurons_dest))) wt = get_weight(kwargs,name_out, name_in) - outgoing_connections = zeros(Int, length(neurons_out)) - for neuron_postsyn in neurons_in + C = Connector[] + outgoing_connections = zeros(Int, length(neurons_src)) + for neuron_postsyn in neurons_dest rem = findall(x -> x < out_degree, outgoing_connections) idx = sample(rem, min(in_degree, length(rem)); replace=false) if length(wt) == 1 - for neuron_presyn in neurons_out[idx] - bc(neuron_presyn, neuron_postsyn; kwargs...) + for neuron_presyn in neurons_src[idx] + push!(C, Connector(neuron_presyn, neuron_postsyn; kwargs...)) end else for i in idx kwargs = (kwargs...,weight=wt[i]) - bc(neurons_out[i], neuron_postsyn; kwargs...) + push!(C, Connector(neurons_src[i], neuron_postsyn; kwargs...)) end end outgoing_connections[idx] .+= 1 end + + return reduce(merge!, C) end function indegree_constrained_connection_matrix(density, N_src, N_dst; kwargs...) @@ -129,603 +204,580 @@ function indegree_constrained_connection_matrix(density, N_src, N_dst; kwargs... conn_mat end -function indegree_constrained_connections!(bc, neurons_src, neurons_dst, name_src, name_dst; kwargs...) +function indegree_constrained_connections(neurons_src, neurons_dst, name_src, name_dst; kwargs...) N_src = length(neurons_src) N_dst = length(neurons_dst) conn_mat = get(kwargs, :connection_matrix) do density = get_density(kwargs, name_src, name_dst) indegree_constrained_connection_matrix(density, N_src, N_dst; kwargs...) end + + C = Connector[] for j ∈ 1:N_dst for i ∈ 1:N_src if conn_mat[i, j] - bc(neurons_src[i], neurons_dst[j]; kwargs...) + push!(C, Connector(neurons_src[i], neurons_dst[j]; kwargs...)) end end end -end -""" - Helper to merge delays and weights into a single vector -""" -function params(bc::BloxConnector) - weights = [] - for w in bc.weights - append!(weights, Symbolics.get_variables(w)) - end - if isempty(weights) - return vcat(weights, bc.delays) - else - return vcat(reduce(vcat, weights), bc.delays) - end + return reduce(merge!, C) end -function (bc::BloxConnector)( - HH_out::Union{HHNeuronExciBlox, HHNeuronInhibBlox, HHNeuronInhib_MSN_Adam_Blox, HHNeuronExci_STN_Adam_Blox, HHNeuronInhib_GPe_Adam_Blox}, - HH_in::Union{HHNeuronExciBlox, HHNeuronInhibBlox, HHNeuronInhib_MSN_Adam_Blox, HHNeuronInhib_FSI_Adam_Blox, HHNeuronExci_STN_Adam_Blox, HHNeuronInhib_GPe_Adam_Blox}; - kwargs... -) - sys_out = get_namespaced_sys(HH_out) - sys_in = get_namespaced_sys(HH_in) +function Connector(blox_src::AbstractBlox, blox_dest::AbstractBlox; kwargs...) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) - w = generate_weight_param(HH_out, HH_in; kwargs...) - push!(bc.weights, w) + w = generate_weight_param(blox_src, blox_dest; kwargs...) - if haskey(kwargs, :learning_rule) - lr = deepcopy(kwargs[:learning_rule]) - maybe_set_state_pre!(lr, sys_out.spikes_cumulative) - maybe_set_state_post!(lr,sys_in.spikes_cumulative) - bc.learning_rules[w] = lr - end + eq = sys_dest.jcn ~ w*sys_src.v - STA = get_sta(kwargs, nameof(HH_out), nameof(HH_in)) - eq = if STA - sys_in.I_syn ~ -w * sys_in.Gₛₜₚ * sys_out.G * (sys_in.V - sys_out.E_syn) - else - sys_in.I_syn ~ -w * sys_out.G * (sys_in.V - sys_out.E_syn) - end - - accumulate_equation!(bc, eq) - - add_adjacency_edge!(bc, HH_out, HH_in, get_weight(kwargs, namespaced_nameof(HH_out), namespaced_nameof(HH_in))) + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) end -function (bc::BloxConnector)( - HH_out::HHNeuronInhib_FSI_Adam_Blox, - HH_in::Union{HHNeuronExciBlox, HHNeuronInhibBlox, HHNeuronInhib_MSN_Adam_Blox, HHNeuronExci_STN_Adam_Blox, HHNeuronInhib_GPe_Adam_Blox}; +function Connector( + blox_src::Union{HHNeuronExciBlox, HHNeuronInhibBlox, HHNeuronInhib_MSN_Adam_Blox, HHNeuronExci_STN_Adam_Blox, HHNeuronInhib_GPe_Adam_Blox}, + blox_dest::Union{HHNeuronExciBlox, HHNeuronInhibBlox, HHNeuronInhib_MSN_Adam_Blox, HHNeuronInhib_FSI_Adam_Blox, HHNeuronExci_STN_Adam_Blox, HHNeuronInhib_GPe_Adam_Blox}; kwargs... ) - sys_out = get_namespaced_sys(HH_out) - sys_in = get_namespaced_sys(HH_in) - - w = generate_weight_param(HH_out, HH_in; kwargs...) - push!(bc.weights, w) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) - eq = sys_in.I_syn ~ -w * sys_out.G * (sys_in.V - sys_out.E_syn) - - accumulate_equation!(bc, eq) -end + w = generate_weight_param(blox_src, blox_dest; kwargs...) -function (bc::BloxConnector)( - HH_out::HHNeuronInhib_FSI_Adam_Blox, - HH_in::HHNeuronInhib_FSI_Adam_Blox; - kwargs... -) - sys_out = get_namespaced_sys(HH_out) - sys_in = get_namespaced_sys(HH_in) - - w = generate_weight_param(HH_out, HH_in; kwargs...) - push!(bc.weights, w) - - eq = sys_in.I_syn ~ -w * sys_out.Gₛ * (sys_in.V - sys_out.E_syn) - - accumulate_equation!(bc, eq) - - GAP = get_gap(kwargs, nameof(HH_out), nameof(HH_in)) - if GAP - w_gap = generate_gap_weight_param(HH_out, HH_in; kwargs...) - push!(bc.weights, w_gap) - eq2 = sys_in.I_gap ~ -w_gap * (sys_in.V - sys_out.V) - accumulate_equation!(bc, eq2) - eq3 = sys_out.I_gap ~ -w_gap * (sys_out.V - sys_in.V) - accumulate_equation!(bc, eq3) + lr = get_learning_rule(kwargs, nameof(sys_src), nameof(sys_dest)) + maybe_set_state_pre!(lr, sys_src.spikes_cumulative) + maybe_set_state_post!(lr, sys_dest.spikes_cumulative) + + STA = get_sta(kwargs, nameof(blox_src), nameof(blox_dest)) + eq = if STA + sys_dest.I_syn ~ -w * sys_dest.Gₛₜₚ * sys_src.G * (sys_dest.V - sys_src.E_syn) + else + sys_dest.I_syn ~ -w * sys_src.G * (sys_dest.V - sys_src.E_syn) end -end - -function (bc::BloxConnector)( - cb_out::Union{Striatum_MSN_Adam,Striatum_FSI_Adam,GPe_Adam}, - cb_in::Union{Striatum_MSN_Adam,Striatum_FSI_Adam,GPe_Adam}; - kwargs... -) - neurons_in = get_inh_neurons(cb_in) - neurons_out = get_inh_neurons(cb_out) - indegree_constrained_connections!(bc, neurons_out, neurons_in, nameof(cb_out), nameof(cb_in); kwargs...) + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w, learning_rule=Dict(w => lr)) end -function (bc::BloxConnector)( - cb_out::STN_Adam, - cb_in::Union{Striatum_MSN_Adam,Striatum_FSI_Adam,GPe_Adam}; +function Connector( + blox_src::Union{HHNeuronInhib_MSN_Adam_Blox, HHNeuronExci_STN_Adam_Blox, HHNeuronInhib_GPe_Adam_Blox}, + blox_dest::Union{HHNeuronInhib_MSN_Adam_Blox, HHNeuronInhib_FSI_Adam_Blox, HHNeuronExci_STN_Adam_Blox, HHNeuronInhib_GPe_Adam_Blox}; kwargs... ) - neurons_in = get_inh_neurons(cb_in) - neurons_out = get_exci_neurons(cb_out) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) - indegree_constrained_connections!(bc, neurons_out, neurons_in, nameof(cb_out), nameof(cb_in); kwargs...) -end - -function (bc::BloxConnector)( - cb_out::Union{Striatum_MSN_Adam,Striatum_FSI_Adam,GPe_Adam}, - cb_in::STN_Adam; - kwargs... -) - neurons_in = get_exci_neurons(cb_in) - neurons_out = get_inh_neurons(cb_out) + w = generate_weight_param(blox_src, blox_dest; kwargs...) + + STA = get_sta(kwargs, nameof(blox_src), nameof(blox_dest)) + eq = if STA + sys_dest.I_syn ~ -w * sys_dest.Gₛₜₚ * sys_src.G * (sys_dest.V - sys_src.E_syn) + else + sys_dest.I_syn ~ -w * sys_src.G * (sys_dest.V - sys_src.E_syn) + end - indegree_constrained_connections!(bc, neurons_out, neurons_in, nameof(cb_out), nameof(cb_in); kwargs...) + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) end -function (bc::BloxConnector)( - asc_out::NextGenerationEIBlox, - HH_in::Union{HHNeuronExciBlox, HHNeuronInhibBlox}; +function Connector( + blox_src::HHNeuronInhib_FSI_Adam_Blox, + blox_dest::Union{HHNeuronExciBlox, HHNeuronInhibBlox, HHNeuronInhib_MSN_Adam_Blox, HHNeuronExci_STN_Adam_Blox, HHNeuronInhib_GPe_Adam_Blox}; kwargs... ) - sys_out = get_namespaced_sys(asc_out) - sys_in = get_namespaced_sys(HH_in) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) - w = generate_weight_param(asc_out, HH_in; kwargs...) - push!(bc.weights, w) + w = generate_weight_param(blox_src, blox_dest; kwargs...) - #Z = sys_out.Z - a = sys_out.aₑ - b = sys_out.bₑ - f = (1/(sys_out.Cₑ*π))*(1-a^2-b^2)/(1+2*a+a^2+b^2) - eq = sys_in.I_asc ~ w*f - - accumulate_equation!(bc, eq) + eq = sys_dest.I_syn ~ -w * sys_src.G * (sys_dest.V - sys_src.E_syn) + + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) end -function (bc::BloxConnector)( - asc_out::NextGenerationEIBlox, - cb_in::CorticalBlox; +function Connector( + blox_src::HHNeuronInhib_FSI_Adam_Blox, + blox_dest::HHNeuronInhib_FSI_Adam_Blox; kwargs... ) - neurons_in = get_inh_neurons(cb_in) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) - bc(asc_out, neurons_in[end]; kwargs...) -end + w = generate_weight_param(blox_src, blox_dest; kwargs...) -function (bc::BloxConnector)( - bloxout::CanonicalMicroCircuitBlox, - bloxin::CanonicalMicroCircuitBlox; - kwargs... -) - sysparts_out = get_parts(bloxout) - sysparts_in = get_parts(bloxin) + eq = sys_dest.I_syn ~ -w * sys_src.Gₛ * (sys_dest.V - sys_src.E_syn) - wm = get_weightmatrix(kwargs, namespaced_nameof(bloxin), namespaced_nameof(bloxout)) + GAP = get_gap(kwargs, nameof(blox_src), nameof(blox_dest)) + if GAP + w_gap = generate_gap_weight_param(blox_src, blox_dest; kwargs...) + eq2 = sys_dest.I_gap ~ -w_gap * (sys_dest.V - sys_src.V) + eq3 = sys_src.I_gap ~ -w_gap * (sys_src.V - sys_dest.V) - idxs = findall(!iszero, wm) - for idx in idxs - bc(sysparts_out[idx[2]], sysparts_in[idx[1]]; weight=wm[idx]) + return Connector(nameof(sys_src), nameof(sys_dest); equation=[eq, eq2, eq3], weight=[w, w_gap]) + else + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) end end -function (bc::BloxConnector)( - bloxout::StimulusBlox, - bloxin::CanonicalMicroCircuitBlox; +function Connector( + blox_src::NextGenerationEIBlox, + blox_dest::Union{HHNeuronExciBlox, HHNeuronInhibBlox}; kwargs... ) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) - sysparts_in = get_parts(bloxin) + w = generate_weight_param(blox_src, blox_dest; kwargs...) - bc(bloxout, sysparts_in[1]; kwargs...) -end - -function (bc::BloxConnector)( - bloxout::CanonicalMicroCircuitBlox, - bloxin::ObserverBlox; - kwargs... -) - sysparts_out = get_parts(bloxout) - - bc(sysparts_out[2], bloxin; kwargs...) + a = sys_src.aₑ + b = sys_src.bₑ + f = (1/(sys_src.Cₑ*π))*(1-a^2-b^2)/(1+2*a+a^2+b^2) + eq = sys_dest.I_asc ~ w*f + + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) end -# define a sigmoid function sigmoid(x, r) = one(x) / (one(x) + exp(-r*x)) -function (bc::BloxConnector)( - bloxout::JansenRitSPM12, - bloxin::JansenRitSPM12; +function Connector( + blox_src::JansenRitSPM12, + blox_dest::JansenRitSPM12; kwargs... ) - sys_out = get_namespaced_sys(bloxout) - sys_in = get_namespaced_sys(bloxin) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) - w = generate_weight_param(bloxout, bloxin; kwargs...) - push!(bc.weights, w) + w = generate_weight_param(blox_src, blox_dest; kwargs...) - x = namespace_expr(bloxout.output, sys_out) - r = namespace_expr(bloxout.params[2], sys_out) - push!(bc.weights, r) + x = namespace_expr(blox_src.output, sys_src) + r = namespace_expr(blox_src.params[2], sys_src) - eq = sys_in.jcn ~ sigmoid(x, r)*w + eq = sys_dest.jcn ~ sigmoid(x, r)*w - accumulate_equation!(bc, eq) + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=[w, r]) end - -function (bc::BloxConnector)( - bloxout::NeuralMassBlox, - bloxin::NeuralMassBlox; +function Connector( + blox_src::NeuralMassBlox, + blox_dest::NeuralMassBlox; kwargs... ) - sys_out = get_namespaced_sys(bloxout) - sys_in = get_namespaced_sys(bloxin) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) - w = generate_weight_param(bloxout, bloxin; kwargs...) - push!(bc.weights, w) + w = generate_weight_param(blox_src, blox_dest; kwargs...) - if haskey(kwargs, :learning_rule) - lr = deepcopy(kwargs[:learning_rule]) - bc.learning_rules[w] = lr - end + lr = get_learning_rule(kwargs, nameof(sys_src), nameof(sys_dest)) + + if typeof(blox_src.output) == Num + x = namespace_expr(blox_src.output, sys_src) + eq = sys_dest.jcn ~ x*w - if typeof(bloxout.output) == Num - x = namespace_expr(bloxout.output, sys_out) - eq = sys_in.jcn ~ x*w + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w, learning_rule=Dict(w => lr)) else @variables t - delay = get_delay(kwargs, nameof(bloxout), nameof(bloxin)) - τ_name = Symbol("τ_$(nameof(sys_out))_$(nameof(sys_in))") + delay = get_delay(kwargs, nameof(sys_src), nameof(sys_dest)) + τ_name = Symbol("τ_$(nameof(sys_src))_$(nameof(sys_dest))") τ = only(@parameters $(τ_name)=delay) - push!(bc.delays, τ) - x = namespace_expr(bloxout.output, sys_out) - eq = sys_in.jcn ~ x(t-τ)*w - end - - accumulate_equation!(bc, eq) + x = namespace_expr(blox_src.output, sys_src) + eq = sys_dest.jcn ~ x(t-τ)*w + + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w, delay=τ, learning_rule=Dict(w => lr)) + end end -function (bc::BloxConnector)( - bloxout::KuramotoOscillator, - bloxin::KuramotoOscillator; +function Connector( + blox_src::KuramotoOscillator, + blox_dest::KuramotoOscillator; kwargs... ) - sys_out = get_namespaced_sys(bloxout) - sys_in = get_namespaced_sys(bloxin) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) - w = generate_weight_param(bloxout, bloxin; kwargs...) - push!(bc.weights, w) + w = generate_weight_param(blox_src, blox_dest; kwargs...) - xₒ = namespace_expr(bloxout.output, sys_out) - xᵢ = namespace_expr(bloxin.output, sys_in) #needed because this is also the θ term of the block receiving the connection + xₒ = namespace_expr(blox_src.output, sys_src) + xᵢ = namespace_expr(blox_dest.output, sys_dest) #needed because this is also the θ term of the block receiving the connection - eq = sys_in.jcn ~ w*sin(xₒ - xᵢ) - accumulate_equation!(bc, eq) + eq = sys_dest.jcn ~ w*sin(xₒ - xᵢ) + + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) end # additional dispatch to connect to hemodynamic observer blox -function (bc::BloxConnector)( - bloxout::NeuralMassBlox, - bloxin::ObserverBlox; +function Connector( + blox_src::NeuralMassBlox, + blox_dest::ObserverBlox; kwargs...) - sys_out = get_namespaced_sys(bloxout) - sys_in = get_namespaced_sys(bloxin) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) + + if typeof(blox_src.output) == Num + w = generate_weight_param(blox_src, blox_dest; kwargs...) + x = namespace_expr(blox_src.output, sys_src, nameof(sys_src)) + eq = sys_dest.jcn ~ x*w - if typeof(bloxout.output) == Num - w = generate_weight_param(bloxout, bloxin; kwargs...) - push!(bc.weights, w) - x = namespace_expr(bloxout.output, sys_out, nameof(sys_out)) - eq = sys_in.jcn ~ x*w + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) else # Need t for the delay term @variables t # Define & accumulate delay parameter # Don't accumulate if zero - τ_name = Symbol("τ_$(nameof(sys_out))_$(nameof(sys_in))") + τ_name = Symbol("τ_$(nameof(sys_src))_$(nameof(sys_dest))") τ = only(@parameters $(τ_name)=delay) - push!(bc.delays, τ) + push!(bc.delay, τ) - w_name = Symbol("w_$(nameof(sys_out))_$(nameof(sys_in))") + w_name = Symbol("w_$(nameof(sys_src))_$(nameof(sys_dest))") w = only(@parameters $(w_name)=weight) - push!(bc.weights, w) + push!(bc.weight, w) + + x = namespace_expr(blox_src.output, sys_src, nameof(sys_src)) + eq = sys_dest.jcn ~ x(t-τ)*w - x = namespace_expr(bloxout.output, sys_out, nameof(sys_out)) - eq = sys_in.jcn ~ x(t-τ)*w + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w, delay=τ) end - - accumulate_equation!(bc, eq) end # additional dispatch to connect to a stimulus blox, first crafted for ExternalInput -function (bc::BloxConnector)( - bloxout::StimulusBlox, - bloxin::NeuralMassBlox; +function Connector( + blox_src::StimulusBlox, + blox_dest::NeuralMassBlox; kwargs...) - sys_out = get_namespaced_sys(bloxout) - sys_in = get_namespaced_sys(bloxin) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) + + w = generate_weight_param(blox_src, blox_dest; kwargs...) + x = namespace_expr(blox_src.output, sys_src, nameof(sys_src)) + eq = sys_dest.jcn ~ x*w + + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) +end - w = generate_weight_param(bloxout, bloxin; kwargs...) - push!(bc.weights, w) +function Connector( + blox_src::Union{Striatum_MSN_Adam,Striatum_FSI_Adam,GPe_Adam}, + blox_dest::Union{Striatum_MSN_Adam,Striatum_FSI_Adam,GPe_Adam}; + kwargs... +) + neurons_dest = get_inh_neurons(blox_dest) + neurons_src = get_inh_neurons(blox_src) - x = namespace_expr(bloxout.output, sys_out, nameof(sys_out)) - eq = sys_in.jcn ~ x*w + conn = indegree_constrained_connections(neurons_src, neurons_dest, nameof(blox_src), nameof(blox_dest); kwargs...) - accumulate_equation!(bc, eq) + return conn end -# # Ok yes this is a bad dispatch but the whole compound blocks implementation is hacky and needs fixing @@ -# # Opening an issue to loop back to this during clean up week -# function (bc::BloxConnector)( -# bloxout::CompoundNOBlox, -# bloxin::CompoundNOBlox; -# weight=1, -# delay=0, -# density=0.1 -# ) +function Connector( + blox_src::STN_Adam, + blox_dest::Union{Striatum_MSN_Adam,Striatum_FSI_Adam,GPe_Adam}; + kwargs... +) + neurons_src = get_exci_neurons(blox_src) + neurons_dest = get_inh_neurons(blox_dest) -# sys_out = get_namespaced_sys(bloxout) -# sys_in = get_namespaced_sys(bloxin) + conn = indegree_constrained_connections(neurons_src, neurons_dest, nameof(blox_src), nameof(blox_dest); kwargs...) -# w_name = Symbol("w_$(nameof(sys_out))_$(nameof(sys_in))") -# if typeof(weight) == Num # Symbol -# w = weight -# else -# w = only(@parameters $(w_name)=weight) -# end -# push!(bc.weights, w) -# x = namespace_expr(bloxout.output, sys_out, nameof(sys_out)) -# eq = sys_in.nmm₊jcn ~ x*w - -# accumulate_equation!(bc, eq) -# end + return conn +end + +function Connector( + blox_src::Union{Striatum_MSN_Adam,Striatum_FSI_Adam,GPe_Adam}, + blox_dest::STN_Adam; + kwargs... +) + neurons_src = get_inh_neurons(blox_src) + neurons_dest = get_exci_neurons(blox_dest) + + conn = indegree_constrained_connections(neurons_src, neurons_dest, nameof(blox_src), nameof(blox_dest); kwargs...) + + return conn +end + +function Connector( + blox_src::NextGenerationEIBlox, + blox_dest::CorticalBlox; + kwargs... +) + neurons_dest = get_inh_neurons(blox_dest) + + conn = Connector(blox_src, neurons_dest[end]; kwargs...) + + return conn +end + +function Connector( + blox_src::CanonicalMicroCircuitBlox, + blox_dest::CanonicalMicroCircuitBlox; + kwargs... +) + sysparts_src = get_parts(blox_src) + sysparts_dest = get_parts(blox_dest) + + wm = get_weightmatrix(kwargs, namespaced_nameof(blox_src), namespaced_nameof(blox_dest)) + + idxs = findall(!iszero, wm) + + conn = mapreduce(merge!, idxs) do idx + Connector(sysparts_src[idx[2]], sysparts_dest[idx[1]]; weight=wm[idx]) + end + + return conn +end -function (bc::BloxConnector)( - wta_out::WinnerTakeAllBlox, - wta_in::WinnerTakeAllBlox; +function Connector( + blox_src::StimulusBlox, + blox_dest::CanonicalMicroCircuitBlox; + kwargs... +) + sysparts_dest = get_parts(blox_dest) + conn = Connector(blox_src, sysparts_dest[1]; kwargs...) + + return conn +end + +function Connector( + blox_src::CanonicalMicroCircuitBlox, + blox_dest::ObserverBlox; + kwargs... +) + sysparts_src = get_parts(blox_src) + conn = Connector(sysparts_src[2], blox_dest; kwargs...) + + return conn +end + +function Connector( + blox_src::WinnerTakeAllBlox, + blox_dest::WinnerTakeAllBlox; kwargs...) - neurons_out = get_exci_neurons(wta_out) - neurons_in = get_exci_neurons(wta_in) + neurons_src = get_exci_neurons(blox_src) + neurons_dest = get_exci_neurons(blox_dest) # users can supply a :connection_matrix to the graph edge, where - # connection_matrix[i, j] determines if neurons_out[i] is connected to neurons_out[j] + # connection_matrix[i, j] determines if neurons_src[i] is connected to neurons_src[j] connection_matrix = get_connection_matrix(kwargs, - namespaced_nameof(wta_out), namespaced_nameof(wta_in), - length(neurons_out), length(neurons_in)) - for (j, neuron_postsyn) in enumerate(neurons_in) + namespaced_nameof(blox_src), namespaced_nameof(blox_dest), + length(neurons_src), length(neurons_dest)) + + C = Connector[] + for (j, neuron_postsyn) in enumerate(neurons_dest) name_postsyn = namespaced_nameof(neuron_postsyn) - for (i, neuron_presyn) in enumerate(neurons_out) + for (i, neuron_presyn) in enumerate(neurons_src) name_presyn = namespaced_nameof(neuron_presyn) # Check names to avoid recurrent connections between the same neuron if (name_postsyn != name_presyn) && connection_matrix[i, j] - bc(neuron_presyn, neuron_postsyn; kwargs...) + push!(C, Connector(neuron_presyn, neuron_postsyn; kwargs...)) end end end + + # Check isempty(C) for the case of no connection being made. + # Connections between WTA neurons can be probabilistic so it's possible that none happen. + if isempty(C) + return Connector(namespaced_nameof(blox_src), namespaced_nameof(blox_dest)) + else + return reduce(merge!, C) + end end -function (bc::BloxConnector)( - neuron_out::HHNeuronInhibBlox, - wta_in::WinnerTakeAllBlox; +function Connector( + blox_src::HHNeuronInhibBlox, + blox_dest::WinnerTakeAllBlox; kwargs... ) - neurons_in = get_exci_neurons(wta_in) + neurons_dest = get_exci_neurons(blox_dest) - for neuron_postsyn in neurons_in - bc(neuron_out, neuron_postsyn; kwargs...) + conn = mapreduce(merge!, neurons_dest) do neuron_postsyn + Connector(blox_src, neuron_postsyn; kwargs...) end -end -function (bc::BloxConnector)( - cb_out::Union{CorticalBlox,STN,Thalamus}, - cb_in::Union{CorticalBlox,STN,Thalamus}; - kwargs... -) - neurons_in = get_exci_neurons(cb_in) - neurons_out = get_exci_neurons(cb_out) - - hypergeometric_connections!(bc, neurons_out, neurons_in, nameof(cb_out), nameof(cb_in); kwargs...) + return conn end -function (bc::BloxConnector)( - cb_out::Union{CorticalBlox,STN,Thalamus}, - cb_in::Union{GPi, GPe}; +function Connector( + blox_src::Union{CorticalBlox,STN,Thalamus}, + blox_dest::Union{CorticalBlox,STN,Thalamus}; kwargs... ) - neurons_in = get_inh_neurons(cb_in) - neurons_out = get_exci_neurons(cb_out) + neurons_dest = get_exci_neurons(blox_dest) + neurons_src = get_exci_neurons(blox_src) + + conn = hypergeometric_connections(neurons_src, neurons_dest, nameof(blox_src), nameof(blox_dest); kwargs...) - hypergeometric_connections!(bc, neurons_out, neurons_in, nameof(cb_out), nameof(cb_in); kwargs...) + return conn end -function (bc::BloxConnector)( - cb_out::Union{Striatum, GPi, GPe}, - cb_in::Union{CorticalBlox,STN,Thalamus}; +function Connector( + blox_src::Union{Striatum, GPi, GPe}, + blox_dest::Union{CorticalBlox,STN,Thalamus}; kwargs... ) - neurons_in = get_exci_neurons(cb_in) - neurons_out = get_inh_neurons(cb_out) + neurons_dest = get_exci_neurons(blox_dest) + neurons_src = get_inh_neurons(blox_src) + + conn = hypergeometric_connections(neurons_src, neurons_dest, nameof(blox_src), nameof(blox_dest); kwargs...) - hypergeometric_connections!(bc, neurons_out, neurons_in, nameof(cb_out), nameof(cb_in); kwargs...) + return conn end -function (bc::BloxConnector)( - cb_out::Union{Striatum, GPi, GPe}, - cb_in::Union{GPi, GPe}; +function Connector( + blox_src::Union{Striatum, GPi, GPe}, + blox_dest::Union{GPi, GPe}; kwargs... ) - neurons_in = get_inh_neurons(cb_in) - neurons_out = get_inh_neurons(cb_out) + neurons_dest = get_inh_neurons(blox_dest) + neurons_src = get_inh_neurons(blox_src) - hypergeometric_connections!(bc, neurons_out, neurons_in, nameof(cb_out), nameof(cb_in); kwargs...) + conn = hypergeometric_connections(neurons_src, neurons_dest, nameof(blox_src), nameof(blox_dest); kwargs...) + + return conn end -function (bc::BloxConnector)( +function Connector( cb::CorticalBlox, str::Striatum; kwargs... ) - neurons_in = get_inh_neurons(str) - neurons_out = get_exci_neurons(cb) + neurons_dest = get_inh_neurons(str) + neurons_src = get_exci_neurons(cb) w = get_weight(kwargs, namespaced_nameof(cb), namespaced_nameof(str)) dist = Uniform(0,1) - wt_ar = 2*w*rand(dist, length(neurons_out)) # generate a uniform distribution of weights with average value w + wt_ar = 2*w*rand(dist, length(neurons_src)) # generate a uniform distribution of weight with average value w kwargs = (kwargs..., weight=wt_ar) if haskey(kwargs, :learning_rule) - lr = deepcopy(kwargs[:learning_rule]) + lr = get_learning_rule(kwargs, namespaced_nameof(cb), namespaced_nameof(str)) sys_matr = get_namespaced_sys(get_matrisome(str)) maybe_set_state_post!(lr, sys_matr.H_learning) kwargs = (kwargs..., learning_rule=lr) end - hypergeometric_connections!(bc, neurons_out, neurons_in, nameof(cb), nameof(str); kwargs...) + conn = hypergeometric_connections(neurons_src, neurons_dest, nameof(cb), nameof(str); kwargs...) algebraic_parts = [get_matrisome(str), get_striosome(str)] - for (i,neuron_presyn) in enumerate(neurons_out) + for (i,neuron_presyn) in enumerate(neurons_src) kwargs = (kwargs...,weight=wt_ar[i]) for part in algebraic_parts - bc(neuron_presyn, part; kwargs...) + merge!(conn, Connector(neuron_presyn, part; kwargs...)) end end + + return conn end -function (bc::BloxConnector)( +function Connector( neuron::HHNeuronExciBlox, - str::Striatum; + str::Union{Striatum, GPi}; kwargs... ) - neurons_in = get_inh_neurons(str) - neuron_out = neuron + neurons_dest = get_inh_neurons(str) + neuron_src = neuron - for neuron_postsyn in neurons_in - bc(neuron_out, neuron_postsyn; kwargs...) + conn = mapreduce(merge!, neurons_dest) do neuron_dest + Connector(neuron_src, neuron_dest; kwargs...) end - + + return conn end -function (bc::BloxConnector)( - neuron::HHNeuronExciBlox, - gpi::GPi; +function Connector( + blox_src::HHNeuronExciBlox, + blox_dest::Union{Matrisome, Striosome}; kwargs... ) - neurons_in = get_inh_neurons(gpi) - neuron_out = neuron + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) - for neuron_postsyn in neurons_in - bc(neuron_out, neuron_postsyn; kwargs...) - end - -end + w = generate_weight_param(blox_src, blox_dest; kwargs...) -function (bc::BloxConnector)( - neuron::HHNeuronExciBlox, - discr::Union{Matrisome, Striosome}; - kwargs... -) - sys_out = get_namespaced_sys(neuron) - sys_in = get_namespaced_sys(discr) + lr = get_learning_rule(kwargs, nameof(sys_src), nameof(sys_dest)) + maybe_set_state_pre!(lr, sys_src.spikes_cumulative) + maybe_set_state_post!(lr, sys_dest.H_learning) - w = generate_weight_param(neuron, discr; kwargs...) - push!(bc.weights, w) - if haskey(kwargs, :learning_rule) - lr = deepcopy(kwargs[:learning_rule]) - maybe_set_state_pre!(lr, sys_out.spikes_cumulative) - maybe_set_state_post!(lr, sys_in.H_learning) - bc.learning_rules[w] = lr - end + eq = sys_dest.jcn ~ w*sys_src.spikes_window - eq = sys_in.jcn ~ w*sys_out.spikes_window - accumulate_equation!(bc, eq) + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w, learning_rule=Dict(w => lr)) end -function (bc::BloxConnector)( - str_out::Striatum, - str_in::Striatum; +function Connector( + blox_src::Striatum, + blox_dest::Striatum; kwargs... ) - sys_matr_out = get_namespaced_sys(get_matrisome(str_out)) - sys_matr_in = get_namespaced_sys(get_matrisome(str_in)) - sys_strios_in = get_namespaced_sys(get_striosome(str_in)) - neurons_in = get_inh_neurons(str_in) - - t_event = get_event_time(kwargs, nameof(str_out), nameof(str_in)) - cb_matr = [t_event] => [sys_matr_in.H ~ ifelse(sys_matr_out.H*sys_matr_out.jcn > sys_matr_in.H*sys_matr_in.jcn, 0, 1)] - cb_strios = [t_event] => [sys_strios_in.H ~ ifelse(sys_matr_out.H*sys_matr_out.jcn > sys_matr_in.H*sys_matr_in.jcn, 0, 1)] + sys_matr_src = get_namespaced_sys(get_matrisome(blox_src)) + sys_matr_dest = get_namespaced_sys(get_matrisome(blox_dest)) + sys_strios_dest = get_namespaced_sys(get_striosome(blox_dest)) + neurons_dest = get_inh_neurons(blox_dest) + + t_event = get_event_time(kwargs, nameof(blox_src), nameof(blox_dest)) + cb_matr = [t_event] => [sys_matr_dest.H ~ ifelse(sys_matr_src.H*sys_matr_src.jcn > sys_matr_src.H*sys_matr_src.jcn, 0, 1)] + cb_strios = [t_event] => [sys_strios_dest.H ~ ifelse(sys_matr_src.H*sys_matr_src.jcn > sys_matr_src.H*sys_matr_src.jcn, 0, 1)] # HACK: H should be reset to 1 at the beginning of each trial # Such callbacks should be moved to RL-specific functions like `run_experiment!` - cb_matr_init = [0.1] => [sys_matr_in.H ~ 1] - cb_strios_init = [0.1] => [sys_strios_in.H ~ 1] + cb_matr_init = [0.1] => [sys_matr_dest.H ~ 1] + cb_strios_init = [0.1] => [sys_strios_dest.H ~ 1] - push!(bc.discrete_callbacks, cb_matr) - push!(bc.discrete_callbacks, cb_strios) - push!(bc.discrete_callbacks, cb_matr_init) - push!(bc.discrete_callbacks, cb_strios_init) + dc = [cb_matr, cb_strios, cb_matr_init, cb_strios_init] - for neuron in neurons_in + for neuron in neurons_dest sys_neuron = get_namespaced_sys(neuron) # Large negative current added to shut down the Striatum spiking neurons. # Value is hardcoded for now, as it's more of a hack, not user option. - cb_neuron = [t_event] => [sys_neuron.I_bg ~ ifelse(sys_matr_out.H*sys_matr_out.jcn > sys_matr_in.H*sys_matr_in.jcn, -2, 0)] + cb_neuron = [t_event] => [sys_neuron.I_bg ~ ifelse(sys_matr_src.H*sys_matr_src.jcn > sys_matr_dest.H*sys_matr_dest.jcn, -2, 0)] # lateral inhibition current I_bg should be set to 0 at the beginning of each trial cb_neuron_init = [0.1] => [sys_neuron.I_bg ~ 0] - push!(bc.discrete_callbacks, cb_neuron) - push!(bc.discrete_callbacks, cb_neuron_init) + push!(dc, cb_neuron) + push!(dc, cb_neuron_init) end + + return Connector(namespaced_nameof(blox_src), namespaced_nameof(blox_dest); discrete_callbacks=dc) end -function (bc::BloxConnector)( - str::Striatum, - discr::Union{TAN, SNc}; +function Connector( + blox_src::Striatum, + blox_dest::Union{TAN, SNc}; kwargs... ) - striosome = get_striosome(str) - bc(striosome, discr; kwargs...) + striosome = get_striosome(blox_src) + + return Connector(striosome, blox_dest; kwargs...) end -function (bc::BloxConnector)( - discr_out::Striosome, - discr_in::Union{TAN, SNc}; +function Connector( + blox_src::Striosome, + blox_dest::Union{TAN, SNc}; kwargs... ) - sys_out = get_namespaced_sys(discr_out) - sys_in = get_namespaced_sys(discr_in) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) - w = generate_weight_param(discr_out, discr_in; kwargs...) - push!(bc.weights, w) + w = generate_weight_param(blox_src, blox_dest; kwargs...) - if haskey(kwargs, :learning_rule) - bc.learning_rules[w] = deepcopy(kwargs[:learning_rule]) - end + eq = sys_dest.jcn ~ w*sys_src.H*sys_src.jcn - eq = sys_in.jcn ~ w*sys_out.H*sys_out.jcn - + lr = get_learning_rule(kwargs, nameof(sys_src), nameof(sys_dest)) - accumulate_equation!(bc, eq) + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w, learning_rule=Dict(w => lr)) end -function (bc::BloxConnector)( - tan::TAN, - str::Striatum; +function Connector( + blox_src::TAN, + blox_dest::Striatum; kwargs... ) - matrisome = get_matrisome(str) - bc(tan, matrisome; kwargs...) + matrisome = get_matrisome(blox_dest) + + return Connector(blox_src, matrisome; kwargs...) end sample_poisson(λ) = rand(Poisson(λ)) @register_symbolic sample_poisson(λ) - """ Non-symbolic, time-block-based way of `@register_symbolic sample_poisson(λ)`. """ @@ -735,308 +787,343 @@ function sample_affect!(integ, u, p, ctx) integ.p[p[3]] = v end -function (bc::BloxConnector)( - discr_out::TAN, - discr_in::Matrisome; +function Connector( + blox_src::TAN, + blox_dest::Matrisome; kwargs... ) - sys_out = get_namespaced_sys(discr_out) - sys_in = get_namespaced_sys(discr_in) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) - w = generate_weight_param(discr_out, discr_in; kwargs...) - push!(bc.weights, w) + w = generate_weight_param(blox_src, blox_dest; kwargs...) - if haskey(kwargs, :learning_rule) - bc.learning_rules[w] = deepcopy(kwargs[:learning_rule]) - end + t_event = get_event_time(kwargs, nameof(blox_src), nameof(blox_dest)) + cb = [t_event+sqrt(eps(t_event))] => (sample_affect!, [], [sys_src.κ, sys_src.jcn, sys_dest.TAN_spikes], []) - t_event = get_event_time(kwargs, nameof(discr_out), nameof(discr_in)) - cb = [t_event+sqrt(eps(t_event))] => (sample_affect!, [], [sys_out.κ, sys_out.jcn, sys_in.TAN_spikes], []) - push!(bc.discrete_callbacks, cb) + eq = sys_dest.jcn ~ w*sys_dest.TAN_spikes - eq = sys_in.jcn ~ w*sys_in.TAN_spikes + lr = get_learning_rule(kwargs, nameof(sys_src), nameof(sys_dest)) - accumulate_equation!(bc, eq) + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w, discrete_callbacks=cb, learning_rule=Dict(w => lr)) end -function (bc::BloxConnector)( - discr_out::Matrisome, - discr_in::Matrisome; +function Connector( + blox_src::Matrisome, + blox_dest::Matrisome; kwargs... ) - sys_out = get_namespaced_sys(discr_out) - sys_in = get_namespaced_sys(discr_in) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) - t_event = get_event_time(kwargs, nameof(discr_out), nameof(discr_in)) - cb = [t_event] => [sys_in.H ~ ifelse(sys_out.H*sys_out.jcn > sys_in.H*sys_in.jcn, 0, 1)] - push!(bc.discrete_callbacks, cb) + t_event = get_event_time(kwargs, nameof(blox_src), nameof(blox_dest)) + cb = [t_event] => [sys_dest.H ~ ifelse(sys_src.H*sys_src.jcn > sys_dest.H*sys_dest.jcn, 0, 1)] + + return Connector(nameof(sys_src), nameof(sys_dest); discrete_callbacks=cb) end -function (bc::BloxConnector)( +function Connector( stim::ImageStimulus, neuron::Union{HHNeuronExciBlox, HHNeuronInhibBlox}; kwargs... ) - sys_out = get_namespaced_sys(stim) - sys_in = get_namespaced_sys(neuron) + sys_src = get_namespaced_sys(stim) + sys_dest = get_namespaced_sys(neuron) - pixels = namespace_parameters(sys_out) + pixels = namespace_parameters(sys_src) w = generate_weight_param(stim, neuron; kwargs...) - push!(bc.weights, w) + # No check for kwargs[:learning_rule] here. # The connection from stimulus is conceptual, the weight can not be updated. - eq = sys_in.I_in ~ w * pixels[stim.current_pixel] + eq = sys_dest.I_in ~ w * pixels[stim.current_pixel] increment_pixel!(stim) - accumulate_equation!(bc, eq) + + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) end -function (bc::BloxConnector)( +function Connector( stim::ImageStimulus, cb::CorticalBlox; kwargs... ) neurons = get_exci_neurons(cb) - for neuron in neurons - bc(stim, neuron; kwargs...) + conn = mapreduce(merge!, neurons) do neuron + Connector(stim, neuron; kwargs...) end -end - -(bc::BloxConnector)(blox, as::AbstractActionSelection; kwargs...) = nothing -function connect_action_selection!(as::AbstractActionSelection, str1::Striatum, str2::Striatum) - connect_action_selection!(as, get_matrisome(str1), get_matrisome(str2)) + return conn end -function connect_action_selection!(as::AbstractActionSelection, matr1::Matrisome, matr2::Matrisome) - sys1 = get_namespaced_sys(matr1) - sys2 = get_namespaced_sys(matr2) - - as.competitor_states = [sys1.ρ_, sys2.ρ_] #HACK : accessing values of rho at a specific time after the simulation - #as.competitor_params = [sys1.H, sys2.H] -end +Connector(blox::AbstractBlox, as::AbstractActionSelection; kwargs...) = Connector(namespaced_nameof(blox), namespaced_nameof(as)) # Connects spiking neuron to another spiking neuron -# None of these neurons have delays yet -function (bc::BloxConnector)( - bloxout::AbstractNeuronBlox, - bloxin::AbstractNeuronBlox; +# None of these neurons have delay yet +function Connector( + blox_src::AbstractNeuronBlox, + blox_dest::AbstractNeuronBlox; kwargs... ) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) - sys_out = get_namespaced_sys(bloxout) - sys_in = get_namespaced_sys(bloxin) - - w = generate_weight_param(bloxout, bloxin; kwargs...) - push!(bc.weights, w) + w = generate_weight_param(blox_src, blox_dest; kwargs...) - cr = get_connection_rule(kwargs, bloxout, bloxin, w) - eq = sys_in.jcn ~ cr + cr = get_connection_rule(kwargs, blox_src, blox_dest, w) + eq = sys_dest.jcn ~ cr - accumulate_equation!(bc, eq) + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) end # Connects a neural mass as a driving input to a spiking neuron # Should be used with care because units will be strange (NMM typically outputs voltage but neuron inputs are typically currents) -function (bc::BloxConnector)( - bloxout::AbstractNeuronBlox, - bloxin::NeuralMassBlox; +function Connector( + blox_src::AbstractNeuronBlox, + blox_dest::NeuralMassBlox; kwargs... ) - sys_out = get_namespaced_sys(bloxout) - sys_in = get_namespaced_sys(bloxin) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) - w = generate_weight_param(bloxout, bloxin; kwargs...) - push!(bc.weights, w) + w = generate_weight_param(blox_src, blox_dest; kwargs...) - if typeof(bloxout.output) == Num - x = namespace_expr(bloxout.output, sys_out) - eq = sys_in.jcn ~ x*w + if typeof(blox_src.output) == Num + x = namespace_expr(blox_src.output, sys_src) + eq = sys_dest.jcn ~ x*w + + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) else @variables t - delay = get_delay(kwargs, nameof(bloxout), nameof(bloxin)) - τ_name = Symbol("τ_$(nameof(sys_out))_$(nameof(sys_in))") + delay = get_delay(kwargs, nameof(blox_src), nameof(blox_dest)) + τ_name = Symbol("τ_$(nameof(sys_src))_$(nameof(sys_dest))") τ = only(@parameters $(τ_name)=delay) - push!(bc.delays, τ) - x = namespace_expr(bloxout.output, sys_out) - eq = sys_in.jcn ~ x(t-τ)*w + x = namespace_expr(blox_src.output, sys_src) + eq = sys_dest.jcn ~ x(t-τ)*w + + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w, delay=τ) end - - accumulate_equation!(bc, eq) end -function (bc::BloxConnector)( - bloxout::LIFExciNeuron, - bloxin::Union{LIFExciNeuron, LIFInhNeuron}; +function Connector( + blox_src::LIFExciNeuron, + blox_dest::Union{LIFExciNeuron, LIFInhNeuron}; kwargs... ) - sys_out = get_namespaced_sys(bloxout) - sys_in = get_namespaced_sys(bloxin) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) - w = generate_weight_param(bloxout, bloxin; kwargs...) - push!(bc.weights, w) + w = generate_weight_param(blox_src, blox_dest; kwargs...) - eq = sys_in.jcn ~ w * sys_out.S_NMDA * sys_in.g_NMDA * (sys_in.V - sys_in.V_E) / - (1 + sys_in.Mg * exp(-0.062 * sys_in.V) / 3.57) - accumulate_equation!(bc, eq) + eq = sys_dest.jcn ~ w * sys_src.S_NMDA * sys_dest.g_NMDA * (sys_dest.V - sys_dest.V_E) / + (1 + sys_dest.Mg * exp(-0.062 * sys_dest.V) / 3.57) # Compare the unique namespaced names of both systems - if nameof(sys_out) == nameof(sys_in) + spike_affects = if nameof(sys_src) == nameof(sys_dest) # x is the rise variable for NMDA synapses and it only applies to self-recurrent connections - accumulate_spike_affects!(bc, nameof(sys_out), [sys_in.S_AMPA, sys_in.x], [w, w]) + Dict(nameof(sys_src) => ([sys_dest.S_AMPA, sys_dest.x], [w, w])) else - accumulate_spike_affects!(bc, nameof(sys_out), [sys_in.S_AMPA], [w]) + Dict(nameof(sys_src) => ([sys_dest.S_AMPA], [w])) end + + return Connector(nameof(sys_src), nameof(sys_dest); equation = eq, weight = [w], spike_affects = spike_affects) end -function (bc::BloxConnector)( - bloxout::LIFInhNeuron, - bloxin::Union{LIFExciNeuron, LIFInhNeuron}; +function Connector( + blox_src::LIFInhNeuron, + blox_dest::Union{LIFExciNeuron, LIFInhNeuron}; kwargs... ) - sys_out = get_namespaced_sys(bloxout) - sys_in = get_namespaced_sys(bloxin) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) + + w = generate_weight_param(blox_src, blox_dest; kwargs...) - w = generate_weight_param(bloxout, bloxin; kwargs...) - push!(bc.weights, w) + spike_affects = Dict(nameof(sys_src) => ([sys_dest.S_GABA], [w])) - accumulate_spike_affects!(bc, nameof(sys_out), [sys_in.S_GABA], [w]) + return Connector(nameof(sys_src), nameof(sys_dest); weight = w, spike_affects = spike_affects) end -function (bc::BloxConnector)( +function Connector( stim::PoissonSpikeTrain, neuron::Union{LIFExciNeuron, LIFInhNeuron}; kwargs... ) - sys_in = get_namespaced_sys(neuron) + sys_dest = get_namespaced_sys(neuron) t_spikes = generate_spike_times(stim) - cb = t_spikes => [sys_in.S_AMPA_ext ~ sys_in.S_AMPA_ext + 1] + cb = t_spikes => [sys_dest.S_AMPA_ext ~ sys_dest.S_AMPA_ext + 1] # TO DO : Consider generating spikes during simulation # to make PoissonSpikeTrain independent of `t_span` of the simulation. # something like : - # discrete_event = t > -Inf => (generate_spike, [sys_in.S_AMPA], [stim.relevant_params...], [], nothing) + # discrete_event = t > -Inf => (generate_spike, [sys_dest.S_AMPA], [stim.relevant_params...], [], nothing) # This way we need to resolve the case of multiple spikes potentially being generated within a single integrator step. - push!(bc.discrete_callbacks, cb) + return Connector(namespaced_nameof(stim), nameof(sys_dest); discrete_callbacks = cb) end -function (bc::BloxConnector)( - bloxout::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}, - bloxin::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}; +function Connector( + blox_src::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}, + blox_dest::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}; kwargs... ) - neurons_out = get_neurons(bloxout) - neurons_in = get_neurons(bloxin) - - for neuron_out in neurons_out - for neuron_in in neurons_in - bc(neuron_out, neuron_in; kwargs...) + neurons_src = get_neurons(blox_src) + neurons_dest = get_neurons(blox_dest) + + C = Vector{Connector}(undef, length(neurons_src)*length(neurons_dest)) + i = 1 + for neuron_out in neurons_src + for neuron_in in neurons_dest + C[i] = Connector(neuron_out, neuron_in; kwargs...) + i += 1 end end + + return reduce(merge!, C) end -function (bc::BloxConnector)( +function Connector( stim::PoissonSpikeTrain, - cb::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}; + blox_dest::Union{LIFExciCircuitBlox, LIFInhCircuitBlox}; kwargs... ) - neurons_in = get_neurons(cb) + neurons_dest = get_neurons(blox_dest) - for neuron in neurons_in - bc(stim, neuron; kwargs...) + conn = mapreduce(merge!, neurons_dest) do neuron + Connector(stim, neuron; kwargs...) end + + return conn end -function (bc::BloxConnector)( - bloxout::PYR_Izh, - bloxin::PYR_Izh; +function Connector( + blox_src::PYR_Izh, + blox_dest::PYR_Izh; kwargs... ) - sys_out = get_namespaced_sys(bloxout) - sys_in = get_namespaced_sys(bloxin) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) - w = generate_weight_param(bloxout, bloxin; kwargs...) - push!(bc.weights, w) + w = generate_weight_param(blox_src, blox_dest; kwargs...) - s_presyn = namespace_expr(bloxout.output, sys_out) - v_postsyn = namespace_expr(bloxin.voltage, sys_in) - eq = sys_in.jcn ~ w*(1-sys_in.κ)*sys_out.gₛ*s_presyn*(sys_in.eᵣ-v_postsyn) + s_presyn = namespace_expr(blox_src.output, sys_src) + v_postsyn = namespace_expr(blox_dest.voltage, sys_dest) + eq = sys_dest.jcn ~ w*(1-sys_dest.κ)*sys_src.gₛ*s_presyn*(sys_dest.eᵣ-v_postsyn) - accumulate_equation!(bc, eq) + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) end -function (bc::BloxConnector)( - bloxout::QIF_PING_NGNMM, - bloxin::QIF_PING_NGNMM; +function Connector( + blox_src::QIF_PING_NGNMM, + blox_dest::QIF_PING_NGNMM; kwargs... ) - sys_out = get_namespaced_sys(bloxout) - sys_in = get_namespaced_sys(bloxin) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) - w = generate_weight_param(bloxout, bloxin; kwargs...) - push!(bc.weights, w) + w = generate_weight_param(blox_src, blox_dest; kwargs...) - x = namespace_expr(bloxout.output, sys_out) - eq = sys_in.jcn ~ w*x + x = namespace_expr(blox_src.output, sys_src) + eq = sys_dest.jcn ~ w*x - accumulate_equation!(bc, eq) + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) end -function (bc::BloxConnector)( - bloxout::DBS, - bloxin::CompositeBlox; +function Connector( + blox_src::DBS, + blox_dest::CompositeBlox; kwargs... ) - components = get_components(bloxin) - for comp in components - bc(bloxout, comp; kwargs...) + components = get_components(blox_dest) + conn = mapreduce(merge!, components) do comp + Connector(blox_src, comp; kwargs...) end + + return conn +end + +function Connector( + blox_src::DBS, + blox_dest::AbstractNeuronBlox; + kwargs... +) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) + + w = generate_weight_param(blox_src, blox_dest; kwargs...) + + eq = sys_dest.I_in ~ w * sys_src.u + + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) end -function (bc::BloxConnector)( - bloxout::DBS, - bloxin::AbstractNeuronBlox; +function Connector( + blox_src::DBS, + blox_dest::NeuralMassBlox; kwargs... ) - sys_dbs = get_namespaced_sys(bloxout) - sys_in = get_namespaced_sys(bloxin) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) - w = generate_weight_param(bloxout, bloxin; kwargs...) - push!(bc.weights, w) + w = generate_weight_param(blox_src, blox_dest; kwargs...) - eq = sys_in.I_in ~ w * sys_dbs.u - accumulate_equation!(bc, eq) + eq = sys_dest.jcn ~ w * sys_src.u + + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) end -function (bc::BloxConnector)( - bloxout::DBS, - bloxin::NeuralMassBlox; +function Connector( + blox_src::DBS, + blox_dest::HHNeuronExci_STN_Adam_Blox; kwargs... ) - sys_dbs = get_namespaced_sys(bloxout) - sys_in = get_namespaced_sys(bloxin) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) - w = generate_weight_param(bloxout, bloxin; kwargs...) - push!(bc.weights, w) + eq = sys_dest.DBS_in ~ - sys_dest.V/sys_dest.b + sys_src.u - eq = sys_in.jcn ~ w * sys_dbs.u - accumulate_equation!(bc, eq) + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq) end -function (bc::BloxConnector)( - bloxout::DBS, - bloxin::HHNeuronExci_STN_Adam_Blox; +# Create excitatory -> inhibitory AMPA receptor conenction +function Connector( + blox_src::PINGNeuronExci, + blox_dest::PINGNeuronInhib; kwargs... ) - sys_dbs = get_namespaced_sys(bloxout) - sys_in = get_namespaced_sys(bloxin) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) + + w = generate_weight_param(blox_src, blox_dest; kwargs...) + + V_E = haskey(kwargs, :V_E) ? kwargs[:V_E] : 0.0 + + s = namespace_expr(blox_src.output, sys_src) + v_in = namespace_expr(blox_dest.voltage, sys_dest) + eq = sys_dest.jcn ~ w*s*(V_E-v_in) - eq = sys_in.DBS_in ~ - sys_in.V/sys_in.b + sys_dbs.u - accumulate_equation!(bc, eq) -end \ No newline at end of file + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) +end + +# Create inhibitory -> inhibitory/excitatory GABA_A receptor connection +function Connector( + blox_src::PINGNeuronInhib, + blox_dest::AbstractPINGNeuron; + kwargs... +) + sys_src = get_namespaced_sys(blox_src) + sys_dest = get_namespaced_sys(blox_dest) + + w = generate_weight_param(blox_src, blox_dest; kwargs...) + + V_I = haskey(kwargs, :V_I) ? kwargs[:V_I] : -80.0 + + s = namespace_expr(blox_src.output, sys_src) + v_in = namespace_expr(blox_dest.voltage, sys_dest) + eq = sys_dest.jcn ~ w*s*(V_I-v_in) + + return Connector(nameof(sys_src), nameof(sys_dest); equation=eq, weight=w) +end diff --git a/src/blox/cortical.jl b/src/blox/cortical.jl index c199a70f..a6a97f3d 100644 --- a/src/blox/cortical.jl +++ b/src/blox/cortical.jl @@ -70,13 +70,12 @@ struct CorticalBlox <: CompositeBlox end add_edge!(g, N_wta+1, i, Dict(:weight => 1)) end - # Construct a BloxConnector object from the graph - # containing all connection equations from lower levels and this level. + bc = connector_from_graph(g) # If a namespace is not provided, assume that this is the highest level # and construct the ODEsystem from the graph. # If there is a higher namespace, construct only a subsystem containing the parts of this level - # and propagate the BloxConnector object `bc` to the higher level + # and propagate the Connector object `bc` to the higher level # to potentially add more terms to the same connections. sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(vcat(wtas, n_ff_inh); name) @@ -155,11 +154,11 @@ struct LIFExciCircuitBlox <: CompositeBlox end end + bc = connector_from_graph(g) + if skip_system_creation - bc = nothing sys = nothing else - bc = connector_from_graph(g) sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(neurons; name) end new(namespace, neurons, sys, bc, kwargs) @@ -233,11 +232,11 @@ struct LIFInhCircuitBlox <: CompositeBlox end end + bc = connector_from_graph(g) + if skip_system_creation - bc = nothing sys = nothing else - bc = connector_from_graph(g) sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(neurons; name) end diff --git a/src/blox/ping_neuron_examples.jl b/src/blox/ping_neuron_examples.jl index 330fe296..d151223f 100644 --- a/src/blox/ping_neuron_examples.jl +++ b/src/blox/ping_neuron_examples.jl @@ -151,44 +151,3 @@ struct PINGNeuronInhib <: AbstractPINGNeuron end end -# Create excitatory -> inhibitory AMPA receptor conenction -function (bc::BloxConnector)( - bloxout::PINGNeuronExci, - bloxin::PINGNeuronInhib; - kwargs... -) - sys_out = get_namespaced_sys(bloxout) - sys_in = get_namespaced_sys(bloxin) - - w = generate_weight_param(bloxout, bloxin; kwargs...) - push!(bc.weights, w) - - V_E = haskey(kwargs, :V_E) ? kwargs[:V_E] : 0.0 - - s = namespace_expr(bloxout.output, sys_out) - v_in = namespace_expr(bloxin.voltage, sys_in) - eq = sys_in.jcn ~ w*s*(V_E-v_in) - - accumulate_equation!(bc, eq) -end - -# Create inhibitory -> inhibitory/excitatory GABA_A receptor connection -function (bc::BloxConnector)( - bloxout::PINGNeuronInhib, - bloxin::AbstractPINGNeuron; - kwargs... -) - sys_out = get_namespaced_sys(bloxout) - sys_in = get_namespaced_sys(bloxin) - - w = generate_weight_param(bloxout, bloxin; kwargs...) - push!(bc.weights, w) - - V_I = haskey(kwargs, :V_I) ? kwargs[:V_I] : -80.0 - - s = namespace_expr(bloxout.output, sys_out) - v_in = namespace_expr(bloxin.voltage, sys_in) - eq = sys_in.jcn ~ w*s*(V_I-v_in) - - accumulate_equation!(bc, eq) -end diff --git a/src/blox/reinforcement_learning.jl b/src/blox/reinforcement_learning.jl index 5f2061b0..ff51b108 100644 --- a/src/blox/reinforcement_learning.jl +++ b/src/blox/reinforcement_learning.jl @@ -1,6 +1,8 @@ abstract type AbstractEnvironment end abstract type AbstractLearningRule end +struct NoLearningRule <: AbstractLearningRule end + mutable struct HebbianPlasticity <:AbstractLearningRule const K::Float64 const W_lim::Float64 @@ -94,6 +96,9 @@ function maybe_set_state_post!(lr::AbstractLearningRule, state) end end +maybe_set_state_pre!(lr::NoLearningRule, state) = lr +maybe_set_state_post!(lr::NoLearningRule, state) = lr + mutable struct ClassificationEnvironment{S} <: AbstractEnvironment const name::Symbol const namespace::Symbol @@ -161,32 +166,21 @@ function (p::GreedyPolicy)(sol::SciMLBase.AbstractSciMLSolution) return argmax(comp_vals) end -get_eval_times(gp::GreedyPolicy) = [gp.t_decision] +function connect_action_selection!(as::AbstractActionSelection, str1::Striatum, str2::Striatum) + connect_action_selection!(as, get_matrisome(str1), get_matrisome(str2)) +end -get_eval_states(gp::GreedyPolicy) = gp.competitor_states +function connect_action_selection!(as::AbstractActionSelection, matr1::Matrisome, matr2::Matrisome) + sys1 = get_namespaced_sys(matr1) + sys2 = get_namespaced_sys(matr2) -""" -function (p::GreedyPolicy)(sys::ODESystem, prob::ODEProblem) - ps = parameters(sys) - params = prob.p - map_idxs = Int.(ModelingToolkit.varmap_to_vars([ps[i] => i for i in eachindex(ps)], ps)) - comp_params = p.competitor_params - idxs_cp = Int64[] - for i in eachindex(comp_params) - idxs = findall(x -> x==comp_params[i], ps) - push!(idxs_cp,idxs) - end - comp_vals = params[map_idxs[idxs_cp]] - @info comp_vals - return argmax(comp_vals) + as.competitor_states = [sys1.ρ_, sys2.ρ_] #HACK : accessing values of rho at a specific time after the simulation + #as.competitor_params = [sys1.H, sys2.H] end -""" -function narrowtype(d::Dict) - types = unique(typeof.(values(d))) - U = Union{types...} - Dict{Num, U}(d) -end +get_eval_times(gp::GreedyPolicy) = [gp.t_decision] + +get_eval_states(gp::GreedyPolicy) = gp.competitor_states mutable struct Agent{S,P,A,LR,C} odesystem::S @@ -197,7 +191,7 @@ mutable struct Agent{S,P,A,LR,C} function Agent(g::MetaDiGraph; name, kwargs...) bc = connector_from_graph(g) - + t_block = haskey(kwargs, :t_block) ? kwargs[:t_block] : missing # TODO: add another version that uses system_from_graph(g,bc,params;) sys = system_from_graph(g, bc; name, t_block, allow_parameter=false) @@ -208,7 +202,7 @@ mutable struct Agent{S,P,A,LR,C} prob = ODEProblem(sys, u0, (0.,1.), p) policy = action_selection_from_graph(g) - learning_rules = narrowtype(bc.learning_rules) + learning_rules = narrowtype(bc.learning_rule) new{typeof(sys), typeof(prob), typeof(policy), typeof(learning_rules), typeof(bc)}(sys, prob, policy, learning_rules, bc) end diff --git a/src/blox/subcortical_blox.jl b/src/blox/subcortical_blox.jl index 0e8f05ab..de105d08 100644 --- a/src/blox/subcortical_blox.jl +++ b/src/blox/subcortical_blox.jl @@ -21,53 +21,52 @@ struct Striatum <: CompositeBlox phase=zeros(N_inhib), τ_inhib=70 ) - n_inh = [ - HHNeuronInhibBlox( - name = Symbol("inh$i"), - namespace = namespaced_name(namespace, name), - E_syn = E_syn_inhib, - G_syn = G_syn_inhib, - τ = τ_inhib, - I_bg = I_bg[i], - freq = freq[i], - phase = phase[i] - ) - for i in Base.OneTo(N_inhib) - ] - - matrisome = Matrisome(; name=:matrisome, namespace=namespaced_name(namespace, name)) - striosome = Striosome(; name=:striosome, namespace=namespaced_name(namespace, name)) - - parts = vcat(n_inh, matrisome, striosome) - - g = MetaDiGraph() - add_blox!.(Ref(g), n_inh) - - # If this blox is simulated on its own, - # then only the parts with dynamics are included in the system. - # This is done to avoid messing with structural_simplify downstream. - # Also it makes sense, as the discrete parts rely exclusively on inputs/outputs, - # which are not present in this case. - if !isnothing(namespace) - add_blox!(g, matrisome) - add_blox!(g, striosome) - bc = connector_from_graph(g) - sys = system_from_parts(parts; name) - else - bc = connector_from_graph(g) - sys = system_from_graph(g, bc; name, simplify=false) - end - - m = if isnothing(namespace) - [s for s in unknowns.((sys,), unknowns(sys)) if contains(string(s), "V(t)")] - else - @variables t - sys_namespace = System(Equation[], t; name=namespaced_name(namespace, name)) - [s for s in unknowns.((sys_namespace,), unknowns(sys)) if contains(string(s), "V(t)")] - end + n_inh = [ + HHNeuronInhibBlox( + name = Symbol("inh$i"), + namespace = namespaced_name(namespace, name), + E_syn = E_syn_inhib, + G_syn = G_syn_inhib, + τ = τ_inhib, + I_bg = I_bg[i], + freq = freq[i], + phase = phase[i] + ) + for i in Base.OneTo(N_inhib) + ] - new(namespace, parts, sys, bc, m) + matrisome = Matrisome(; name=:matrisome, namespace=namespaced_name(namespace, name)) + striosome = Striosome(; name=:striosome, namespace=namespaced_name(namespace, name)) + + parts = vcat(n_inh, matrisome, striosome) + + g = MetaDiGraph() + add_blox!.(Ref(g), n_inh) + + # If this blox is simulated on its own, + # then only the parts with dynamics are included in the system. + # This is done to avoid messing with structural_simplify downstream. + # Also it makes sense, as the discrete parts rely exclusively on inputs/outputs, + # which are not present in this case. + if !isnothing(namespace) + add_blox!(g, matrisome) + add_blox!(g, striosome) + bc = connector_from_graph(g) + sys = system_from_parts(parts; name) + @variables t + sys_namespace = System(Equation[], t; name=namespaced_name(namespace, name)) + m = [s for s in unknowns.((sys_namespace,), unknowns(sys)) if contains(string(s), "V(t)")] + + new(namespace, parts, sys, bc, m) + else + bc = connector_from_graph(g) + sys = system_from_graph(g, bc; name, simplify=false) + + m = [s for s in unknowns.((sys,), unknowns(sys)) if contains(string(s), "V(t)")] + + new(namespace, parts, sys, bc, m) + end end end @@ -119,9 +118,8 @@ struct GPi <: CompositeBlox end parts = n_inh - + bc = connector_from_graph(g) - sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name) m = if isnothing(namespace) @@ -176,9 +174,8 @@ struct GPe <: CompositeBlox end parts = n_inh - + bc = connector_from_graph(g) - sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name) m = if isnothing(namespace) @@ -190,7 +187,6 @@ struct GPe <: CompositeBlox end new(namespace, parts, sys, bc, m) - end end @@ -234,9 +230,8 @@ struct Thalamus <: CompositeBlox end parts = n_exci - - bc = connector_from_graph(g) + bc = connector_from_graph(g) sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name) m = if isnothing(namespace) @@ -249,7 +244,6 @@ struct Thalamus <: CompositeBlox new(namespace, parts, sys, bc, m) end - end @@ -271,51 +265,43 @@ struct STN <: CompositeBlox phase=zeros(N_exci), τ_exci=5 ) - n_exci = [ - HHNeuronExciBlox( - name = Symbol("exci$i"), - namespace = namespaced_name(namespace, name), - E_syn = E_syn_exci, - G_syn = G_syn_exci, - τ = τ_exci, - I_bg = I_bg[i], - freq = freq[i], - phase = phase[i] - ) - for i in Base.OneTo(N_exci) - ] - - g = MetaDiGraph() - for i in Base.OneTo(N_exci) - add_blox!(g, n_exci[i]) - end + n_exci = [ + HHNeuronExciBlox( + name = Symbol("exci$i"), + namespace = namespaced_name(namespace, name), + E_syn = E_syn_exci, + G_syn = G_syn_exci, + τ = τ_exci, + I_bg = I_bg[i], + freq = freq[i], + phase = phase[i] + ) + for i in Base.OneTo(N_exci) + ] - parts = n_exci - # Construct a BloxConnector object from the graph - # containing all connection equations from lower levels and this level. - bc = connector_from_graph(g) - # If a namespace is not provided, assume that this is the highest level - # and construct the ODEsystem from the graph. - # If there is a higher namespace, construct only a subsystem containing the parts of this level - # and propagate the BloxConnector object `bc` to the higher level - # to potentially add more terms to the same connections. - sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name) - - # TO DO : m is a subset of unknowns to be plotted in the GUI. - # This can be moved to NeurobloxGUI, maybe via plotting recipes, - # since it is not an essential part of the blox. - m = if isnothing(namespace) - [s for s in unknowns.((sys,), unknowns(sys)) if contains(string(s), "V(t)")] - else - @variables t - # HACK : Need to define an empty system to add the correct namespace to unknowns. - # Adding a dispatch `ModelingToolkit.unknowns(::Symbol, ::AbstractArray)` upstream will solve this. - sys_namespace = System(Equation[], t; name=namespaced_name(namespace, name)) - [s for s in unknowns.((sys_namespace,), unknowns(sys)) if contains(string(s), "V(t)")] - end + g = MetaDiGraph() + for i in Base.OneTo(N_exci) + add_blox!(g, n_exci[i]) + end - new(namespace, parts, sys, bc, m) + parts = n_exci + + bc = connector_from_graph(g) + sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name) + + # TO DO : m is a subset of unknowns to be plotted in the GUI. + # This can be moved to NeurobloxGUI, maybe via plotting recipes, + # since it is not an essential part of the blox. + m = if isnothing(namespace) + [s for s in unknowns.((sys,), unknowns(sys)) if contains(string(s), "V(t)")] + else + @variables t + # HACK : Need to define an empty system to add the correct namespace to unknowns. + # Adding a dispatch `ModelingToolkit.unknowns(::Symbol, ::AbstractArray)` upstream will solve this. + sys_namespace = System(Equation[], t; name=namespaced_name(namespace, name)) + [s for s in unknowns.((sys_namespace,), unknowns(sys)) if contains(string(s), "V(t)")] + end + new(namespace, parts, sys, bc, m) end - end diff --git a/src/blox/winnertakeall.jl b/src/blox/winnertakeall.jl index ada650c4..ea260ec1 100644 --- a/src/blox/winnertakeall.jl +++ b/src/blox/winnertakeall.jl @@ -55,14 +55,10 @@ struct WinnerTakeAllBlox{P} <: CompositeBlox end parts = vcat(n_inh, n_excis) - # Construct a BloxConnector object from the graph - # containing all connection equations from lower levels and this level. + bc = connector_from_graph(g) # If a namespace is not provided, assume that this is the highest level # and construct the ODEsystem from the graph. - # If there is a higher namespace, construct only a subsystem containing the parts of this level - # and propagate the BloxConnector object `bc` to the higher level - # to potentially add more terms to the same connections. sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name) new{Union{eltype(n_excis), typeof(n_inh)}}(namespace, parts, sys, bc) diff --git a/test/graphs.jl b/test/graphs.jl index 24456916..e95d828b 100644 --- a/test/graphs.jl +++ b/test/graphs.jl @@ -1,5 +1,4 @@ using Neuroblox -using Neuroblox: get_adjacency using Graphs using MetaGraphs using Test @@ -17,7 +16,7 @@ using Random add_edge!(g, n3 => n2 , weight = 1) add_edge!(g, n2 => n2 , weight = 1) - adj = get_adjacency(g) + adj = AdjacencyMatrix(g) A = [0 1 1 ; 0 1 0; 0 1 0] @@ -34,7 +33,7 @@ end @named cb1 = CorticalBlox(namespace = global_ns, N_wta=2, N_exci=2, connection_matrices=A, weight=1) - adj = get_adjacency(cb1) + adj = AdjacencyMatrix(cb1) adj_wta_11 = [0 1 1; 1 0 0; 1 0 0] adj_wta_12 = [[0 0 0]; hcat([0, 0], A[1,2])] @@ -47,7 +46,7 @@ end [0 1 1 0 1 1 0] ] - @test all(A .== adj.matrix) + @test sum(A) == nnz(adj.matrix) nms = [ :cb1₊wta1₊inh, @@ -59,7 +58,7 @@ end :cb1₊ff_inh ] - @test all(nms .== adj.names) + @test all(n -> n in nms, adj.names) && length(nms) == length(adj.names) end @testset "AdjacencyMatrix [Agent]" begin @@ -73,13 +72,13 @@ end add_edge!(g, VAC => AC, weight=3, density=0.1) Random.seed!(123) - A_graph = get_adjacency(g) + A_graph = AdjacencyMatrix(g) Random.seed!(123) agent = Agent(g; name=global_namespace, t_block = 1); - A_agent = get_adjacency(agent) + A_agent = AdjacencyMatrix(agent) - @test all(A_graph.matrix .== A_agent.matrix) + @test all(A_graph.matrix .== A_agent.matrix) @test all(A_graph.names .== A_agent.names) end