Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference improvements for timeevolution #396

Merged
merged 4 commits into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/bloch_redfield_master.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ end

# Integrate with given fout
function integrate_br(tspan, dmaster_br, rho,
transf_op, inv_transf_op, fout::Function;
kwargs...)
transf_op, inv_transf_op, fout::F;
kwargs...) where {F}
# Pre-allocate for in-place back-transformation from eigenbasis
rho_out = copy(transf_op)
tmp = copy(transf_op)
Expand Down
8 changes: 5 additions & 3 deletions src/master.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ function master_dynamic(tspan, rho0::Operator, f;
fout=nothing,
kwargs...)
tmp = copy(rho0)
dmaster_(t, rho, drho) = dmaster_h_dynamic!(drho, f, rates, rho, tmp, t)
dmaster_ = let f = f, tmp = tmp
dmaster_(t, rho, drho) = dmaster_h_dynamic!(drho, f, rates, rho, tmp, t)
end
integrate_master(tspan, dmaster_, rho0, fout; kwargs...)
end

Expand Down Expand Up @@ -395,7 +397,7 @@ returned from `f`.
See also: [`master_dynamic`](@ref), [`dmaster_h!`](@ref), [`dmaster_nh!`](@ref),
[`dmaster_nh_dynamic!`](@ref)
"""
function dmaster_h_dynamic!(drho, f, rates, rho, drho_cache, t)
function dmaster_h_dynamic!(drho, f::F, rates, rho, drho_cache, t) where {F}
result = f(t, rho)
QO_CHECKS[] && @assert 3 <= length(result) <= 4
if length(result) == 3
Expand All @@ -418,7 +420,7 @@ equation. Optionally, rates can also be returned from `f`.
See also: [`master_dynamic`](@ref), [`dmaster_h!`](@ref), [`dmaster_nh!`](@ref),
[`dmaster_h_dynamic!`](@ref)
"""
function dmaster_nh_dynamic!(drho, f, rates, rho, drho_cache, t)
function dmaster_nh_dynamic!(drho, f::F, rates, rho, drho_cache, t) where {F}
result = f(t, rho)
QO_CHECKS[] && @assert 4 <= length(result) <= 5
if length(result) == 4
Expand Down
156 changes: 103 additions & 53 deletions src/mcwf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@ function mcwf_h(tspan, psi0::Ket, H::AbstractOperator, J;
_check_const.(J)
_check_const.(Jdagger)
check_mcwf(psi0, H, J, Jdagger, rates)
f(t, psi, dpsi) = dmcwf_h!(dpsi, H, J, Jdagger, rates, psi, tmp)
j(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, rates)
f = let H = H, J = J, Jdagger = Jdagger, rates = rates, tmp = tmp
f(t, psi, dpsi) = dmcwf_h!(dpsi, H, J, Jdagger, rates, psi, tmp)
end
probs = zeros(real(eltype(psi0)), length(J))
j = let J = J, probs = probs, rates = rates
j(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, probs, rates)
end
integrate_mcwf(f, j, tspan, psi0, seed, fout;
display_beforeevent=display_beforeevent,
display_afterevent=display_afterevent,
Expand All @@ -48,8 +53,13 @@ function mcwf_nh(tspan, psi0::Ket, Hnh::AbstractOperator, J;
_check_const(Hnh)
_check_const.(J)
check_mcwf(psi0, Hnh, J, J, nothing)
f(t, psi, dpsi) = dschroedinger!(dpsi, Hnh, psi)
j(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, nothing)
f = let Hnh = Hnh
f(t, psi, dpsi) = dschroedinger!(dpsi, Hnh, psi)
end
probs = zeros(real(eltype(psi0)), length(J))
j = let J = J, probs = probs
j(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, probs, nothing)
end
integrate_mcwf(f, j, tspan, psi0, seed, fout;
display_beforeevent=display_beforeevent,
display_afterevent=display_afterevent,
Expand Down Expand Up @@ -107,8 +117,13 @@ function mcwf(tspan, psi0::Ket, H::AbstractOperator, J;
isreducible = check_mcwf(psi0, H, J, Jdagger, rates)
if !isreducible
tmp = copy(psi0)
dmcwf_h_(t, psi, dpsi) = dmcwf_h!(dpsi, H, J, Jdagger, rates, psi, tmp)
j_h(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, rates)
dmcwf_h_ = let H = H, J = J, Jdagger = Jdagger, rates = rates, tmp = tmp
dmcwf_h_(t, psi, dpsi) = dmcwf_h!(dpsi, H, J, Jdagger, rates, psi, tmp)
end
probs = zeros(real(eltype(psi0)), length(J))
j_h = let J = J, probs = probs, rates = rates
j_h(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, probs, rates)
end
integrate_mcwf(dmcwf_h_, j_h, tspan, psi0, seed,
fout;
display_beforeevent=display_beforeevent,
Expand All @@ -125,8 +140,13 @@ function mcwf(tspan, psi0::Ket, H::AbstractOperator, J;
Hnh -= complex(float(eltype(H)))(0.5im*rates[i])*Jdagger[i]*J[i]
end
end
dmcwf_nh_(t, psi, dpsi) = dschroedinger!(dpsi, Hnh, psi)
j_nh(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, rates)
dmcwf_nh_ = let Hnh = Hnh # Hnh type often not inferrable
dmcwf_nh_(t, psi, dpsi) = dschroedinger!(dpsi, Hnh, psi)
end
probs = zeros(real(eltype(psi0)), length(J))
j_nh = let J = J, probs = probs, rates = rates
j_nh(rng, t, psi, psi_new) = jump(rng, t, psi, J, psi_new, probs, rates)
end
integrate_mcwf(dmcwf_nh_, j_nh, tspan, psi0, seed,
fout;
display_beforeevent=display_beforeevent,
Expand Down Expand Up @@ -177,8 +197,14 @@ function mcwf_dynamic(tspan, psi0::Ket, f;
fout=nothing, display_beforeevent=false, display_afterevent=false,
kwargs...)
tmp = copy(psi0)
dmcwf_(t, psi, dpsi) = dmcwf_h_dynamic!(dpsi, f, rates, psi, tmp, t)
j_(rng, t, psi, psi_new) = jump_dynamic(rng, t, psi, f, psi_new, rates)
dmcwf_ = let f = f, tmp = tmp, rates = rates
dmcwf_(t, psi, dpsi) = dmcwf_h_dynamic!(dpsi, f, rates, psi, tmp, t)
end
J = f(first(tspan), psi0)[2]
probs = zeros(real(eltype(psi0)), length(J))
j_ = let f = f, probs = probs, rates = rates
j_(rng, t, psi, psi_new) = jump_dynamic(rng, t, psi, f, psi_new, probs, rates)
end
integrate_mcwf(dmcwf_, j_, tspan, psi0, seed,
fout;
display_beforeevent=display_beforeevent,
Expand All @@ -203,8 +229,14 @@ function mcwf_nh_dynamic(tspan, psi0::Ket, f;
seed=rand(UInt), rates=nothing,
fout=nothing, display_beforeevent=false, display_afterevent=false,
kwargs...)
dmcwf_(t, psi, dpsi) = dmcwf_nh_dynamic!(dpsi, f, psi, t)
j_(rng, t, psi, psi_new) = jump_dynamic(rng, t, psi, f, psi_new, rates)
dmcwf_ = let f = f
dmcwf_(t, psi, dpsi) = dmcwf_nh_dynamic!(dpsi, f, psi, t)
end
J = f(first(tspan), psi0)[2]
probs = zeros(real(eltype(psi0)), length(J))
j_ = let f = f, probs = probs, rates = rates
j_(rng, t, psi, psi_new) = jump_dynamic(rng, t, psi, f, psi_new, probs, rates)
end
integrate_mcwf(dmcwf_, j_, tspan, psi0, seed,
fout;
display_beforeevent=display_beforeevent,
Expand All @@ -225,7 +257,7 @@ update `dpsi` according to a non-Hermitian Schrödinger equation.

See also: [`mcwf_dynamic`](@ref), [`dmcwf_h!`](@ref), [`dmcwf_nh_dynamic`](@ref)
"""
function dmcwf_h_dynamic!(dpsi, f, rates, psi, dpsi_cache, t)
function dmcwf_h_dynamic!(dpsi, f::F, rates, psi, dpsi_cache, t) where {F}
result = f(t, psi)
QO_CHECKS[] && @assert 3 <= length(result) <= 4
if length(result) == 3
Expand All @@ -246,15 +278,15 @@ and update `dpsi` according to a Schrödinger equation.

See also: [`mcwf_nh_dynamic`](@ref), [`dmcwf_nh!`](@ref), [`dschroedinger!`](@ref)
"""
function dmcwf_nh_dynamic!(dpsi, f, psi, t)
function dmcwf_nh_dynamic!(dpsi, f::F, psi, t) where {F}
result = f(t, psi)
QO_CHECKS[] && @assert 3 <= length(result) <= 4
H, J, Jdagger = result[1:3]
QO_CHECKS[] && check_mcwf(psi, H, J, Jdagger, nothing)
dschroedinger!(dpsi, H, psi)
end

function jump_dynamic(rng, t, psi, f, psi_new, rates)
function jump_dynamic(rng, t, psi, f::F, psi_new, probs_tmp, rates) where {F}
result = f(t, psi)
QO_CHECKS[] && @assert 3 <= length(result) <= 4
J = result[2]
Expand All @@ -263,7 +295,7 @@ function jump_dynamic(rng, t, psi, f, psi_new, rates)
else
rates_ = result[4]
end
jump(rng, t, psi, J, psi_new, rates_)
jump(rng, t, psi, J, psi_new, probs_tmp, rates_)
end

"""
Expand All @@ -289,15 +321,15 @@ Integrate a single Monte Carlo wave function trajectory.
an initial jump threshold. If provided, `seed` is ignored.
* `kwargs`: Further arguments are passed on to the ode solver.
"""
function integrate_mcwf(dmcwf, jumpfun, tspan,
psi0, seed, fout::Function;
function integrate_mcwf(dmcwf::T, jumpfun::J, tspan,
psi0, seed, fout;
display_beforeevent=false, display_afterevent=false,
display_jumps=false,
rng_state=nothing,
save_everystep=false, callback=nothing,
saveat=tspan,
alg=OrdinaryDiffEq.DP5(),
kwargs...)
kwargs...) where {T, J}

tspan_ = convert(Vector{float(eltype(tspan))}, tspan)
# Display before or after events
Expand All @@ -308,29 +340,33 @@ function integrate_mcwf(dmcwf, jumpfun, tspan,
affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
return nothing
end
save_before! = display_beforeevent ? save_func! : (affect!,integrator)->nothing
save_after! = display_afterevent ? save_func! : (affect!,integrator)->nothing
no_save_func!(affect!,integrator) = nothing
save_before! = display_beforeevent ? save_func! : no_save_func!
save_after! = display_afterevent ? save_func! : no_save_func!

# Display jump operator index and times
jump_t = eltype(tspan_)[]
jump_index = Int[]
save_t_index = if display_jumps
function(t,i)
push!(jump_t,t)
push!(jump_index,i)
return nothing
end
else
(t,i)->nothing
end

function fout_(x, t, integrator)
recast!(state,x)
fout(t, state)
function jump_saver(t, i)
push!(jump_t,t)
push!(jump_index,i)
return nothing
end
no_jump_saver(t, i) = nothing

save_t_index = display_jumps ? jump_saver : no_jump_saver

state = copy(psi0)
dstate = copy(psi0)

fout_ = let state = state, fout = fout
function fout_(x, t, integrator)
recast!(state,x)
fout(t, state)
end
end

out_type = pure_inference(fout, Tuple{eltype(tspan_),typeof(state)})
out = DiffEqCallbacks.SavedValues(eltype(tspan_),out_type)
scb = DiffEqCallbacks.SavingCallback(fout_,out,saveat=tspan_,
Expand All @@ -340,11 +376,14 @@ function integrate_mcwf(dmcwf, jumpfun, tspan,
cb = jump_callback(jumpfun, seed, scb, save_before!, save_after!, save_t_index, psi0, rng_state)
full_cb = OrdinaryDiffEq.CallbackSet(callback,cb,scb)

function df_(dx, x, p, t)
recast!(state,x)
recast!(dstate,dx)
dmcwf(t, state, dstate)
recast!(dx,dstate)
df_ = let state = state, dstate = dstate # help inference along
function df_(dx, x, p, t)
recast!(state,x)
recast!(dstate,dx)
dmcwf(t, state, dstate)
recast!(dx,dstate)
return nothing
end
end

prob = OrdinaryDiffEq.ODEProblem{true}(df_, as_vector(psi0), (tspan_[1],tspan_[end]))
Expand Down Expand Up @@ -396,8 +435,8 @@ end
roll!(s::JumpRNGState{T}) where T = (s.threshold = rand(s.rng, T))
threshold(s::JumpRNGState) = s.threshold

function jump_callback(jumpfun, seed, scb, save_before!,
save_after!, save_t_index, psi0, rng_state::JumpRNGState)
function jump_callback(jumpfun::F, seed, scb, save_before!::G,
save_after!::H, save_t_index::I, psi0, rng_state::JumpRNGState) where {F,G,H,I}

tmp = copy(psi0)
psi_tmp = copy(psi0)
Expand Down Expand Up @@ -431,7 +470,7 @@ jump_callback(jumpfun, seed, scb, save_before!,
as_vector(psi::StateVector) = psi.data

"""
jump(rng, t, psi, J, psi_new)
jump(rng, t, psi, J, psi_new, probs_tmp)

Default jump function.

Expand All @@ -441,41 +480,52 @@ Default jump function.
* `psi`: State vector before the jump.
* `J`: List of jump operators.
* `psi_new`: Result of jump.
* `probs_tmp`: Temporary array for holding jump probailities.
"""
function jump(rng, t, psi, J, psi_new, rates::Nothing)
function jump(rng, t, psi, J, psi_new, probs_tmp, rates::Nothing)
if length(J)==1
QuantumOpticsBase.mul!(psi_new,J[1],psi,true,false)
psi_new.data ./= norm(psi_new)
i=1
else
probs = zeros(real(eltype(psi)), length(J))
for i=1:length(J)
QuantumOpticsBase.mul!(psi_new,J[i],psi,true,false)
probs[i] = real(dot(psi_new.data, psi_new.data))
probs_tmp[i] = real(dot(psi_new.data, psi_new.data))
end
cumprobs = cumsum(probs./sum(probs))
r = rand(rng)
i = findfirst(cumprobs.>r)
QuantumOpticsBase.mul!(psi_new,J[i],psi,one(eltype(psi))/sqrt(probs[i]),zero(eltype(psi)))
total = sum(probs_tmp)
cumulative_prob = 0.0
i = 0
for p in probs_tmp
i += 1
cumulative_prob += p / total
cumulative_prob > r && break
end
QuantumOpticsBase.mul!(psi_new,J[i],psi,eltype(psi)(1/sqrt(probs_tmp[i])),zero(eltype(psi)))
end
return i
end

function jump(rng, t, psi, J, psi_new, rates::AbstractVector)
function jump(rng, t, psi, J, psi_new, probs_tmp, rates::AbstractVector)
if length(J)==1
QuantumOpticsBase.mul!(psi_new,J[1],psi,eltype(psi)(sqrt(rates[1])),zero(eltype(psi)))
psi_new.data ./= norm(psi_new)
i=1
else
probs = zeros(real(eltype(psi)), length(J))
for i=1:length(J)
QuantumOpticsBase.mul!(psi_new,J[i],psi,eltype(psi)(sqrt(rates[i])),zero(eltype(psi)))
probs[i] = real(dot(psi_new.data, psi_new.data))
probs_tmp[i] = real(dot(psi_new.data, psi_new.data))
end
cumprobs = cumsum(probs./sum(probs))
r = rand(rng)
i = findfirst(cumprobs.>r)
QuantumOpticsBase.mul!(psi_new,J[i],psi,eltype(psi)(sqrt(rates[i]/probs[i])),zero(eltype(psi)))
total = sum(probs_tmp)
cumulative_prob = 0.0
i = 0
for p in probs_tmp
i += 1
cumulative_prob += p / total
cumulative_prob > r && break
end
QuantumOpticsBase.mul!(psi_new,J[i],psi,eltype(psi)(sqrt(rates[i]/probs_tmp[i])),zero(eltype(psi)))
end
return i
end
Expand Down
10 changes: 6 additions & 4 deletions src/schroedinger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Integrate Schroedinger equation to evolve states or compute propagators.
therefore must not be changed.
"""
function schroedinger(tspan, psi0::T, H::AbstractOperator{B,B};
fout::Union{Function,Nothing}=nothing,
fout=nothing,
kwargs...) where {B,Bo,T<:Union{AbstractOperator{B,Bo},StateVector{B}}}
_check_const(H)
dschroedinger_(t, psi, dpsi) = dschroedinger!(dpsi, H, psi)
Expand Down Expand Up @@ -44,9 +44,11 @@ Integrate time-dependent Schroedinger equation to evolve states or compute propa
Instead of a function `f`, this takes a time-dependent operator `H`.
"""
function schroedinger_dynamic(tspan, psi0, f;
fout::Union{Function,Nothing}=nothing,
fout=nothing,
kwargs...)
dschroedinger_(t, psi, dpsi) = dschroedinger_dynamic!(dpsi, f, psi, t)
dschroedinger_ = let f = f
dschroedinger_(t, psi, dpsi) = dschroedinger_dynamic!(dpsi, f, psi, t)
end
tspan, psi0 = _promote_time_and_state(psi0, f, tspan) # promote only if ForwardDiff.Dual
x0 = psi0.data
state = copy(psi0)
Expand Down Expand Up @@ -105,7 +107,7 @@ Schrödinger equation as `-im*H*psi`.

See also: [`dschroedinger!`](@ref)
"""
function dschroedinger_dynamic!(dpsi, f, psi, t)
function dschroedinger_dynamic!(dpsi, f::F, psi, t) where {F}
H = f(t, psi)
dschroedinger!(dpsi, H, psi)
end
Expand Down
Loading
Loading