Skip to content

Commit

Permalink
Merge improved Cholesky performance
Browse files Browse the repository at this point in the history
This switches to a column Cholesky implementation which simplifies the dependency structure
and significantly speeds up the factorization, bringing it on par with LU.

Related PR: #1366
  • Loading branch information
upsj authored Jul 26, 2023
2 parents 24223b4 + db38887 commit 945a4d8
Show file tree
Hide file tree
Showing 2 changed files with 295 additions and 412 deletions.
47 changes: 16 additions & 31 deletions common/cuda_hip/factorization/cholesky_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,6 @@ __global__ __launch_bounds__(default_block_size) void symbolic_factorize(
template <typename ValueType, typename IndexType>
__global__ __launch_bounds__(default_block_size) void factorize(
const IndexType* __restrict__ row_ptrs, const IndexType* __restrict__ cols,
const IndexType* __restrict__ elim_tree_child_ptrs,
const IndexType* __restrict__ elim_tree_children,
const IndexType* __restrict__ storage_offsets,
const int32* __restrict__ storage, const int64* __restrict__ row_descs,
const IndexType* __restrict__ diag_idxs,
Expand All @@ -171,50 +169,38 @@ __global__ __launch_bounds__(default_block_size) void factorize(
const auto row_begin = row_ptrs[row];
const auto row_diag = diag_idxs[row];
const auto row_end = row_ptrs[row + 1];
const auto child_begin = elim_tree_child_ptrs[row];
const auto child_end = elim_tree_child_ptrs[row + 1];
gko::matrix::csr::device_sparsity_lookup<IndexType> lookup{
row_ptrs, cols, storage_offsets,
storage, row_descs, static_cast<size_type>(row)};
for (auto child = child_begin; child < child_end; child++) {
const auto dep = elim_tree_children[child];
scheduler.wait(dep);
// TODO evaluate parallel waiting with __all_sync
}
// for each lower triangular entry: eliminate with corresponding row
// for each lower triangular entry: eliminate with corresponding column
for (auto lower_nz = row_begin; lower_nz < row_diag; lower_nz++) {
const auto dep = cols[lower_nz];
auto val = vals[lower_nz];
scheduler.wait(dep);
const auto scale = vals[lower_nz];
const auto diag_idx = diag_idxs[dep];
const auto dep_end = row_ptrs[dep + 1];
const auto diag = vals[diag_idx];
const auto scale = val / diag;
if (lane == 0) {
vals[lower_nz] = scale;
}
// subtract all entries past the diagonal
for (auto upper_nz = diag_idx + 1 + lane; upper_nz < dep_end;
// subtract column dep from current column
for (auto upper_nz = diag_idx + lane; upper_nz < dep_end;
upper_nz += config::warp_size) {
const auto upper_col = cols[upper_nz];
if (upper_col < row) {
if (upper_col >= row) {
const auto upper_val = vals[upper_nz];
const auto output_pos =
lookup.lookup_unsafe(upper_col) + row_begin;
vals[output_pos] -= scale * upper_val;
}
}
}
ValueType sum{};
for (auto lower_nz = row_begin + lane; lower_nz < row_diag;
lower_nz += config::warp_size) {
sum += squared_norm(vals[lower_nz]);
// copy the lower triangular entries to the transpose
vals[transpose_idxs[lower_nz]] = conj(vals[lower_nz]);
auto diag_val = sqrt(vals[row_diag]);
for (auto upper_nz = row_diag + 1 + lane; upper_nz < row_end;
upper_nz += config::warp_size) {
vals[upper_nz] /= diag_val;
// copy the upper triangular entries to the transpose
vals[transpose_idxs[upper_nz]] = conj(vals[upper_nz]);
}
sum = reduce(warp, sum, thrust::plus<ValueType>{});
if (lane == 0) {
// store computed diagonal
vals[row_diag] = sqrt(vals[row_diag] - sum);
vals[row_diag] = diag_val;
}
scheduler.mark_ready();
}
Expand Down Expand Up @@ -365,10 +351,9 @@ void factorize(std::shared_ptr<const DefaultExecutor> exec,
kernel::factorize<<<num_blocks, default_block_size, 0,
exec->get_stream()>>>(
factors->get_const_row_ptrs(), factors->get_const_col_idxs(),
forest.child_ptrs.get_const_data(),
forest.children.get_const_data(), lookup_offsets, lookup_storage,
lookup_descs, diag_idxs, transpose_idxs,
as_device_type(factors->get_values()), storage, num_rows);
lookup_offsets, lookup_storage, lookup_descs, diag_idxs,
transpose_idxs, as_device_type(factors->get_values()), storage,
num_rows);
}
}

Expand Down
Loading

0 comments on commit 945a4d8

Please sign in to comment.