Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change BloxConnector to Connector #502

Merged
merged 14 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/Neuroblox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ include("blox/winnertakeall.jl")
include("blox/subcortical_blox.jl")
include("blox/stochastic.jl")
include("blox/discrete.jl")
include("blox/ping_neuron_examples.jl")
include("blox/reinforcement_learning.jl")
include("gui/GUI.jl")
include("blox/connections.jl")
include("blox/blox_utilities.jl")
include("blox/ping_neuron_examples.jl")
include("GraphDynamicsInterop/GraphDynamicsInterop.jl")
include("Neurographs.jl")
include("adjacency.jl")
Expand Down Expand Up @@ -256,4 +256,5 @@ export meanfield, meanfield!, rasterplot, rasterplot!, stackplot, stackplot!, fr
export powerspectrumplot, powerspectrumplot!, welch_pgram, periodogram, hanning, hamming
export detect_spikes, mean_firing_rate, firing_rate
export voltage_timeseries, meanfield_timeseries, state_timeseries, get_neurons, get_exci_neurons, get_inh_neurons, get_neuron_color
export AdjacencyMatrix, Connector, connection_rule, connection_equation
end
60 changes: 36 additions & 24 deletions src/Neurographs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,30 +52,38 @@ get_dynamics_bloxs(blox::CompositeBlox) = get_parts(blox)

flatten_graph(g::MetaDiGraph) = mapreduce(get_dynamics_bloxs, vcat, get_bloxs(g))

function connector_from_graph(g::MetaDiGraph)
bloxs = get_bloxs(g)
link = BloxConnector(bloxs)
function connectors_from_graph(g::MetaDiGraph)
conns = get_connector.(get_bloxs(g))
for edge in edges(g)

for v in vertices(g)
b = get_prop(g, v, :blox)
for vn in inneighbors(g, v)
bn = get_prop(g, vn, :blox)
kwargs = props(g, vn, v)
link(bn, b; kwargs...)
end
blox_src = get_prop(g, edge.src, :blox)
blox_dest = get_prop(g, edge.dst, :blox)

kwargs = props(g, edge)
push!(conns, Connector(blox_src, blox_dest; kwargs...))
end
return link

filter!(conn -> !isempty(conn), conns)

return conns
end

function connector_from_graph(g::MetaDiGraph)
conns = connectors_from_graph(g)

return isempty(conns) ? Connector(:none, :none) : reduce(merge!, conns)
end

# Helper function to get delays from a graph
function graph_delays(g::MetaDiGraph)
bc = connector_from_graph(g)
return bc.delays
conn = connector_from_graph(g)

return conn.delay
end

generate_discrete_callbacks(blox, ::BloxConnector; t_block = missing) = []
generate_discrete_callbacks(blox, ::Connector; t_block = missing) = []

