Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adam ketamine dynamics #558

Draft
wants to merge 18 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ ExponentialUtilities = "d4d017d3-3776-5f7e-afef-a10c40355c18"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GraphDynamics = "bcd5d0fe-e6b7-4ef1-9848-780c183c7f4c"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Expand Down
2 changes: 1 addition & 1 deletion examples/adams_example_for_brain_r01.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ add_edge!(g, exci_PING => inhi_PING; weight=10.0)
add_edge!(g, inhi_PING => exci_PING; weight=10.0)

@named sys = system_from_graph(g)
sys = structural_simplify(sys)
#sys = structural_simplify(sys)

sim_dur = 1000.0
prob = SDEProblem(sys, [], (0.0, sim_dur))
Expand Down
112 changes: 112 additions & 0 deletions examples/qif_ngnmm_learning_demo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
using DifferentialEquations
using Distributions
using Statistics
using Random
using Plots

function qif_ngnmm_params(;Δₑ=1.0,
τₘₑ=20.0,
Hₑ=1.3,
Jₑₑ=8.0,
Jₑᵢ=10.0,
Δᵢ=1.0,
τₘᵢ=10.0,
Hᵢ=-5.0,
Jᵢₑ=10.0,
Jᵢᵢ=0.0,
w₁¹=1.0,
w₂¹=5.0,
w₁²=1.0,
w₂²=5.0,
curr_stim=0.0)

return [Δₑ, τₘₑ, Hₑ, Jₑₑ, Jₑᵢ, Δᵢ, τₘᵢ, Hᵢ, Jᵢₑ, Jᵢᵢ, w₁¹, w₂¹, w₁², w₂², curr_stim]
end

p = qif_ngnmm_params()

function simple_cs_learning!(du, u, p, t)
Δₑ, τₘₑ, Hₑ, Jₑₑ, Jₑᵢ, Δᵢ, τₘᵢ, Hᵢ, Jᵢₑ, Jᵢᵢ, w₁¹, w₂¹, w₁², w₂², curr_stim = p

rₑ¹, Vₑ¹, rᵢ¹, Vᵢ¹, rₑ², Vₑ², rᵢ², Vᵢ², I₁, I₂ = u

du[1] = Δₑ/(π*τₘₑ^2) + 2*rₑ¹*Vₑ¹/τₘₑ
du[2] = (Vₑ¹^2 + Hₑ + 4*sin(5*2*π*t/1000))/τₘₑ - τₘₑ*(π*rₑ¹)^2 + Jₑₑ*rₑ¹ + Jᵢₑ*rᵢ¹ + w₁¹*I₁ + w₂¹*I₂ - 1*rₑ²
du[3] = Δᵢ/(π*τₘᵢ^2) + 2*rᵢ¹*Vᵢ¹/τₘᵢ
du[4] = (Vᵢ¹^2 + Hᵢ + 2*sin(5*2*π*t/1000))/τₘᵢ - τₘᵢ*(π*rᵢ¹)^2 + Jₑᵢ*rₑ¹ + Jᵢᵢ*rᵢ¹
du[5] = Δₑ/(π*τₘₑ^2) + 2*rₑ²*Vₑ²/τₘₑ
du[6] = (Vₑ²^2 + Hₑ + 4*sin(5*2*π*(t-100)/1000))/τₘₑ - τₘₑ*(π*rₑ²)^2 + Jₑₑ*rₑ² + Jᵢₑ*rᵢ² + w₁²*I₁ + w₂²*I₂ - 1*rₑ¹
du[7] = Δᵢ/(π*τₘᵢ^2) + 2*rᵢ²*Vᵢ²/τₘᵢ
du[8] = (Vᵢ²^2 + Hᵢ + 2*sin(5*2*π*(t-100)/1000))/τₘᵢ - τₘᵢ*(π*rᵢ²)^2 + Jₑᵢ*rₑ² + Jᵢᵢ*rᵢ²
du[9] = -u[9]/30
du[10] = -u[10]/30
end


max_time = 100000.0
stim1_times = collect(250:500:max_time)
stim2_times = collect(500:500:max_time)
all_stim_times = sort(vcat(stim1_times, stim2_times))

condtion_s1(u, t, integrator) = t ∈ stim1_times
condtion_s2(u, t, integrator) = t ∈ stim2_times

function affect_s1!(integrator)
integrator.u[9] += 30.0
integrator.p[15] = 1.0
end

