Skip to content

Commit

Permalink
feat: making sure algebraic states are states. Rolling.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ipuch committed Jan 26, 2025
1 parent 2187a56 commit 4da2bba
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 7 deletions.
8 changes: 5 additions & 3 deletions bioptim/dynamics/configure_new_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def _declare_cx_and_plot(self):
for node_index in range(
self.nlp.n_states_nodes if self.nlp.phase_dynamics == PhaseDynamics.ONE_PER_NODE else 1
):
n_cx = 2
n_cx = self.nlp.ode_solver.n_required_cx + 2
cx_scaled = (
self.ocp.nlp[self.nlp.use_states_from_phase_idx].algebraic_states[node_index][self.name].original_cx
if self.copy_algebraic_states
Expand All @@ -463,14 +463,16 @@ def _declare_cx_and_plot(self):
node_index,
)
if not self.skip_plot:
all_variables_in_one_subplot = True if self.name in ["m", "cov", "k"] else False
all_variables_in_one_subplot = (
True if self.name in ["m", "cov", "k"] else False
) # To Eve: This should not be there.
self.nlp.plot[f"{self.name}_algebraic"] = CustomPlot(
lambda t0, phases_dt, node_idx, x, u, p, a, d: (
a[self.nlp.algebraic_states.key_index(self.name), :]
if a.any()
else np.ndarray((cx[0][0].shape[0], 1)) * np.nan
),
plot_type=PlotType.STEP,
plot_type=PlotType.INTEGRATED,
axes_idx=self.axes_idx,
legend=self.legend,
combine_to=self.combine_name,
Expand Down
14 changes: 11 additions & 3 deletions bioptim/dynamics/integrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self, ode: dict, ode_opt: dict):
self._x_sym_modified,
self.u_sym,
self.param_sym,
self.a_sym,
self._a_sym_modified,
self.numerical_timeseries_sym,
],
self.dxdt(
Expand Down Expand Up @@ -130,6 +130,10 @@ def _time_xall_from_dt_func(self) -> Function:
def _x_sym_modified(self):
return self.x_sym

@property
def _a_sym_modified(self):
return self.a_sym

@property
def _input_names(self):
return ["t_span", "x0", "u", "p", "a", "d"]
Expand Down Expand Up @@ -585,6 +589,10 @@ def _initialize(self, ode: dict, ode_opt: dict):
def _x_sym_modified(self):
return horzcat(*self.x_sym) if self.duplicate_starting_point else horzcat(*self.x_sym[1:])

@property
def _a_sym_modified(self):
return horzcat(*self.a_sym) if self.duplicate_starting_point else horzcat(*self.a_sym[1:])

@property
def _output_names(self):
return ["xf", "xall", "defects"]
Expand Down Expand Up @@ -664,7 +672,7 @@ def dxdt(
states[j + 1],
self.get_u(controls, self._integration_time[j]),
params,
algebraic_states,
algebraic_states[j],
numerical_timeseries,
)[:, self.ode_idx]
defects.append(xp_j - f_j * self.h)
Expand All @@ -676,7 +684,7 @@ def dxdt(
states[j + 1],
self.get_u(controls, self._integration_time[j]),
params,
algebraic_states,
algebraic_states[j],
numerical_timeseries,
xp_j / self.h,
)
Expand Down
9 changes: 8 additions & 1 deletion bioptim/dynamics/ode_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,15 @@ def x_ode(self, nlp):
def p_ode(self, nlp):
return nlp.controls.scaled.cx_start

# def a_ode(self, nlp):
# return nlp.algebraic_states.scaled.cx_start

def a_ode(self, nlp):
return nlp.algebraic_states.scaled.cx_start
out = [nlp.algebraic_states.scaled.cx_start]
if not self.duplicate_starting_point:
out += [nlp.algebraic_states.scaled.cx_start]
out += nlp.algebraic_states.scaled.cx_intermediates_list
return out

def d_ode(self, nlp):
return nlp.numerical_timeseries.cx_start
Expand Down
4 changes: 4 additions & 0 deletions bioptim/interfaces/interface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@ def generic_get_all_penalties(interface, nlp: NonLinearProgram, penalties, scale
tp[: u_tp.shape[0], :] = u_tp
u_tp = tp
u = horzcat(u, u_tp)
if idx != 0 and a_tp.shape[0] != a.shape[0]:
tp = ocp.cx.nan(a.shape[0], 1)
tp[: a_tp.shape[0], :] = a_tp
a_tp = tp
a = horzcat(a, a_tp)
d = horzcat(d, d_tp) if d is not None else d_tp
weight = np.concatenate((weight, [[float(weight_tp)]]), axis=1)
Expand Down

0 comments on commit 4da2bba

Please sign in to comment.