Skip to content

Commit

Permalink
Merge generic solver test
Browse files Browse the repository at this point in the history
This adds a test that checks for uniform behavior across all solvers in Ginkgo.
It also fixes a few of the found issues.

Related PR: #973
  • Loading branch information
upsj authored Apr 21, 2022
2 parents b7a8edf + 9fd4c48 commit 7d8f86d
Show file tree
Hide file tree
Showing 40 changed files with 1,455 additions and 496 deletions.
153 changes: 16 additions & 137 deletions benchmark/tools/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <ginkgo/core/base/mtx_io.hpp>


#include "core/utils/matrix_utils.hpp"


#ifdef GKO_TOOL_COMPLEX
using value_type = std::complex<double>;
#else
Expand All @@ -49,127 +52,6 @@ using value_type = double;
using matrix_data = gko::matrix_data<value_type, gko::int64>;


matrix_data make_lower_triangular(const matrix_data& data)
{
matrix_data out(data.size);
for (auto entry : data.nonzeros) {
if (entry.column <= entry.row) {
out.nonzeros.push_back(entry);
}
}
return out;
}


matrix_data make_upper_triangular(const matrix_data& data)
{
matrix_data out(data.size);
for (auto entry : data.nonzeros) {
if (entry.column >= entry.row) {
out.nonzeros.push_back(entry);
}
}
return out;
}


matrix_data make_remove_diagonal(matrix_data data)
{
data.nonzeros.erase(
std::remove_if(data.nonzeros.begin(), data.nonzeros.end(),
[](auto entry) { return entry.row == entry.column; }),
data.nonzeros.end());
return data;
}


matrix_data make_unit_diagonal(matrix_data data)
{
data = make_remove_diagonal(data);
auto num_diags = std::min(data.size[0], data.size[1]);
for (gko::int64 i = 0; i < num_diags; i++) {
data.nonzeros.emplace_back(i, i, 1.0);
}
data.ensure_row_major_order();
return data;
}


matrix_data make_remove_zeros(matrix_data data)
{
data.nonzeros.erase(
std::remove_if(data.nonzeros.begin(), data.nonzeros.end(),
[](auto entry) { return entry.value == value_type{}; }),
data.nonzeros.end());
return data;
}


template <typename Op>
matrix_data make_symmetric_generic(const matrix_data& data, Op op)
{
matrix_data out(data.size);
// compute A + op(A^T)
for (auto entry : data.nonzeros) {
out.nonzeros.emplace_back(entry);
out.nonzeros.emplace_back(entry.column, entry.row, op(entry.value));
}
out.ensure_row_major_order();
// combine matching nonzeros
matrix_data out_compressed(data.size);
auto it = out.nonzeros.begin();
while (it != out.nonzeros.end()) {
auto entry = *it;
it++;
for (; it != out.nonzeros.end() && it->row == entry.row &&
it->column == entry.column;
++it) {
entry.value += it->value;
}
// store sum of entries at (row, column) divided by 2
out_compressed.nonzeros.emplace_back(entry.row, entry.column,
entry.value / 2.0);
}
return out_compressed;
}

matrix_data make_diag_dominant(matrix_data data, double scale = 1.01)
{
GKO_ASSERT_IS_SQUARE_MATRIX(data.size);
std::vector<double> norms(data.size[0]);
std::vector<gko::int64> diag_positions(data.size[0], -1);
gko::int64 i{};
for (auto entry : data.nonzeros) {
if (entry.row == entry.column) {
diag_positions[entry.row] = i;
} else {
norms[entry.row] += gko::abs(entry.value);
}
i++;
}
for (gko::int64 i = 0; i < data.size[0]; i++) {
if (diag_positions[i] < 0) {
data.nonzeros.emplace_back(i, i, norms[i] * scale);
} else {
auto& diag_value = data.nonzeros[diag_positions[i]].value;
const auto diag_magnitude = gko::abs(diag_value);
const auto offdiag_magnitude = norms[i];
if (diag_magnitude < offdiag_magnitude * scale) {
const auto scaled_value =
diag_value * (offdiag_magnitude * scale / diag_magnitude);
if (gko::is_finite(scaled_value)) {
diag_value = scaled_value;
} else {
diag_value = offdiag_magnitude * scale;
}
}
}
}
data.ensure_row_major_order();
return data;
}


int main(int argc, char** argv)
{
if (argc == 1) {
Expand Down Expand Up @@ -202,33 +84,30 @@ int main(int argc, char** argv)
for (int argi = binary ? 2 : 1; argi < argc; argi++) {
std::string arg{argv[argi]};
if (arg == "lower-triangular") {
data = make_lower_triangular(data);
gko::test::make_lower_triangular(data);
} else if (arg == "upper-triangular") {
data = make_upper_triangular(data);
gko::test::make_upper_triangular(data);
} else if (arg == "remove-diagonal") {
data = make_remove_diagonal(data);
gko::test::make_remove_diagonal(data);
} else if (arg == "remove-zeros") {
data = make_remove_zeros(data);
data.remove_zeros();
} else if (arg == "unit-diagonal") {
data = make_unit_diagonal(data);
gko::test::make_unit_diagonal(data);
} else if (arg == "symmetric") {
data = make_symmetric_generic(data, [](auto v) { return v; });
gko::test::make_symmetric(data);
} else if (arg == "skew-symmetric") {
data = make_symmetric_generic(data, [](auto v) { return -v; });
gko::test::make_symmetric_generic(data, [](auto v) { return -v; });
} else if (arg == "hermitian") {
data = make_symmetric_generic(data,
[](auto v) { return gko::conj(v); });
gko::test::make_hermitian(data);
} else if (arg == "skew-hermitian") {
data = make_symmetric_generic(data,
[](auto v) { return -gko::conj(v); });
gko::test::make_symmetric_generic(
data, [](auto v) { return -gko::conj(v); });
} else if (arg == "diagonal-dominant") {
data = make_diag_dominant(data);
gko::test::make_diag_dominant(data);
} else if (arg == "spd") {
data = make_diag_dominant(
make_symmetric_generic(data, [](auto v) { return v; }));
gko::test::make_spd(data);
} else if (arg == "hpd") {
data = make_diag_dominant(make_symmetric_generic(
data, [](auto v) { return gko::conj(v); }));
gko::test::make_hpd(data);
} else {
std::cerr << "Unknown operation " << arg << std::endl;
return 1;
Expand Down
9 changes: 9 additions & 0 deletions core/test/matrix/dense.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,15 @@ TYPED_TEST(Dense, CanCreateSubmatrix)
}


TYPED_TEST(Dense, CanCreateEmptySubmatrix)
{
using value_type = typename TestFixture::value_type;
auto submtx = this->mtx->create_submatrix(gko::span{0, 0}, gko::span{1, 1});

EXPECT_FALSE(submtx->get_size());
}


TYPED_TEST(Dense, CanCreateSubmatrixWithStride)
{
using value_type = typename TestFixture::value_type;
Expand Down
2 changes: 1 addition & 1 deletion core/test/utils/assertions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ ::testing::AssertionResult matrices_near_impl(
<< second_expression << " is " << err << "\n"
<< "\twhich is larger than " << tolerance_expression
<< " (which is " << tolerance << ")\n";
if (num_rows * num_cols <= 1000) {
if (num_rows <= 10 && num_cols <= 10) {
fail << first_expression << " is:\n";
detail::print_matrix(fail, first);
fail << second_expression << " is:\n";
Expand Down
Loading

0 comments on commit 7d8f86d

Please sign in to comment.