function affect_s2!(integrator)
integrator.u[10] += 30.0
integrator.p[15] = 2.0
end

cb_s1 = DiscreteCallback(condtion_s1, affect_s1!)
cb_s2 = DiscreteCallback(condtion_s2, affect_s2!)

eval_times = all_stim_times[2:end]
eval_times .-= 50.0

condition_eval(u, t, integrator) = t ∈ eval_times

all_choices = []

function learn!(integrator)
u = integrator.u
p = integrator.p

str1 = u[2]
str2 = u[6]

choice = str1 > str2

if p[15] == 1 && choice
p[11] += 0.1*rand()
p[11] = min(p[11], 5.0)
push!(all_choices, 1.0)
elseif p[15] == 1 && !choice
p[13] -= 0.1*rand()
p[13] = max(p[13], -5.0)
push!(all_choices, 0.0)
elseif p[15] == 2 && !choice
p[14] += 0.1*rand()
p[14] = min(p[14], 5.0)
push!(all_choices, 1.0)
elseif p[15] == 2 && choice
p[12] -= 0.1*rand()
p[12] = max(p[12], -5.0)
push!(all_choices, 0.0)
end
end

cb_learn = DiscreteCallback(condition_eval, learn!)

tstop = sort(vcat(all_stim_times, eval_times))

u₀ = zeros(10)

tspan = (0.0, max_time)
cbs = CallbackSet(cb_s1, cb_s2, cb_learn)
prob = ODEProblem(simple_cs_learning!, u₀, tspan, p)
sol = solve(prob, Tsit5(), callback = cbs, tstops = tstop, saveat=1.0)
plot(all_choices)
9 changes: 8 additions & 1 deletion src/GraphDynamicsInterop/GraphDynamicsInterop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,14 @@ using ..Neuroblox:
PINGNeuronInhib,
AbstractPINGNeuron,
Connector,
VanDerPol
VanDerPol,
AdamPYR,
AdamINP,
AbstractAdamNeuron,
AdamNMDAR,
AbstractReceptor,
AdamGlu,
AbstractNeurotransmitter

using GraphDynamics:
GraphDynamics,
Expand Down
53 changes: 53 additions & 0 deletions src/GraphDynamicsInterop/connection_interop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1047,3 +1047,56 @@ function (c::PINGConnection)(blox_src::Subsystem{PINGNeuronInhib}, blox_dst::Sub
(; jcn = w * s * (V_I - V))
end

# #-------------------------
# Adam Network
# #-------------------------
struct AdamConnection <: ConnectionRule
w::Float64
V_E::Float64
V_I::Float64
end
Base.zero(::Type{AdamConnection}) = AdamConnection(0.0, 0.0, 0.0)

function get_connection(blox_src::AdamPYR, blox_dst::AbstractAdamNeuron, kwargs)
(;w_val, name) = generate_weight_param(blox_src, blox_dst, kwargs)
V_E = get(kwargs, :V_E, 0.0)
(; conn = AdamConnection(w_val, V_E, 0.0), names=[name])
end
function get_connection(blox_src::AdamINP, blox_dst::AbstractAdamNeuron, kwargs)
(;w_val, name) = generate_weight_param(blox_src, blox_dst, kwargs)
V_I = get(kwargs, :V_I, 0.0)
(; conn = AdamConnection(w_val, V_I, -80.0), names=[name])
end

function (c::AdamConnection)(blox_src::Subsystem{AdamPYR}, blox_dst::Subsystem{<:AbstractAdamNeuron})
(; w, V_E) = c
(; sₐₘₚₐ) = blox_src
(; V) = blox_dst
(; jcn = w * sₐₘₚₐ * (V - V_E))
end

function (c::AdamConnection)(blox_src::Subsystem{AdamINP}, blox_dst::Subsystem{<:AbstractAdamNeuron})
(; w, V_I) = c
(; sᵧ) = blox_src
(; V) = blox_dst
(; jcn = w * sᵧ * (V - V_I))
end

function (c::BasicConnection)(blox_src::Subsystem{AdamPYR}, blox_dst::Subsystem{AdamGlu})
w = c.weight
(; V) = blox_src
(; jcn = V)
end

# using Accessors
# function (c::AdamConnection)(sys_src::Subsystem{AdamGlu}, sys_dst::Subsystem{AdamNMDAR})
# w = c.weight
# input = initialize_input(sys_dest)
# @set input.Glu = sys_src.Glu
# end

# function (c::AdamConnection)(sys_src::Subsystem{<:AbstractAdamNeuron}, sys_dst::Subsystem{AdamNMDAR})
# w = c.weight
# input = initialize_input(sys_dest)
# @set input.V = sys_src.V
# end
30 changes: 29 additions & 1 deletion src/GraphDynamicsInterop/neuron_interop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,11 @@ for sys ∈ [HHNeuronExciBlox(name=:hhne)
VanDerPol{NonNoisy}(name=:VdP)
VanDerPol{Noisy}(name=:VdPN)
KuramotoOscillator{NonNoisy}(name=:ko)
KuramotoOscillator{Noisy}(name=:kon)]
KuramotoOscillator{Noisy}(name=:kon)
AdamPYR(name=:adam_pyr)
AdamINP(name=:adam_inp)
AdamGlu(name=:adam_glu)
AdamNMDAR(name=:adam_nmdar)]
define_neuron(sys)
end

