Skip to content

Commit

Permalink
Merge pull request #337 from Neuroblox/ho/quick_perf_improve
Browse files Browse the repository at this point in the history
A few quick performance improvements
  • Loading branch information
harisorgn authored Feb 9, 2024
2 parents 76450d4 + 61a4b4a commit e7bf50d
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 32 deletions.
22 changes: 11 additions & 11 deletions src/blox/connections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ mutable struct BloxConnector
BloxConnector() = new(Equation[], Num[], Num[], Pair{Any, Vector{Equation}}[], Dict{Num, AbstractLearningRule}())

function BloxConnector(bloxs)
eqs = reduce(vcat, input_equations.(bloxs))
weights = reduce(vcat, weight_parameters.(bloxs))
delays = reduce(vcat, delay_parameters.(bloxs))
events = reduce(vcat, event_callbacks.(bloxs))
learning_rules = reduce(merge, weight_learning_rules.(bloxs))
eqs = mapreduce(input_equations, vcat, bloxs)
weights = mapreduce(weight_parameters, vcat, bloxs)
delays = mapreduce(delay_parameters, vcat, bloxs)
events = mapreduce(event_callbacks, vcat, bloxs)
learning_rules = mapreduce(weight_learning_rules, merge, bloxs)

new(eqs, weights, delays, events, learning_rules)
end
Expand Down Expand Up @@ -41,14 +41,14 @@ function get_callbacks(g, bc; t_block=missing)

