From 1c4d641f7734d7b34162d2a143b2297b45db2ae4 Mon Sep 17 00:00:00 2001 From: Andy Goldschmidt Date: Mon, 11 Nov 2024 20:27:17 -0600 Subject: [PATCH] test unitary fid callback --- src/callbacks.jl | 32 ++++++++++++++++++++++++++++++-- test/test_utils.jl | 17 +++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/src/callbacks.jl b/src/callbacks.jl index 4e13a0ca..b5e83224 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -34,7 +34,7 @@ function best_rollout_fidelity_callback(prob::QuantumControlProblem) end function best_unitary_rollout_fidelity_callback(prob::QuantumControlProblem) - return best_rollout_callback(prob, fidelity_unitary) + return best_rollout_callback(prob, unitary_fidelity) end function trajectory_history_callback(prob::QuantumControlProblem) @@ -82,7 +82,7 @@ end @test length(trajectory_history) == 21 end -@testitem "Callback can get best trajectory" begin +@testitem "Callback can get best state trajectory" begin using MathOptInterface using NamedTrajectories const MOI = MathOptInterface @@ -110,6 +110,34 @@ end @test after ≤ best end +@testitem "Callback can get best unitary trajectory" begin + using MathOptInterface + using NamedTrajectories + const MOI = MathOptInterface + include("../test/test_utils.jl") + + prob, system = smooth_unitary_problem(return_system=true) + + callback, best_trajs = best_unitary_rollout_fidelity_callback(prob) + @test length(best_trajs) == 0 + + # measure fidelity + before = unitary_fidelity(prob) + solve!(prob, max_iter=20, callback=callback) + + # length must increase if iterations are made + @test length(best_trajs) > 0 + @test best_trajs[end] isa NamedTrajectory + + # fidelity ranking + after = unitary_fidelity(prob) + best = unitary_fidelity(best_trajs[end], system) + + @test before < after + @test before < best + @test after ≤ best +end + @testitem "Callback with full parameter test" begin using MathOptInterface using NamedTrajectories diff --git a/test/test_utils.jl b/test/test_utils.jl index 3e7bc411..f4c3b5e9 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -134,3 +134,20 @@ function smooth_quantum_state_problem(; return_system::Bool=false) return prob end end + +function smooth_unitary_problem(; return_system::Bool=false) + T = 50 + Δt = 0.2 + sys = QuantumSystem(0.1 * PAULIS[:Z], [PAULIS[:X], PAULIS[:Y]]) + U_goal = GATES[:H] + prob = UnitarySmoothPulseProblem( + sys, U_goal, T, Δt; + ipopt_options=IpoptOptions(print_level=1), + piccolo_options=PiccoloOptions(verbose=false) + ) + if return_system + return prob, sys + else + return prob + end +end