Skip to content

Commit

Permalink
Make scale a member func, rem solver support
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Jan 16, 2024
1 parent dc2862a commit 233eb5d
Show file tree
Hide file tree
Showing 14 changed files with 93 additions and 197 deletions.
40 changes: 15 additions & 25 deletions core/matrix/batch_csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,21 @@ void Csr<ValueType, IndexType>::apply_impl(const MultiVector<ValueType>* alpha,
}


template <typename ValueType, typename IndexType>
void Csr<ValueType, IndexType>::scale(const array<ValueType>& col_scale,
const array<ValueType>& row_scale)
{
GKO_ASSERT_EQ(col_scale.get_size(),
(this->get_common_size()[1] * this->get_num_batch_items()));
GKO_ASSERT_EQ(row_scale.get_size(),
(this->get_common_size()[0] * this->get_num_batch_items()));
auto exec = this->get_executor();
exec->run(csr::make_scale(make_temporary_clone(exec, &col_scale).get(),
make_temporary_clone(exec, &row_scale).get(),
this));
}


template <typename ValueType, typename IndexType>
void Csr<ValueType, IndexType>::convert_to(
Csr<next_precision<ValueType>, IndexType>* result) const
Expand All @@ -205,31 +220,6 @@ void Csr<ValueType, IndexType>::move_to(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_CSR_MATRIX);


template <typename ValueType, typename IndexType>
void scale_in_place(const array<ValueType>& col_scale,
const array<ValueType>& row_scale,
batch::matrix::Csr<ValueType, IndexType>* in_out)
{
GKO_ASSERT_EQ(col_scale.get_size(), (in_out->get_common_size()[1] *
in_out->get_num_batch_items()));
GKO_ASSERT_EQ(row_scale.get_size(), (in_out->get_common_size()[0] *
in_out->get_num_batch_items()));
auto exec = in_out->get_executor();
exec->run(csr::make_scale(make_temporary_clone(exec, &col_scale).get(),
make_temporary_clone(exec, &row_scale).get(),
in_out));
}


#define GKO_DECLARE_TWO_SIDED_BATCH_SCALE(_vtype, _itype) \
void scale_in_place(const array<_vtype>& col_scale, \
const array<_vtype>& row_scale, \
batch::matrix::Csr<_vtype, _itype>* in_out)

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_DECLARE_TWO_SIDED_BATCH_SCALE);


} // namespace matrix
} // namespace batch
} // namespace gko
39 changes: 15 additions & 24 deletions core/matrix/batch_dense.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,21 @@ void Dense<ValueType>::apply_impl(const MultiVector<ValueType>* alpha,
}


template <typename ValueType>
void Dense<ValueType>::scale(const array<ValueType>& col_scale,
const array<ValueType>& row_scale)
{
GKO_ASSERT_EQ(col_scale.get_size(),
(this->get_common_size()[1] * this->get_num_batch_items()));
GKO_ASSERT_EQ(row_scale.get_size(),
(this->get_common_size()[0] * this->get_num_batch_items()));
auto exec = this->get_executor();
exec->run(dense::make_scale(make_temporary_clone(exec, &col_scale).get(),
make_temporary_clone(exec, &row_scale).get(),
this));
}


template <typename ValueType>
void Dense<ValueType>::convert_to(
Dense<next_precision<ValueType>>* result) const
Expand All @@ -191,30 +206,6 @@ void Dense<ValueType>::move_to(Dense<next_precision<ValueType>>* result)
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_DENSE_MATRIX);


template <typename ValueType>
void scale_in_place(const array<ValueType>& col_scale,
const array<ValueType>& row_scale,
batch::matrix::Dense<ValueType>* in_out)
{
GKO_ASSERT_EQ(col_scale.get_size(), (in_out->get_common_size()[1] *
in_out->get_num_batch_items()));
GKO_ASSERT_EQ(row_scale.get_size(), (in_out->get_common_size()[0] *
in_out->get_num_batch_items()));
auto exec = in_out->get_executor();
exec->run(dense::make_scale(make_temporary_clone(exec, &col_scale).get(),
make_temporary_clone(exec, &row_scale).get(),
in_out));
}


#define GKO_DECLARE_TWO_SIDED_BATCH_SCALE(_type) \
void scale_in_place(const array<_type>& col_scale, \
const array<_type>& row_scale, \
batch::matrix::Dense<_type>* in_out)

GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_TWO_SIDED_BATCH_SCALE);


} // namespace matrix
} // namespace batch
} // namespace gko
43 changes: 21 additions & 22 deletions core/matrix/batch_ell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,21 @@ void Ell<ValueType, IndexType>::apply_impl(const MultiVector<ValueType>* alpha,
}


template <typename ValueType, typename IndexType>
void Ell<ValueType, IndexType>::scale(const array<ValueType>& col_scale,
const array<ValueType>& row_scale)
{
GKO_ASSERT_EQ(col_scale.get_size(),
(this->get_common_size()[1] * this->get_num_batch_items()));
GKO_ASSERT_EQ(row_scale.get_size(),
(this->get_common_size()[0] * this->get_num_batch_items()));
auto exec = this->get_executor();
exec->run(ell::make_scale(make_temporary_clone(exec, &col_scale).get(),
make_temporary_clone(exec, &row_scale).get(),
this));
}


