Skip to content

Commit

Permalink
[hip][cuda] Merged pending_queue_actions implementations. (#18220)
Browse files Browse the repository at this point in the history
This creates a single deferred work queue that merges the almost
identical implementations of pending_queue_actions.

This leaves the actual underlying implementation unchanged, and simply
moves the work into a shared location. There are some future cleanups
that can be made in order to improve this even further.

---------

Signed-off-by: Andrew Woloszyn <andrew.woloszyn@gmail.com>
  • Loading branch information
AWoloszyn authored Aug 15, 2024
1 parent 3f97c02 commit 75ad937
Show file tree
Hide file tree
Showing 16 changed files with 1,131 additions and 2,172 deletions.
3 changes: 1 addition & 2 deletions runtime/src/iree/hal/drivers/cuda/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ iree_runtime_cc_library(
"nccl_channel.h",
"nop_executable_cache.c",
"nop_executable_cache.h",
"pending_queue_actions.c",
"pending_queue_actions.h",
"pipeline_layout.c",
"pipeline_layout.h",
"stream_command_buffer.c",
Expand Down Expand Up @@ -66,6 +64,7 @@ iree_runtime_cc_library(
"//runtime/src/iree/hal",
"//runtime/src/iree/hal/utils:collective_batch",
"//runtime/src/iree/hal/utils:deferred_command_buffer",
"//runtime/src/iree/hal/utils:deferred_work_queue",
"//runtime/src/iree/hal/utils:file_transfer",
"//runtime/src/iree/hal/utils:memory_file",
"//runtime/src/iree/hal/utils:resource_set",
Expand Down
3 changes: 1 addition & 2 deletions runtime/src/iree/hal/drivers/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ iree_cc_library(
"nccl_channel.h"
"nop_executable_cache.c"
"nop_executable_cache.h"
"pending_queue_actions.c"
"pending_queue_actions.h"
"pipeline_layout.c"
"pipeline_layout.h"
"stream_command_buffer.c"
Expand All @@ -63,6 +61,7 @@ iree_cc_library(
iree::hal
iree::hal::utils::collective_batch
iree::hal::utils::deferred_command_buffer
iree::hal::utils::deferred_work_queue
iree::hal::utils::file_transfer
iree::hal::utils::memory_file
iree::hal::utils::resource_set
Expand Down
253 changes: 234 additions & 19 deletions runtime/src/iree/hal/drivers/cuda/cuda_device.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
#include "iree/hal/drivers/cuda/nccl_channel.h"
#include "iree/hal/drivers/cuda/nccl_dynamic_symbols.h"
#include "iree/hal/drivers/cuda/nop_executable_cache.h"
#include "iree/hal/drivers/cuda/pending_queue_actions.h"
#include "iree/hal/drivers/cuda/pipeline_layout.h"
#include "iree/hal/drivers/cuda/stream_command_buffer.h"
#include "iree/hal/drivers/cuda/timepoint_pool.h"
#include "iree/hal/drivers/cuda/tracing.h"
#include "iree/hal/utils/deferred_command_buffer.h"
#include "iree/hal/utils/deferred_work_queue.h"
#include "iree/hal/utils/file_transfer.h"
#include "iree/hal/utils/memory_file.h"

Expand Down Expand Up @@ -76,7 +76,7 @@ typedef struct iree_hal_cuda_device_t {
// are met. It buffers submissions and allocations internally before they
// are ready. This queue couples with HAL semaphores backed by iree_event_t
// and CUevent objects.
iree_hal_cuda_pending_queue_actions_t* pending_queue_actions;
iree_hal_deferred_work_queue_t* work_queue;

// Device memory pools and allocators.
bool supports_memory_pools;
Expand All @@ -88,6 +88,176 @@ typedef struct iree_hal_cuda_device_t {
} iree_hal_cuda_device_t;

static const iree_hal_device_vtable_t iree_hal_cuda_device_vtable;
static const iree_hal_deferred_work_queue_device_interface_vtable_t
iree_hal_cuda_deferred_work_queue_device_interface_vtable;

// We put a CUEvent into a iree_hal_deferred_work_queue_native_event_t.
static_assert(sizeof(CUevent) <=
sizeof(iree_hal_deferred_work_queue_native_event_t),
"Unexpected event size");
typedef struct iree_hal_cuda_deferred_work_queue_device_interface_t {
iree_hal_deferred_work_queue_device_interface_t base;
iree_hal_device_t* device;
CUdevice cu_device;
CUcontext cu_context;
CUstream dispatch_cu_stream;
iree_allocator_t host_allocator;
const iree_hal_cuda_dynamic_symbols_t* cuda_symbols;
} iree_hal_cuda_deferred_work_queue_device_interface_t;

static void iree_hal_cuda_deferred_work_queue_device_interface_destroy(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface) {
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface);
iree_allocator_free(device_interface->host_allocator, device_interface);
}

static iree_status_t
iree_hal_cuda_deferred_work_queue_device_interface_bind_to_thread(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface) {
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface);
return IREE_CURESULT_TO_STATUS(device_interface->cuda_symbols,
cuCtxSetCurrent(device_interface->cu_context),
"cuCtxSetCurrent");
}

static iree_status_t
iree_hal_cuda_deferred_work_queue_device_interface_wait_native_event(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
iree_hal_deferred_work_queue_native_event_t event) {
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface);
return IREE_CURESULT_TO_STATUS(
device_interface->cuda_symbols,
cuStreamWaitEvent(device_interface->dispatch_cu_stream, (CUevent)event,
CU_EVENT_WAIT_DEFAULT),
"cuStreamWaitEvent");
}

static iree_status_t
iree_hal_cuda_deferred_work_queue_device_interface_create_native_event(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
iree_hal_deferred_work_queue_native_event_t* out_event) {
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface);
return IREE_CURESULT_TO_STATUS(
device_interface->cuda_symbols,
cuEventCreate((CUevent*)out_event, CU_EVENT_WAIT_DEFAULT),
"cuEventCreate");
}

static iree_status_t
iree_hal_cuda_deferred_work_queue_device_interface_record_native_event(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
iree_hal_deferred_work_queue_native_event_t event) {
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface);
return IREE_CURESULT_TO_STATUS(
device_interface->cuda_symbols,
cuEventRecord((CUevent)event, device_interface->dispatch_cu_stream),
"cuEventCreate");
}

static iree_status_t
iree_hal_cuda_deferred_work_queue_device_interface_synchronize_native_event(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
iree_hal_deferred_work_queue_native_event_t event) {
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface);
return IREE_CURESULT_TO_STATUS(device_interface->cuda_symbols,
cuEventSynchronize((CUevent)event));
}

