Skip to content

Commit

Permalink
bug fix: tests match initialize signature change
Browse files Browse the repository at this point in the history
  • Loading branch information
andgoldschmidt authored and jack-champagne committed Jul 3, 2024
1 parent b20f026 commit 3abfeba
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
8 changes: 7 additions & 1 deletion src/problem_templates/unitary_direct_sum_problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ function UnitaryDirectSumProblem(
@assert N 2 "At least two problems are required"
@assert 0 drive_reset_ratio 1 "drive_reset_ratio must be in [0, 1]"
@assert isempty(intersect(keys(boundary_values), prob_labels)) "Boundary value keys cannot be in prob_labels"
@assert all([:dda p.trajectory.names for p in probs]) "Only smooth pulse problems are supported."
n_derivatives = 2

# Default chain graph and boundary
boundary = Tuple{Symbol, Array}[]
Expand Down Expand Up @@ -102,7 +104,11 @@ function UnitaryDirectSumProblem(
forin prob_labels
a_symb, da_symb, dda_symb = add_suffix(:a, ℓ), add_suffix(:da, ℓ), add_suffix(:dda, ℓ)
a, da, dda = TrajectoryInitialization.initialize_controls(
length(traj.components[a_symb]), traj.T, traj.bounds[a_symb], drive_derivative_σ
length(traj.components[a_symb]),
n_derivatives,
traj.T,
traj.bounds[a_symb],
drive_derivative_σ
)
update!(traj, a_symb, (1 - drive_reset_ratio) * traj[a_symb] + drive_reset_ratio * a)
update!(traj, da_symb, (1 - drive_reset_ratio) * traj[da_symb] + drive_reset_ratio * da)
Expand Down
3 changes: 2 additions & 1 deletion test/trajectory_initialization_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
@testitem "Random drive initialization" begin
T = 10
n_drives = 2
n_derivates = 2
drive_bounds = [1.0, 2.0]
drive_derivative_σ = 0.01

a, da, dda = TrajectoryInitialization.initialize_controls(n_drives, T, drive_bounds, drive_derivative_σ)
a, da, dda = TrajectoryInitialization.initialize_controls(n_drives, n_derivates, T, drive_bounds, drive_derivative_σ)

@test size(a) == (n_drives, T)
@test size(da) == (n_drives, T)
Expand Down

0 comments on commit 3abfeba

Please sign in to comment.