diff --git a/Manifest.toml b/Manifest.toml index bde5724f..a2abdfcb 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.10.3" +julia_version = "1.10.4" manifest_format = "2.0" -project_hash = "ace2e1df9d9dbcf666fc13e87dfbf3d8881dd7d1" +project_hash = "1d025dab4c6dd8f7ccf572cc68ecdd4a371d25d2" [[deps.ADTypes]] git-tree-sha1 = "99a6f5d0ce1c7c6afdb759daa30226f71c54f6b0" diff --git a/Project.toml b/Project.toml index 1ea93c3a..2b19dd2a 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ Einsum = "b7d42ee7-0b51-5a75-98ca-779d3107e4c0" ExponentialAction = "e24c0720-ea99-47e8-929e-571b494574d3" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a" +Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" diff --git a/src/QuantumCollocation.jl b/src/QuantumCollocation.jl index bf22a5a8..d5ce2604 100644 --- a/src/QuantumCollocation.jl +++ b/src/QuantumCollocation.jl @@ -57,6 +57,9 @@ include("rollouts.jl") include("trajectory_initialization.jl") @reexport using .TrajectoryInitialization +include("trajectory_interpolations.jl") +@reexport using .TrajectoryInterpolation + include("problem_templates/_problem_templates.jl") @reexport using .ProblemTemplates diff --git a/src/trajectory_interpolations.jl b/src/trajectory_interpolations.jl new file mode 100644 index 00000000..365a9b43 --- /dev/null +++ b/src/trajectory_interpolations.jl @@ -0,0 +1,119 @@ +module TrajectoryInterpolation + +export DataInterpolation + +using NamedTrajectories + +using Interpolations: Extrapolation, constant_interpolation, linear_interpolation +using TestItemRunner + + +struct DataInterpolation + times::Vector{Float64} + values::Matrix{Float64} + interpolants::Vector{Extrapolation} + timestep_components::Vector{Int} + values_components::Vector{Int} + + function DataInterpolation( + times::AbstractVector{Float64}, values::AbstractMatrix{Float64}; + timestep_components::AbstractVector{Int}=Int[], kind::Symbol=:linear + ) + comps = setdiff(1:size(values, 1), timestep_components) + if kind == :linear + interpolants = [linear_interpolation(times, values[c, :]) for c in comps] + elseif kind == :constant + interpolants = [constant_interpolation(times, values[c, :]) for c in comps] + else + error("Unknown interpolation kind: $kind") + end + return new(times, values, interpolants, timestep_components, comps) + end + + function DataInterpolation( + T::Int, Δt::Real, values::AbstractMatrix{Float64}; kwargs... + ) + times = range(0, Δt * (T - 1), step=Δt) + return DataInterpolation(times, values; kwargs...) + end + + function DataInterpolation( + traj::NamedTrajectory; timestep_symbol::Symbol=:Δt, kwargs... + ) + if timestep_symbol ∈ keys(traj.components) + timestep_components = traj.components[timestep_symbol] + else + timestep_components = Int[] + end + return DataInterpolation( + get_times(traj), traj.data; timestep_components=timestep_components, kwargs... + ) + end +end + +function (traj_int::DataInterpolation)(times::AbstractVector) + values = zeros(eltype(traj_int.values), size(traj_int.values, 1), length(times)) + for (c, interp) in zip(traj_int.values_components, traj_int.interpolants) + values[c, :] = interp(times) + end + if !isempty(traj_int.timestep_components) + timesteps = times[2:end] .- times[1:end-1] + # NOTE: Arbitrary choice of the last timestep + values[traj_int.timestep_components, :] = vcat(timesteps, timesteps[end]) + end + return values +end + +function (traj_int::DataInterpolation)(T::Int) + times = range(traj_int.times[1], traj_int.times[end], length=T) + return traj_int(times) +end + +# ========================================================================= + +@testitem "Trajectory interpolation test" begin + include("../test/test_utils.jl") + + # fixed time + traj = named_trajectory_type_1() + + interp = DataInterpolation(traj) + new_data = interp(get_times(traj)) + @test new_data ≈ traj.data + + new_data = interp(2 * traj.T) + @test size(new_data) == (size(traj.data, 1), 2 * traj.T) + + # free time + free_traj = named_trajectory_type_1(free_time=true) + + interp = DataInterpolation(free_traj) + new_free_data = interp(get_times(traj)) + + # Replace the final timestep with the original value (can't be known a priori) + new_free_data[free_traj.components.Δt, end] = free_traj.data[free_traj.components.Δt, end] + @test new_free_data ≈ free_traj.data + + new_free_data = interp(2 * traj.T) + @test size(new_free_data) == (size(free_traj.data, 1), 2 * traj.T) +end + +@testitem "Component interpolation test" begin + include("../test/test_utils.jl") + + traj = named_trajectory_type_1() + + # interpolate with times + interp_val1 = DataInterpolation(get_times(traj), traj.a) + @test size(interp_val1(2 * traj.T)) == (size(traj.a, 1), 2 * traj.T) + + # interpolate with steps + interp_val2 = DataInterpolation(traj.T, traj.timestep, traj.a) + @test size(interp_val2(3 * traj.T)) == (size(traj.a, 1), 3 * traj.T) + + # check if times match + @test interp_val1.times ≈ interp_val2.times +end + + +end \ No newline at end of file diff --git a/test/test_utils.jl b/test/test_utils.jl index c70c0af7..ce5d309d 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -1,6 +1,5 @@ -""" - utility functions for debugging tests -""" +using NamedTrajectories + """ dense(vals, structure, shape)