static iree_status_t
iree_hal_cuda_deferred_work_queue_device_interface_destroy_native_event(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
iree_hal_deferred_work_queue_native_event_t event) {
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface);
return IREE_CURESULT_TO_STATUS(device_interface->cuda_symbols,
cuEventDestroy((CUevent)event));
}

static iree_status_t
iree_hal_cuda_deferred_work_queue_device_interface_semaphore_acquire_timepoint_device_signal_native_event(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
struct iree_hal_semaphore_t* semaphore, uint64_t value,
iree_hal_deferred_work_queue_native_event_t* out_event) {
return iree_hal_cuda_event_semaphore_acquire_timepoint_device_signal(
semaphore, value, (CUevent*)out_event);
}

static bool
iree_hal_cuda_deferred_work_queue_device_interface_acquire_host_wait_event(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
struct iree_hal_semaphore_t* semaphore, uint64_t value,
iree_hal_deferred_work_queue_host_device_event_t* out_event) {
return iree_hal_cuda_semaphore_acquire_event_host_wait(
semaphore, value, (iree_hal_cuda_event_t**)out_event);
}

static void
iree_hal_cuda_deferred_work_queue_device_interface_release_wait_event(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
iree_hal_deferred_work_queue_host_device_event_t wait_event) {
iree_hal_cuda_event_release(wait_event);
}

