Skip to content

Commit

Permalink
Change BloxConnector to Connector (#502)
Browse files Browse the repository at this point in the history
* change all dispatches of `BloxConnector` to `Connector`

* export `Connector` and related convenience functions

* update all `<: CompositeBlox` to use `Connector`

* move PING neuron dispatches to `/src/blox/connections.jl`

* add `AdjacencyMatrix` constructors for `Connector`

* add `connectors_from_graph, returns `Vector{Connector}`

* update `connector_from_graph`

* update all `system_from_graph` utility functions

* update getter functions

* update graphs tests

* add `NoLearningRule`

* move `connect_action_selection!` dispatches to `src/blox/RL`

* cleanup old commented block

* include PING neurons before connections
  • Loading branch information
harisorgn authored and david-hofmann committed Jan 2, 2025
1 parent 87344a3 commit 0b321a9
Show file tree
Hide file tree
Showing 13 changed files with 923 additions and 848 deletions.
3 changes: 2 additions & 1 deletion src/Neuroblox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ include("blox/winnertakeall.jl")
include("blox/subcortical_blox.jl")
include("blox/stochastic.jl")
include("blox/discrete.jl")
include("blox/ping_neuron_examples.jl")
include("blox/reinforcement_learning.jl")
include("gui/GUI.jl")
include("blox/connections.jl")
include("blox/blox_utilities.jl")
include("blox/ping_neuron_examples.jl")
include("GraphDynamicsInterop/GraphDynamicsInterop.jl")
include("Neurographs.jl")
include("adjacency.jl")
Expand Down Expand Up @@ -256,4 +256,5 @@ export meanfield, meanfield!, rasterplot, rasterplot!, stackplot, stackplot!, fr
export powerspectrumplot, powerspectrumplot!, welch_pgram, periodogram, hanning, hamming
export detect_spikes, mean_firing_rate, firing_rate
export voltage_timeseries, meanfield_timeseries, state_timeseries, get_neurons, get_exci_neurons, get_inh_neurons, get_neuron_color
export AdjacencyMatrix, Connector, connection_rule, connection_equation
end
60 changes: 36 additions & 24 deletions src/Neurographs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,30 +52,38 @@ get_dynamics_bloxs(blox::CompositeBlox) = get_parts(blox)

flatten_graph(g::MetaDiGraph) = mapreduce(get_dynamics_bloxs, vcat, get_bloxs(g))

function connector_from_graph(g::MetaDiGraph)
bloxs = get_bloxs(g)
link = BloxConnector(bloxs)
function connectors_from_graph(g::MetaDiGraph)
conns = get_connector.(get_bloxs(g))
for edge in edges(g)

for v in vertices(g)
b = get_prop(g, v, :blox)
for vn in inneighbors(g, v)
bn = get_prop(g, vn, :blox)
kwargs = props(g, vn, v)
link(bn, b; kwargs...)
end
blox_src = get_prop(g, edge.src, :blox)
blox_dest = get_prop(g, edge.dst, :blox)

kwargs = props(g, edge)
push!(conns, Connector(blox_src, blox_dest; kwargs...))
end
return link

filter!(conn -> !isempty(conn), conns)

return conns
end

function connector_from_graph(g::MetaDiGraph)
conns = connectors_from_graph(g)

return isempty(conns) ? Connector(:none, :none) : reduce(merge!, conns)
end

# Helper function to get delays from a graph
function graph_delays(g::MetaDiGraph)
bc = connector_from_graph(g)
return bc.delays
conn = connector_from_graph(g)

return conn.delay
end

generate_discrete_callbacks(blox, ::BloxConnector; t_block = missing) = []
generate_discrete_callbacks(blox, ::Connector; t_block = missing) = []

function generate_discrete_callbacks(blox::Union{LIFExciNeuron, LIFInhNeuron}, bc::BloxConnector; t_block = missing)
function generate_discrete_callbacks(blox::Union{LIFExciNeuron, LIFInhNeuron}, bc::Connector; t_block = missing)
spike_affects = get_spike_affects(bc)
name_blox = namespaced_nameof(blox)
sys = get_namespaced_sys(blox)
Expand Down Expand Up @@ -116,7 +124,7 @@ function generate_discrete_callbacks(blox::Union{LIFExciNeuron, LIFInhNeuron}, b
return cb
end

function generate_discrete_callbacks(blox::HHNeuronExciBlox, ::BloxConnector; t_block = missing)
function generate_discrete_callbacks(blox::HHNeuronExciBlox, ::Connector; t_block = missing)
if !ismissing(t_block)
nn = get_namespaced_sys(blox)
eq = nn.spikes_window ~ 0
Expand All @@ -128,7 +136,7 @@ function generate_discrete_callbacks(blox::HHNeuronExciBlox, ::BloxConnector; t_
end
end

function generate_discrete_callbacks(bc::BloxConnector; t_block = missing)
function generate_discrete_callbacks(bc::Connector; t_block = missing)
eqs_params = get_equations_with_parameter_lhs(bc)

if !ismissing(t_block) && !isempty(eqs_params)
Expand All @@ -139,7 +147,7 @@ function generate_discrete_callbacks(bc::BloxConnector; t_block = missing)
end
end

function generate_discrete_callbacks(g::MetaDiGraph, bc::BloxConnector; t_block = missing)
function generate_discrete_callbacks(g::MetaDiGraph, bc::Connector; t_block = missing)
bloxs = flatten_graph(g)

cbs = mapreduce(vcat, bloxs) do blox
Expand Down Expand Up @@ -175,30 +183,34 @@ function system_from_graph(g::MetaDiGraph, p::Vector{Num}=Num[]; name=nothing, t
isempty(p) || error(ArgumentError("The GraphDynamics.jl backend does yet support extra parameter lists. Got $p."))
GraphDynamicsInterop.graphsystem_from_graph(g; kwargs...)
else
bc = connector_from_graph(g)
if isnothing(name)
throw(UndefKeywordError(:name))
end

bc = connector_from_graph(g)

return system_from_graph(g, bc, p; name, t_block, simplify, kwargs...)
end
end

function system_from_graph(g::MetaDiGraph, bc::BloxConnector, p::Vector{Num}=Num[];
name, t_block=missing, simplify=true, simplify_kwargs...)
blox_syss = get_system(g)
function system_from_graph(g::MetaDiGraph, bc::Connector, p::Vector{Num}=Num[]; name=nothing, t_block=missing, simplify=true, graphdynamics=false, kwargs...)
bloxs = get_bloxs(g)
blox_syss = get_system.(bloxs)

accumulate_equations!(bc, bloxs)

connection_eqs = get_equations_with_state_lhs(bc)

discrete_cbs = identity.(generate_discrete_callbacks(g, bc; t_block))

sys = compose(System(connection_eqs, t, [], vcat(params(bc), p); name, discrete_events = discrete_cbs), blox_syss)
if simplify
structural_simplify(sys; simplify_kwargs...)
structural_simplify(sys; kwargs...)
else
sys
end
end


function system_from_parts(parts::AbstractVector; name)
return compose(System(Equation[], t; name), get_system.(parts))
end
Expand Down
58 changes: 40 additions & 18 deletions src/adjacency.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,33 @@ struct AdjacencyMatrix
names::Vector{Symbol}
end

function AdjacencyMatrix(name)
return AdjacencyMatrix(spzeros(1,1), [name])
function AdjacencyMatrix(names::AbstractVector)
return AdjacencyMatrix(spzeros(1,1), names)
end

function Base.merge(adj1::AdjacencyMatrix, adj2::AdjacencyMatrix)
return AdjacencyMatrix(
cat(adj1.matrix, adj2.matrix; dims=(1,2)),
vcat(adj1.names, adj2.names)
)
end
function AdjacencyMatrix(C::Connector)
weights = C.weight
srcs = C.source
dests = C.destination
names = unique(vcat(srcs, dests))
sort!(names)

get_adjacency(bc::BloxConnector) = bc.adjacency
get_adjacency(blox::CompositeBlox) = (get_adjacency get_connector)(blox)
get_adjacency(blox) = AdjacencyMatrix(namespaced_nameof(blox))
ADJ = AdjacencyMatrix(spzeros(length(names), length(names)), names)
for i in eachindex(srcs)
add_adjacency_edge!(ADJ, srcs[i], dests[i], weights[i])
end

function get_adjacency(g::MetaDiGraph)
bc = connector_from_graph(g)
return get_adjacency(bc)
return ADJ
end

function get_adjacency(bc::BloxConnector, sys::AbstractODESystem, prob::ODEProblem)
A = get_adjacency(bc)
AdjacencyMatrix(blox::CompositeBlox) = AdjacencyMatrix(get_connector(blox))

AdjacencyMatrix(blox) = AdjacencyMatrix(namespaced_nameof(blox))

AdjacencyMatrix(g::MetaDiGraph) = AdjacencyMatrix(connector_from_graph(g))

function AdjacencyMatrix(bc::Connector, sys::AbstractODESystem, prob::ODEProblem)
A = AdjacencyMatrix(bc)
names = A.names
mat = A.matrix

Expand All @@ -43,12 +48,29 @@ function get_adjacency(bc::BloxConnector, sys::AbstractODESystem, prob::ODEProbl
return AdjacencyMatrix(S, names)
end

function get_adjacency(agent::Agent)
function AdjacencyMatrix(agent::Agent)
prob = agent.problem
sys = get_system(agent)
bc = get_connector(agent)

return get_adjacency(bc, sys, prob)
return AdjacencyMatrix(bc, sys, prob)
end

function add_adjacency_edge!(ADJ::AdjacencyMatrix, name_src, name_dest, weight)
src_idx = findfirst(x -> isequal(name_src, x), ADJ.names)
dest_idx = findfirst(x -> isequal(name_dest, x), ADJ.names)

weight_def = ModelingToolkit.getdefault(weight)
weight_value = substitute(weight_def, map(x -> x => ModelingToolkit.getdefault(x), Symbolics.get_variables(weight_def)))

ADJ.matrix[src_idx, dest_idx] = weight_value
end

function Base.merge(adj1::AdjacencyMatrix, adj2::AdjacencyMatrix)
return AdjacencyMatrix(
cat(adj1.matrix, adj2.matrix; dims=(1,2)),
vcat(adj1.names, adj2.names)
)
end

function adjmatrixfromdigraph(g::MetaDiGraph)
Expand Down
17 changes: 9 additions & 8 deletions src/blox/DBS_Model_Blox_Adam_Brown.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ struct Striatum_MSN_Adam <: CompositeBlox
end
end
parts = n_inh

bc = connector_from_graph(g)

sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name)

m = if isnothing(namespace)
Expand Down Expand Up @@ -184,9 +184,9 @@ struct Striatum_FSI_Adam <: CompositeBlox
end

parts = n_inh

bc = connector_from_graph(g)

sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name)

m = if isnothing(namespace)
Expand Down Expand Up @@ -254,9 +254,9 @@ struct GPe_Adam <: CompositeBlox
end
end
parts = n_inh

bc = connector_from_graph(g)

sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name)

m = if isnothing(namespace)
Expand Down Expand Up @@ -323,9 +323,9 @@ struct STN_Adam <: CompositeBlox
end
end
parts = n_exci

bc = connector_from_graph(g)

bc = connector_from_graph(g)

sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name)

m = if isnothing(namespace)
Expand All @@ -334,6 +334,7 @@ struct STN_Adam <: CompositeBlox
sys_namespace = System(Equation[], t; name=namespaced_name(namespace, name))
[s for s in unknowns.((sys_namespace,), unknowns(sys)) if contains(string(s), "V(t)")]
end

new(namespace, parts, sys, bc, m, connection_matrix)
end

Expand Down
43 changes: 33 additions & 10 deletions src/blox/blox_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ function get_neurons(vn::AbstractVector{<:AbstractBlox})
end

get_parts(blox::CompositeBlox) = blox.parts
get_parts(blox::Union{AbstractBlox, ObserverBlox}) = blox

get_components(blox::CompositeBlox) = mapreduce(x -> get_components(x), vcat, get_parts(blox))
get_components(blox::Vector{<:AbstractBlox}) = mapreduce(x -> get_components(x), vcat, blox)
Expand Down Expand Up @@ -97,14 +98,15 @@ end
get_namespaced_sys(sys::AbstractODESystem) = sys

nameof(blox) = (nameof get_system)(blox)
nameof(blox::AbstractActionSelection) = blox.name

namespaceof(blox) = blox.namespace

namespaced_nameof(blox) = namespaced_name(inner_namespaceof(blox), nameof(blox))

"""
Returns the complete namespace EXCLUDING the outermost (highest) level.
This is useful for manually preparing equations (e.g. connections, see BloxConnector),
This is useful for manually preparing equations (e.g. connections, see Connector),
that will later be composed and will automatically get the outermost namespace.
"""
function inner_namespaceof(blox)
Expand All @@ -119,7 +121,7 @@ end
namespaced_name(parent_name, name) = Symbol(parent_name, :₊, name)
namespaced_name(::Nothing, name) = Symbol(name)

function find_eq(eqs::AbstractVector{<:Equation}, lhs)
function find_eq(eqs::Union{AbstractVector{<:Equation}, Equation}, lhs)
findfirst(eqs) do eq
lhs_vars = get_variables(eq.lhs)
length(lhs_vars) == 1 && isequal(only(lhs_vars), lhs)
Expand All @@ -136,7 +138,7 @@ end
the higher-level namespaces will be added to them.
If blox isa AbstractComponent, it is assumed that it contains a `connector` field,
which holds a `BloxConnector` object with all relevant connections
which holds a `Connector` object with all relevant connections
from lower levels and this level.
"""
function get_input_equations(blox::Union{AbstractBlox, ObserverBlox})
Expand All @@ -162,28 +164,28 @@ function get_input_equations(blox::Union{AbstractBlox, ObserverBlox})
end

get_connector(blox::Union{CompositeBlox, Agent}) = blox.connector
get_connector(blox) = Connector(namespaced_nameof(blox), namespaced_nameof(blox))

get_input_equations(bc::BloxConnector) = bc.eqs
get_input_equations(blox::Union{CompositeBlox, AbstractComponent}) = (get_input_equations get_connector)(blox)
get_input_equations(bc::Connector) = bc.equation
get_input_equations(blox) = []

get_weight_parameters(bc::BloxConnector) = bc.weights
get_weight_parameters(bc::Connector) = bc.weights
get_weight_parameters(blox::Union{CompositeBlox, AbstractComponent}) = (get_weight_parameters get_connector)(blox)
get_weight_parameters(blox) = Num[]

get_delay_parameters(bc::BloxConnector) = bc.delays
get_delay_parameters(bc::Connector) = bc.delays
get_delay_parameters(blox::Union{CompositeBlox, AbstractComponent}) = (get_delay_parameters get_connector)(blox)
get_delay_parameters(blox) = Num[]

get_discrete_callbacks(bc::BloxConnector) = bc.discrete_callbacks
get_discrete_callbacks(bc::Connector) = bc.discrete_callbacks
get_discrete_callbacks(blox::Union{CompositeBlox, AbstractComponent}) = (get_discrete_callbacks get_connector)(blox)
get_discrete_callbacks(blox) = []

get_spike_affects(bc::BloxConnector) = bc.spike_affects
get_spike_affects(bc::Connector) = bc.spike_affects
get_spike_affects(blox::Union{CompositeBlox, AbstractComponent}) = (get_spike_affects get_connector)(blox)
get_spike_affects(blox) = Dict{Symbol, Tuple{Vector{Num}, Vector{Num}}}()

get_weight_learning_rules(bc::BloxConnector) = bc.learning_rules
get_weight_learning_rules(bc::Connector) = bc.learning_rules
get_weight_learning_rules(blox::Union{CompositeBlox, AbstractComponent}) = (get_weight_learning_rules get_connector)(blox)
get_weight_learning_rules(blox) = Dict{Num, AbstractLearningRule}()

Expand Down Expand Up @@ -253,6 +255,14 @@ function get_connection_matrix(kwargs, name_out, name_in, N_out, N_in)
connection_matrix
end

function get_learning_rule(kwargs, name_src, name_dest)
if haskey(kwargs, :learning_rule)
return deepcopy(kwargs[:learning_rule])
else
return NoLearningRule()
end
end

function get_weights(agent::Agent, blox_out, blox_in)
ps = parameters(agent.odesystem)
pv = agent.problem.p
Expand Down Expand Up @@ -622,3 +632,16 @@ function get_sampling_info(sol::SciMLBase.AbstractSolution; sampling_rate=nothin
return nothing, 1000 / sampling_rate
end
end

function narrowtype_union(d::Dict)
types = unique(typeof.(values(d)))
U = Union{types...}

return U
end

function narrowtype(d::Dict)
U = narrowtype_union(d)

return Dict{Num, U}(d)
end
Loading

0 comments on commit 0b321a9

Please sign in to comment.