Skip to content

Commit

Permalink
feat: data interpolations
Browse files Browse the repository at this point in the history
  • Loading branch information
andgoldschmidt authored and aarontrowbridge committed Sep 9, 2024
1 parent 0ed62fc commit 9d4d46d
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 5 deletions.
4 changes: 2 additions & 2 deletions Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.10.3"
julia_version = "1.10.4"
manifest_format = "2.0"
project_hash = "ace2e1df9d9dbcf666fc13e87dfbf3d8881dd7d1"
project_hash = "1d025dab4c6dd8f7ccf572cc68ecdd4a371d25d2"

[[deps.ADTypes]]
git-tree-sha1 = "99a6f5d0ce1c7c6afdb759daa30226f71c54f6b0"
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Einsum = "b7d42ee7-0b51-5a75-98ca-779d3107e4c0"
ExponentialAction = "e24c0720-ea99-47e8-929e-571b494574d3"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Expand Down
3 changes: 3 additions & 0 deletions src/QuantumCollocation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ include("rollouts.jl")
include("trajectory_initialization.jl")
@reexport using .TrajectoryInitialization

include("trajectory_interpolations.jl")
@reexport using .TrajectoryInterpolation

include("problem_templates/_problem_templates.jl")
@reexport using .ProblemTemplates

Expand Down
119 changes: 119 additions & 0 deletions src/trajectory_interpolations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
module TrajectoryInterpolation

export DataInterpolation

using NamedTrajectories

using Interpolations: Extrapolation, constant_interpolation, linear_interpolation
using TestItemRunner


struct DataInterpolation
times::Vector{Float64}
values::Matrix{Float64}
interpolants::Vector{Extrapolation}
timestep_components::Vector{Int}
values_components::Vector{Int}

function DataInterpolation(
times::AbstractVector{Float64}, values::AbstractMatrix{Float64};
timestep_components::AbstractVector{Int}=Int[], kind::Symbol=:linear
)
comps = setdiff(1:size(values, 1), timestep_components)
if kind == :linear
interpolants = [linear_interpolation(times, values[c, :]) for c in comps]
elseif kind == :constant
interpolants = [constant_interpolation(times, values[c, :]) for c in comps]
else
error("Unknown interpolation kind: $kind")
end
return new(times, values, interpolants, timestep_components, comps)
end

function DataInterpolation(
T::Int, Δt::Real, values::AbstractMatrix{Float64}; kwargs...
)
times = range(0, Δt * (T - 1), step=Δt)
return DataInterpolation(times, values; kwargs...)
end

function DataInterpolation(
traj::NamedTrajectory; timestep_symbol::Symbol=:Δt, kwargs...
)
if timestep_symbol keys(traj.components)
timestep_components = traj.components[timestep_symbol]
else
timestep_components = Int[]
end
return DataInterpolation(
get_times(traj), traj.data; timestep_components=timestep_components, kwargs...
)
end
end

function (traj_int::DataInterpolation)(times::AbstractVector)
values = zeros(eltype(traj_int.values), size(traj_int.values, 1), length(times))
for (c, interp) in zip(traj_int.values_components, traj_int.interpolants)
values[c, :] = interp(times)
end
if !isempty(traj_int.timestep_components)
timesteps = times[2:end] .- times[1:end-1]
# NOTE: Arbitrary choice of the last timestep
values[traj_int.timestep_components, :] = vcat(timesteps, timesteps[end])
end
return values
end

function (traj_int::DataInterpolation)(T::Int)
times = range(traj_int.times[1], traj_int.times[end], length=T)
return traj_int(times)
end

# =========================================================================

@testitem "Trajectory interpolation test" begin
include("../test/test_utils.jl")

# fixed time
traj = named_trajectory_type_1()

interp = DataInterpolation(traj)
new_data = interp(get_times(traj))
@test new_data traj.data

new_data = interp(2 * traj.T)
@test size(new_data) == (size(traj.data, 1), 2 * traj.T)

# free time
free_traj = named_trajectory_type_1(free_time=true)

interp = DataInterpolation(free_traj)
new_free_data = interp(get_times(traj))

# Replace the final timestep with the original value (can't be known a priori)
new_free_data[free_traj.components.Δt, end] = free_traj.data[free_traj.components.Δt, end]
@test new_free_data free_traj.data

new_free_data = interp(2 * traj.T)
@test size(new_free_data) == (size(free_traj.data, 1), 2 * traj.T)
end

@testitem "Component interpolation test" begin
include("../test/test_utils.jl")

traj = named_trajectory_type_1()

# interpolate with times
interp_val1 = DataInterpolation(get_times(traj), traj.a)
@test size(interp_val1(2 * traj.T)) == (size(traj.a, 1), 2 * traj.T)

# interpolate with steps
interp_val2 = DataInterpolation(traj.T, traj.timestep, traj.a)
@test size(interp_val2(3 * traj.T)) == (size(traj.a, 1), 3 * traj.T)

# check if times match
@test interp_val1.times interp_val2.times
end


end
5 changes: 2 additions & 3 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""
utility functions for debugging tests
"""
using NamedTrajectories


"""
dense(vals, structure, shape)
Expand Down

0 comments on commit 9d4d46d

Please sign in to comment.