static iree_status_t
iree_hal_cuda_deferred_work_queue_device_interface_device_wait_on_host_event(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
iree_hal_deferred_work_queue_host_device_event_t wait_event) {
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface);
return IREE_CURESULT_TO_STATUS(
device_interface->cuda_symbols,
cuStreamWaitEvent(
device_interface->dispatch_cu_stream,
iree_hal_cuda_event_handle((iree_hal_cuda_event_t*)wait_event), 0),
"cuStreamWaitEvent");
}

static void*
iree_hal_cuda_deferred_work_queue_device_interface_native_event_from_wait_event(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
iree_hal_deferred_work_queue_host_device_event_t event) {
return iree_hal_cuda_event_handle((iree_hal_cuda_event_t*)event);
}

static iree_status_t
iree_hal_cuda_deferred_work_queue_device_interface_create_stream_command_buffer(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t categories,
iree_hal_command_buffer_t** out) {
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface);
return iree_hal_cuda_device_create_stream_command_buffer(
device_interface->device, mode, categories, 0, out);
}

static iree_status_t
iree_hal_cuda_deferred_work_queue_device_interface_submit_command_buffer(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
iree_hal_command_buffer_t* command_buffer) {
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface);
iree_status_t status = iree_ok_status();
if (iree_hal_cuda_stream_command_buffer_isa(command_buffer)) {
// Stream command buffer so nothing to do but notify it was submitted.
iree_hal_cuda_stream_notify_submitted_commands(command_buffer);
} else {
CUgraphExec exec =
iree_hal_cuda_graph_command_buffer_handle(command_buffer);
status = IREE_CURESULT_TO_STATUS(
device_interface->cuda_symbols,
cuGraphLaunch(exec, device_interface->dispatch_cu_stream));
if (IREE_LIKELY(iree_status_is_ok(status))) {
iree_hal_cuda_graph_tracing_notify_submitted_commands(command_buffer);
}
}
return status;
}

