Skip to content

Commit

Permalink
add get component names to NT
Browse files Browse the repository at this point in the history
  • Loading branch information
andgoldschmidt committed Aug 7, 2024
1 parent 60084d5 commit ced6184
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 118 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NamedTrajectories"
uuid = "538bc3a1-5ab9-4fc3-b776-35ca1e893e08"
authors = ["Aaron Trowbridge <aaron.j.trowbridge@gmail.com> and contributors"]
version = "0.1.8"
version = "0.1.9"

[deps]
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
Expand Down
211 changes: 120 additions & 91 deletions src/methods_named_trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module MethodsNamedTrajectory

export vec
export get_components
export get_component_names
export add_component!
export remove_component
export remove_components
Expand All @@ -16,6 +17,9 @@ using OrderedCollections
using ..StructNamedTrajectory
using ..StructKnotPoint

# -------------------------------------------------------------- #
# Base indexing
# -------------------------------------------------------------- #

function StructKnotPoint.KnotPoint(
Z::NamedTrajectory,
Expand All @@ -26,6 +30,67 @@ function StructKnotPoint.KnotPoint(
return KnotPoint(t, Z.data[:, t], timestep, Z.components, Z.names, Z.control_names)
end

"""
getindex(traj, t::Int)::KnotPoint
Returns the knot point at time `t`.
"""
Base.getindex(traj::NamedTrajectory, t::Int) = KnotPoint(traj, t)

"""
getindex(traj, ts::AbstractVector{Int})::Vector{KnotPoint}
Returns the knot points at times `ts`.
"""
function Base.getindex(traj::NamedTrajectory, ts::AbstractVector{Int})::Vector{KnotPoint}
return [traj[t] for t ts]
end

"""
lastindex(traj::NamedTrajectory)
Returns the final time index of the trajectory.
"""
Base.lastindex(traj::NamedTrajectory) = traj.T

"""
getindex(traj, symb::Symbol)
Dispatches indexing of trajectories as either accessing a component or a property via `getproperty`.
"""
Base.getindex(traj::NamedTrajectory, symb::Symbol) = getproperty(traj, symb)

"""
getproperty(traj, symb::Symbol)
Returns the component of the trajectory with name `symb` or the property of the trajectory with name `symb`.
"""
function Base.getproperty(traj::NamedTrajectory, symb::Symbol)
if symb fieldnames(NamedTrajectory)
return getfield(traj, symb)
else
indices = traj.components[symb]
return traj.data[indices, :]
end
end

"""
setproperty!(traj, name::Symbol, val::Any)
Dispatches setting properties of trajectories as either setting a component or a property via `setfield!` or `update!`.
"""
function Base.setproperty!(traj::NamedTrajectory, symb::Symbol, val::Any)
if symb fieldnames(NamedTrajectory)
setfield!(traj, symb, val)
else
update!(traj, symb, val)
end
end

# -------------------------------------------------------------- #
# Base operations
# -------------------------------------------------------------- #

"""
vec(::NamedTrajectory)
Expand All @@ -35,16 +100,23 @@ function Base.vec(Z::NamedTrajectory)
return vcat(Z.datavec, values(Z.global_data)...)
end


"""
length(::NamedTrajectory)
Returns the length of all variables of the trajectory.
Returns the length of all variables of the trajectory, including global data.
"""
function Base.length(Z::NamedTrajectory)
return Z.dim * Z.T + Z.global_dim
end

"""
size(traj::NamedTrajectory) = (dim = traj.dim, T = traj.T)
Returns the size of the trajectory (dim, T), excluding global data.
TODO: Should global data be in size?
"""
Base.size(traj::NamedTrajectory) = (dim = traj.dim, T = traj.T)

"""
copy(::NamedTrajectory)
Expand All @@ -69,7 +141,6 @@ function Base.isequal(traj1::NamedTrajectory, traj2::NamedTrajectory)
end
end


"""
:(==)(traj1::NamedTrajectory, traj2::NamedTrajectory)
Expand All @@ -84,8 +155,31 @@ function Base.:(==)(traj1::NamedTrajectory, traj2::NamedTrajectory)
end
end

function Base.:*::Float64, traj::NamedTrajectory)
return NamedTrajectory* traj.datavec, traj)
end

function Base.:*(traj::NamedTrajectory, α::Float64)
return NamedTrajectory* traj.datavec, traj)
end

function Base.:+(traj1::NamedTrajectory, traj2::NamedTrajectory)
@assert traj1.names == traj2.names
@assert traj1.dim == traj2.dim
@assert traj1.T == traj2.T
return NamedTrajectory(traj1.datavec + traj2.datavec, traj1)
end

function Base.:-(traj1::NamedTrajectory, traj2::NamedTrajectory)
@assert traj1.names == traj2.names
@assert traj1.dim == traj2.dim
@assert traj1.T == traj2.T
return NamedTrajectory(traj1.datavec - traj2.datavec, traj1)
end

# -------------------------------------------------------------- #
# Methods
# -------------------------------------------------------------- #

"""
get_components(::NamedTrajectory)
Expand All @@ -100,6 +194,29 @@ end

get_components(traj::NamedTrajectory) = get_components(traj.names, traj)

function filter_by_value(f::Function, nt::NamedTuple)
return (; (k => v for (k, v) in pairs(nt) if f(v))...)
end

"""
get_component_names(traj::NamedTrajectory, comps::AbstractVector{<:Int})
Returns the name of the component with the given indices. If only one component is found,
the name is returned as a single symbol. Else, the names are returned as a vector of symbols.
The filter requires that the components are a complete subset of the given indices, so that
a partial match is excluded from the returned names.
"""
function get_component_names(traj::NamedTrajectory, comps::AbstractVector{<:Int})
name = [n for n keys(filter_by_value(x -> issubset(x, comps), traj.components)) if n traj.names]
if isempty(name)
error("Component names not found in traj")
elseif length(name) == 1
return name[1]
else
return name
end
end

"""
add_component!(traj, name::Symbol, data::AbstractVecOrMat; type={:state, :control})
Expand Down Expand Up @@ -306,8 +423,6 @@ function update_bound!(traj::NamedTrajectory, name::Symbol, new_bound::BoundType
return nothing
end



"""
get_times(traj)::Vector{Float64}
Expand Down Expand Up @@ -343,91 +458,5 @@ function get_duration(traj::NamedTrajectory)
return get_times(traj)[end]
end

"""
size(traj::NamedTrajectory) = (dim = traj.dim, T = traj.T)
"""
Base.size(traj::NamedTrajectory) = (dim = traj.dim, T = traj.T)

"""
getindex(traj, t::Int)::KnotPoint
Returns the knot point at time `t`.
"""
Base.getindex(traj::NamedTrajectory, t::Int) = KnotPoint(traj, t)

"""
getindex(traj, ts::AbstractVector{Int})::Vector{KnotPoint}
Returns the knot points at times `ts`.
"""
function Base.getindex(traj::NamedTrajectory, ts::AbstractVector{Int})::Vector{KnotPoint}
return [traj[t] for t ts]
end

"""
lastindex(traj::NamedTrajectory)
Returns the final time index of the trajectory.
"""
Base.lastindex(traj::NamedTrajectory) = traj.T

"""
getindex(traj, symb::Symbol)
Dispatches indexing of trajectories as either accessing a component or a property via `getproperty`.
"""
Base.getindex(traj::NamedTrajectory, symb::Symbol) = getproperty(traj, symb)

"""
getproperty(traj, symb::Symbol)
Returns the component of the trajectory with name `symb` or the property of the trajectory with name `symb`.
"""
function Base.getproperty(traj::NamedTrajectory, symb::Symbol)
if symb fieldnames(NamedTrajectory)
return getfield(traj, symb)
else
indices = traj.components[symb]
return traj.data[indices, :]
end
end

"""
setproperty!(traj, name::Symbol, val::Any)
Dispatches setting properties of trajectories as either setting a component or a property via `setfield!` or `update!`.
"""
function Base.setproperty!(traj::NamedTrajectory, symb::Symbol, val::Any)
if symb fieldnames(NamedTrajectory)
setfield!(traj, symb, val)
else
update!(traj, symb, val)
end
end



function Base.:*::Float64, traj::NamedTrajectory)
return NamedTrajectory* traj.datavec, traj)
end

function Base.:*(traj::NamedTrajectory, α::Float64)
return NamedTrajectory* traj.datavec, traj)
end

function Base.:+(traj1::NamedTrajectory, traj2::NamedTrajectory)
@assert traj1.names == traj2.names
@assert traj1.dim == traj2.dim
@assert traj1.T == traj2.T
return NamedTrajectory(traj1.datavec + traj2.datavec, traj1)
end

function Base.:-(traj1::NamedTrajectory, traj2::NamedTrajectory)
@assert traj1.names == traj2.names
@assert traj1.dim == traj2.dim
@assert traj1.T == traj2.T
return NamedTrajectory(traj1.datavec - traj2.datavec, traj1)
end


end
Loading

3 comments on commit ced6184

@aarontrowbridge
Copy link
Member

@aarontrowbridge aarontrowbridge commented on ced6184 Aug 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aarontrowbridge
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aarontrowbridge
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.