template <typename ValueType, typename IndexType>
void Ell<ValueType, IndexType>::convert_to(
Ell<next_precision<ValueType>, IndexType>* result) const
Expand All @@ -206,29 +221,13 @@ void Ell<ValueType, IndexType>::move_to(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_BATCH_ELL_MATRIX);


template <typename ValueType, typename IndexType>
void scale_in_place(const array<ValueType>& col_scale,
const array<ValueType>& row_scale,
batch::matrix::Ell<ValueType, IndexType>* in_out)
{
GKO_ASSERT_EQ(col_scale.get_size(), (in_out->get_common_size()[1] *
in_out->get_num_batch_items()));
GKO_ASSERT_EQ(row_scale.get_size(), (in_out->get_common_size()[0] *
in_out->get_num_batch_items()));
auto exec = in_out->get_executor();
exec->run(ell::make_scale(make_temporary_clone(exec, &col_scale).get(),
make_temporary_clone(exec, &row_scale).get(),
in_out));
}


#define GKO_DECLARE_TWO_SIDED_BATCH_SCALE(_vtype, _itype) \
void scale_in_place(const array<_vtype>& col_scale, \
const array<_vtype>& row_scale, \
batch::matrix::Ell<_vtype, _itype>* in_out)
// #define GKO_DECLARE_BATCH_ELL_TWO_SIDED_BATCH_SCALE(_vtype, _itype) \
// void scale(const array<_vtype>& col_scale, \
// const array<_vtype>& row_scale, \
// ptr_param<batch::matrix::Ell<_vtype, _itype>> in_out)

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
GKO_DECLARE_TWO_SIDED_BATCH_SCALE);
// GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INT32_TYPE(
// GKO_DECLARE_BATCH_ELL_TWO_SIDED_BATCH_SCALE);


} // namespace matrix
Expand Down
28 changes: 0 additions & 28 deletions core/test/solver/batch_bicgstab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,34 +245,6 @@ TYPED_TEST(BatchBicgstab, CanSetTolType)
}


TYPED_TEST(BatchBicgstab, CanSetScalingVectors)
{
using Solver = typename TestFixture::Solver;
using value_type = typename TestFixture::value_type;
using real_type = typename TestFixture::real_type;
auto scale_size = this->num_batch_items * this->num_rows;
auto col_scale = gko::array<value_type>(this->exec, scale_size);
col_scale.fill(0.5);
auto row_scale = gko::array<value_type>(this->exec, scale_size);
row_scale.fill(0.8);

auto solver_factory = Solver::build()
.with_max_iterations(22)
.with_tolerance(static_cast<real_type>(0.25))
.with_col_scaling(col_scale)
.with_row_scaling(row_scale)
.on(this->exec);
auto solver = solver_factory->generate(this->mtx);

ASSERT_EQ(solver->get_parameters().row_scaling.get_size(), scale_size);
ASSERT_EQ(solver->get_parameters().row_scaling.get_const_data()[0],
value_type{0.8});
ASSERT_EQ(solver->get_parameters().col_scaling.get_size(), scale_size);
ASSERT_EQ(solver->get_parameters().col_scaling.get_const_data()[0],
value_type{0.5});
}


TYPED_TEST(BatchBicgstab, ThrowsOnRectangularMatrixInFactory)
{
using Mtx = typename TestFixture::Mtx;
Expand Down
24 changes: 9 additions & 15 deletions include/ginkgo/core/matrix/batch_csr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,15 @@ class Csr final
ptr_param<const MultiVector<value_type>> beta,
ptr_param<MultiVector<value_type>> x) const;

/**
* Performs in-place row and column scaling for this matrix.
*
* @param col_scale the column scalars
* @param row_scale the row scalars
*/
void scale(const array<value_type>& col_scale,
const array<value_type>& row_scale);

private:
/**
* Creates an uninitialized Csr matrix of the specified size.
Expand Down Expand Up @@ -331,21 +340,6 @@ class Csr final
};


/**
* Performs in-place row and column scaling for a given matrix.
*
* @param col_scale the column scalars
* @param row_scale the row scalars
* @param in_out the matrix to be scaled
*
* @note the operation is performed in-place
*/
template <typename ValueType, typename IndexType>
void scale_in_place(const array<ValueType>& col_scale,
const array<ValueType>& row_scale,
batch::matrix::Csr<ValueType, IndexType>* in_out);


} // namespace matrix
} // namespace batch
} // namespace gko
Expand Down
24 changes: 9 additions & 15 deletions include/ginkgo/core/matrix/batch_dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,15 @@ class Dense final : public EnableBatchLinOp<Dense<ValueType>>,
ptr_param<const MultiVector<value_type>> beta,
ptr_param<MultiVector<value_type>> x) const;

