Skip to content

Commit

Permalink
Solved issue with new MTK parameter structure, removed split=false & …
Browse files Browse the repository at this point in the history
…some cleanup. (#546)
  • Loading branch information
david-hofmann authored Feb 19, 2025
1 parent 11582e6 commit 93802d8
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 17 deletions.
2 changes: 1 addition & 1 deletion docs/src/tutorials/spectralDCM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ end
# For instance the the effective connections that are set to zero in the simulation:
untune = Dict(A[3] => false, A[7] => false)
fitmodel = changetune(fitmodel, untune) # 3 and 7 are not present in the simulation model
fitmodel = structural_simplify(fitmodel, split=false) # and now simplify the euqations; the `split` parameter is necessary for some ModelingToolkit peculiarities and will soon be removed. So don't lose time with it ;)
fitmodel = structural_simplify(fitmodel) # and now simplify the euqations; the `split` parameter is necessary for some ModelingToolkit peculiarities and will soon be removed. So don't lose time with it ;)

# ## Setup spectral DCM
max_iter = 128; # maximum number of iterations
Expand Down
15 changes: 2 additions & 13 deletions src/datafitting/spDCM_VL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,6 @@ function LinearAlgebra.eigen(M::Matrix{Dual{T, P, np}}) where {T, P, np}
end

function transferfunction(freq, derivatives, params, indices)
# nr = length(indices[:u])
# pars = params[indices[:dspars]]
# ∂f = derivatives([pars[1:nr^2], pars[nr^2+1:end]...])
∂f = derivatives(params[indices[:dspars]])
∂f∂x = ∂f[indices[:sts], indices[:sts]]
∂f∂u = ∂f[indices[:sts], indices[:u]]
Expand All @@ -119,7 +116,7 @@ function transferfunction(freq, derivatives, params, indices)
∂g∂v = ∂g∂x*V
∂v∂u = V\∂f∂u # u is external variable which we don't use right now. With external variable this would read V/dfdu

nfreq = size(freq, 1) # number of frequencies
nfreq = size(freq, 1) # number of frequencies
ng = size(∂g∂x, 1) # number of outputs
nu = size(∂v∂u, 2) # number of inputs
nk = size(V, 2) # number of modes
Expand Down Expand Up @@ -412,7 +409,7 @@ function setup_sDCM(data, model, initcond, csdsetup, priors, hyperpriors, indice
statevals = [v for v in values(initcond)]
append!(statevals, zeros(length(unknowns(model)) - length(statevals)))
f_model = generate_function(model; expression=Val{false})[1]
f_at(params, t) = states -> f_model(states, params, t)
f_at(params, t) = states -> f_model(states, MTKParameters(model, params)..., t)
derivatives = par -> jacobian(f_at(addnontunableparams(par, model), t), statevals)

μθ_pr = vecparam(priors.μθ_pr) # note: μθ_po is posterior and μθ_pr is prior
Expand Down Expand Up @@ -462,15 +459,7 @@ function setup_sDCM(data, model, initcond, csdsetup, priors, hyperpriors, indice
return (vlstate, vlsetup)
end

with_stack(f, n) = fetch(schedule(Task(f, n)));

function run_sDCM_iteration!(state::VLState, setup::VLSetup)
with_stack(5_000_000) do
_run_sDCM_iteration!(state, setup)
end
end

function _run_sDCM_iteration!(state::VLState, setup::VLSetup)
(;μθ_po, λ, v, ϵ_θ, dFdθ, dFdθθ) = state

f = setup.model_at_x0
Expand Down
5 changes: 2 additions & 3 deletions test/datafitting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ using MAT
untunelist[A[i]] = v == 0 ? false : true
end
neuronmodel = changetune(neuronmodel, untunelist)
neuronmodel = structural_simplify(neuronmodel, split=false)
neuronmodel = structural_simplify(neuronmodel)

# attribute initial conditions to states
_, obsvars = get_eqidx_tagged_vars(neuronmodel, "measurement") # get index of equation of bold state
Expand Down Expand Up @@ -93,7 +93,6 @@ using MAT
end
end
end
print("maxixmum iterations reached\n")

### COMPARE RESULTS WITH MATLAB RESULTS ###
@show state.F[end], vars["F"]
Expand Down Expand Up @@ -160,7 +159,7 @@ end
end
end

@named fullmodel = system_from_graph(g; split=false)
@named fullmodel = system_from_graph(g)

# attribute initial conditions to states
sts, idx_sts = get_dynamic_states(fullmodel)
Expand Down

0 comments on commit 93802d8

Please sign in to comment.