Skip to content

Commit

Permalink
new scorecritic single msg (#5)
Browse files Browse the repository at this point in the history
* new scorecritic single msg

* Vince fix

* remove duplicate function

* linting
  • Loading branch information
turtlewizard73 authored Dec 16, 2024
1 parent ee12719 commit a97c9ef
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 75 deletions.
1 change: 1 addition & 0 deletions nav2_mppi_controller/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ endforeach()
nav2_package()

rosidl_generate_interfaces(${PROJECT_NAME}
"msg/CriticScore.msg"
"msg/CriticScores.msg"
DEPENDENCIES std_msgs
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "nav2_mppi_controller/tools/trajectory_visualizer.hpp"
#include "nav2_mppi_controller/models/constraints.hpp"
#include "nav2_mppi_controller/tools/utils.hpp"
#include "nav2_mppi_controller/msg/critic_score.hpp"
#include "nav2_mppi_controller/msg/critic_scores.hpp"

#include "nav2_core/controller.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ class CriticManager
*/
void evalTrajectoriesScores(CriticData & data) const;

xt::xtensor<float, 1> evalTrajectory(CriticData & data) const;

std::vector<std::string> getCriticNames() const;

protected:
Expand Down
2 changes: 2 additions & 0 deletions nav2_mppi_controller/msg/CriticScore.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
std_msgs/String name
std_msgs/Float32 score
5 changes: 3 additions & 2 deletions nav2_mppi_controller/msg/CriticScores.msg
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
std_msgs/Header header # ROS time that this log message was sent.
std_msgs/String[] critic_names
std_msgs/Float32[] critic_scores
CriticScore[] critic_scores
# std_msgs/String[] critic_names
# std_msgs/Float32[] critic_scores
26 changes: 16 additions & 10 deletions nav2_mppi_controller/src/controller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ void MPPIController::cleanup()

void MPPIController::activate()
{
critics_publisher_->on_activate();
if (publish_critics_) critics_publisher_->on_activate();
trajectory_visualizer_.on_activate();
parameters_handler_->start();
RCLCPP_INFO(logger_, "Activated MPPI Controller: %s", name_.c_str());
}

void MPPIController::deactivate()
{
critics_publisher_->on_deactivate();
if (publish_critics_) critics_publisher_->on_deactivate();
trajectory_visualizer_.on_deactivate();
RCLCPP_INFO(logger_, "Deactivated MPPI Controller: %s", name_.c_str());
}
Expand Down Expand Up @@ -124,20 +124,26 @@ geometry_msgs::msg::TwistStamped MPPIController::computeVelocityCommands(

// log critic names and costs
for (size_t i = 0; i < critic_names.size(); i++) {
RCLCPP_INFO(logger_, "Critic: %s, Cost: %f", critic_names[i].c_str(), critic_costs[i]);
RCLCPP_DEBUG(logger_, "Critic: %s, Cost: %f", critic_names[i].c_str(), critic_costs[i]);
}

// make msg
auto critic_scores_ = std::make_unique<nav2_mppi_controller::msg::CriticScores>();
for (size_t i = 0; i < critic_names.size(); i++) {
std_msgs::msg::String name_msg;
name_msg.data = critic_names[i];
critic_scores_->critic_names.push_back(std::move(name_msg));
if (critic_names.size() != critic_costs.size()) {
RCLCPP_ERROR(
logger_,
"Critic names %ld and costs %ld size mismatch!",
critic_names.size(), critic_costs.size());
return cmd;
}

std_msgs::msg::Float32 cost_msg;
cost_msg.data = critic_costs[i];
critic_scores_->critic_scores.push_back(std::move(cost_msg));
for (size_t i = 0; i < critic_names.size(); i++) {
nav2_mppi_controller::msg::CriticScore critic_score;
critic_score.name.data = critic_names[i];
critic_score.score.data = critic_costs[i];
critic_scores_->critic_scores.push_back(critic_score);
}

critic_scores_->header.stamp = clock_->now();
critics_publisher_->publish(std::move(critic_scores_));
}
Expand Down
24 changes: 0 additions & 24 deletions nav2_mppi_controller/src/critic_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,28 +81,4 @@ void CriticManager::evalTrajectoriesScores(
critics_[q]->score(data);
}
}

xt::xtensor<float, 1> CriticManager::evalTrajectory(
CriticData & data) const
{
// log the critic_scores shape
RCLCPP_INFO(logger_, "(BEFORE FOR)");

xt::xtensor<float, 1> critic_scores = xt::zeros<float>({critics_.size()});

for (size_t q = 0; q < critics_.size(); q++) {
if (data.fail_flag) {
break;
}
data.costs = xt::zeros<float>({1});
// log costs values
critics_[q]->score(data);
critic_scores(q) = data.costs[0];
}
// log the for cycle finished in criticmanager
RCLCPP_INFO(logger_, "CriticManager: Critic evaluation (FOR) finished");

return critic_scores;
}

} // namespace mppi
55 changes: 18 additions & 37 deletions nav2_mppi_controller/src/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,58 +163,39 @@ void Optimizer::optimize()

xt::xtensor<float, 1> Optimizer::getOptimizationResults()
{
// get the final optimized trajectory
const xt::xtensor<float, 2> optimized_trajectory = getOptimizedTrajectory();
xt::xtensor<float, 1> costs = xt::zeros<float>({1});

/*auto size = optimized_trajectory.size(); // size = 6
auto dim = optimized_trajectory.dimension(); // dim = 2
auto shape = optimized_trajectory.shape(); // shape = {2, 3}
//log size, dim, and shape
RCLCPP_INFO(
logger_, "getOptimizedTrajectory() size: %ld, dim: %ld, shape: {%ld, %ld}",
size, dim, shape[0], shape[1]);*/

// evalTrajectory evals multiple trajectories, but we only have one
// create a dummy_trajectories object and put the optimized in it
models::Trajectories dummy_trajectories;
/*size = dummy_trajectories.x.size();
dim = dummy_trajectories.x.dimension();
shape = dummy_trajectories.x.shape();
RCLCPP_INFO(
logger_, "[after creation] dummy_trajectories size: %ld, dim: %ld, shape: {%ld, %ld}",
size, dim, shape[0], shape[1]);*/

dummy_trajectories.reset(1, settings_.time_steps);
/*size = dummy_trajectories.x.size();
dim = dummy_trajectories.x.dimension();
shape = dummy_trajectories.x.shape();
RCLCPP_INFO(
logger_, "[after reset] dummy_trajectories size: %ld, dim: %ld, shape: {%ld, %ld}",
size, dim, shape[0], shape[1]);*/

dummy_trajectories.x += xt::view(optimized_trajectory, xt::all(), 0);
dummy_trajectories.y += xt::view(optimized_trajectory, xt::all(), 1);
dummy_trajectories.yaws += xt::view(optimized_trajectory, xt::all(), 2);

/*size = dummy_trajectories.x.size();
dim = dummy_trajectories.x.dimension();
shape = dummy_trajectories.x.shape();
RCLCPP_INFO(
logger_, "[after valuedump] dummy_trajectories size: %ld, dim: %ld, shape: {%ld, %ld}",
size, dim, shape[0], shape[1]);*/

// create a dummy_data object to pass to evalTrajectory
CriticData dummy_data = {
state_, dummy_trajectories, path_, costs, settings_.model_dt,
false, critics_data_.goal_checker, critics_data_.motion_model, std::nullopt, std::nullopt};
// dummy_data.goal_checker = critics_data_.goal_checker;
// dummy_data.motion_model = critics_data_.motion_model;
dummy_data.furthest_reached_path_point.reset();
dummy_data.path_pts_valid.reset();

/*RCLCPP_INFO(
logger_, "dummy_data type: %s",
typeid(dummy_data).name());*/
// use evalTrajectoriesScores
critic_manager_.evalTrajectoriesScores(dummy_data);

size_t num_critics = critic_manager_.getCriticNames().size();

xt::xtensor<float, 1> critic_scores = xt::zeros<float>(std::vector<size_t>{num_critics});
for (size_t i = 0; i < num_critics; i++) {
critic_scores(i) = dummy_data.costs(0); // Assuming costs are updated for each critic
}

return critic_scores;

return critic_manager_.evalTrajectory(dummy_data);
// evaluate the optimized trajectory
// return critic_manager_.evalTrajectory(dummy_data);
}

std::vector<std::string> Optimizer::getCriticNames() const
Expand Down Expand Up @@ -496,7 +477,7 @@ void Optimizer::setSpeedLimit(double speed_limit, bool percentage)
s.constraints.vx_min = s.base_constraints.vx_min * ratio;
s.constraints.vy = s.base_constraints.vy * ratio;
s.constraints.wz = s.base_constraints.wz * ratio;
} else {
} else if (speed_limit < s.base_constraints.vx_max) {
// Speed limit is expressed in absolute value
double ratio = speed_limit / s.base_constraints.vx_max;
s.constraints.vx_max = s.base_constraints.vx_max * ratio;
Expand Down

0 comments on commit a97c9ef

Please sign in to comment.