Skip to content

Commit

Permalink
Add to write prediction for lightNet training
Browse files Browse the repository at this point in the history
  • Loading branch information
umedan committed May 8, 2023
1 parent 52267ca commit f1e0c7b
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 4 deletions.
11 changes: 10 additions & 1 deletion modules/yolo_config_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ DEFINE_string(config_file_path1, "not-specified", "[REQUIRED] Darknet cfg file")
DEFINE_string(wts_file_path1, "not-specified", "[REQUIRED] Darknet weights file");
DEFINE_string(config_file_path2, "not-specified", "[REQUIRED] Darknet cfg file");
DEFINE_string(wts_file_path2, "not-specified", "[REQUIRED] Darknet weights file");
DEFINE_string(labels_file_path, "configs/bdd100k.names", "[REQUIRED] Object class labels file");
DEFINE_string(labels_file_path, "../configs/bdd100k.names", "[REQUIRED] Object class labels file");
DEFINE_string(precision, "kFLOAT",
"[OPTIONAL] Inference precision. Choose from kFLOAT, kHALF and kINT8.");
DEFINE_string(deviceType, "kGPU",
Expand Down Expand Up @@ -94,6 +94,9 @@ DEFINE_bool(prof, false,
DEFINE_string(dump, "not-specified",
"[OPTIONAL] Path to dump predictions for mAP calculation");

DEFINE_string(output, "not-specified",
"[OPTIONAL] Path to output predictions for pseudo-labeling");

DEFINE_bool(mp, false,
"[OPTIONAL] Flag to multi-precision");
DEFINE_int64(dla, -1, "[OPTIONAL] DLA");
Expand Down Expand Up @@ -269,3 +272,9 @@ get_cuda_flg(void)
{
return FLAGS_cuda;
}

std::string
get_output_path(void)
{
return FLAGS_output;
}
2 changes: 2 additions & 0 deletions modules/yolo_config_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,6 @@ bool
get_multi_precision_flg(void);
bool
get_cuda_flg(void);
std::string
get_output_path(void);
#endif //_YOLO_CONFIG_PARSER_
45 changes: 42 additions & 3 deletions samples/sample_detector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,34 @@ write_prediction(std::string dumpPath, std::string filename, std::vector<std::st
writing_file.close();
}

void
write_outputs(std::string outputPath, std::string filename, std::vector<std::string> names, std::vector<Result> objects, int width, int height)
{
int pos = filename.find_last_of(".");
std::string body = filename.substr(0, pos);
std::string dstName = body + ".txt";
std::ofstream writing_file;
fs::path p = outputPath;
fs::create_directory(p);
p.append(dstName);
writing_file.open(p.string(), std::ios::out);
for (const auto & object : objects) {
const auto left = object.rect.x;
const auto top = object.rect.y;
const auto right = std::clamp(left + object.rect.width, 0, width);
const auto bottom = std::clamp(top + object.rect.height, 0, height);
std::string writing_text = format("%d %f %f %f %f",
object.id,
(left+right)/2/(double)width,
(top+bottom)/2/(double)height,
object.rect.width/(double)width,
object.rect.height/(double)height
);
writing_file << writing_text << std::endl;
}
writing_file.close();
}


int main(int argc, char** argv)
{
Expand All @@ -58,6 +86,7 @@ int main(int argc, char** argv)
std::string save_path = getSaveDetectionsPath();
bool dont_show = get_dont_show_flg();
const std::string dumpPath = get_dump_path();
const std::string outputPath = get_output_path();
const bool cuda = get_cuda_flg();
Config config;
config.net_type = YOLOV4;
Expand Down Expand Up @@ -97,14 +126,22 @@ int main(int argc, char** argv)
batch_img.push_back(src);
}
detector->detect(batch_img, batch_res, cuda);
detector->segment(batch_img);
if (!dont_show) {
detector->segment(batch_img);
}

if (dumpPath != "not-specified") {
fs::path p (file.path());
std::string filename = p.filename().string();
std::vector<std::string> names = get_names();
write_prediction(dumpPath, filename, names, batch_res[0], src.cols, src.rows);
}
}
if (outputPath != "not-specified") {
fs::path p (file.path());
std::string filename = p.filename().string();
std::vector<std::string> names = get_names();
write_outputs(outputPath, filename, names, batch_res[0], src.cols, src.rows);
}
//disp
if (dont_show == true) {
continue;
Expand Down Expand Up @@ -139,7 +176,9 @@ int main(int argc, char** argv)
std::vector<cv::Mat> batch_img;
batch_img.push_back(frame);
detector->detect(batch_img, batch_res, cuda);
detector->segment(batch_img);
if (!dont_show) {
detector->segment(batch_img);
}

if (dont_show == true) {
continue;
Expand Down

0 comments on commit f1e0c7b

Please sign in to comment.