From db388873d51d797484804d4908a67d36ff4aee0c Mon Sep 17 00:00:00 2001 From: Tobias Ribizel Date: Tue, 6 Jun 2023 15:02:49 +0200 Subject: [PATCH] use column Cholesky for GPU --- .../factorization/cholesky_kernels.hpp.inc | 47 +++++++------------ 1 file changed, 16 insertions(+), 31 deletions(-) diff --git a/common/cuda_hip/factorization/cholesky_kernels.hpp.inc b/common/cuda_hip/factorization/cholesky_kernels.hpp.inc index f87969a7ad0..eb90127a8ca 100644 --- a/common/cuda_hip/factorization/cholesky_kernels.hpp.inc +++ b/common/cuda_hip/factorization/cholesky_kernels.hpp.inc @@ -149,8 +149,6 @@ __global__ __launch_bounds__(default_block_size) void symbolic_factorize( template __global__ __launch_bounds__(default_block_size) void factorize( const IndexType* __restrict__ row_ptrs, const IndexType* __restrict__ cols, - const IndexType* __restrict__ elim_tree_child_ptrs, - const IndexType* __restrict__ elim_tree_children, const IndexType* __restrict__ storage_offsets, const int32* __restrict__ storage, const int64* __restrict__ row_descs, const IndexType* __restrict__ diag_idxs, @@ -171,32 +169,21 @@ __global__ __launch_bounds__(default_block_size) void factorize( const auto row_begin = row_ptrs[row]; const auto row_diag = diag_idxs[row]; const auto row_end = row_ptrs[row + 1]; - const auto child_begin = elim_tree_child_ptrs[row]; - const auto child_end = elim_tree_child_ptrs[row + 1]; gko::matrix::csr::device_sparsity_lookup lookup{ row_ptrs, cols, storage_offsets, storage, row_descs, static_cast(row)}; - for (auto child = child_begin; child < child_end; child++) { - const auto dep = elim_tree_children[child]; - scheduler.wait(dep); - // TODO evaluate parallel waiting with __all_sync - } - // for each lower triangular entry: eliminate with corresponding row + // for each lower triangular entry: eliminate with corresponding column for (auto lower_nz = row_begin; lower_nz < row_diag; lower_nz++) { const auto dep = cols[lower_nz]; - auto val = vals[lower_nz]; + scheduler.wait(dep); + const auto scale = vals[lower_nz]; const auto diag_idx = diag_idxs[dep]; const auto dep_end = row_ptrs[dep + 1]; - const auto diag = vals[diag_idx]; - const auto scale = val / diag; - if (lane == 0) { - vals[lower_nz] = scale; - } - // subtract all entries past the diagonal - for (auto upper_nz = diag_idx + 1 + lane; upper_nz < dep_end; + // subtract column dep from current column + for (auto upper_nz = diag_idx + lane; upper_nz < dep_end; upper_nz += config::warp_size) { const auto upper_col = cols[upper_nz]; - if (upper_col < row) { + if (upper_col >= row) { const auto upper_val = vals[upper_nz]; const auto output_pos = lookup.lookup_unsafe(upper_col) + row_begin; @@ -204,17 +191,16 @@ __global__ __launch_bounds__(default_block_size) void factorize( } } } - ValueType sum{}; - for (auto lower_nz = row_begin + lane; lower_nz < row_diag; - lower_nz += config::warp_size) { - sum += squared_norm(vals[lower_nz]); - // copy the lower triangular entries to the transpose - vals[transpose_idxs[lower_nz]] = conj(vals[lower_nz]); + auto diag_val = sqrt(vals[row_diag]); + for (auto upper_nz = row_diag + 1 + lane; upper_nz < row_end; + upper_nz += config::warp_size) { + vals[upper_nz] /= diag_val; + // copy the upper triangular entries to the transpose + vals[transpose_idxs[upper_nz]] = conj(vals[upper_nz]); } - sum = reduce(warp, sum, thrust::plus{}); if (lane == 0) { // store computed diagonal - vals[row_diag] = sqrt(vals[row_diag] - sum); + vals[row_diag] = diag_val; } scheduler.mark_ready(); } @@ -365,10 +351,9 @@ void factorize(std::shared_ptr exec, kernel::factorize<<get_stream()>>>( factors->get_const_row_ptrs(), factors->get_const_col_idxs(), - forest.child_ptrs.get_const_data(), - forest.children.get_const_data(), lookup_offsets, lookup_storage, - lookup_descs, diag_idxs, transpose_idxs, - as_device_type(factors->get_values()), storage, num_rows); + lookup_offsets, lookup_storage, lookup_descs, diag_idxs, + transpose_idxs, as_device_type(factors->get_values()), storage, + num_rows); } }