Skip to content

Commit

Permalink
HIP: workaround runtime bug in hipGraph support
Browse files Browse the repository at this point in the history
  • Loading branch information
IMbackK committed Feb 19, 2025
1 parent 300907b commit dbf12a6
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2558,7 +2558,14 @@ static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vecto
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
#if defined(__HIP_PLATFORM_AMD__)
// Workaroudnd for https://github.com/ROCm/clr/issues/138
// The hip runtime fails to copy this and calls delete on it later, so we must
// perform an alloc here.
cuda_ctx->cuda_graph->params[i].kernelParams[1] = new void*(*updated_kernel_arg_ptr);
#else
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
#endif // defined(__HIP_PLATFORM_AMD__)
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
}
}
Expand Down

0 comments on commit dbf12a6

Please sign in to comment.