Skip to content

Commit

Permalink
Mppi critic score publisher (#3)
Browse files Browse the repository at this point in the history
* merge enricosutera commits

* revert unknown dockerfile changes

* linting and small fixes

* critic publisher

* add header to criticscores msg
  • Loading branch information
turtlewizard73 authored May 29, 2024
1 parent ade87a2 commit ee12719
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 2 deletions.
14 changes: 14 additions & 0 deletions nav2_mppi_controller/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ set(XTENSOR_USE_XSIMD 1)
find_package(ament_cmake REQUIRED)
find_package(xtensor REQUIRED)
find_package(xsimd REQUIRED)
find_package(rosidl_default_generators REQUIRED)

include_directories(
include
Expand All @@ -27,12 +28,14 @@ set(dependencies_pkgs
geometry_msgs
visualization_msgs
nav_msgs
nav2_msgs
nav2_core
nav2_costmap_2d
nav2_util
tf2_geometry_msgs
tf2_eigen
tf2_ros
std_msgs
)

foreach(pkg IN LISTS dependencies_pkgs)
Expand All @@ -41,6 +44,11 @@ endforeach()

nav2_package()

rosidl_generate_interfaces(${PROJECT_NAME}
"msg/CriticScores.msg"
DEPENDENCIES std_msgs
)

include(CheckCXXCompilerFlag)

check_cxx_compiler_flag("-mno-avx512f" COMPILER_SUPPORTS_AVX512)
Expand Down Expand Up @@ -119,8 +127,14 @@ if(BUILD_TESTING)
# add_subdirectory(benchmark)
endif()

rosidl_get_typesupport_target(cpp_typesupport_target
${PROJECT_NAME} rosidl_typesupport_cpp)

target_link_libraries(mppi_controller "${cpp_typesupport_target}")

ament_export_libraries(${libraries})
ament_export_dependencies(${dependencies_pkgs})
ament_export_dependencies(rosidl_default_runtime)
ament_export_include_directories(include)
pluginlib_export_plugin_description_file(nav2_core mppic.xml)
pluginlib_export_plugin_description_file(nav2_mppi_controller critics.xml)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "nav2_mppi_controller/tools/trajectory_visualizer.hpp"
#include "nav2_mppi_controller/models/constraints.hpp"
#include "nav2_mppi_controller/tools/utils.hpp"
#include "nav2_mppi_controller/msg/critic_scores.hpp"

#include "nav2_core/controller.hpp"
#include "nav2_core/goal_checker.hpp"
Expand Down Expand Up @@ -121,10 +122,14 @@ class MPPIController : public nav2_core::Controller
TrajectoryVisualizer trajectory_visualizer_;

bool visualize_;
bool publish_critics_;

double reset_period_;
// Last time computeVelocityCommands was called
rclcpp::Time last_time_called_;

std::shared_ptr<rclcpp_lifecycle::LifecyclePublisher<nav2_mppi_controller::msg::CriticScores>>
critics_publisher_;
};

} // namespace nav2_mppi_controller
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ class CriticManager
* @brief Constructor for mppi::CriticManager
*/
CriticManager() = default;

/**
* @brief Virtual Destructor for mppi::CriticManager
*/
virtual ~CriticManager() = default;

/**
* @brief Configure critic manager on bringup and load plugins
* @param parent WeakPtr to node
Expand All @@ -69,6 +69,10 @@ class CriticManager
*/
void evalTrajectoriesScores(CriticData & data) const;

xt::xtensor<float, 1> evalTrajectory(CriticData & data) const;

std::vector<std::string> getCriticNames() const;

protected:
/**
* @brief Get parameters (critics to load)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ class Optimizer
*/
xt::xtensor<float, 2> getOptimizedTrajectory();


/**
* @brief Get the critic costs for given trajectory
* @return Names and costs of the critics
*/
xt::xtensor<float, 1> getOptimizationResults();

std::vector<std::string> getCriticNames() const;

