Skip to content

Commit

Permalink
fix CUDA 9.2 (this time for real)
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed Sep 17, 2022
1 parent 158ae0f commit 053fda1
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions common/cuda_hip/matrix/csr_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -1122,17 +1122,11 @@ void fallback_transpose(std::shared_ptr<const DefaultExecutor> exec,
out_col_idxs);
exec->copy(nnz, in_vals, out_vals);
exec->copy(nnz, in_col_idxs, out_row_idxs.get_data());
auto zip_it = thrust::make_zip_iterator(
thrust::make_tuple(thrust::device_pointer_cast(out_row_idxs.get_data()),
thrust::device_pointer_cast(out_col_idxs),
thrust::device_pointer_cast(out_vals)));
auto loc_it = thrust::make_zip_iterator(
thrust::make_tuple(out_row_idxs.get_data(), out_col_idxs));
using tuple_type =
thrust::tuple<IndexType, IndexType, device_member_type<ValueType>>;
thrust::sort(thrust::device, zip_it, zip_it + nnz,
[] __device__(const tuple_type& a, const tuple_type& b) {
return thrust::tie(thrust::get<0>(a), thrust::get<1>(a)) <
thrust::tie(thrust::get<0>(b), thrust::get<1>(b));
});
thrust::tuple<IndexType, IndexType, device_type<ValueType>>;
thrust::sort_by_key(thrust::device, loc_it, loc_it + nnz, out_vals);
components::convert_idxs_to_ptrs(exec, out_row_idxs.get_data(), nnz,
out_num_rows, out_row_ptrs);
}

0 comments on commit 053fda1

Please sign in to comment.