Skip to content

Commit

Permalink
make input stream configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed May 25, 2023
1 parent f68868e commit 12a4218
Show file tree
Hide file tree
Showing 12 changed files with 38 additions and 13 deletions.
2 changes: 1 addition & 1 deletion benchmark/blas/blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ Parameters for a benchmark case are:
print_general_information(extra_information);
auto exec = executor_factory.at(FLAGS_executor)(FLAGS_gpu_timer);

rapidjson::IStreamWrapper jcin(std::cin);
rapidjson::IStreamWrapper jcin(get_input_stream());
rapidjson::Document test_cases;
test_cases.ParseStream(jcin);
if (!test_cases.IsArray()) {
Expand Down
2 changes: 1 addition & 1 deletion benchmark/blas/distributed/multi_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Parameters for a benchmark case are:

auto exec = executor_factory_mpi.at(FLAGS_executor)(comm.get());

std::string json_input = broadcast_json_input(std::cin, comm);
std::string json_input = broadcast_json_input(get_input_stream(), comm);
rapidjson::Document test_cases;
test_cases.Parse(json_input.c_str());
if (!test_cases.IsArray()) {
Expand Down
2 changes: 1 addition & 1 deletion benchmark/conversions/conversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ int main(int argc, char* argv[])
auto exec = executor_factory.at(FLAGS_executor)(FLAGS_gpu_timer);
auto formats = split(FLAGS_formats, ',');

rapidjson::IStreamWrapper jcin(std::cin);
rapidjson::IStreamWrapper jcin(get_input_stream());
rapidjson::Document test_cases;
test_cases.ParseStream(jcin);
if (!test_cases.IsArray()) {
Expand Down
2 changes: 1 addition & 1 deletion benchmark/matrix_generator/matrix_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ int main(int argc, char* argv[])
std::clog << gko::version_info::get() << std::endl;

auto engine = get_engine();
rapidjson::IStreamWrapper jcin(std::cin);
rapidjson::IStreamWrapper jcin(get_input_stream());
rapidjson::Document configurations;
configurations.ParseStream(jcin);

Expand Down
2 changes: 1 addition & 1 deletion benchmark/matrix_statistics/matrix_statistics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ int main(int argc, char* argv[])

std::clog << gko::version_info::get() << std::endl;

rapidjson::IStreamWrapper jcin(std::cin);
rapidjson::IStreamWrapper jcin(get_input_stream());
rapidjson::Document test_cases;
test_cases.ParseStream(jcin);
if (!test_cases.IsArray()) {
Expand Down
2 changes: 1 addition & 1 deletion benchmark/preconditioner/preconditioner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ int main(int argc, char* argv[])
std::exit(1);
}

rapidjson::IStreamWrapper jcin(std::cin);
rapidjson::IStreamWrapper jcin(get_input_stream());
rapidjson::Document test_cases;
test_cases.ParseStream(jcin);
if (!test_cases.IsArray()) {
Expand Down
6 changes: 3 additions & 3 deletions benchmark/solver/distributed/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ int main(int argc, char* argv[])
}
}

std::string json_input = FLAGS_overhead
? R"(
std::string json_input =
FLAGS_overhead ? R"(
[{"filename": "overhead.mtx",
"optimal": {"spmv": "csr-csr"}]
)"
: broadcast_json_input(std::cin, comm);
: broadcast_json_input(get_input_stream(), comm);
rapidjson::Document test_cases;
test_cases.Parse(json_input.c_str());

Expand Down
2 changes: 1 addition & 1 deletion benchmark/solver/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ int main(int argc, char* argv[])

rapidjson::Document test_cases;
if (!FLAGS_overhead) {
rapidjson::IStreamWrapper jcin(std::cin);
rapidjson::IStreamWrapper jcin(get_input_stream());
test_cases.ParseStream(jcin);
} else {
// Fake test case to run once
Expand Down
2 changes: 1 addition & 1 deletion benchmark/sparse_blas/sparse_blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ int main(int argc, char* argv[])

auto exec = executor_factory.at(FLAGS_executor)(FLAGS_gpu_timer);

rapidjson::IStreamWrapper jcin(std::cin);
rapidjson::IStreamWrapper jcin(get_input_stream());
rapidjson::Document test_cases;
test_cases.ParseStream(jcin);
if (!test_cases.IsArray()) {
Expand Down
2 changes: 1 addition & 1 deletion benchmark/spmv/distributed/spmv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ int main(int argc, char* argv[])
}
}

std::string json_input = broadcast_json_input(std::cin, comm);
std::string json_input = broadcast_json_input(get_input_stream(), comm);
rapidjson::Document test_cases;
test_cases.Parse(json_input.c_str());
if (!test_cases.IsArray()) {
Expand Down
2 changes: 1 addition & 1 deletion benchmark/spmv/spmv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ int main(int argc, char* argv[])
auto exec = executor_factory.at(FLAGS_executor)(FLAGS_gpu_timer);
auto formats = split(FLAGS_formats, ',');

rapidjson::IStreamWrapper jcin(std::cin);
rapidjson::IStreamWrapper jcin(get_input_stream());
rapidjson::Document test_cases;
test_cases.ParseStream(jcin);
if (!test_cases.IsArray()) {
Expand Down
25 changes: 25 additions & 0 deletions benchmark/utils/general.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ DEFINE_string(double_buffer, "",
" buffering of backup files, in case of a"
" crash when overwriting the backup");

DEFINE_string(
input, "",
"If set, the value is used as the input for the benchmark (if set to a "
"string ending with ]) or as input file path (otherwise).");

DEFINE_bool(detailed, true,
"If set, performs several runs to obtain more detailed results");

Expand Down Expand Up @@ -243,6 +248,26 @@ std::vector<std::string> split(const std::string& s, char delimiter = ',')
}


// returns the stream to be used as input of the application
std::istream& get_input_stream()
{
static auto stream = []() -> std::unique_ptr<std::istream> {
std::string input_str(FLAGS_input);
if (input_str.empty()) {
return nullptr;
}
if (input_str.back() == ']') {
return std::make_unique<std::stringstream>(input_str);
}
return std::make_unique<std::ifstream>(input_str);
}();
if (stream) {
return *stream;
}
return std::cin;
}


// backup generation
void backup_results(rapidjson::Document& results)
{
Expand Down

0 comments on commit 12a4218

Please sign in to comment.