diff --git a/perception/lidar_centerpoint/lib/network/network_trt.cpp b/perception/lidar_centerpoint/lib/network/network_trt.cpp index 88319ff51fe35..2d841d22c2eb1 100644 --- a/perception/lidar_centerpoint/lib/network/network_trt.cpp +++ b/perception/lidar_centerpoint/lib/network/network_trt.cpp @@ -14,6 +14,8 @@ #include "lidar_centerpoint/network/network_trt.hpp" +#include + namespace centerpoint { bool VoxelEncoderTRT::setProfile( @@ -59,6 +61,15 @@ bool HeadTRT::setProfile( for (std::size_t ci = 0; ci < out_channel_sizes_.size(); ci++) { auto out_name = network.getOutput(ci)->getName(); + + if ( + out_name == std::string("heatmap") && + network.getOutput(ci)->getDimensions().d[1] != static_cast(out_channel_sizes_[ci])) { + RCLCPP_ERROR( + rclcpp::get_logger("lidar_centerpoint"), + "Expected and actual number of classes do not match"); + return false; + } auto out_dims = nvinfer1::Dims4( config_.batch_size_, out_channel_sizes_[ci], config_.down_grid_size_y_, config_.down_grid_size_x_); diff --git a/perception/lidar_centerpoint/lib/network/tensorrt_wrapper.cpp b/perception/lidar_centerpoint/lib/network/tensorrt_wrapper.cpp index 079c41d06c6e0..4840b63940df1 100644 --- a/perception/lidar_centerpoint/lib/network/tensorrt_wrapper.cpp +++ b/perception/lidar_centerpoint/lib/network/tensorrt_wrapper.cpp @@ -14,6 +14,8 @@ #include "lidar_centerpoint/network/tensorrt_wrapper.hpp" +#include + #include #include @@ -38,7 +40,7 @@ bool TensorRTWrapper::init( runtime_ = tensorrt_common::TrtUniquePtr(nvinfer1::createInferRuntime(logger_)); if (!runtime_) { - std::cout << "Fail to create runtime" << std::endl; + RCLCPP_ERROR(rclcpp::get_logger("lidar_centerpoint"), "Failed to create runtime"); return false; } @@ -57,14 +59,15 @@ bool TensorRTWrapper::init( bool TensorRTWrapper::createContext() { if (!engine_) { - std::cout << "Fail to create context: Engine isn't created" << std::endl; + RCLCPP_ERROR( + rclcpp::get_logger("lidar_centerpoint"), "Failed to create context: Engine was not created"); return false; } context_ = tensorrt_common::TrtUniquePtr(engine_->createExecutionContext()); if (!context_) { - std::cout << "Fail to create context" << std::endl; + RCLCPP_ERROR(rclcpp::get_logger("lidar_centerpoint"), "Failed to create context"); return false; } @@ -78,14 +81,14 @@ bool TensorRTWrapper::parseONNX( auto builder = tensorrt_common::TrtUniquePtr(nvinfer1::createInferBuilder(logger_)); if (!builder) { - std::cout << "Fail to create builder" << std::endl; + RCLCPP_ERROR(rclcpp::get_logger("lidar_centerpoint"), "Failed to create builder"); return false; } auto config = tensorrt_common::TrtUniquePtr(builder->createBuilderConfig()); if (!config) { - std::cout << "Fail to create config" << std::endl; + RCLCPP_ERROR(rclcpp::get_logger("lidar_centerpoint"), "Failed to create config"); return false; } #if (NV_TENSORRT_MAJOR * 1000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 8400 @@ -95,10 +98,12 @@ bool TensorRTWrapper::parseONNX( #endif if (precision == "fp16") { if (builder->platformHasFastFp16()) { - std::cout << "use TensorRT FP16 Inference" << std::endl; + RCLCPP_INFO(rclcpp::get_logger("lidar_centerpoint"), "Using TensorRT FP16 Inference"); config->setFlag(nvinfer1::BuilderFlag::kFP16); } else { - std::cout << "TensorRT FP16 Inference isn't supported in this environment" << std::endl; + RCLCPP_INFO( + rclcpp::get_logger("lidar_centerpoint"), + "TensorRT FP16 Inference isn't supported in this environment"); } } @@ -107,7 +112,7 @@ bool TensorRTWrapper::parseONNX( auto network = tensorrt_common::TrtUniquePtr(builder->createNetworkV2(flag)); if (!network) { - std::cout << "Fail to create network" << std::endl; + RCLCPP_ERROR(rclcpp::get_logger("lidar_centerpoint"), "Failed to create network"); return false; } @@ -116,22 +121,23 @@ bool TensorRTWrapper::parseONNX( parser->parseFromFile(onnx_path.c_str(), static_cast(nvinfer1::ILogger::Severity::kERROR)); if (!setProfile(*builder, *network, *config)) { - std::cout << "Fail to set profile" << std::endl; + RCLCPP_ERROR(rclcpp::get_logger("lidar_centerpoint"), "Failed to set profile"); return false; } - std::cout << "Applying optimizations and building TRT CUDA engine (" << onnx_path << ") ..." - << std::endl; + RCLCPP_INFO_STREAM( + rclcpp::get_logger("lidar_centerpoint"), + "Applying optimizations and building TRT CUDA engine (" << onnx_path << ") ..."); plan_ = tensorrt_common::TrtUniquePtr( builder->buildSerializedNetwork(*network, *config)); if (!plan_) { - std::cout << "Fail to create serialized network" << std::endl; + RCLCPP_ERROR(rclcpp::get_logger("lidar_centerpoint"), "Failed to create serialized network"); return false; } engine_ = tensorrt_common::TrtUniquePtr( runtime_->deserializeCudaEngine(plan_->data(), plan_->size())); if (!engine_) { - std::cout << "Fail to create engine" << std::endl; + RCLCPP_ERROR(rclcpp::get_logger("lidar_centerpoint"), "Failed to create engine"); return false; } @@ -140,7 +146,7 @@ bool TensorRTWrapper::parseONNX( bool TensorRTWrapper::saveEngine(const std::string & engine_path) { - std::cout << "Writing to " << engine_path << std::endl; + RCLCPP_INFO_STREAM(rclcpp::get_logger("lidar_centerpoint"), "Writing to " << engine_path); std::ofstream file(engine_path, std::ios::out | std::ios::binary); file.write(reinterpret_cast(plan_->data()), plan_->size()); return true; @@ -154,7 +160,7 @@ bool TensorRTWrapper::loadEngine(const std::string & engine_path) std::string engine_str = engine_buffer.str(); engine_ = tensorrt_common::TrtUniquePtr(runtime_->deserializeCudaEngine( reinterpret_cast(engine_str.data()), engine_str.size())); - std::cout << "Loaded engine from " << engine_path << std::endl; + RCLCPP_INFO_STREAM(rclcpp::get_logger("lidar_centerpoint"), "Loaded engine from " << engine_path); return true; }