Skip to content

Commit

Permalink
only allow const access to local vector
Browse files Browse the repository at this point in the history
this adds in turn mutable access through get_local_values and at_local

Co-authored-by: Tobias Ribizel <ribizel@kit.edu>
  • Loading branch information
MarcelKoch and upsj committed May 4, 2022
1 parent 3e65ffd commit ab8b482
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 101 deletions.
110 changes: 57 additions & 53 deletions core/distributed/vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,9 @@ void Vector<ValueType>::read_distributed(
global_cols));

auto rank = this->get_communicator().rank();
this->get_local()->fill(zero<ValueType>());
local_.fill(zero<ValueType>());
exec->run(vector::make_build_local(
data, make_temporary_clone(exec, partition).get(), rank,
this->get_local()));
data, make_temporary_clone(exec, partition).get(), rank, &local_));
}


Expand All @@ -157,7 +156,7 @@ void Vector<ValueType>::read_distributed(
template <typename ValueType>
void Vector<ValueType>::fill(const ValueType value)
{
this->get_local()->fill(value);
local_.fill(value);
}


Expand All @@ -168,7 +167,7 @@ void Vector<ValueType>::convert_to(
GKO_ASSERT(this->get_communicator().size() ==
result->get_communicator().size());
result->set_size(this->get_size());
this->get_const_local()->convert_to(result->get_local());
this->get_local_vector()->convert_to(&result->local_);
}


Expand All @@ -187,10 +186,10 @@ Vector<ValueType>::compute_absolute() const

auto result =
absolute_type::create(exec, this->get_communicator(), this->get_size(),
this->get_const_local()->get_size());
this->get_local_vector()->get_size());

exec->run(vector::make_outplace_absolute_dense(this->get_const_local(),
result->get_local()));
exec->run(vector::make_outplace_absolute_dense(this->get_local_vector(),
&result->local_));

return result;
}
Expand All @@ -199,20 +198,13 @@ Vector<ValueType>::compute_absolute() const
template <typename ValueType>
void Vector<ValueType>::compute_absolute_inplace()
{
this->get_local()->compute_absolute_inplace();
local_.compute_absolute_inplace();
}


template <typename ValueType>
const typename Vector<ValueType>::local_vector_type*
Vector<ValueType>::get_const_local() const
{
return &local_;
}


template <typename ValueType>
typename Vector<ValueType>::local_vector_type* Vector<ValueType>::get_local()
Vector<ValueType>::get_local_vector() const
{
return &local_;
}
Expand All @@ -224,8 +216,8 @@ Vector<ValueType>::make_complex() const
{
auto result = complex_type::create(
this->get_executor(), this->get_communicator(), this->get_size(),
this->get_const_local()->get_size(),
this->get_const_local()->get_stride());
this->get_local_vector()->get_size(),
this->get_local_vector()->get_stride());
this->make_complex(result.get());
return result;
}
Expand All @@ -234,18 +226,18 @@ Vector<ValueType>::make_complex() const
template <typename ValueType>
void Vector<ValueType>::make_complex(Vector::complex_type* result) const
{
this->get_const_local()->make_complex(result->get_local());
this->get_local_vector()->make_complex(&result->local_);
}


template <typename ValueType>
std::unique_ptr<typename Vector<ValueType>::real_type>
Vector<ValueType>::get_real() const
{
auto result =
real_type::create(this->get_executor(), this->get_communicator(),
this->get_size(), this->get_const_local()->get_size(),
this->get_const_local()->get_stride());
auto result = real_type::create(this->get_executor(),
this->get_communicator(), this->get_size(),
this->get_local_vector()->get_size(),
this->get_local_vector()->get_stride());
this->get_real(result.get());
return result;
}
Expand All @@ -254,18 +246,18 @@ Vector<ValueType>::get_real() const
template <typename ValueType>
void Vector<ValueType>::get_real(Vector::real_type* result) const
{
this->get_const_local()->get_real(result->get_local());
this->get_local_vector()->get_real(&result->local_);
}


