-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added c++ functions to parse multi-arrays into tensors
- Loading branch information
Showing
6 changed files
with
173 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
message(STATUS "Configuring Obelisk Utils") | ||
|
||
|
||
if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") | ||
add_compile_options(-Wall -Wextra -Wpedantic) | ||
endif() | ||
|
||
# ------- ROS 2 Packages ------- # | ||
find_package(ament_cmake REQUIRED) | ||
find_package(rclcpp REQUIRED) | ||
find_package(rclcpp_lifecycle REQUIRED) | ||
find_package(rcl REQUIRED) | ||
|
||
# ------- Obelisk Messages ------- # | ||
find_package(obelisk_std_msgs REQUIRED) | ||
|
||
# ------- Eigen ------- # | ||
find_package(Eigen3 REQUIRED) | ||
|
||
# ------- Source files ------- # | ||
set(UTILS_INC "${CMAKE_CURRENT_SOURCE_DIR}/include") | ||
|
||
# ------- Making the library ------- # | ||
add_library(ObkUtils INTERFACE) | ||
add_library(Obelisk::Utils ALIAS ObkUtils) # Namespaced alias | ||
target_include_directories(ObkUtils INTERFACE ${UTILS_INC} ${mujoco_SOURCE_DIR}/include) | ||
|
||
target_link_libraries(ObkUtils INTERFACE Eigen3::Eigen) | ||
|
||
ament_target_dependencies(ObkUtils INTERFACE | ||
rclcpp | ||
rclcpp_lifecycle | ||
obelisk_std_msgs | ||
std_msgs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
#include <unsupported/Eigen/CXX11/Tensor> | ||
|
||
#include "obelisk_std_msgs/msg/float_multi_array.hpp" | ||
#include "obelisk_std_msgs/msg/u_int8_multi_array.hpp" | ||
|
||
namespace obelisk::utils::msgs { | ||
namespace internal { | ||
template <typename ScalarT, std::size_t N, std::size_t... Indices> | ||
Eigen::Tensor<ScalarT, N> CreateTensor(std::vector<ScalarT>& data, const std::array<int, N>& dims, | ||
std::index_sequence<Indices...>) { | ||
Eigen::TensorMap<Eigen::Tensor<ScalarT, N>> tensor(data.data(), dims[Indices]...); | ||
return tensor; | ||
} | ||
|
||
// template <std::size_t N, std::size_t... Indices> | ||
// void SetTensorElement(Eigen::Tensor<double, N>& tensor, double val, const std::array<int, N>& element, | ||
// std::index_sequence<Indices...>) { | ||
// tensor(element[Indices]...) = val; | ||
// } | ||
} // namespace internal | ||
|
||
template <int Size> | ||
Eigen::Tensor<double, Size> MutliArrayToTensor(const obelisk_std_msgs::msg::FloatMultiArray& msg) { | ||
|
||
// Get the flat part of the data | ||
std::vector<double> data(msg.data.begin() + msg.layout.data_offset, msg.data.end()); | ||
|
||
if (msg.layout.dim.size() != Size) { | ||
// TODO: Consider just logging this, but without a node, we don't have access to a logger. | ||
throw std::runtime_error("Templated size does not match the size provided by the message!"); | ||
} | ||
|
||
std::array<int, Size> sizes; | ||
for (int i = 0; i < Size; i++) { | ||
sizes.at(i) = msg.layout.dim.at(i).size; | ||
} | ||
|
||
auto tensor = internal::CreateTensor<double, Size>(data, sizes, std::make_index_sequence<Size>{}); | ||
|
||
return tensor; | ||
} | ||
|
||
template <int Size> | ||
Eigen::Tensor<uint8_t, Size> MutliArrayToTensor(const obelisk_std_msgs::msg::UInt8MultiArray& msg) { | ||
|
||
// Get the flat part of the data | ||
std::vector<uint8_t> data(msg.data.begin() + msg.layout.data_offset, msg.data.end()); | ||
|
||
if (msg.layout.dim.size() != Size) { | ||
// TODO: Consider just logging this, but without a node, we don't have access to a logger. | ||
throw std::runtime_error("Templated size does not match the size provided by the message!"); | ||
} | ||
|
||
std::array<int, Size> sizes; | ||
for (int i = 0; i < Size; i++) { | ||
sizes.at(i) = msg.layout.dim.at(i).size; | ||
} | ||
|
||
auto tensor = internal::CreateTensor<uint8_t, Size>(data, sizes, std::make_index_sequence<Size>{}); | ||
|
||
return tensor; | ||
} | ||
|
||
template <int Size> | ||
obelisk_std_msgs::msg::FloatMultiArray TensorToMultiArray(const Eigen::Tensor<double, Size>& tensor) { | ||
obelisk_std_msgs::msg::FloatMultiArray msg; | ||
msg.layout.data_offset = 0; | ||
|
||
// Get data into flat vector | ||
msg.data.resize(tensor.size()); | ||
std::copy(tensor.data(), tensor.data() + tensor.size(), msg.data.begin()); | ||
|
||
// // Compute stride lengths | ||
std_msgs::msg::MultiArrayDimension dim; | ||
dim.label = "dim_" + std::to_string(Size); | ||
dim.size = tensor.dimension(Size - 1); | ||
dim.stride = 1; | ||
|
||
msg.layout.dim.emplace_back(dim); // The stride for the last dimension | ||
for (int i = tensor.dimensions().size() - 2; i >= 0; --i) { | ||
dim.label = "dim_" + std::to_string(i); | ||
dim.size = tensor.dimension(i); | ||
dim.stride = msg.layout.dim.back().stride * tensor.dimension(i + 1); | ||
|
||
msg.layout.dim.emplace_back(dim); | ||
} | ||
std::reverse(msg.layout.dim.begin(), msg.layout.dim.end()); // Reverse to match dimension order | ||
|
||
return msg; | ||
} | ||
} // namespace obelisk::utils::msgs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#include <catch2/catch_test_macros.hpp> | ||
#include <iostream> | ||
|
||
#include "msg_conversions.h" | ||
|
||
TEST_CASE("Multiarray Testing", "[utils][msgs]") { | ||
for (int l = 0; l < 10; l++) { | ||
Eigen::Tensor<double, 5> tensor2(4, 4, 5, 2, 3); | ||
for (int i = 0; i < tensor2.dimension(0); i++) { | ||
for (int j = 0; j < tensor2.dimension(1); j++) { | ||
for (int k = 0; k < tensor2.dimension(2); k++) { | ||
for (int m = 0; m < tensor2.dimension(3); m++) { | ||
for (int n = 0; n < tensor2.dimension(4); n++) { | ||
tensor2(i, j, k, m, n) = rand() % 100; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
auto msg = obelisk::utils::msgs::TensorToMultiArray<5>(tensor2); | ||
Eigen::Tensor<double, 5> tensor3 = obelisk::utils::msgs::MutliArrayToTensor<5>(msg); | ||
|
||
for (int i = 0; i < tensor2.dimension(0); i++) { | ||
for (int j = 0; j < tensor2.dimension(1); j++) { | ||
for (int k = 0; k < tensor2.dimension(2); k++) { | ||
for (int m = 0; m < tensor2.dimension(3); m++) { | ||
for (int n = 0; n < tensor2.dimension(4); n++) { | ||
CHECK(tensor2(i, j, k, m, n) == tensor3(i, j, k, m, n)); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} |