end
if !isempty(eqs_params) && !isempty(eqs)
cbs_spikes = (t_block + eps(float(t_block))) => eqs
cbs_params = (t_block - eps(float(t_block))) => eqs_params
cbs_spikes = (t_block + sqrt(eps(float(t_block)))) => eqs
cbs_params = (t_block - sqrt(eps(float(t_block)))) => eqs_params
return vcat(cbs_params, cbs_spikes, bc.events)
elseif isempty(eqs_params) && !isempty(eqs)
cbs_spikes = (t_block + eps(float(t_block))) => eqs
cbs_spikes = (t_block + sqrt(eps(float(t_block)))) => eqs
return vcat(cbs_spikes, bc.events)
elseif !isempty(eqs_params) && isempty(eqs)
cbs_params = (t_block - eps(float(t_block))) => eqs_params
cbs_params = (t_block - sqrt(eps(float(t_block)))) => eqs_params
return vcat(cbs_params, bc.events)
else
return bc.events
Expand Down Expand Up @@ -564,7 +564,7 @@ sample_poisson(λ) = rand(Poisson(λ))
Non-symbolic, time-block-based way of `@register_symbolic sample_poisson(λ)`.
"""
function sample_affect!(integ, u, p, ctx)
R = minimum([integ.p[p[1]]/(integ.p[p[2]] + eps()), integ.p[p[1]]])
R = min(integ.p[p[1]]/(integ.p[p[2]] + sqrt(eps())), integ.p[p[1]])
v = rand(Poisson(R))
integ.p[p[3]] = v
end
Expand All @@ -585,7 +585,7 @@ function (bc::BloxConnector)(
end

t_event = get_event_time(kwargs, nameof(discr_out), nameof(discr_in))
cb = [t_event+eps(t_event)] => (sample_affect!, [], [sys_out.κ, sys_out.jcn, sys_in.TAN_spikes], nothing)
cb = [t_event+sqrt(eps(t_event))] => (sample_affect!, [], [sys_out.κ, sys_out.jcn, sys_in.TAN_spikes], nothing)
push!(bc.events, cb)

eq = sys_in.jcn ~ w*sys_in.TAN_spikes
Expand Down
10 changes: 5 additions & 5 deletions src/blox/discrete.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct Matrisome <: AbstractDiscrete
cb_eqs = [ jcn_ ~ jcn,
H_ ~ H
]
Rho_cb = [[t_event+3*eps(t_event)] => cb_eqs]
Rho_cb = [[t_event + sqrt(eps(t_event))] => cb_eqs]
sys = ODESystem(eqs, t, sts, ps; name = name, discrete_events = Rho_cb)

new(sys, namespace)
Expand Down Expand Up @@ -53,7 +53,7 @@ struct TAN <: AbstractDiscrete
sts = @variables R(t)=κ
ps = @parameters κ=κ jcn=0.0 [input=true]
eqs = [
R ~ minimum([κ, κ/*jcn + eps())])
R ~ min(κ, κ/*jcn + sqrt(eps())))
]
sys = ODESystem(eqs, t, sts, ps; name)

Expand All @@ -73,11 +73,11 @@ struct SNc <: AbstractModulator
sts = @variables R(t)=κ_DA R_(t)=κ_DA
ps = @parameters κ=κ_DA λ_DA=λ_DA jcn=0.0 [input=true] jcn_=0.0 #HACK: jcn_ stores the value of jcn at time t_event that can be accessed after the simulation
eqs = [
R ~ minimum([κ_DA, κ_DA/(λ_DA*jcn + eps())]),
R_ ~ minimum([κ_DA, κ_DA/(λ_DA*jcn_ + eps())])
R ~ min(κ_DA, κ_DA/(λ_DA*jcn + sqrt(eps()))),
R_ ~ min(κ_DA, κ_DA/(λ_DA*jcn_ + sqrt(eps())))
]

R_cb = [[t_event+3*eps(t_event)] => [jcn_ ~ jcn]]
R_cb = [[t_event + sqrt(eps(t_event))] => [jcn_ ~ jcn]]

sys = ODESystem(eqs, t, sts, ps; name = name, discrete_events = R_cb)

Expand Down
14 changes: 1 addition & 13 deletions src/blox/neuron_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ struct HHNeuronExciBlox <: AbstractExciNeuronBlox
function HHNeuronExciBlox(;
name,
namespace=nothing,
t_spike_window=90.0,
θ_spike=0.0,
E_syn=0.0,
G_syn=3,
I_bg=0,
Expand Down Expand Up @@ -196,18 +194,15 @@ struct HHNeuronExciBlox <: AbstractExciNeuronBlox
D(G)~(-1/τ₂)*G + z,
D(z)~(-1/τ₁)*z + G_asymp(V,G_syn),
D(Gₛₜₚ)~(-1/τ₃)*Gₛₜₚ + (z/5)*(kₛₜₚ-Gₛₜₚ),
# HACK : need to define a Differential equation for spikes
# HACK : need to define a Differential equation for spike counting
# the alternative of having it as an algebraic equation with [irreducible=true]
# leads to incorrect or unstable solutions. Needs more attention!
D(spikes_cumulative) ~ spk_const*G_asymp(V,G_syn),
D(spikes_window) ~ spk_const*G_asymp(V,G_syn)
]

# spike_reset_cb = [(t_spike_window + eps(float(t_spike_window))) => [spikes_window ~ 0]]

sys = ODESystem(
eqs, t, sts, ps;
#name = Symbol(name),discrete_events = spike_reset_cb
name = Symbol(name)
)

Expand All @@ -221,8 +216,6 @@ struct HHNeuronInhibBlox <: AbstractInhNeuronBlox
function HHNeuronInhibBlox(;
name,
namespace = nothing,
t_spike_window=90.0,
θ_spike=0.0,
E_syn=-70.0,
G_syn=11.5,
I_bg=0,
Expand Down Expand Up @@ -283,15 +276,10 @@ struct HHNeuronInhibBlox <: AbstractInhNeuronBlox
D(h)~ϕ*(αₕ(V)*(1-h)-βₕ(V)*h),
D(G)~(-1/τ₂)*G + z,
D(z)~(-1/τ₁)*z + G_asymp(V,G_syn)
#D(spikes_cumulative) ~ spk_const*G_asymp(V,G_syn),
#D(spikes_window) ~ spk_const*G_asymp(V,G_syn)
]

# spike_reset_cb = [(t_spike_window + eps(float(t_spike_window))) => [spikes_window ~ 0]]

sys = ODESystem(
eqs, t, sts, ps;
# name = Symbol(name), discrete_events = spike_reset_cb
name = Symbol(name)
)

Expand Down
6 changes: 3 additions & 3 deletions test/components.jl
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,8 @@ end
# end

@testset "HH Neuron excitatory & inhibitory network" begin
nn1 = HHNeuronExciBlox(name=Symbol("nrn1"), I_bg=3, freq=4; t_spike_window=0.1)
nn2 = HHNeuronExciBlox(name=Symbol("nrn2"), I_bg=2, freq=6; t_spike_window=0.1)
nn1 = HHNeuronExciBlox(name=Symbol("nrn1"), I_bg=3, freq=4)
nn2 = HHNeuronExciBlox(name=Symbol("nrn2"), I_bg=2, freq=6)
nn3 = HHNeuronInhibBlox(name=Symbol("nrn3"), I_bg=2, freq=3)
assembly = [nn1, nn2, nn3]

Expand All @@ -358,7 +358,7 @@ end
@testset "NextGenerationEIBlox connected to neuron" begin
global_ns = :g
@named LC = NextGenerationEIBlox(;namespace=global_ns, Cₑ=2*26,Cᵢ=1*26, Δₑ=0.5, Δᵢ=0.5, η_0ₑ=10.0, η_0ᵢ=0.0, v_synₑₑ=10.0, v_synₑᵢ=-10.0, v_synᵢₑ=10.0, v_synᵢᵢ=-10.0, alpha_invₑₑ=10.0/26, alpha_invₑᵢ=0.8/26, alpha_invᵢₑ=10.0/26, alpha_invᵢᵢ=0.8/26, kₑₑ=0.0*26, kₑᵢ=0.6*26, kᵢₑ=0.6*26, kᵢᵢ=0*26)
@named nn = HHNeuronExciBlox(;namespace=global_ns, t_spike_window=0.1)
@named nn = HHNeuronExciBlox(;namespace=global_ns)
assembly = [LC, nn]
g = MetaDiGraph()
add_blox!.(Ref(g), assembly)
Expand Down

0 comments on commit e7bf50d

Please sign in to comment.