Skip to content

Commit

Permalink
Optimize function blobFromImage
Browse files Browse the repository at this point in the history
  • Loading branch information
fateshelled committed Jun 12, 2024
1 parent a9d35df commit 41c9925
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 49 deletions.
2 changes: 1 addition & 1 deletion yolov9mit_ros/yolov9mit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ endif()
find_package(ament_cmake_auto REQUIRED)
ament_auto_find_build_dependencies()

option(YOLOV9_MIT_USE_TENSORRT "Use TensorRT" ON)
option(YOLOV9_MIT_USE_TENSORRT "Use TensorRT" ON)

set(ENABLE_TENSORRT OFF)

Expand Down
58 changes: 20 additions & 38 deletions yolov9mit_ros/yolov9mit/include/yolov9mit/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace yolov9mit

struct Object
{
cv::Rect_<float> rect;
cv::Rect2f rect;
int class_id;
float confidence;
};
Expand All @@ -30,9 +30,12 @@ class AbcYOLOV9MIT
protected:
size_t input_w_;
size_t input_h_;
const size_t input_channel_ = 3;
float min_iou_;
float min_confidence_;
size_t num_classes_;
const float blob_scale = 1.0f / 255.0f;
std::vector<float> blob_data_;

cv::Mat preprocess(const cv::Mat &img)
{
Expand All @@ -43,48 +46,27 @@ class AbcYOLOV9MIT
}

// HWC -> NCHW
std::vector<float> blobFromImage(const cv::Mat &img)
void blobFromImage(const cv::Mat &img)
{
static const float scale = 1.0f / 255.0f;
static const size_t channels = 3;
const size_t img_h = img.rows;
const size_t img_w = img.cols;
std::vector<float> blob_data(channels * img_h * img_w);

for (size_t c = 0; c < channels; ++c)
{
const size_t chw = c * img_w * img_h;
for (size_t h = 0; h < img_h; ++h)
{
const size_t chw_hh = chw + h * img_w;
for (size_t w = 0; w < img_w; ++w)
{
// blob_data[c * img_w * img_h + h * img_w + w] =
// (float)img.ptr<cv::Vec3b>(h)[w][c] * scale;
blob_data[chw_hh + w] = (float)img.ptr<cv::Vec3b>(h)[w][c] * scale;
}
}
}
return blob_data;
}

// HWC -> NHWC
std::vector<float> blobFromImage_nhwc(const cv::Mat &img)
{
static const float scale = 1.0f / 255.0f;
static const size_t channels = 3;
const size_t img_hw = img.rows * img.cols;

std::vector<float> blob_data(channels * img_hw);

for (size_t i = 0; i < img_hw; ++i)
const size_t input_size = input_channel_ * input_h_ * input_w_;
blob_data_.resize(input_size);

// (input_h_, input_w_, input_channel_) -> (input_h_ * input_w_ * input_channel_)
cv::Mat flatten = img.reshape(1, 1);
std::vector<float> img_vec;
flatten.convertTo(img_vec, CV_32FC1, blob_scale);

// img_vec = [r0, g0, b0, r1, g1, b1, ... ]
// blob_data_ = [r0, r1, ..., g0, g1, ..., b0, b1, ... ]
float *blob_ptr = blob_data_.data();
float *img_vec_ptr = img_vec.data();
for (size_t c = 0; c < input_channel_; ++c)
{
for (size_t c = 0; c < channels; ++c)
for (size_t i = c; i < input_size; i += 3)
{
blob_data[i * channels + c] = (float)img.data[i * channels + c] * scale;
*blob_ptr++ = img_vec_ptr[i];
}
}
return blob_data;
}

std::vector<Object> outputs_to_objects(const std::vector<float> &prob_classes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ class YOLOV9MIT_TensorRT : public AbcYOLOV9MIT
std::vector<Object> inference(const cv::Mat &frame) override;

private:
void doInference(std::vector<float> input, std::vector<float> &output0,
std::vector<float> &output1);
void doInference(std::vector<float> &output0, std::vector<float> &output1);

int32_t device_ = 0;
MyTRTLogger trt_logger_;
Expand Down
12 changes: 7 additions & 5 deletions yolov9mit_ros/yolov9mit/src/yolov9mit_tensorrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,22 +115,24 @@ YOLOV9MIT_TensorRT::YOLOV9MIT_TensorRT(file_name_t engine_path, int device, floa

std::vector<Object> YOLOV9MIT_TensorRT::inference(const cv::Mat& frame)
{
// preprocess
const auto pr_img = preprocess(frame);
const auto input_blob = blobFromImage(pr_img);

// HWC -> NCHW
blobFromImage(pr_img);

// inference
std::vector<float> output_blob_classes;
std::vector<float> output_blob_bbox;
this->doInference(input_blob, output_blob_classes, output_blob_bbox);
this->doInference(output_blob_classes, output_blob_bbox);

const auto objects =
decode_outputs(output_blob_classes, output_blob_bbox, frame.cols, frame.rows);

return objects;
}

void YOLOV9MIT_TensorRT::doInference(std::vector<float> input, std::vector<float>& output0,
std::vector<float>& output1)
void YOLOV9MIT_TensorRT::doInference(std::vector<float>& output0, std::vector<float>& output1)
{
void* buffers[3];
output0.resize(this->output0_size_);
Expand All @@ -146,7 +148,7 @@ void YOLOV9MIT_TensorRT::doInference(std::vector<float> input, std::vector<float
cuda_check(cudaStreamCreate(&stream));

// cudaMemcpyAsync(dist, src, size, type, stream)
cuda_check(cudaMemcpyAsync(buffers[this->input_index_], input.data(),
cuda_check(cudaMemcpyAsync(buffers[this->input_index_], blob_data_.data(),
3 * this->input_h_ * this->input_w_ * sizeof(float),
cudaMemcpyHostToDevice, stream));
context_->enqueueV2(buffers, stream, nullptr);
Expand Down
7 changes: 4 additions & 3 deletions yolov9mit_ros/yolov9mit_ros/src/yolov9mit_ros.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,17 @@ void YOLOV9MIT_Node::image_callback(const sensor_msgs::msg::Image::ConstSharedPt

// time log
{
RCLCPP_INFO(this->get_logger(), "Elapsed");
auto inf_elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(t1_inf - t0_inf);
RCLCPP_INFO(this->get_logger(), "Inference: %ld ms", inf_elapsed.count());
RCLCPP_INFO(this->get_logger(), " - Inference: %ld ms", inf_elapsed.count());

auto bboxes_elapsed =
std::chrono::duration_cast<std::chrono::milliseconds>(t1_bboxes - t0_bboxes);
RCLCPP_INFO(this->get_logger(), "to Detection2DArray: %ld ms", bboxes_elapsed.count());
RCLCPP_INFO(this->get_logger(), " - to Detection2DArray: %ld ms", bboxes_elapsed.count());

auto draw_elapsed =
std::chrono::duration_cast<std::chrono::milliseconds>(t1_draw - t0_draw);
RCLCPP_INFO(this->get_logger(), "Draw objects: %ld ms", draw_elapsed.count());
RCLCPP_INFO(this->get_logger(), " - Draw objects: %ld ms", draw_elapsed.count());

RCLCPP_INFO(this->get_logger(), "Detections: %ld count", objects.size());
RCLCPP_INFO(this->get_logger(), " ");
Expand Down

0 comments on commit 41c9925

Please sign in to comment.