template <typename ValueType>
std::unique_ptr<typename Vector<ValueType>::real_type>
Vector<ValueType>::get_imag() const
{
auto result =
real_type::create(this->get_executor(), this->get_communicator(),
this->get_size(), this->get_const_local()->get_size(),
this->get_const_local()->get_stride());
auto result = real_type::create(this->get_executor(),
this->get_communicator(), this->get_size(),
this->get_local_vector()->get_size(),
this->get_local_vector()->get_stride());
this->get_imag(result.get());
return result;
}
Expand All @@ -274,37 +266,37 @@ Vector<ValueType>::get_imag() const
template <typename ValueType>
void Vector<ValueType>::get_imag(Vector::real_type* result) const
{
this->get_const_local()->get_imag(result->get_local());
this->get_local_vector()->get_imag(&result->local_);
}


template <typename ValueType>
void Vector<ValueType>::scale(const LinOp* alpha)
{
this->get_local()->scale(alpha);
local_.scale(alpha);
}


template <typename ValueType>
void Vector<ValueType>::inv_scale(const LinOp* alpha)
{
this->get_local()->inv_scale(alpha);
local_.inv_scale(alpha);
}


template <typename ValueType>
void Vector<ValueType>::add_scaled(const LinOp* alpha, const LinOp* b)
{
auto dense_b = as<Vector<ValueType>>(b);
this->get_local()->add_scaled(alpha, dense_b->get_const_local());
local_.add_scaled(alpha, dense_b->get_local_vector());
}


template <typename ValueType>
void Vector<ValueType>::sub_scaled(const LinOp* alpha, const LinOp* b)
{
auto dense_b = as<Vector<ValueType>>(b);
this->get_local()->sub_scaled(alpha, dense_b->get_const_local());
local_.sub_scaled(alpha, dense_b->get_local_vector());
}


