Skip to content

Commit

Permalink
Merge pull request #47 from kestrelquantum/feat_global-data
Browse files Browse the repository at this point in the history
Feat global data
  • Loading branch information
aarontrowbridge authored Aug 14, 2024
2 parents 68f193c + ced6184 commit fea058c
Show file tree
Hide file tree
Showing 3 changed files with 306 additions and 173 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NamedTrajectories"
uuid = "538bc3a1-5ab9-4fc3-b776-35ca1e893e08"
authors = ["Aaron Trowbridge <aaron.j.trowbridge@gmail.com> and contributors"]
version = "0.1.8"
version = "0.1.9"

[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Expand Down
235 changes: 146 additions & 89 deletions src/methods_named_trajectory.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module MethodsNamedTrajectory

export vec
export get_components
export get_component_names
export add_component!
export remove_component
export remove_components
Expand All @@ -15,6 +17,9 @@ using OrderedCollections
using ..StructNamedTrajectory
using ..StructKnotPoint

# -------------------------------------------------------------- #
# Base indexing
# -------------------------------------------------------------- #

function StructKnotPoint.KnotPoint(
Z::NamedTrajectory,
Expand All @@ -25,9 +30,93 @@ function StructKnotPoint.KnotPoint(
return KnotPoint(t, Z.data[:, t], timestep, Z.components, Z.names, Z.control_names)
end

"""
getindex(traj, t::Int)::KnotPoint
Returns the knot point at time `t`.
"""
Base.getindex(traj::NamedTrajectory, t::Int) = KnotPoint(traj, t)

"""
getindex(traj, ts::AbstractVector{Int})::Vector{KnotPoint}
Returns the knot points at times `ts`.
"""
function Base.getindex(traj::NamedTrajectory, ts::AbstractVector{Int})::Vector{KnotPoint}
return [traj[t] for t ts]
end

"""
lastindex(traj::NamedTrajectory)
Returns the final time index of the trajectory.
"""
Base.lastindex(traj::NamedTrajectory) = traj.T

"""
getindex(traj, symb::Symbol)
Dispatches indexing of trajectories as either accessing a component or a property via `getproperty`.
"""
Base.getindex(traj::NamedTrajectory, symb::Symbol) = getproperty(traj, symb)

"""
getproperty(traj, symb::Symbol)
Returns the component of the trajectory with name `symb` or the property of the trajectory with name `symb`.
"""
function Base.getproperty(traj::NamedTrajectory, symb::Symbol)
if symb fieldnames(NamedTrajectory)
return getfield(traj, symb)
else
indices = traj.components[symb]
return traj.data[indices, :]
end
end

"""
setproperty!(traj, name::Symbol, val::Any)
Dispatches setting properties of trajectories as either setting a component or a property via `setfield!` or `update!`.
"""
function Base.setproperty!(traj::NamedTrajectory, symb::Symbol, val::Any)
if symb fieldnames(NamedTrajectory)
setfield!(traj, symb, val)
else
update!(traj, symb, val)
end
end

# -------------------------------------------------------------- #
# Base operations
# -------------------------------------------------------------- #

"""
vec(::NamedTrajectory)
Returns all variables of the trajectory as a vector, Z⃗.
"""
function Base.vec(Z::NamedTrajectory)
return vcat(Z.datavec, values(Z.global_data)...)
end

"""
length(::NamedTrajectory)
Returns the length of all variables of the trajectory, including global data.
"""
function Base.length(Z::NamedTrajectory)
return Z.dim * Z.T + Z.global_dim
end

"""
size(traj::NamedTrajectory) = (dim = traj.dim, T = traj.T)
Returns the size of the trajectory (dim, T), excluding global data.
TODO: Should global data be in size?
"""
Base.size(traj::NamedTrajectory) = (dim = traj.dim, T = traj.T)

"""
copy(::NamedTrajectory)
Expand All @@ -52,7 +141,6 @@ function Base.isequal(traj1::NamedTrajectory, traj2::NamedTrajectory)
end
end


"""
:(==)(traj1::NamedTrajectory, traj2::NamedTrajectory)
Expand All @@ -67,8 +155,31 @@ function Base.:(==)(traj1::NamedTrajectory, traj2::NamedTrajectory)
end
end

function Base.:*::Float64, traj::NamedTrajectory)
return NamedTrajectory* traj.datavec, traj)
end

function Base.:*(traj::NamedTrajectory, α::Float64)
return NamedTrajectory* traj.datavec, traj)
end

function Base.:+(traj1::NamedTrajectory, traj2::NamedTrajectory)
@assert traj1.names == traj2.names
@assert traj1.dim == traj2.dim
@assert traj1.T == traj2.T
return NamedTrajectory(traj1.datavec + traj2.datavec, traj1)
end

function Base.:-(traj1::NamedTrajectory, traj2::NamedTrajectory)
@assert traj1.names == traj2.names
@assert traj1.dim == traj2.dim
@assert traj1.T == traj2.T
return NamedTrajectory(traj1.datavec - traj2.datavec, traj1)
end