/**
* @brief Set the maximum speed based on the speed limits callback
* @param speed_limit Limit of the speed for use
Expand Down
3 changes: 3 additions & 0 deletions nav2_mppi_controller/msg/CriticScores.msg
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
std_msgs/Header header # ROS time that this log message was sent.
std_msgs/String[] critic_names
std_msgs/Float32[] critic_scores
5 changes: 5 additions & 0 deletions nav2_mppi_controller/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

<buildtool_depend>ament_cmake</buildtool_depend>
<buildtool_depend>ament_cmake_ros</buildtool_depend>
<buildtool_depend>rosidl_default_generators</buildtool_depend>

<exec_depend>rosidl_default_runtime</exec_depend>

<depend>rclcpp</depend>
<depend>nav2_common</depend>
Expand All @@ -33,6 +36,8 @@
<test_depend>ament_lint_auto</test_depend>
<test_depend>ament_lint_common</test_depend>
<test_depend>ament_cmake_gtest</test_depend>

<member_of_group>rosidl_interface_packages</member_of_group>
<export>
<build_type>ament_cmake</build_type>
<nav2_core plugin="${prefix}/mppic.xml" />
Expand Down
32 changes: 32 additions & 0 deletions nav2_mppi_controller/src/controller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ void MPPIController::configure(
// Get high-level controller parameters
auto getParam = parameters_handler_->getParamGetter(name_);
getParam(visualize_, "visualize", false);
getParam(publish_critics_, "publish_critics", false);
getParam(reset_period_, "reset_period", 1.0);

// Configure composed objects
Expand All @@ -48,6 +49,11 @@ void MPPIController::configure(
parent_, name_,
costmap_ros_->getGlobalFrameID(), parameters_handler_.get());

if (publish_critics_) {
critics_publisher_ = node->create_publisher<nav2_mppi_controller::msg::CriticScores>(
"/mppi_critic_scores", 1);
}

RCLCPP_INFO(logger_, "Configured MPPI Controller: %s", name_.c_str());
}

Expand All @@ -61,13 +67,15 @@ void MPPIController::cleanup()

void MPPIController::activate()
{
critics_publisher_->on_activate();
trajectory_visualizer_.on_activate();
parameters_handler_->start();
RCLCPP_INFO(logger_, "Activated MPPI Controller: %s", name_.c_str());
}

void MPPIController::deactivate()
{
critics_publisher_->on_deactivate();
trajectory_visualizer_.on_deactivate();
RCLCPP_INFO(logger_, "Deactivated MPPI Controller: %s", name_.c_str());
}
Expand Down Expand Up @@ -110,6 +118,30 @@ geometry_msgs::msg::TwistStamped MPPIController::computeVelocityCommands(
visualize(std::move(transformed_plan));
}

if (publish_critics_) {
std::vector<std::string> critic_names = optimizer_.getCriticNames();
xt::xtensor<float, 1> critic_costs = optimizer_.getOptimizationResults();

// log critic names and costs
for (size_t i = 0; i < critic_names.size(); i++) {
RCLCPP_INFO(logger_, "Critic: %s, Cost: %f", critic_names[i].c_str(), critic_costs[i]);
}

// make msg
auto critic_scores_ = std::make_unique<nav2_mppi_controller::msg::CriticScores>();
for (size_t i = 0; i < critic_names.size(); i++) {
std_msgs::msg::String name_msg;
name_msg.data = critic_names[i];
critic_scores_->critic_names.push_back(std::move(name_msg));

std_msgs::msg::Float32 cost_msg;
cost_msg.data = critic_costs[i];
critic_scores_->critic_scores.push_back(std::move(cost_msg));
}
critic_scores_->header.stamp = clock_->now();
critics_publisher_->publish(std::move(critic_scores_));
}

return cmd;
}

Expand Down
30 changes: 30 additions & 0 deletions nav2_mppi_controller/src/critic_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <xtensor/xtensor.hpp>

#include "nav2_mppi_controller/critic_manager.hpp"

namespace mppi
Expand Down Expand Up @@ -64,6 +66,11 @@ std::string CriticManager::getFullName(const std::string & name)
return "mppi::critics::" + name;
}

std::vector<std::string> CriticManager::getCriticNames() const
{
return critic_names_;
}

void CriticManager::evalTrajectoriesScores(
CriticData & data) const
{
Expand All @@ -75,4 +82,27 @@ void CriticManager::evalTrajectoriesScores(
}
}

xt::xtensor<float, 1> CriticManager::evalTrajectory(
CriticData & data) const
{
// log the critic_scores shape
RCLCPP_INFO(logger_, "(BEFORE FOR)");

xt::xtensor<float, 1> critic_scores = xt::zeros<float>({critics_.size()});

for (size_t q = 0; q < critics_.size(); q++) {
if (data.fail_flag) {
break;
}
data.costs = xt::zeros<float>({1});
// log costs values
critics_[q]->score(data);
critic_scores(q) = data.costs[0];
}
// log the for cycle finished in criticmanager
RCLCPP_INFO(logger_, "CriticManager: Critic evaluation (FOR) finished");

return critic_scores;
}

} // namespace mppi
61 changes: 61 additions & 0 deletions nav2_mppi_controller/src/optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,67 @@ void Optimizer::optimize()
}
}

xt::xtensor<float, 1> Optimizer::getOptimizationResults()
{
const xt::xtensor<float, 2> optimized_trajectory = getOptimizedTrajectory();
xt::xtensor<float, 1> costs = xt::zeros<float>({1});

/*auto size = optimized_trajectory.size(); // size = 6
auto dim = optimized_trajectory.dimension(); // dim = 2
auto shape = optimized_trajectory.shape(); // shape = {2, 3}
//log size, dim, and shape
RCLCPP_INFO(
logger_, "getOptimizedTrajectory() size: %ld, dim: %ld, shape: {%ld, %ld}",
size, dim, shape[0], shape[1]);*/