/**
* Performs in-place row and column scaling for this matrix.
*
* @param col_scale the column scalars
* @param row_scale the row scalars
*/
void scale(const array<value_type>& col_scale,
const array<value_type>& row_scale);

private:
inline size_type compute_num_elems(const batch_dim<2>& size)
{
Expand Down Expand Up @@ -348,21 +357,6 @@ class Dense final : public EnableBatchLinOp<Dense<ValueType>>,
};


/**
* Performs in-place row and column scaling for a given matrix.
*
* @param col_scale the column scalars
* @param row_scale the row scalars
* @param in_out the matrix to be scaled
*
* @note the operation is performed in-place
*/
template <typename ValueType>
void scale_in_place(const array<ValueType>& col_scale,
const array<ValueType>& row_scale,
batch::matrix::Dense<ValueType>* in_out);


} // namespace matrix
} // namespace batch
} // namespace gko
Expand Down
25 changes: 10 additions & 15 deletions include/ginkgo/core/matrix/batch_ell.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <ginkgo/core/base/range_accessors.hpp>
#include <ginkgo/core/base/types.hpp>
#include <ginkgo/core/base/utils.hpp>
#include <ginkgo/core/base/utils_helper.hpp>
#include <ginkgo/core/matrix/ell.hpp>


Expand Down Expand Up @@ -283,6 +284,15 @@ class Ell final
ptr_param<const MultiVector<value_type>> beta,
ptr_param<MultiVector<value_type>> x) const;

/**
* Performs in-place row and column scaling for this matrix.
*
* @param col_scale the column scalars
* @param row_scale the row scalars
*/
void scale(const array<value_type>& col_scale,
const array<value_type>& row_scale);

private:
size_type compute_num_elems(const batch_dim<2>& size,
IndexType num_elems_per_row)
Expand Down Expand Up @@ -348,21 +358,6 @@ class Ell final
};


/**
* Performs in-place row and column scaling for a given matrix.
*
* @param col_scale the column scalars
* @param row_scale the row scalars
* @param in_out the matrix to be scaled
*
* @note the operation is performed in-place
*/
template <typename ValueType, typename IndexType>
void scale_in_place(const array<ValueType>& col_scale,
const array<ValueType>& row_scale,
batch::matrix::Ell<ValueType, IndexType>* in_out);


} // namespace matrix
} // namespace batch
} // namespace gko
Expand Down
35 changes: 0 additions & 35 deletions include/ginkgo/core/solver/batch_solver_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
#include <ginkgo/core/base/batch_multi_vector.hpp>
#include <ginkgo/core/base/utils_helper.hpp>
#include <ginkgo/core/log/batch_logger.hpp>
#include <ginkgo/core/matrix/batch_csr.hpp>
#include <ginkgo/core/matrix/batch_dense.hpp>
#include <ginkgo/core/matrix/batch_ell.hpp>
#include <ginkgo/core/matrix/batch_identity.hpp>
#include <ginkgo/core/stop/batch_stop_enum.hpp>

Expand Down Expand Up @@ -191,16 +188,6 @@ struct enable_preconditioned_iterative_solver_factory_parameters
*/
std::shared_ptr<const BatchLinOp> GKO_FACTORY_PARAMETER_SCALAR(
generated_preconditioner, nullptr);

/**
* Column scaling vector
*/
array<ValueType> GKO_FACTORY_PARAMETER_SCALAR(col_scaling, {});

/**
* Row scaling vector
*/
array<ValueType> GKO_FACTORY_PARAMETER_SCALAR(row_scaling, {});
};


Expand Down Expand Up @@ -291,28 +278,6 @@ class EnableBatchSolver
using value_type = typename ConcreteSolver::value_type;
using Identity = matrix::Identity<value_type>;
using real_type = remove_complex<value_type>;
using batch_dense = matrix::Dense<value_type>;
using batch_csr = matrix::Csr<value_type>;
using batch_ell = matrix::Ell<value_type>;

if (params.col_scaling.get_executor() &&
params.row_scaling.get_executor()) {
GKO_ASSERT_EQ(params.col_scaling.get_size(),
system_matrix->get_common_size()[0] *
system_matrix->get_num_batch_items());
GKO_ASSERT_EQ(params.col_scaling.get_size(),
params.row_scaling.get_size());
if (auto mat = as<batch_dense>(system_matrix)) {
matrix::scale_in_place(params.col_scaling, params.row_scaling,
const_cast<batch_dense*>(mat.get()));
} else if (auto mat = as<batch_csr>(system_matrix)) {
matrix::scale_in_place(params.col_scaling, params.row_scaling,
const_cast<batch_csr*>(mat.get()));
} else if (auto mat = as<batch_ell>(system_matrix)) {
matrix::scale_in_place(params.col_scaling, params.row_scaling,
const_cast<batch_ell*>(mat.get()));
}
}

if (params.generated_preconditioner) {
GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(params.generated_preconditioner,
Expand Down
Loading

0 comments on commit 233eb5d

Please sign in to comment.