static iree_hal_cuda_device_t* iree_hal_cuda_device_cast(
iree_hal_device_t* base_value) {
Expand Down Expand Up @@ -152,9 +322,27 @@ static iree_status_t iree_hal_cuda_device_create_internal(
device->dispatch_cu_stream = dispatch_stream;
device->host_allocator = host_allocator;

iree_status_t status = iree_hal_cuda_pending_queue_actions_create(
cuda_symbols, cu_device, context, &device->block_pool, host_allocator,
&device->pending_queue_actions);
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface;
iree_status_t status = iree_allocator_malloc(
host_allocator,
sizeof(iree_hal_cuda_deferred_work_queue_device_interface_t),
(void**)&device_interface);
if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
iree_hal_device_release((iree_hal_device_t*)device);
return status;
}
device_interface->base.vtable =
&iree_hal_cuda_deferred_work_queue_device_interface_vtable;
device_interface->cu_context = context;
device_interface->cuda_symbols = cuda_symbols;
device_interface->cu_device = cu_device;
device_interface->device = (iree_hal_device_t*)device;
device_interface->dispatch_cu_stream = dispatch_stream;
device_interface->host_allocator = host_allocator;

status = iree_hal_deferred_work_queue_create(
(iree_hal_deferred_work_queue_device_interface_t*)device_interface,
&device->block_pool, host_allocator, &device->work_queue);

// Enable tracing for the (currently only) stream - no-op if disabled.
if (iree_status_is_ok(status) && device->params.stream_tracing) {
Expand Down Expand Up @@ -297,8 +485,7 @@ static void iree_hal_cuda_device_destroy(iree_hal_device_t* base_device) {
IREE_TRACE_ZONE_BEGIN(z0);

// Destroy the pending workload queue.
iree_hal_cuda_pending_queue_actions_destroy(
(iree_hal_resource_t*)device->pending_queue_actions);
iree_hal_deferred_work_queue_destroy(device->work_queue);

// There should be no more buffers live that use the allocator.
iree_hal_allocator_release(device->device_allocator);
Expand Down Expand Up @@ -620,7 +807,7 @@ static iree_status_t iree_hal_cuda_device_create_semaphore(
iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
return iree_hal_cuda_event_semaphore_create(
initial_value, device->cuda_symbols, device->timepoint_pool,
device->pending_queue_actions, device->host_allocator, out_semaphore);
device->work_queue, device->host_allocator, out_semaphore);
}

static iree_hal_semaphore_compatibility_t
Expand Down Expand Up @@ -765,15 +952,13 @@ static iree_status_t iree_hal_cuda_device_queue_execute(
iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
IREE_TRACE_ZONE_BEGIN(z0);

iree_status_t status = iree_hal_cuda_pending_queue_actions_enqueue_execution(
base_device, device->dispatch_cu_stream, device->pending_queue_actions,
iree_hal_cuda_device_collect_tracing_context, device->tracing_context,
wait_semaphore_list, signal_semaphore_list, command_buffer_count,
command_buffers, binding_tables);
iree_status_t status = iree_hal_deferred_work_queue_enque(
device->work_queue, iree_hal_cuda_device_collect_tracing_context,
device->tracing_context, wait_semaphore_list, signal_semaphore_list,
command_buffer_count, command_buffers, binding_tables);
if (iree_status_is_ok(status)) {
// Try to advance the pending workload queue.
status = iree_hal_cuda_pending_queue_actions_issue(
device->pending_queue_actions);
// Try to advance the deferred work queue.
status = iree_hal_deferred_work_queue_issue(device->work_queue);
}

IREE_TRACE_ZONE_END(z0);
Expand All @@ -784,9 +969,8 @@ static iree_status_t iree_hal_cuda_device_queue_flush(
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity) {
iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
IREE_TRACE_ZONE_BEGIN(z0);
// Try to advance the pending workload queue.
iree_status_t status =
iree_hal_cuda_pending_queue_actions_issue(device->pending_queue_actions);
// Try to advance the deferred work queue.
iree_status_t status = iree_hal_deferred_work_queue_issue(device->work_queue);
IREE_TRACE_ZONE_END(z0);
return status;
}
Expand Down Expand Up @@ -850,3 +1034,34 @@ static const iree_hal_device_vtable_t iree_hal_cuda_device_vtable = {
.profiling_flush = iree_hal_cuda_device_profiling_flush,
.profiling_end = iree_hal_cuda_device_profiling_end,
};

static const iree_hal_deferred_work_queue_device_interface_vtable_t
iree_hal_cuda_deferred_work_queue_device_interface_vtable = {
.destroy = iree_hal_cuda_deferred_work_queue_device_interface_destroy,
.bind_to_thread =
iree_hal_cuda_deferred_work_queue_device_interface_bind_to_thread,
.wait_native_event =
iree_hal_cuda_deferred_work_queue_device_interface_wait_native_event,
.create_native_event =
iree_hal_cuda_deferred_work_queue_device_interface_create_native_event,
.record_native_event =
iree_hal_cuda_deferred_work_queue_device_interface_record_native_event,
.synchronize_native_event =
iree_hal_cuda_deferred_work_queue_device_interface_synchronize_native_event,
.destroy_native_event =
iree_hal_cuda_deferred_work_queue_device_interface_destroy_native_event,
.semaphore_acquire_timepoint_device_signal_native_event =
iree_hal_cuda_deferred_work_queue_device_interface_semaphore_acquire_timepoint_device_signal_native_event,
.acquire_host_wait_event =
iree_hal_cuda_deferred_work_queue_device_interface_acquire_host_wait_event,
.device_wait_on_host_event =
iree_hal_cuda_deferred_work_queue_device_interface_device_wait_on_host_event,
.release_wait_event =
iree_hal_cuda_deferred_work_queue_device_interface_release_wait_event,
.native_event_from_wait_event =
iree_hal_cuda_deferred_work_queue_device_interface_native_event_from_wait_event,
.create_stream_command_buffer =
iree_hal_cuda_deferred_work_queue_device_interface_create_stream_command_buffer,
.submit_command_buffer =
iree_hal_cuda_deferred_work_queue_device_interface_submit_command_buffer,
};
Loading

0 comments on commit 75ad937

Please sign in to comment.