models::Trajectories dummy_trajectories;
/*size = dummy_trajectories.x.size();
dim = dummy_trajectories.x.dimension();
shape = dummy_trajectories.x.shape();
RCLCPP_INFO(
logger_, "[after creation] dummy_trajectories size: %ld, dim: %ld, shape: {%ld, %ld}",
size, dim, shape[0], shape[1]);*/

dummy_trajectories.reset(1, settings_.time_steps);
/*size = dummy_trajectories.x.size();
dim = dummy_trajectories.x.dimension();
shape = dummy_trajectories.x.shape();
RCLCPP_INFO(
logger_, "[after reset] dummy_trajectories size: %ld, dim: %ld, shape: {%ld, %ld}",
size, dim, shape[0], shape[1]);*/

dummy_trajectories.x += xt::view(optimized_trajectory, xt::all(), 0);
dummy_trajectories.y += xt::view(optimized_trajectory, xt::all(), 1);
dummy_trajectories.yaws += xt::view(optimized_trajectory, xt::all(), 2);

/*size = dummy_trajectories.x.size();
dim = dummy_trajectories.x.dimension();
shape = dummy_trajectories.x.shape();
RCLCPP_INFO(
logger_, "[after valuedump] dummy_trajectories size: %ld, dim: %ld, shape: {%ld, %ld}",
size, dim, shape[0], shape[1]);*/

CriticData dummy_data = {
state_, dummy_trajectories, path_, costs, settings_.model_dt,
false, critics_data_.goal_checker, critics_data_.motion_model, std::nullopt, std::nullopt};
// dummy_data.goal_checker = critics_data_.goal_checker;
// dummy_data.motion_model = critics_data_.motion_model;
dummy_data.furthest_reached_path_point.reset();
dummy_data.path_pts_valid.reset();

/*RCLCPP_INFO(
logger_, "dummy_data type: %s",
typeid(dummy_data).name());*/

return critic_manager_.evalTrajectory(dummy_data);
}

std::vector<std::string> Optimizer::getCriticNames() const
{
return critic_manager_.getCriticNames();
}

bool Optimizer::fallback(bool fail)
{
static size_t counter = 0;
Expand Down

0 comments on commit ee12719

Please sign in to comment.