# -------------------------------------------------------------- #
# Methods
# -------------------------------------------------------------- #

"""
get_components(::NamedTrajectory)
Expand All @@ -83,11 +194,36 @@ end

get_components(traj::NamedTrajectory) = get_components(traj.names, traj)

function filter_by_value(f::Function, nt::NamedTuple)
return (; (k => v for (k, v) in pairs(nt) if f(v))...)
end

"""
get_component_names(traj::NamedTrajectory, comps::AbstractVector{<:Int})
Returns the name of the component with the given indices. If only one component is found,
the name is returned as a single symbol. Else, the names are returned as a vector of symbols.
The filter requires that the components are a complete subset of the given indices, so that
a partial match is excluded from the returned names.
"""
function get_component_names(traj::NamedTrajectory, comps::AbstractVector{<:Int})
name = [n for n keys(filter_by_value(x -> issubset(x, comps), traj.components)) if n traj.names]
if isempty(name)
error("Component names not found in traj")
elseif length(name) == 1
return name[1]
else
return name
end
end

"""
add_component!(traj, name::Symbol, data::AbstractVecOrMat; type={:state, :control})
Add a component to the trajectory.
NOTE: This function resizes the trajectory, so global components and components must be adjusted.
"""
function add_component!(
traj::NamedTrajectory,
Expand Down Expand Up @@ -166,6 +302,15 @@ function add_component!(

traj.datavec = vec(view(traj.data, :, :))

# update global data

global_comps_pairs::Vector{Pair{Symbol, AbstractVector{Int}}} = []
for (k, v) pairs(traj.global_components)
# increase offset for new components
push!(global_comps_pairs, k => v .+ dim * traj.T)
end
traj.global_components = NamedTuple(global_comps_pairs)

return nothing
end

Expand Down Expand Up @@ -278,8 +423,6 @@ function update_bound!(traj::NamedTrajectory, name::Symbol, new_bound::BoundType
return nothing
end



"""
get_times(traj)::Vector{Float64}
Expand Down Expand Up @@ -315,91 +458,5 @@ function get_duration(traj::NamedTrajectory)
return get_times(traj)[end]
end

"""
size(traj::NamedTrajectory) = (dim = traj.dim, T = traj.T)
"""
Base.size(traj::NamedTrajectory) = (dim = traj.dim, T = traj.T)

"""
getindex(traj, t::Int)::KnotPoint
Returns the knot point at time `t`.
"""
Base.getindex(traj::NamedTrajectory, t::Int) = KnotPoint(traj, t)

"""
getindex(traj, ts::AbstractVector{Int})::Vector{KnotPoint}
Returns the knot points at times `ts`.
"""
function Base.getindex(traj::NamedTrajectory, ts::AbstractVector{Int})::Vector{KnotPoint}
return [traj[t] for t ts]
end

"""
lastindex(traj::NamedTrajectory)
Returns the final time index of the trajectory.
"""
Base.lastindex(traj::NamedTrajectory) = traj.T

"""
getindex(traj, symb::Symbol)
Dispatches indexing of trajectories as either accessing a component or a property via `getproperty`.
"""
Base.getindex(traj::NamedTrajectory, symb::Symbol) = getproperty(traj, symb)

"""
getproperty(traj, symb::Symbol)
Returns the component of the trajectory with name `symb` or the property of the trajectory with name `symb`.
"""
function Base.getproperty(traj::NamedTrajectory, symb::Symbol)
if symb fieldnames(NamedTrajectory)
return getfield(traj, symb)
else
indices = traj.components[symb]
return traj.data[indices, :]
end
end

"""
setproperty!(traj, name::Symbol, val::Any)
Dispatches setting properties of trajectories as either setting a component or a property via `setfield!` or `update!`.
"""
function Base.setproperty!(traj::NamedTrajectory, symb::Symbol, val::Any)
if symb fieldnames(NamedTrajectory)
setfield!(traj, symb, val)
else
update!(traj, symb, val)
end
end



function Base.:*::Float64, traj::NamedTrajectory)
return NamedTrajectory* traj.datavec, traj)
end

function Base.:*(traj::NamedTrajectory, α::Float64)
return NamedTrajectory* traj.datavec, traj)
end

function Base.:+(traj1::NamedTrajectory, traj2::NamedTrajectory)
@assert traj1.names == traj2.names
@assert traj1.dim == traj2.dim
@assert traj1.T == traj2.T
return NamedTrajectory(traj1.datavec + traj2.datavec, traj1)
end

function Base.:-(traj1::NamedTrajectory, traj2::NamedTrajectory)
@assert traj1.names == traj2.names
@assert traj1.dim == traj2.dim
@assert traj1.T == traj2.T
return NamedTrajectory(traj1.datavec - traj2.datavec, traj1)
end


end
Loading

0 comments on commit fea058c

Please sign in to comment.