Expand Down Expand Up @@ -281,3 +285,27 @@ function GraphDynamics.apply_discrete_event!(integrator, _, vparams, s::Subsyste
vparams[] = @set params.jcn_ = jcn
nothing
end

# GraphDynamics.initialize_input(s::Subsystem{AdamNMDAR}) = (; Glu = 0.0, V = 0.0)

# function GraphDynamics.subsystem_differential(sys::Subsystem{AdamNMDAR}, inputs, t)
# # Unpack the system
# (; C, C_A, C_AA, D_AA, O_AA, O_AAB, C_AAB, D_AAB, C_AB, C_B) = sys
# (; k_on, k_off, k_r, k_d, k_unblock, k_block, α, β) = sys
# (; Glu, V) = inputs

# # Differential equations
# SubsystemStates{AdamNMDAR}(
# #=d/dt=# C = k_off*C_A - 2*k_on*Glu*C,
# #=d/dt=# C_A = 2*k_off*C_AA + 2*k_on*Glu*C - (k_on*Glu + k_off)*C_A,
# #=d/dt=# C_AA = k_on*Glu*C_A + α*O_AA + k_r*D_AA - (2*k_off + β + k_d)*C_AA,
# #=d/dt=# D_AA = k_d*C_AA - k_r*D_AA,
# #=d/dt=# O_AA = k_r*D_AA - α*O_AA,
# #=d/dt=# O_AAB = k_unblock*C_AAB - k_block*O_AAB,
# #=d/dt=# C_AAB = k_block*O_AAB - k_unblock*C_AAB,
# #=d/dt=# D_AAB = k_d*C_AAB - k_r*D_AAB,
# #=d/dt=# C_AB = k_off*C_AAB - 2*k_on*Glu*C_AB,
# #=d/dt=# C_B = k_off*C_AB - 2*k_on*Glu*C_B
# )

# end
4 changes: 4 additions & 0 deletions src/Neuroblox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ include("blox/subcortical_blox.jl")
include("blox/stochastic.jl")
include("blox/discrete.jl")
include("blox/ping_neuron_examples.jl")
include("blox/adam_neurons.jl")
include("blox/reinforcement_learning.jl")
include("gui/GUI.jl")
include("blox/connections.jl")
Expand Down Expand Up @@ -260,4 +261,7 @@ export voltage_timeseries, meanfield_timeseries, state_timeseries, get_neurons,
export AdjacencyMatrix, Connector, connection_rule, connection_equations, connection_spike_affects, connection_learning_rules, connection_callbacks
export inputs, outputs, equations, unknowns, parameters, discrete_events
export MetabolicHHNeuron
export AdamPYR, AdamINP
export AdamGlu
export AdamNMDAR
end
133 changes: 133 additions & 0 deletions src/blox/adam_examples.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
using Neuroblox
using OrdinaryDiffEq
using Random, Distributions
using Plots
#import Neuroblox: AbstractNeuronBlox, paramscoping
#using BenchmarkTools

include("adams_graphdynamics_draft.jl")

ḡᵢ = 0.5
ḡₑ = 0.2

NE = 80
NI = 20

exci = [AdamPYR(name=Symbol("PYR$i"), Iₐₚₚ=rand(Normal(1.5, 0.05))) for i in 1:NE]
inhi = [AdamINP(name=Symbol("INP$i"), Iₐₚₚ=rand(Normal(0.1, 0.05))) for i in 1:NI] # bump up to 0.3

g = MetaDiGraph()

for ne ∈ exci
for ni ∈ inhi
make_nmda_edge!(g, ne, ni)
end
end

for ne ∈ exci
for ne ∈ exci
add_edge!(g, ne => ne; weight=1.0)
end
end

for ni ∈ inhi
for ne ∈ exci[1:20]
add_edge!(g, ni => ne; weight=ḡᵢ/NI)
end
end

begin
tspan = (0.0, 1000.0)
@time sys = system_from_graph(g, graphdynamics=true)
@time prob = ODEProblem(sys, [], tspan)
@time sol = solve(prob, Tsit5(), saveat=0.5)
end


### Older tests

## Test network without NMDAR connections
NI = 20
NE = 80

exci = [AdamPYR(name=Symbol("PYR$i"), Iₐₚₚ=rand(Normal(0.25, 0.05))) for i in 1:NE]
inhi = [AdamINP(name=Symbol("INP$i"), Iₐₚₚ=rand(Normal(0.3, 0.05))) for i in 1:NI] # bump up to 0.3

g = MetaDiGraph()

for ne ∈ exci
for ni ∈ inhi
add_edge!(g, ne => ni; weight=ḡₑ/NE)
add_edge!(g, ni => ne; weight=ḡᵢ/NI)
end
end

tspan = (0.0, 500.0)
# begin
# @btime @named sys = system_from_graph(g, graphdynamics=true)
# @btime prob = ODEProblem(sys, [], tspan)
# @btime sol = solve(prob, Tsit5(), saveat=0.5)
# end

@named sys = system_from_graph(g, graphdynamics=true)
prob = ODEProblem(sys, [], tspan)
sol = solve(prob, Tsit5(), saveat=0.5)

plot(sol, idxs=1:5:(NE+NI)*5)


### Testing single neuron connections

exci = AdamPYR(name=:PYR, Iₐₚₚ=0.25)
glur = AdamGlu(name=:Glu, θ=-59.0)
nmda = AdamNMDAR(name=:NMDA)
exci2 = AdamPYR(name=:PYR2, Iₐₚₚ=0.33)

g = MetaDiGraph()
add_edge!(g, exci => glur; weight=1.0)
add_edge!(g, glur => nmda; weight=1.0)
add_edge!(g, exci2 => nmda; weight=1.0)
add_edge!(g, nmda => exci2; weight=1.0)

tspan = (0.0, 500.0)
@named sys = system_from_graph(g, graphdynamics=true)
prob = ODEProblem(sys, [], tspan)
sol = solve(prob, Tsit5(), saveat=0.5)
plot(sol)

### Testing multiple neuron connections
NE = 800
NI = 800

exci = [AdamPYR(name=Symbol("PYR$i"), Iₐₚₚ=rand(Normal(0.25, 0.05))) for i in 1:NE]
inhi = [AdamINP(name=Symbol("INP$i"), Iₐₚₚ=rand(Normal(0.3, 0.05))) for i in 1:NI] # bump up to 0.3
nmdar = [AdamNMDAR(name=Symbol("NMDA$i")) for i in 1:NE]
glu = [AdamGlu(name=Symbol("Glu$i")) for i in 1:NI]

g = MetaDiGraph()

for i in axes(exci, 1)
add_edge!(g, exci[i] => glu[i]; weight=1.0)
add_edge!(g, glu[i] => nmdar[i]; weight=1.0)
add_edge!(g, inhi[i] => nmdar[i]; weight=1.0)
add_edge!(g, nmdar[i] => inhi[i]; weight=1.0)
end

tspan = (0.0, 500.0)
@time @named sys = system_from_graph(g, graphdynamics=true)
@time prob = ODEProblem(sys, [], tspan)
@time sol = solve(prob, Tsit5(), saveat=0.5)

## Testing Glu for threshold setting
## Commented out for now but useful for tuning later so leaving in the file
# exci = AdamPYR(name=:PYR, Iₐₚₚ=0.25)
# glur = AdamGlu(name=:Glu, θ=-59.0)

# g = MetaDiGraph()
# add_edge!(g, exci => glur; weight=1.0)

# tspan = (0.0, 500.0)
# @named sys = system_from_graph(g, graphdynamics=false)
# prob = ODEProblem(sys, [], tspan)
# sol = solve(prob, Tsit5(), saveat=0.5)
# plot(sol, idxs=6)
Loading