Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add csr lookup implementation without storage #1583

Merged
merged 2 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/cuda_hip/matrix/csr_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -1140,6 +1140,13 @@ __global__ __launch_bounds__(default_block_size) void build_csr_lookup(
}
return;
}
// if hash lookup is not allowed, we are done here
if (!csr_lookup_allowed(allowed, sparsity_type::hash)) {
if (lane == 0) {
row_desc[row] = static_cast<int64>(sparsity_type::none);
}
return;
}
// sparse hashmap storage
// we need at least one unfilled entry to avoid infinite loops on search
GKO_ASSERT(row_len < available_storage);
Expand Down
4 changes: 3 additions & 1 deletion common/unified/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,10 @@ void build_lookup_offsets(std::shared_ptr<const DefaultExecutor> exec,
if (csr_lookup_allowed(allowed, sparsity_type::bitmap) &&
bitmap_storage <= hashmap_storage) {
storage_offsets[row] = bitmap_storage;
} else {
} else if (csr_lookup_allowed(allowed, sparsity_type::hash)) {
storage_offsets[row] = hashmap_storage;
} else {
storage_offsets[row] = 0;
}
}
},
Expand Down
37 changes: 34 additions & 3 deletions core/matrix/csr_lookup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ namespace csr {
* a single bit set.
*/
enum class sparsity_type {
/**
* The row has no precomputed lookup-information associated with it, so the
* nonzero location needs to be located from the column indices explicitly.
*/
none = 0,
/**
* The row is dense, i.e. it contains all entries in
* `[min_col, min_col + storage_size)`.
Expand Down Expand Up @@ -148,9 +153,9 @@ struct device_sparsity_lookup {
return lookup_bitmap(col);
case sparsity_type::hash:
return lookup_hash(col);
default:
return lookup_search(col);
}
GKO_ASSERT(false);
return invalid_index<IndexType>();
}

/**
Expand All @@ -176,7 +181,8 @@ struct device_sparsity_lookup {
result = lookup_hash_unsafe(col);
break;
default:
GKO_ASSERT(false);
result = lookup_search_unsafe(col);
break;
}
GKO_ASSERT(local_cols[result] == col);
return result;
Expand Down Expand Up @@ -290,6 +296,31 @@ struct device_sparsity_lookup {
// out_idx is either correct or invalid_index, the hashmap sentinel
return out_idx;
}

GKO_ATTRIBUTES GKO_INLINE IndexType
lookup_search_unsafe(IndexType col) const
{
// binary search through the column indices
auto length = row_nnz;
IndexType offset{};
while (length > 0) {
auto half_length = length / 2;
auto mid = offset + half_length;
// this finds the first index with column index >= col
auto pred = local_cols[mid] >= col;
length = pred ? half_length : length - (half_length + 1);
offset = pred ? offset : mid + 1;
}
return offset;
}

GKO_ATTRIBUTES GKO_INLINE IndexType lookup_search(IndexType col) const
{
const auto offset = lookup_search_unsafe(col);
return offset < row_nnz && local_cols[offset] == col
? offset
: invalid_index<IndexType>();
}
};


Expand Down
10 changes: 8 additions & 2 deletions dpcpp/matrix/csr_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2799,6 +2799,7 @@ void build_lookup(std::shared_ptr<const DpcppExecutor> exec,
const IndexType* storage_offsets, int64* row_desc,
int32* storage)
{
using matrix::csr::sparsity_type;
exec->get_queue()->submit([&](sycl::handler& cgh) {
cgh.parallel_for(sycl::range<1>{num_rows}, [=](sycl::id<1> idx) {
const auto row = static_cast<size_type>(idx[0]);
Expand All @@ -2820,8 +2821,13 @@ void build_lookup(std::shared_ptr<const DpcppExecutor> exec,
row_desc[row], local_storage, local_cols);
}
if (!done) {
csr_lookup_build_hash(row_len, available_storage, row_desc[row],
local_storage, local_cols);
if (csr_lookup_allowed(allowed, sparsity_type::hash)) {
csr_lookup_build_hash(row_len, available_storage,
row_desc[row], local_storage,
local_cols);
} else {
row_desc[row] = static_cast<int64>(sparsity_type::none);
}
}
});
});
Expand Down
9 changes: 7 additions & 2 deletions omp/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,7 @@ void build_lookup(std::shared_ptr<const DefaultExecutor> exec,
const IndexType* storage_offsets, int64* row_desc,
int32* storage)
{
using matrix::csr::sparsity_type;
#pragma omp parallel for
for (size_type row = 0; row < num_rows; row++) {
const auto row_begin = row_ptrs[row];
Expand All @@ -1386,8 +1387,12 @@ void build_lookup(std::shared_ptr<const DefaultExecutor> exec,
row_desc[row], local_storage, local_cols);
}
if (!done) {
csr_lookup_build_hash(row_len, available_storage, row_desc[row],
local_storage, local_cols);
if (csr_lookup_allowed(allowed, sparsity_type::hash)) {
csr_lookup_build_hash(row_len, available_storage, row_desc[row],
local_storage, local_cols);
} else {
row_desc[row] = static_cast<int64>(sparsity_type::none);
}
}
}
}
Expand Down
14 changes: 11 additions & 3 deletions reference/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "core/components/prefix_sum_kernels.hpp"
#include "core/matrix/csr_accessor_helper.hpp"
#include "core/matrix/csr_builder.hpp"
#include "core/matrix/csr_lookup.hpp"
#include "reference/components/csr_spgeam.hpp"


Expand Down Expand Up @@ -1297,8 +1298,10 @@ void build_lookup_offsets(std::shared_ptr<const ReferenceExecutor> exec,
if (csr_lookup_allowed(allowed, sparsity_type::bitmap) &&
bitmap_storage <= hashmap_storage) {
storage_offsets[row] = bitmap_storage;
} else {
} else if (csr_lookup_allowed(allowed, sparsity_type::hash)) {
storage_offsets[row] = hashmap_storage;
} else {
storage_offsets[row] = 0;
}
}
}
Expand Down Expand Up @@ -1397,6 +1400,7 @@ void build_lookup(std::shared_ptr<const ReferenceExecutor> exec,
const IndexType* storage_offsets, int64* row_desc,
int32* storage)
{
using matrix::csr::sparsity_type;
for (size_type row = 0; row < num_rows; row++) {
const auto row_begin = row_ptrs[row];
const auto row_len = row_ptrs[row + 1] - row_begin;
Expand All @@ -1415,8 +1419,12 @@ void build_lookup(std::shared_ptr<const ReferenceExecutor> exec,
row_desc[row], local_storage, local_cols);
}
if (!done) {
csr_lookup_build_hash(row_len, available_storage, row_desc[row],
local_storage, local_cols);
if (csr_lookup_allowed(allowed, sparsity_type::hash)) {
csr_lookup_build_hash(row_len, available_storage, row_desc[row],
local_storage, local_cols);
} else {
row_desc[row] = static_cast<int64>(sparsity_type::none);
}
}
}
}
Expand Down
30 changes: 18 additions & 12 deletions reference/test/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2598,6 +2598,7 @@ TYPED_TEST_SUITE(CsrLookup, gko::test::ValueIndexTypes,
TYPED_TEST(CsrLookup, GeneratesLookupDataOffsets)
{
using IndexType = typename TestFixture::index_type;
using gko::matrix::csr::csr_lookup_allowed;
using gko::matrix::csr::sparsity_type;
const auto num_rows = this->mtx->get_size()[0];
gko::array<IndexType> storage_offset_array(this->exec, num_rows + 1);
Expand All @@ -2608,19 +2609,19 @@ TYPED_TEST(CsrLookup, GeneratesLookupDataOffsets)
for (auto allowed :
{sparsity_type::full | sparsity_type::bitmap | sparsity_type::hash,
sparsity_type::bitmap | sparsity_type::hash,
sparsity_type::full | sparsity_type::hash, sparsity_type::hash}) {
sparsity_type::full | sparsity_type::hash, sparsity_type::hash,
sparsity_type::none}) {
gko::kernels::reference::csr::build_lookup_offsets(
this->exec, row_ptrs, col_idxs, num_rows, allowed, storage_offsets);
bool allow_full =
gko::matrix::csr::csr_lookup_allowed(allowed, sparsity_type::full);
bool allow_bitmap = gko::matrix::csr::csr_lookup_allowed(
allowed, sparsity_type::bitmap);
bool allow_full = csr_lookup_allowed(allowed, sparsity_type::full);
bool allow_bitmap = csr_lookup_allowed(allowed, sparsity_type::bitmap);
bool allow_hash = csr_lookup_allowed(allowed, sparsity_type::hash);

for (gko::size_type row = 0; row < num_rows; row++) {
const auto expected_size =
std::min(allow_full ? this->full_sizes[row] : 1000,
std::min(allow_bitmap ? this->bitmap_sizes[row] : 1000,
this->hash_sizes[row]));
const auto expected_size = std::min(
allow_full ? this->full_sizes[row] : 1000,
std::min(allow_bitmap ? this->bitmap_sizes[row] : 1000,
allow_hash ? this->hash_sizes[row] : IndexType{}));
const auto size = storage_offsets[row + 1] - storage_offsets[row];

ASSERT_EQ(size, expected_size);
Expand All @@ -2644,16 +2645,21 @@ TYPED_TEST(CsrLookup, GeneratesLookupData)
for (auto allowed :
{sparsity_type::full | sparsity_type::bitmap | sparsity_type::hash,
sparsity_type::bitmap | sparsity_type::hash,
sparsity_type::full | sparsity_type::hash, sparsity_type::hash}) {
sparsity_type::full | sparsity_type::hash, sparsity_type::hash,
sparsity_type::none}) {
gko::kernels::reference::csr::build_lookup_offsets(
this->exec, row_ptrs, col_idxs, num_rows, allowed, storage_offsets);
gko::array<gko::int32> storage_array(this->exec,
storage_offsets[num_rows]);
const auto storage = storage_array.get_data();
const auto hash_equivalent =
csr_lookup_allowed(allowed, sparsity_type::hash)
? sparsity_type::hash
: sparsity_type::none;
const auto bitmap_equivalent =
csr_lookup_allowed(allowed, sparsity_type::bitmap)
? sparsity_type::bitmap
: sparsity_type::hash;
: hash_equivalent;
const auto full_equivalent =
csr_lookup_allowed(allowed, sparsity_type::full)
? sparsity_type::full
Expand Down Expand Up @@ -2687,7 +2693,7 @@ TYPED_TEST(CsrLookup, GeneratesLookupData)
ASSERT_EQ(row_descs[0] & 0xF, static_cast<int>(full_equivalent));
ASSERT_EQ(row_descs[1] & 0xF, static_cast<int>(full_equivalent));
ASSERT_EQ(row_descs[2] & 0xF, static_cast<int>(bitmap_equivalent));
ASSERT_EQ(row_descs[3] & 0xF, static_cast<int>(sparsity_type::hash));
ASSERT_EQ(row_descs[3] & 0xF, static_cast<int>(hash_equivalent));
ASSERT_EQ(row_descs[4] & 0xF, static_cast<int>(full_equivalent));
ASSERT_EQ(row_descs[5] & 0xF, static_cast<int>(full_equivalent));
}
Expand Down
3 changes: 2 additions & 1 deletion test/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ TYPED_TEST(CsrLookup, BuildLookupWorks)
for (auto allowed :
{sparsity_type::full | sparsity_type::bitmap | sparsity_type::hash,
sparsity_type::bitmap | sparsity_type::hash,
sparsity_type::full | sparsity_type::hash, sparsity_type::hash}) {
sparsity_type::full | sparsity_type::hash, sparsity_type::hash,
sparsity_type::none}) {
// check that storage offsets are calculated correctly
// otherwise things might crash
gko::kernels::reference::csr::build_lookup_offsets(
Expand Down
Loading