Skip to content

Commit

Permalink
add pointer_param
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed Feb 15, 2023
1 parent e303e8f commit 26d6fea
Show file tree
Hide file tree
Showing 23 changed files with 734 additions and 343 deletions.
22 changes: 14 additions & 8 deletions core/distributed/matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,10 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::move_to(
template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(
const device_matrix_data<value_type, global_index_type>& data,
const Partition<local_index_type, global_index_type>* row_partition,
const Partition<local_index_type, global_index_type>* col_partition)
pointer_param<const Partition<local_index_type, global_index_type>>
row_partition,
pointer_param<const Partition<local_index_type, global_index_type>>
col_partition)
{
const auto comm = this->get_communicator();
GKO_ASSERT_EQ(data.get_size()[0], row_partition->get_size());
Expand Down Expand Up @@ -171,8 +173,8 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(

// build local, non-local matrix data and communication structures
exec->run(matrix::make_build_local_nonlocal(
data, make_temporary_clone(exec, row_partition).get(),
make_temporary_clone(exec, col_partition).get(), local_part,
data, make_temporary_clone(exec, row_partition.get()).get(),
make_temporary_clone(exec, col_partition.get()).get(), local_part,
local_row_idxs, local_col_idxs, local_values, non_local_row_idxs,
non_local_col_idxs, non_local_values, recv_gather_idxs,
recv_sizes_array, non_local_to_global_));
Expand Down Expand Up @@ -228,8 +230,10 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(
template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(
const matrix_data<value_type, global_index_type>& data,
const Partition<local_index_type, global_index_type>* row_partition,
const Partition<local_index_type, global_index_type>* col_partition)
pointer_param<const Partition<local_index_type, global_index_type>>
row_partition,
pointer_param<const Partition<local_index_type, global_index_type>>
col_partition)
{
this->read_distributed(
device_matrix_data<value_type, global_index_type>::create_from_host(
Expand All @@ -241,7 +245,8 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(
template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(
const matrix_data<ValueType, global_index_type>& data,
const Partition<local_index_type, global_index_type>* partition)
pointer_param<const Partition<local_index_type, global_index_type>>
partition)
{
this->read_distributed(
device_matrix_data<value_type, global_index_type>::create_from_host(
Expand All @@ -253,7 +258,8 @@ void Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(
template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void Matrix<ValueType, LocalIndexType, GlobalIndexType>::read_distributed(
const device_matrix_data<ValueType, GlobalIndexType>& data,
const Partition<local_index_type, global_index_type>* partition)
pointer_param<const Partition<local_index_type, global_index_type>>
partition)
{
this->read_distributed(data, partition, partition);
}
Expand Down
128 changes: 87 additions & 41 deletions core/distributed/vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
template <typename ValueType>
Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
mpi::communicator comm, dim<2> global_size,
local_vector_type* local_vector)
pointer_param<local_vector_type> local_vector)
: EnableDistributedLinOp<Vector<ValueType>>{exec, global_size},
DistributedBase{comm},
local_{exec}
Expand All @@ -114,7 +114,7 @@ Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
template <typename ValueType>
Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,
mpi::communicator comm,
local_vector_type* local_vector)
pointer_param<local_vector_type> local_vector)
: EnableDistributedLinOp<Vector<ValueType>>{exec, {}},
DistributedBase{comm},
local_{exec}
Expand All @@ -126,7 +126,7 @@ Vector<ValueType>::Vector(std::shared_ptr<const Executor> exec,

template <typename ValueType>
std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create_with_config_of(
const Vector* other)
pointer_param<const Vector> other)
{
// De-referencing `other` before calling the functions (instead of
// using operator `->`) is currently required to be compatible with
Expand All @@ -138,16 +138,18 @@ std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create_with_config_of(

template <typename ValueType>
std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create_with_type_of(
const Vector<ValueType>* other, std::shared_ptr<const Executor> exec)
pointer_param<const Vector<ValueType>> other,
std::shared_ptr<const Executor> exec)
{
return (*other).create_with_type_of_impl(exec, {}, {}, 0);
}


template <typename ValueType>
std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create_with_type_of(
const Vector<ValueType>* other, std::shared_ptr<const Executor> exec,
const dim<2>& global_size, const dim<2>& local_size, size_type stride)
pointer_param<const Vector<ValueType>> other,
std::shared_ptr<const Executor> exec, const dim<2>& global_size,
const dim<2>& local_size, size_type stride)
{
return (*other).create_with_type_of_impl(exec, global_size, local_size,
stride);
Expand All @@ -156,7 +158,7 @@ std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create_with_type_of(

template <typename ValueType>
template <typename LocalIndexType, typename GlobalIndexType>
void Vector<ValueType>::read_distributed(
void Vector<ValueType>::read_distributed_impl(
const device_matrix_data<ValueType, GlobalIndexType>& data,
const Partition<LocalIndexType, GlobalIndexType>* partition)
{
Expand All @@ -175,16 +177,65 @@ void Vector<ValueType>::read_distributed(


template <typename ValueType>
template <typename LocalIndexType, typename GlobalIndexType>
void Vector<ValueType>::read_distributed(
const matrix_data<ValueType, GlobalIndexType>& data,
const Partition<LocalIndexType, GlobalIndexType>* partition)
const device_matrix_data<ValueType, int64>& data,
pointer_param<const Partition<int64, int64>> partition)
{
this->read_distributed_impl(data, partition.get());
}


template <typename ValueType>
void Vector<ValueType>::read_distributed(
const device_matrix_data<ValueType, int64>& data,
pointer_param<const Partition<int32, int64>> partition)
{
this->read_distributed_impl(data, partition.get());
}


template <typename ValueType>
void Vector<ValueType>::read_distributed(
const device_matrix_data<ValueType, int32>& data,
pointer_param<const Partition<int32, int32>> partition)
{
this->read_distributed_impl(data, partition.get());
}


template <typename ValueType>
void Vector<ValueType>::read_distributed(
const matrix_data<ValueType, int64>& data,
pointer_param<const Partition<int64, int64>> partition)
{
this->read_distributed(
device_matrix_data<value_type, int64>::create_from_host(
this->get_executor(), data),
partition);
}


template <typename ValueType>
void Vector<ValueType>::read_distributed(
const matrix_data<ValueType, int64>& data,
pointer_param<const Partition<int32, int64>> partition)
{
this->read_distributed(
device_matrix_data<value_type, int64>::create_from_host(
this->get_executor(), data),
partition);
}


template <typename ValueType>
void Vector<ValueType>::read_distributed(
const matrix_data<ValueType, int32>& data,
pointer_param<const Partition<int32, int32>> partition)
{
this->read_distributed(
device_matrix_data<value_type, GlobalIndexType>::create_from_host(
device_matrix_data<value_type, int32>::create_from_host(
this->get_executor(), data),
std::move(partition));
partition);
}


Expand Down Expand Up @@ -259,7 +310,8 @@ Vector<ValueType>::make_complex() const


template <typename ValueType>
void Vector<ValueType>::make_complex(Vector::complex_type* result) const
void Vector<ValueType>::make_complex(
pointer_param<Vector::complex_type> result) const
{
this->get_local_vector()->make_complex(&result->local_);
}
Expand All @@ -279,7 +331,7 @@ Vector<ValueType>::get_real() const


template <typename ValueType>
void Vector<ValueType>::get_real(Vector::real_type* result) const
void Vector<ValueType>::get_real(pointer_param<Vector::real_type> result) const
{
this->get_local_vector()->get_real(&result->local_);
}
Expand All @@ -299,52 +351,56 @@ Vector<ValueType>::get_imag() const


template <typename ValueType>
void Vector<ValueType>::get_imag(Vector::real_type* result) const
void Vector<ValueType>::get_imag(pointer_param<Vector::real_type> result) const
{
this->get_local_vector()->get_imag(&result->local_);
}


template <typename ValueType>
void Vector<ValueType>::scale(const LinOp* alpha)
void Vector<ValueType>::scale(pointer_param<const LinOp> alpha)
{
local_.scale(alpha);
}


template <typename ValueType>
void Vector<ValueType>::inv_scale(const LinOp* alpha)
void Vector<ValueType>::inv_scale(pointer_param<const LinOp> alpha)
{
local_.inv_scale(alpha);
}


template <typename ValueType>
void Vector<ValueType>::add_scaled(const LinOp* alpha, const LinOp* b)
void Vector<ValueType>::add_scaled(pointer_param<const LinOp> alpha,
pointer_param<const LinOp> b)
{
auto dense_b = as<Vector<ValueType>>(b);
local_.add_scaled(alpha, dense_b->get_local_vector());
}


template <typename ValueType>
void Vector<ValueType>::sub_scaled(const LinOp* alpha, const LinOp* b)
void Vector<ValueType>::sub_scaled(pointer_param<const LinOp> alpha,
pointer_param<const LinOp> b)
{
auto dense_b = as<Vector<ValueType>>(b);
local_.sub_scaled(alpha, dense_b->get_local_vector());
}


template <typename ValueType>
void Vector<ValueType>::compute_dot(const LinOp* b, LinOp* result) const
void Vector<ValueType>::compute_dot(pointer_param<const LinOp> b,
pointer_param<LinOp> result) const
{
array<char> tmp{this->get_executor()};
this->compute_dot(b, result, tmp);
}


template <typename ValueType>
void Vector<ValueType>::compute_dot(const LinOp* b, LinOp* result,
void Vector<ValueType>::compute_dot(pointer_param<const LinOp> b,
pointer_param<LinOp> result,
array<char>& tmp) const
{
GKO_ASSERT_EQUAL_DIMENSIONS(result, dim<2>(1, this->get_size()[1]));
Expand All @@ -370,15 +426,17 @@ void Vector<ValueType>::compute_dot(const LinOp* b, LinOp* result,


template <typename ValueType>
void Vector<ValueType>::compute_conj_dot(const LinOp* b, LinOp* result) const
void Vector<ValueType>::compute_conj_dot(pointer_param<const LinOp> b,
pointer_param<LinOp> result) const
{
array<char> tmp{this->get_executor()};
this->compute_conj_dot(b, result, tmp);
}


template <typename ValueType>
void Vector<ValueType>::compute_conj_dot(const LinOp* b, LinOp* result,
void Vector<ValueType>::compute_conj_dot(pointer_param<const LinOp> b,
pointer_param<LinOp> result,
array<char>& tmp) const
{
GKO_ASSERT_EQUAL_DIMENSIONS(result, dim<2>(1, this->get_size()[1]));
Expand All @@ -404,15 +462,16 @@ void Vector<ValueType>::compute_conj_dot(const LinOp* b, LinOp* result,


template <typename ValueType>
void Vector<ValueType>::compute_norm2(LinOp* result) const
void Vector<ValueType>::compute_norm2(pointer_param<LinOp> result) const
{
array<char> tmp{this->get_executor()};
this->compute_norm2(result, tmp);
}


template <typename ValueType>
void Vector<ValueType>::compute_norm2(LinOp* result, array<char>& tmp) const
void Vector<ValueType>::compute_norm2(pointer_param<LinOp> result,
array<char>& tmp) const
{
using NormVector = typename local_vector_type::absolute_type;
GKO_ASSERT_EQUAL_DIMENSIONS(result, dim<2>(1, this->get_size()[1]));
Expand All @@ -437,15 +496,16 @@ void Vector<ValueType>::compute_norm2(LinOp* result, array<char>& tmp) const


template <typename ValueType>
void Vector<ValueType>::compute_norm1(LinOp* result) const
void Vector<ValueType>::compute_norm1(pointer_param<LinOp> result) const
{
array<char> tmp{this->get_executor()};
this->compute_norm1(result, tmp);
}


template <typename ValueType>
void Vector<ValueType>::compute_norm1(LinOp* result, array<char>& tmp) const
void Vector<ValueType>::compute_norm1(pointer_param<LinOp> result,
array<char>& tmp) const
{
using NormVector = typename local_vector_type::absolute_type;
GKO_ASSERT_EQUAL_DIMENSIONS(result, dim<2>(1, this->get_size()[1]));
Expand Down Expand Up @@ -570,20 +630,6 @@ std::unique_ptr<Vector<ValueType>> Vector<ValueType>::create_with_type_of_impl(
GKO_INSTANTIATE_FOR_EACH_VALUE_TYPE(GKO_DECLARE_DISTRIBUTED_VECTOR);


#define GKO_DECLARE_DISTRIBUTED_VECTOR_READ_DISTRIBUTED( \
ValueType, LocalIndexType, GlobalIndexType) \
void Vector<ValueType>::read_distributed<LocalIndexType, GlobalIndexType>( \
const device_matrix_data<ValueType, GlobalIndexType>& data, \
const Partition<LocalIndexType, GlobalIndexType>* partition); \
template void \
Vector<ValueType>::read_distributed<LocalIndexType, GlobalIndexType>( \
const matrix_data<ValueType, GlobalIndexType>& data, \
const Partition<LocalIndexType, GlobalIndexType>* partition)

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_DISTRIBUTED_VECTOR_READ_DISTRIBUTED);


} // namespace distributed
} // namespace experimental
} // namespace gko
Loading

0 comments on commit 26d6fea

Please sign in to comment.