function generate_discrete_callbacks(blox::Union{LIFExciNeuron, LIFInhNeuron}, bc::BloxConnector; t_block = missing)
function generate_discrete_callbacks(blox::Union{LIFExciNeuron, LIFInhNeuron}, bc::Connector; t_block = missing)
spike_affects = get_spike_affects(bc)
name_blox = namespaced_nameof(blox)
sys = get_namespaced_sys(blox)
Expand Down Expand Up @@ -116,7 +124,7 @@ function generate_discrete_callbacks(blox::Union{LIFExciNeuron, LIFInhNeuron}, b
return cb
end

function generate_discrete_callbacks(blox::HHNeuronExciBlox, ::BloxConnector; t_block = missing)
function generate_discrete_callbacks(blox::HHNeuronExciBlox, ::Connector; t_block = missing)
if !ismissing(t_block)
nn = get_namespaced_sys(blox)
eq = nn.spikes_window ~ 0
Expand All @@ -128,7 +136,7 @@ function generate_discrete_callbacks(blox::HHNeuronExciBlox, ::BloxConnector; t_
end
end

function generate_discrete_callbacks(bc::BloxConnector; t_block = missing)
function generate_discrete_callbacks(bc::Connector; t_block = missing)
eqs_params = get_equations_with_parameter_lhs(bc)

if !ismissing(t_block) && !isempty(eqs_params)
Expand All @@ -139,7 +147,7 @@ function generate_discrete_callbacks(bc::BloxConnector; t_block = missing)
end
end

function generate_discrete_callbacks(g::MetaDiGraph, bc::BloxConnector; t_block = missing)
function generate_discrete_callbacks(g::MetaDiGraph, bc::Connector; t_block = missing)
bloxs = flatten_graph(g)

cbs = mapreduce(vcat, bloxs) do blox
Expand Down Expand Up @@ -175,30 +183,34 @@ function system_from_graph(g::MetaDiGraph, p::Vector{Num}=Num[]; name=nothing, t
isempty(p) || error(ArgumentError("The GraphDynamics.jl backend does yet support extra parameter lists. Got $p."))
GraphDynamicsInterop.graphsystem_from_graph(g; kwargs...)
else
bc = connector_from_graph(g)
if isnothing(name)
throw(UndefKeywordError(:name))
end

bc = connector_from_graph(g)

return system_from_graph(g, bc, p; name, t_block, simplify, kwargs...)
end
end

function system_from_graph(g::MetaDiGraph, bc::BloxConnector, p::Vector{Num}=Num[];
name, t_block=missing, simplify=true, simplify_kwargs...)
blox_syss = get_system(g)
function system_from_graph(g::MetaDiGraph, bc::Connector, p::Vector{Num}=Num[]; name=nothing, t_block=missing, simplify=true, graphdynamics=false, kwargs...)
bloxs = get_bloxs(g)
blox_syss = get_system.(bloxs)

accumulate_equations!(bc, bloxs)

connection_eqs = get_equations_with_state_lhs(bc)

discrete_cbs = identity.(generate_discrete_callbacks(g, bc; t_block))

sys = compose(System(connection_eqs, t, [], vcat(params(bc), p); name, discrete_events = discrete_cbs), blox_syss)
if simplify
structural_simplify(sys; simplify_kwargs...)
structural_simplify(sys; kwargs...)
else
sys
end
end


function system_from_parts(parts::AbstractVector; name)
return compose(System(Equation[], t; name), get_system.(parts))
end
Expand Down
58 changes: 40 additions & 18 deletions src/adjacency.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,33 @@ struct AdjacencyMatrix
names::Vector{Symbol}
end

function AdjacencyMatrix(name)
return AdjacencyMatrix(spzeros(1,1), [name])
function AdjacencyMatrix(names::AbstractVector)
return AdjacencyMatrix(spzeros(1,1), names)
end

function Base.merge(adj1::AdjacencyMatrix, adj2::AdjacencyMatrix)
return AdjacencyMatrix(
cat(adj1.matrix, adj2.matrix; dims=(1,2)),
vcat(adj1.names, adj2.names)
)
end
function AdjacencyMatrix(C::Connector)
weights = C.weight
srcs = C.source
dests = C.destination
names = unique(vcat(srcs, dests))
sort!(names)

get_adjacency(bc::BloxConnector) = bc.adjacency
get_adjacency(blox::CompositeBlox) = (get_adjacency ∘ get_connector)(blox)
get_adjacency(blox) = AdjacencyMatrix(namespaced_nameof(blox))
ADJ = AdjacencyMatrix(spzeros(length(names), length(names)), names)
for i in eachindex(srcs)
add_adjacency_edge!(ADJ, srcs[i], dests[i], weights[i])
end

function get_adjacency(g::MetaDiGraph)
bc = connector_from_graph(g)
return get_adjacency(bc)
return ADJ
end

function get_adjacency(bc::BloxConnector, sys::AbstractODESystem, prob::ODEProblem)
A = get_adjacency(bc)
AdjacencyMatrix(blox::CompositeBlox) = AdjacencyMatrix(get_connector(blox))

AdjacencyMatrix(blox) = AdjacencyMatrix(namespaced_nameof(blox))

AdjacencyMatrix(g::MetaDiGraph) = AdjacencyMatrix(connector_from_graph(g))

function AdjacencyMatrix(bc::Connector, sys::AbstractODESystem, prob::ODEProblem)
A = AdjacencyMatrix(bc)
names = A.names
mat = A.matrix

Expand All @@ -43,12 +48,29 @@ function get_adjacency(bc::BloxConnector, sys::AbstractODESystem, prob::ODEProbl
return AdjacencyMatrix(S, names)
end

function get_adjacency(agent::Agent)
function AdjacencyMatrix(agent::Agent)
prob = agent.problem
sys = get_system(agent)
bc = get_connector(agent)

return get_adjacency(bc, sys, prob)
return AdjacencyMatrix(bc, sys, prob)
end

function add_adjacency_edge!(ADJ::AdjacencyMatrix, name_src, name_dest, weight)
src_idx = findfirst(x -> isequal(name_src, x), ADJ.names)
dest_idx = findfirst(x -> isequal(name_dest, x), ADJ.names)

weight_def = ModelingToolkit.getdefault(weight)
weight_value = substitute(weight_def, map(x -> x => ModelingToolkit.getdefault(x), Symbolics.get_variables(weight_def)))

ADJ.matrix[src_idx, dest_idx] = weight_value
end

function Base.merge(adj1::AdjacencyMatrix, adj2::AdjacencyMatrix)
return AdjacencyMatrix(
cat(adj1.matrix, adj2.matrix; dims=(1,2)),
vcat(adj1.names, adj2.names)
)
end

function adjmatrixfromdigraph(g::MetaDiGraph)
Expand Down
17 changes: 9 additions & 8 deletions src/blox/DBS_Model_Blox_Adam_Brown.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ struct Striatum_MSN_Adam <: CompositeBlox
end
end
parts = n_inh

bc = connector_from_graph(g)

sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name)

m = if isnothing(namespace)
Expand Down Expand Up @@ -184,9 +184,9 @@ struct Striatum_FSI_Adam <: CompositeBlox
end

parts = n_inh

bc = connector_from_graph(g)

sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name)

m = if isnothing(namespace)
Expand Down Expand Up @@ -254,9 +254,9 @@ struct GPe_Adam <: CompositeBlox
end
end
parts = n_inh

bc = connector_from_graph(g)

sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name)

m = if isnothing(namespace)
Expand Down Expand Up @@ -323,9 +323,9 @@ struct STN_Adam <: CompositeBlox
end
end
parts = n_exci

bc = connector_from_graph(g)

bc = connector_from_graph(g)

sys = isnothing(namespace) ? system_from_graph(g, bc; name, simplify=false) : system_from_parts(parts; name)

m = if isnothing(namespace)
Expand All @@ -334,6 +334,7 @@ struct STN_Adam <: CompositeBlox
sys_namespace = System(Equation[], t; name=namespaced_name(namespace, name))
[s for s in unknowns.((sys_namespace,), unknowns(sys)) if contains(string(s), "V(t)")]
end

new(namespace, parts, sys, bc, m, connection_matrix)
end

Expand Down
43 changes: 33 additions & 10 deletions src/blox/blox_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ function get_neurons(vn::AbstractVector{<:AbstractBlox})
end

get_parts(blox::CompositeBlox) = blox.parts
get_parts(blox::Union{AbstractBlox, ObserverBlox}) = blox

get_components(blox::CompositeBlox) = mapreduce(x -> get_components(x), vcat, get_parts(blox))
get_components(blox::Vector{<:AbstractBlox}) = mapreduce(x -> get_components(x), vcat, blox)
Expand Down Expand Up @@ -97,14 +98,15 @@ end
get_namespaced_sys(sys::AbstractODESystem) = sys

nameof(blox) = (nameof ∘ get_system)(blox)
nameof(blox::AbstractActionSelection) = blox.name

namespaceof(blox) = blox.namespace

namespaced_nameof(blox) = namespaced_name(inner_namespaceof(blox), nameof(blox))

"""
Returns the complete namespace EXCLUDING the outermost (highest) level.
This is useful for manually preparing equations (e.g. connections, see BloxConnector),
This is useful for manually preparing equations (e.g. connections, see Connector),
that will later be composed and will automatically get the outermost namespace.
"""
function inner_namespaceof(blox)
Expand All @@ -119,7 +121,7 @@ end
namespaced_name(parent_name, name) = Symbol(parent_name, :₊, name)
namespaced_name(::Nothing, name) = Symbol(name)

function find_eq(eqs::AbstractVector{<:Equation}, lhs)
function find_eq(eqs::Union{AbstractVector{<:Equation}, Equation}, lhs)
findfirst(eqs) do eq
lhs_vars = get_variables(eq.lhs)
length(lhs_vars) == 1 && isequal(only(lhs_vars), lhs)
Expand All @@ -136,7 +138,7 @@ end
the higher-level namespaces will be added to them.

If blox isa AbstractComponent, it is assumed that it contains a `connector` field,
which holds a `BloxConnector` object with all relevant connections
which holds a `Connector` object with all relevant connections
from lower levels and this level.
"""
function get_input_equations(blox::Union{AbstractBlox, ObserverBlox})
Expand All @@ -162,28 +164,28 @@ function get_input_equations(blox::Union{AbstractBlox, ObserverBlox})
end

get_connector(blox::Union{CompositeBlox, Agent}) = blox.connector
get_connector(blox) = Connector(namespaced_nameof(blox), namespaced_nameof(blox))

get_input_equations(bc::BloxConnector) = bc.eqs
get_input_equations(blox::Union{CompositeBlox, AbstractComponent}) = (get_input_equations ∘ get_connector)(blox)
get_input_equations(bc::Connector) = bc.equation
get_input_equations(blox) = []

get_weight_parameters(bc::BloxConnector) = bc.weights
get_weight_parameters(bc::Connector) = bc.weights
get_weight_parameters(blox::Union{CompositeBlox, AbstractComponent}) = (get_weight_parameters ∘ get_connector)(blox)
get_weight_parameters(blox) = Num[]

get_delay_parameters(bc::BloxConnector) = bc.delays
get_delay_parameters(bc::Connector) = bc.delays
get_delay_parameters(blox::Union{CompositeBlox, AbstractComponent}) = (get_delay_parameters ∘ get_connector)(blox)
get_delay_parameters(blox) = Num[]

get_discrete_callbacks(bc::BloxConnector) = bc.discrete_callbacks
get_discrete_callbacks(bc::Connector) = bc.discrete_callbacks
get_discrete_callbacks(blox::Union{CompositeBlox, AbstractComponent}) = (get_discrete_callbacks ∘ get_connector)(blox)
get_discrete_callbacks(blox) = []

get_spike_affects(bc::BloxConnector) = bc.spike_affects
get_spike_affects(bc::Connector) = bc.spike_affects
get_spike_affects(blox::Union{CompositeBlox, AbstractComponent}) = (get_spike_affects ∘ get_connector)(blox)
get_spike_affects(blox) = Dict{Symbol, Tuple{Vector{Num}, Vector{Num}}}()

get_weight_learning_rules(bc::BloxConnector) = bc.learning_rules
get_weight_learning_rules(bc::Connector) = bc.learning_rules
get_weight_learning_rules(blox::Union{CompositeBlox, AbstractComponent}) = (get_weight_learning_rules ∘ get_connector)(blox)
get_weight_learning_rules(blox) = Dict{Num, AbstractLearningRule}()

Expand Down Expand Up @@ -253,6 +255,14 @@ function get_connection_matrix(kwargs, name_out, name_in, N_out, N_in)
connection_matrix
end

function get_learning_rule(kwargs, name_src, name_dest)
if haskey(kwargs, :learning_rule)
return deepcopy(kwargs[:learning_rule])
else
return NoLearningRule()
end
end

function get_weights(agent::Agent, blox_out, blox_in)
ps = parameters(agent.odesystem)
pv = agent.problem.p
Expand Down Expand Up @@ -622,3 +632,16 @@ function get_sampling_info(sol::SciMLBase.AbstractSolution; sampling_rate=nothin
return nothing, 1000 / sampling_rate
end
end

function narrowtype_union(d::Dict)
types = unique(typeof.(values(d)))
U = Union{types...}

return U
end

function narrowtype(d::Dict)
U = narrowtype_union(d)

return Dict{Num, U}(d)
end
Loading
Loading