Expand All @@ -315,8 +307,8 @@ void Vector<ValueType>::compute_dot(const LinOp* b, LinOp* result) const
const auto comm = this->get_communicator();
auto dense_res =
make_temporary_clone(exec, as<matrix::Dense<ValueType>>(result));
this->get_const_local()->compute_dot(as<Vector>(b)->get_const_local(),
dense_res.get());
this->get_local_vector()->compute_dot(as<Vector>(b)->get_local_vector(),
dense_res.get());
exec->synchronize();
auto use_host_buffer =
exec->get_master() != exec && !gko::mpi::is_gpu_aware();
Expand All @@ -340,8 +332,8 @@ void Vector<ValueType>::compute_conj_dot(const LinOp* b, LinOp* result) const
const auto comm = this->get_communicator();
auto dense_res =
make_temporary_clone(exec, as<matrix::Dense<ValueType>>(result));
this->get_const_local()->compute_conj_dot(as<Vector>(b)->get_const_local(),
dense_res.get());
this->get_local_vector()->compute_conj_dot(
as<Vector>(b)->get_local_vector(), dense_res.get());
exec->synchronize();
auto use_host_buffer =
exec->get_master() != exec && !gko::mpi::is_gpu_aware();
Expand All @@ -366,7 +358,7 @@ void Vector<ValueType>::compute_norm2(LinOp* result) const
auto exec = this->get_executor();
const auto comm = this->get_communicator();
auto dense_res = make_temporary_clone(exec, as<NormVector>(result));
exec->run(vector::make_compute_squared_norm2(this->get_const_local(),
exec->run(vector::make_compute_squared_norm2(this->get_local_vector(),
dense_res.get()));
exec->synchronize();
auto use_host_buffer =
Expand All @@ -393,7 +385,7 @@ void Vector<ValueType>::compute_norm1(LinOp* result) const
auto exec = this->get_executor();
const auto comm = this->get_communicator();
auto dense_res = make_temporary_clone(exec, as<NormVector>(result));
this->get_const_local()->compute_norm1(dense_res.get());
this->get_local_vector()->compute_norm1(dense_res.get());
exec->synchronize();
auto use_host_buffer =
exec->get_master() != exec && !gko::mpi::is_gpu_aware();
Expand All @@ -411,26 +403,38 @@ void Vector<ValueType>::compute_norm1(LinOp* result) const


template <typename ValueType>
void Vector<ValueType>::resize(dim<2> global_size, dim<2> local_size)
ValueType& Vector<ValueType>::at_local(size_type row, size_type col) noexcept
{
if (this->get_size() != global_size) {
this->set_size(global_size);
}
this->get_local()->resize(local_size);
return local_.at(row, col);
}

template <typename ValueType>
ValueType Vector<ValueType>::at_local(size_type row,
size_type col) const noexcept
{
return local_.at(row, col);
}

template <typename ValueType>
std::unique_ptr<typename Vector<ValueType>::real_type>
Vector<ValueType>::create_real_view()
ValueType& Vector<ValueType>::at_local(size_type idx) noexcept
{
const auto num_global_rows = this->get_size()[0];
const auto num_cols =
is_complex<ValueType>() ? 2 * this->get_size()[1] : this->get_size()[1];
return local_.at(idx);
}

return real_type::create(this->get_executor(), this->get_communicator(),
dim<2>{num_global_rows, num_cols},
local_.create_real_view().get());
template <typename ValueType>
ValueType Vector<ValueType>::at_local(size_type idx) const noexcept
{
return local_.at(idx);
}


template <typename ValueType>
void Vector<ValueType>::resize(dim<2> global_size, dim<2> local_size)
{
if (this->get_size() != global_size) {
this->set_size(global_size);
}
local_.resize(local_size);
}


Expand Down
68 changes: 55 additions & 13 deletions include/ginkgo/core/distributed/vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ namespace distributed {
* Using this approach the size of the global vectors, as well as the size of
* the local vectors, will be automatically inferred. It is possible to create a
* vector with specified global and local sizes and fill the local vectors using
* the accessor get_local.
* the accessor get_local_vector.
*
* @note Operations between two vectors (axpy, dot product, etc.) are only valid
* if both vectors where created using the same partition.
Expand All @@ -86,6 +86,7 @@ class Vector
friend class EnableCreateMethod<Vector<ValueType>>;
friend class EnablePolymorphicObject<Vector<ValueType>, LinOp>;
friend class Vector<to_complex<ValueType>>;
friend class Vector<remove_complex<ValueType>>;
friend class Vector<next_precision<ValueType>>;

public:
Expand Down Expand Up @@ -267,19 +268,65 @@ class Vector
void compute_norm1(LinOp* result) const;

/**
* Direct (read) access to the underlying local local_vector_type vectors.
* Returns a single element of the multi-vector.
*
* @return a constant pointer to the underlying local_vector_type vectors
* @param row the local row of the requested element
* @param col the local column of the requested element
*
* @note the method has to be called on the same Executor the multi-vector
* is stored at (e.g. trying to call this method on a GPU multi-vector from
* the OMP results in a runtime error)
*/
const local_vector_type* get_const_local() const;
value_type& at_local(size_type row, size_type col) noexcept;

/*
* Direct (read/write) access to the underlying local_vector_type Dense
* vectors.
/**
* @copydoc Vector::at(size_type, size_type)
*/
value_type at_local(size_type row, size_type col) const noexcept;

/**
* Returns a single element of the multi-vector.
*
* Useful for iterating across all elements of the multi-vector.
* However, it is less efficient than the two-parameter variant of this
* method.
*
* @param idx a linear index of the requested element
* (ignoring the stride)
*
* @note the method has to be called on the same Executor the matrix is
* stored at (e.g. trying to call this method on a GPU matrix from
* the OMP results in a runtime error)
*/
ValueType& at_local(size_type idx) noexcept;

/**
* @copydoc Vector::at(size_type)
*/
ValueType at_local(size_type idx) const noexcept;

/**
* Returns a pointer to the array of local values of the multi-vector.
*
* @return the pointer to the array of local values
*/
value_type* get_local_values();

/**
* @copydoc get_local_values()
*
* @note This is the constant version of the function, which can be
* significantly more memory efficient than the non-constant version,
* so always prefer this version.
*/
const value_type* get_const_local_values();

/**
* Direct (read) access to the underlying local local_vector_type vectors.
*
* @return a constant pointer to the underlying local_vector_type vectors
*/
local_vector_type* get_local();
const local_vector_type* get_local_vector() const;

/**
* Create a real view of the (potentially) complex original multi-vector.
Expand All @@ -288,11 +335,6 @@ class Vector
* real with a reinterpret_cast with twice the number of columns and
* double the stride.
*/
std::unique_ptr<real_type> create_real_view();

/**
* @copydoc create_real_view()
*/
std::unique_ptr<const real_type> create_real_view() const;

protected:
Expand Down
Loading

0 comments on commit ab8b482

Please sign in to comment.