diff --git a/runtime/src/iree/hal/drivers/hip/cleanup_thread.c b/runtime/src/iree/hal/drivers/hip/cleanup_thread.c index e1b15f860288..95c7e206b7b9 100644 --- a/runtime/src/iree/hal/drivers/hip/cleanup_thread.c +++ b/runtime/src/iree/hal/drivers/hip/cleanup_thread.c @@ -115,13 +115,14 @@ iree_status_t iree_hal_hip_cleanup_thread_initialize( iree_status_t status = iree_thread_create((iree_thread_entry_t)iree_hal_hip_cleanup_thread_main, thread, params, host_allocator, &thread->thread); - if (!iree_status_is_ok(status)) { + if (iree_status_is_ok(status)) { + *out_thread = thread; + } else { iree_hal_hip_callback_queue_deinitialize(&thread->queue); iree_slim_mutex_deinitialize(&thread->mutex); iree_allocator_free(host_allocator, thread); } IREE_TRACE_ZONE_END(z0); - *out_thread = thread; return status; } diff --git a/runtime/src/iree/hal/drivers/hip/dispatch_thread.c b/runtime/src/iree/hal/drivers/hip/dispatch_thread.c index 1c435ff88b25..23313a68836f 100644 --- a/runtime/src/iree/hal/drivers/hip/dispatch_thread.c +++ b/runtime/src/iree/hal/drivers/hip/dispatch_thread.c @@ -108,13 +108,15 @@ iree_status_t iree_hal_hip_dispatch_thread_initialize( iree_status_t status = iree_thread_create((iree_thread_entry_t)iree_hal_hip_dispatch_thread_main, thread, params, host_allocator, &thread->thread); - if (!iree_status_is_ok(status)) { + + if (iree_status_is_ok(status)) { + *out_thread = thread; + } else { iree_hal_hip_dispatch_queue_deinitialize(&thread->queue); iree_slim_mutex_deinitialize(&thread->mutex); iree_allocator_free(host_allocator, thread); } IREE_TRACE_ZONE_END(z0); - *out_thread = thread; return status; } @@ -163,11 +165,10 @@ iree_status_t iree_hal_hip_dispatch_thread_add_dispatch( iree_slim_mutex_unlock(&thread->mutex); iree_notification_post(&thread->notification, IREE_ALL_WAITERS); - IREE_TRACE_ZONE_END(z0); - if (!iree_status_is_ok(status)) { iree_status_ignore(dispatch(user_data, iree_status_clone(status))); } + IREE_TRACE_ZONE_END(z0); // If this was a failure then it was put into thread->failure_status. return status; diff --git a/runtime/src/iree/hal/drivers/hip/dispatch_thread.h b/runtime/src/iree/hal/drivers/hip/dispatch_thread.h index e9252d23091b..1f715b094871 100644 --- a/runtime/src/iree/hal/drivers/hip/dispatch_thread.h +++ b/runtime/src/iree/hal/drivers/hip/dispatch_thread.h @@ -11,10 +11,9 @@ #include "iree/hal/api.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" -// iree_hal_hip_dispatch_thread is used simply as a way to get -// work off of the main thread. This is important to do for -// a single reason. There are 2 types of command buffer that we -// use in hip. One is a pre-recorded command buffer +// iree_hal_hip_dispatch_thread is used to get work off of the main thread. +// This is important to do for a single reason. There are 2 types of +// command buffer that we use in hip. One is a pre-recorded command buffer // iree_hal_deferred_command_buffer_t, which when executed // calls all of the associated hipStream based commands. // The other is iree_hal_hip_graph_command_buffer_t which when executed @@ -28,8 +27,8 @@ // work off of the main thread. There are a couple of // caveats, as now we have to move async allocations and deallocations // to that thread as well, as they need to remain in-order. - typedef struct iree_hal_hip_dispatch_thread_t iree_hal_hip_dispatch_thread_t; + typedef struct iree_hal_hip_event_t iree_hal_hip_event_t; typedef iree_status_t (*iree_hal_hip_dispatch_callback_t)(void* user_data, diff --git a/runtime/src/iree/hal/drivers/hip/event_pool.c b/runtime/src/iree/hal/drivers/hip/event_pool.c index 6cad294cfb78..95232902b803 100644 --- a/runtime/src/iree/hal/drivers/hip/event_pool.c +++ b/runtime/src/iree/hal/drivers/hip/event_pool.c @@ -172,7 +172,6 @@ iree_status_t iree_hal_hip_event_pool_allocate( event_pool->device_context = device_context; iree_status_t status = iree_hal_hip_set_context(symbols, device_context); - if (iree_status_is_ok(status)) { for (iree_host_size_t i = 0; i < available_capacity; ++i) { status = iree_hal_hip_event_create( @@ -253,22 +252,17 @@ iree_status_t iree_hal_hip_event_pool_acquire( IREE_TRACE_ZONE_APPEND_TEXT(z0, "unpooled acquire"); IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)remaining_count); - iree_status_t status = iree_hal_hip_set_context(event_pool->symbols, - event_pool->device_context); - if (iree_status_is_ok(status)) { - for (iree_host_size_t i = 0; i < remaining_count; ++i) { - status = iree_hal_hip_event_create(event_pool->symbols, event_pool, - event_pool->host_allocator, - &out_events[from_pool_count + i]); - if (!iree_status_is_ok(status)) { - // Must release all events we've acquired so far. - iree_hal_hip_event_pool_release_event(event_pool, from_pool_count + i, - out_events); - IREE_TRACE_ZONE_END(z0); - return status; - } - } + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_hal_hip_set_context(event_pool->symbols, + event_pool->device_context)); + for (iree_host_size_t i = 0; i < remaining_count; ++i) { + iree_status_t status = iree_hal_hip_event_create( + event_pool->symbols, event_pool, event_pool->host_allocator, + &out_events[from_pool_count + i]); if (!iree_status_is_ok(status)) { + // Must release all events we've acquired so far. + iree_hal_hip_event_pool_release_event(event_pool, from_pool_count + i, + out_events); IREE_TRACE_ZONE_END(z0); return status; } diff --git a/runtime/src/iree/hal/drivers/hip/event_semaphore.c b/runtime/src/iree/hal/drivers/hip/event_semaphore.c index 926bef74a5bb..94d91113bbb6 100644 --- a/runtime/src/iree/hal/drivers/hip/event_semaphore.c +++ b/runtime/src/iree/hal/drivers/hip/event_semaphore.c @@ -41,6 +41,19 @@ typedef struct iree_hal_hip_semaphore_work_item_t { struct iree_hal_hip_semaphore_work_item_t* next; } iree_hal_hip_semaphore_work_item_t; +// Work associated with a particular point in the semaphore timeline. +// +// The |work_item| is a set of callbacks to be made when the semaphore +// is guaranteed to make forward progress the associated key value. They +// will also be cleaned up at this time. If the semaphore is failed, +// the callbacks will be called with the status code of the failure. +// If the semaphore is destroyed while callbacks are active, +// they will be called with the CANCELLED erorr. +// The |cpu_event| is a value for the CPU to wait on when +// we may not have to wait infinitely. For example with a multi +// wait or a non-infinite timeout. +// The |event| is a hip event that is used for GPU waits or +// infinite CPU waits. typedef struct iree_hal_hip_semaphore_queue_item_t { iree_hal_hip_event_t* event; iree_hal_hip_cpu_event_t* cpu_event; @@ -57,8 +70,11 @@ typedef struct iree_hal_hip_semaphore_t { // The symbols used to issue HIP API calls. const iree_hal_hip_dynamic_symbols_t* symbols; + // This queue represents the values in the timeline. + // The keys in the queue are the timeline values that + // are being signaled/waited on in the semaphore + // The values are |iree_hal_hip_semaphore_queue_item_t| values. struct { - // The queue of hip events that back any GPU signals of this semaphore. iree_hal_hip_util_tree_t tree; // Inline storage for this tree. We expect the normal number of // nodes in use for a single semaphore to be relatively small. @@ -140,11 +156,21 @@ static void iree_hal_hip_semaphore_destroy( for (iree_hal_hip_util_tree_node_t* i = iree_hal_hip_util_tree_first(&semaphore->event_queue.tree); i != NULL; i = iree_hal_hip_util_tree_node_next(i)) { - iree_hal_hip_event_t* event = ((iree_hal_hip_semaphore_queue_item_t*) - iree_hal_hip_util_tree_node_get_value(i)) - ->event; - if (event) { - iree_hal_hip_event_release(event); + iree_hal_hip_semaphore_queue_item_t* queue_item = + (iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(i); + iree_hal_hip_event_release(queue_item->event); + iree_hal_resource_release(queue_item->cpu_event); + iree_hal_hip_semaphore_work_item_t* work_item = queue_item->work_item; + while (work_item) { + work_item->scheduled_callback( + work_item->user_data, base_semaphore, + iree_make_status( + IREE_STATUS_CANCELLED, + "semaphore was destroyed while callback is in flight")); + iree_hal_hip_semaphore_work_item_t* next = work_item->next; + iree_allocator_free(host_allocator, work_item); + work_item = next; } } iree_hal_hip_util_tree_deinitialize(&semaphore->event_queue.tree); @@ -160,11 +186,9 @@ static iree_status_t iree_hal_hip_semaphore_get_cpu_event( *out_event = NULL; iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(base_semaphore); - IREE_TRACE_ZONE_BEGIN(z0); iree_slim_mutex_lock(&semaphore->mutex); if (value <= semaphore->current_visible_value) { iree_slim_mutex_unlock(&semaphore->mutex); - IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } iree_status_t status = iree_ok_status(); @@ -211,7 +235,6 @@ static iree_status_t iree_hal_hip_semaphore_get_cpu_event( iree_allocator_free(semaphore->host_allocator, item->cpu_event); } } - IREE_TRACE_ZONE_END(z0); return status; } @@ -310,11 +333,12 @@ iree_status_t iree_hal_hip_semaphore_multi_wait( iree_hal_resource_release(&cpu_events[i]->resource); } iree_allocator_free(host_allocator, cpu_events); + IREE_TRACE_ZONE_END(z0); return status; } -static iree_status_t iree_hal_hip_event_semaphore_advance( +static iree_status_t iree_hal_hip_event_semaphore_run_scheduled_callbacks( iree_hal_semaphore_t* base_semaphore) { iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(base_semaphore); @@ -344,12 +368,11 @@ static iree_status_t iree_hal_hip_event_semaphore_advance( iree_hal_hip_util_tree_node_get_value(node); iree_hal_hip_util_tree_erase(&semaphore->event_queue.tree, node); iree_slim_mutex_unlock(&semaphore->mutex); - if (copy.event) { - iree_hal_hip_event_release(copy.event); - } + iree_hal_hip_event_release(copy.event); if (copy.cpu_event) { iree_event_set(©.cpu_event->event); iree_hal_resource_release(©.cpu_event->resource); + iree_allocator_free(copy.cpu_event->host_allocator, copy.cpu_event); } iree_hal_hip_semaphore_work_item_t* next_work_item = copy.work_item; @@ -369,6 +392,7 @@ static iree_status_t iree_hal_hip_event_semaphore_advance( semaphore->max_value_to_be_signaled = iree_max( semaphore->max_value_to_be_signaled, semaphore->current_visible_value); iree_status_t status = iree_status_clone(semaphore->failure_status); + iree_slim_mutex_unlock(&semaphore->mutex); // Now that we have accumulated all of the work items, and we have // unlocked the semaphore, start running through the work items. @@ -392,7 +416,6 @@ iree_status_t iree_hal_hip_semaphore_notify_work( void* user_data) { iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(base_semaphore); - IREE_TRACE_ZONE_BEGIN(z0); iree_slim_mutex_lock(&semaphore->mutex); iree_status_t status = iree_status_clone(semaphore->failure_status); @@ -436,7 +459,6 @@ iree_status_t iree_hal_hip_semaphore_notify_work( if (callback) { status = callback(user_data, base_semaphore, status); } - IREE_TRACE_ZONE_END(z0); return status; } @@ -444,12 +466,10 @@ iree_status_t iree_hal_hip_semaphore_notify_forward_progress_to( iree_hal_semaphore_t* base_semaphore, uint64_t value) { iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(base_semaphore); - IREE_TRACE_ZONE_BEGIN(z0); iree_slim_mutex_lock(&semaphore->mutex); iree_status_t status = iree_status_clone(semaphore->failure_status); if (!iree_status_is_ok(status)) { iree_slim_mutex_unlock(&semaphore->mutex); - IREE_TRACE_ZONE_END(z0); return status; } iree_hal_hip_semaphore_work_item_t* work_item = NULL; @@ -496,7 +516,6 @@ iree_status_t iree_hal_hip_semaphore_notify_forward_progress_to( iree_allocator_free(semaphore->host_allocator, work_item); work_item = next_work_item; } - IREE_TRACE_ZONE_END(z0); return status; } @@ -507,11 +526,9 @@ iree_status_t iree_hal_hip_semaphore_get_hip_event( iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(base_semaphore); *out_hip_event = NULL; - IREE_TRACE_ZONE_BEGIN(z0); iree_slim_mutex_lock(&semaphore->mutex); if (value <= semaphore->current_visible_value) { iree_slim_mutex_unlock(&semaphore->mutex); - IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } iree_status_t status = iree_status_clone(semaphore->failure_status); @@ -557,23 +574,17 @@ iree_status_t iree_hal_hip_semaphore_get_hip_event( } iree_slim_mutex_unlock(&semaphore->mutex); - IREE_TRACE_ZONE_END(z0); - return status; } iree_status_t iree_hal_hip_semaphore_create_event_and_record_if_necessary( iree_hal_semaphore_t* base_semaphore, uint64_t value, - hipStream_t dispatch_stream, iree_hal_hip_event_pool_t* event_pool, - iree_hal_hip_event_t** out_hip_event) { + hipStream_t dispatch_stream, iree_hal_hip_event_pool_t* event_pool) { iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(base_semaphore); - *out_hip_event = NULL; - IREE_TRACE_ZONE_BEGIN(z0); iree_slim_mutex_lock(&semaphore->mutex); if (value <= semaphore->current_visible_value) { iree_slim_mutex_unlock(&semaphore->mutex); - IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } iree_status_t status = iree_status_clone(semaphore->failure_status); @@ -614,23 +625,15 @@ iree_status_t iree_hal_hip_semaphore_create_event_and_record_if_necessary( dispatch_stream)); } } - if (event) { - iree_hal_hip_event_retain(event); - } - *out_hip_event = event; } } iree_slim_mutex_unlock(&semaphore->mutex); - IREE_TRACE_ZONE_END(z0); - return status; } static iree_status_t iree_hal_hip_semaphore_query_locked( iree_hal_hip_semaphore_t* semaphore, uint64_t* out_value) { - IREE_TRACE_ZONE_BEGIN(z0); - iree_status_t status = iree_ok_status(); *out_value = semaphore->current_visible_value; iree_hal_hip_util_tree_node_t* node = @@ -672,7 +675,6 @@ static iree_status_t iree_hal_hip_semaphore_query_locked( } } - IREE_TRACE_ZONE_END(z0); return status; } @@ -680,7 +682,6 @@ static iree_status_t iree_hal_hip_semaphore_query( iree_hal_semaphore_t* base_semaphore, uint64_t* out_value) { iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(base_semaphore); - IREE_TRACE_ZONE_BEGIN(z0); iree_slim_mutex_lock(&semaphore->mutex); *out_value = semaphore->current_visible_value; @@ -694,9 +695,71 @@ static iree_status_t iree_hal_hip_semaphore_query( iree_status_ignore(status); status = iree_ok_status(); } + + return iree_status_join( + status, + iree_hal_hip_event_semaphore_run_scheduled_callbacks(base_semaphore)); +} + +iree_status_t iree_hal_hip_event_semaphore_advance( + iree_hal_semaphore_t* base_semaphore) { + iree_hal_hip_semaphore_t* semaphore = + iree_hal_hip_semaphore_cast(base_semaphore); + + iree_slim_mutex_lock(&semaphore->mutex); + + iree_status_t status = iree_ok_status(); + iree_hal_hip_util_tree_node_t* node = + iree_hal_hip_util_tree_first(&semaphore->event_queue.tree); + + iree_host_size_t highest_value = 0; + while (node) { + if (!((iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(node)) + ->event) { + node = iree_hal_hip_util_tree_node_next(node); + continue; + } + + hipError_t err = + semaphore->symbols->hipEventQuery(iree_hal_hip_event_handle( + ((iree_hal_hip_semaphore_queue_item_t*) + iree_hal_hip_util_tree_node_get_value(node)) + ->event)); + if (err == hipErrorNotReady) { + break; + } + if (err != hipSuccess) { + status = IREE_HIP_RESULT_TO_STATUS(semaphore->symbols, err); + break; + } + + highest_value = iree_hal_hip_util_tree_node_get_key(node); + node = iree_hal_hip_util_tree_node_next(node); + } + + if (iree_status_is_ok(status)) { + if (semaphore->current_visible_value < highest_value) { + semaphore->current_visible_value = highest_value; + iree_notification_post(&semaphore->state_notification, IREE_ALL_WAITERS); + } + + if (highest_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { + status = + iree_make_status(IREE_STATUS_ABORTED, "the semaphore was aborted"); + } + } + + iree_slim_mutex_unlock(&semaphore->mutex); + // If the status is aborted, we will pick up the real status from + // iree_hal_hip_event_semaphore_run_scheduled_callbacks. + if (iree_status_is_aborted(status)) { + iree_status_ignore(status); + status = iree_ok_status(); + } status = iree_status_join( - status, iree_hal_hip_event_semaphore_advance(base_semaphore)); - IREE_TRACE_ZONE_END(z0); + status, + iree_hal_hip_event_semaphore_run_scheduled_callbacks(base_semaphore)); return status; } @@ -704,8 +767,6 @@ static iree_status_t iree_hal_hip_semaphore_signal( iree_hal_semaphore_t* base_semaphore, uint64_t new_value) { iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(base_semaphore); - IREE_TRACE_ZONE_BEGIN(z0); - iree_slim_mutex_lock(&semaphore->mutex); iree_status_t status = iree_ok_status(); @@ -727,9 +788,9 @@ static iree_status_t iree_hal_hip_semaphore_signal( iree_slim_mutex_unlock(&semaphore->mutex); if (iree_status_is_ok(status)) { - status = iree_hal_hip_event_semaphore_advance(base_semaphore); + status = + iree_hal_hip_event_semaphore_run_scheduled_callbacks(base_semaphore); } - IREE_TRACE_ZONE_END(z0); return status; } @@ -737,7 +798,6 @@ static void iree_hal_hip_semaphore_fail(iree_hal_semaphore_t* base_semaphore, iree_status_t status) { iree_hal_hip_semaphore_t* semaphore = iree_hal_hip_semaphore_cast(base_semaphore); - IREE_TRACE_ZONE_BEGIN(z0); iree_slim_mutex_lock(&semaphore->mutex); @@ -745,9 +805,7 @@ static void iree_hal_hip_semaphore_fail(iree_hal_semaphore_t* base_semaphore, // do this if we are going from a valid semaphore to a failed one. if (!iree_status_is_ok(semaphore->failure_status)) { // Previous sta-tus was not OK; drop our new status. - IREE_IGNORE_ERROR(status); iree_slim_mutex_unlock(&semaphore->mutex); - IREE_TRACE_ZONE_END(z0); return; } @@ -756,8 +814,8 @@ static void iree_hal_hip_semaphore_fail(iree_hal_semaphore_t* base_semaphore, semaphore->failure_status = status; iree_slim_mutex_unlock(&semaphore->mutex); - iree_status_ignore(iree_hal_hip_event_semaphore_advance(base_semaphore)); - IREE_TRACE_ZONE_END(z0); + iree_status_ignore( + iree_hal_hip_event_semaphore_run_scheduled_callbacks(base_semaphore)); } static iree_status_t iree_hal_hip_semaphore_wait( @@ -786,7 +844,8 @@ static iree_status_t iree_hal_hip_semaphore_wait( iree_slim_mutex_unlock(&semaphore->mutex); // We are going to pick up the correct status from query_locked below. - iree_status_ignore(iree_hal_hip_event_semaphore_advance(base_semaphore)); + iree_status_ignore( + iree_hal_hip_event_semaphore_run_scheduled_callbacks(base_semaphore)); // We have to wait for the semaphore to catch up. bool committed = @@ -812,7 +871,8 @@ static iree_status_t iree_hal_hip_semaphore_wait( // value, so we can return. if (semaphore->current_visible_value >= value) { iree_slim_mutex_unlock(&semaphore->mutex); - iree_hal_hip_event_semaphore_advance(base_semaphore); + iree_hal_hip_event_semaphore_run_scheduled_callbacks(base_semaphore); + iree_slim_mutex_lock(&semaphore->mutex); } else if (iree_timeout_is_infinite(timeout)) { // This is the fast-path. Since we have an infinite timeout, we can // wait directly on the hip event. @@ -826,7 +886,7 @@ static iree_status_t iree_hal_hip_semaphore_wait( IREE_ASSERT( node, "We really should either have an event in the queue that will satisfy" - "this semaphore, (we checked max_value_to_be_signaled above), or we" + "this semaphore (we checked max_value_to_be_signaled above) or we" "should already have signaled (current_visible_value above)"); iree_hal_hip_semaphore_queue_item_t* item = (iree_hal_hip_semaphore_queue_item_t*) @@ -838,11 +898,12 @@ static iree_status_t iree_hal_hip_semaphore_wait( // while we sleep on the event. iree_hal_hip_event_retain(event); iree_slim_mutex_unlock(&semaphore->mutex); - iree_hal_hip_event_semaphore_advance(base_semaphore); + iree_hal_hip_event_semaphore_run_scheduled_callbacks(base_semaphore); status = IREE_HIP_CALL_TO_STATUS( semaphore->symbols, hipEventSynchronize(iree_hal_hip_event_handle(event))); iree_hal_hip_event_release(event); + iree_slim_mutex_lock(&semaphore->mutex); } else { // If we have a non-infinite timeout, this is the slow-path. // because we will end up having to wait for either the @@ -859,17 +920,14 @@ static iree_status_t iree_hal_hip_semaphore_wait( iree_hal_resource_release(&cpu_event->resource); } } + iree_slim_mutex_lock(&semaphore->mutex); } } - // If we are ok status, we expect that the mutex was unlocked. - // If there was an error the mutex is still locked. if (iree_status_is_ok(status)) { - iree_slim_mutex_lock(&semaphore->mutex); if (semaphore->current_visible_value >= IREE_HAL_SEMAPHORE_FAILURE_VALUE) { - status = iree_status_join( - status, - iree_make_status(IREE_STATUS_ABORTED, "the semaphore was aborted")); + status = + iree_make_status(IREE_STATUS_ABORTED, "the semaphore was aborted"); } } iree_slim_mutex_unlock(&semaphore->mutex); diff --git a/runtime/src/iree/hal/drivers/hip/event_semaphore.h b/runtime/src/iree/hal/drivers/hip/event_semaphore.h index c051a0ab6abf..ebbf23992dcb 100644 --- a/runtime/src/iree/hal/drivers/hip/event_semaphore.h +++ b/runtime/src/iree/hal/drivers/hip/event_semaphore.h @@ -38,7 +38,7 @@ iree_status_t iree_hal_hip_semaphore_multi_wait( iree_allocator_t host_allocator); // Adds a work item to be executed once we have a forward progress -// guarantee on this semaphore to reach a paritcular value. +// guarantee on this semaphore to reach a particular value. // The event pool must be an event pool specifically // for the queue that will be doing the work. iree_status_t iree_hal_hip_semaphore_notify_work( @@ -68,10 +68,9 @@ iree_status_t iree_hal_hip_semaphore_get_hip_event( iree_status_t iree_hal_hip_semaphore_create_event_and_record_if_necessary( iree_hal_semaphore_t* base_semaphore, uint64_t value, - hipStream_t dispatch_stream, iree_hal_hip_event_pool_t* event_pool, - iree_hal_hip_event_t** out_hip_event); + hipStream_t dispatch_stream, iree_hal_hip_event_pool_t* event_pool); -static iree_status_t iree_hal_hip_event_semaphore_advance( +iree_status_t iree_hal_hip_event_semaphore_advance( iree_hal_semaphore_t* semaphore); #endif // IREE_HAL_DRIVERS_HIP_EVENT_SEMAPHORE_H_ diff --git a/runtime/src/iree/hal/drivers/hip/hip_allocator.c b/runtime/src/iree/hal/drivers/hip/hip_allocator.c index b6dcb76ea8af..a5fed2b279f8 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_allocator.c +++ b/runtime/src/iree/hal/drivers/hip/hip_allocator.c @@ -445,11 +445,9 @@ static iree_status_t iree_hal_hip_allocator_allocate_buffer( } } - status = iree_status_join( + return iree_status_join( status, IREE_HIP_CALL_TO_STATUS(allocator->symbols, hipCtxPopCurrent(NULL))); - - return status; } static void iree_hal_hip_allocator_deallocate_buffer( diff --git a/runtime/src/iree/hal/drivers/hip/hip_device.c b/runtime/src/iree/hal/drivers/hip/hip_device.c index e64775a4bb35..b5f9f0287855 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_device.c +++ b/runtime/src/iree/hal/drivers/hip/hip_device.c @@ -433,16 +433,6 @@ iree_status_t iree_hal_hip_device_create( status = iree_hal_hip_device_initialize_internal( driver, identifier, params, device, symbols, nccl_symbols, host_allocator); - } else { - for (iree_host_size_t i = 0; i < device_count && iree_status_is_ok(status); - ++i) { - if (device->devices[i].hip_dispatch_stream) - symbols->hipStreamDestroy(device->devices[i].hip_dispatch_stream); - // NOTE: This function return hipSuccess though doesn't release the - // primaryCtx by design on HIP/HCC path. - if (device->devices[i].hip_context) - symbols->hipDevicePrimaryCtxRelease(device->devices[i].hip_device); - } } iree_event_pool_t* host_event_pool = NULL; @@ -453,12 +443,9 @@ iree_status_t iree_hal_hip_device_create( for (iree_host_size_t i = 0; i < device_count && iree_status_is_ok(status); ++i) { - if (iree_status_is_ok(status)) { - status = iree_hal_hip_event_pool_allocate( - symbols, params->event_pool_capacity, host_allocator, - device->devices[i].hip_context, - &device->devices[i].device_event_pool); - } + status = iree_hal_hip_event_pool_allocate( + symbols, params->event_pool_capacity, host_allocator, + device->devices[i].hip_context, &device->devices[i].device_event_pool); } if (iree_status_is_ok(status)) { @@ -483,9 +470,10 @@ const iree_hal_hip_dynamic_symbols_t* iree_hal_hip_device_dynamic_symbols( static void iree_hal_hip_device_destroy(iree_hal_device_t* base_device) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); iree_allocator_t host_allocator = iree_hal_device_host_allocator(base_device); - const iree_hal_hip_dynamic_symbols_t* symbols = device->hip_symbols; IREE_TRACE_ZONE_BEGIN(z0); + const iree_hal_hip_dynamic_symbols_t* symbols = device->hip_symbols; + iree_hal_hip_cleanup_thread_deinitialize(device->cleanup_thread); iree_hal_hip_cleanup_thread_deinitialize(device->buffer_free_thread); @@ -580,6 +568,7 @@ static iree_status_t iree_hal_hip_device_query_attribute( iree_hal_hip_device_t* device, hipDeviceAttribute_t attribute, int64_t* out_value) { IREE_ASSERT_ARGUMENT(out_value); + *out_value = 0; int value = 0; IREE_HIP_RETURN_IF_ERROR( @@ -623,10 +612,8 @@ static iree_status_t iree_hal_hip_device_query_i64( static iree_status_t iree_hal_hip_device_create_channel( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, iree_hal_channel_params_t params, iree_hal_channel_t** out_channel) { - IREE_TRACE_ZONE_BEGIN(z0); iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); if (!device->nccl_symbols || !device->nccl_symbols->dylib) { - IREE_TRACE_ZONE_END(z0); return iree_make_status( IREE_STATUS_UNAVAILABLE, "RCCL runtime library version %d.%d and greater not available; " @@ -641,7 +628,6 @@ static iree_status_t iree_hal_hip_device_create_channel( int requested_count = iree_math_count_ones_u64(queue_affinity); // TODO(#12206): properly assign affinity in the compiler. if (requested_count != 64 && requested_count != 1) { - IREE_TRACE_ZONE_END(z0); return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "exactly one participant is allowed in a " "channel but %d were specified", @@ -653,8 +639,7 @@ static iree_status_t iree_hal_hip_device_create_channel( if (device->channel_provider && (params.rank == IREE_HAL_CHANNEL_RANK_DEFAULT || params.count == IREE_HAL_CHANNEL_COUNT_DEFAULT)) { - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, + IREE_RETURN_IF_ERROR( iree_hal_channel_provider_query_default_rank_and_count( device->channel_provider, ¶ms.rank, ¶ms.count), "querying default collective group rank and count"); @@ -667,7 +652,6 @@ static iree_status_t iree_hal_hip_device_create_channel( if (iree_const_byte_span_is_empty(params.id)) { // User wants the default ID. if (!device->channel_provider) { - IREE_TRACE_ZONE_END(z0); return iree_make_status( IREE_STATUS_INVALID_ARGUMENT, "default collective channel ID requested but no channel provider has " @@ -675,19 +659,16 @@ static iree_status_t iree_hal_hip_device_create_channel( } if (params.rank == 0) { // Bootstrap NCCL to get the root ID. - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_nccl_get_unique_id(device->nccl_symbols, &id), + IREE_RETURN_IF_ERROR( + iree_hal_hip_nccl_get_unique_id(device->nccl_symbols, &id), "bootstrapping NCCL root"); } // Exchange NCCL ID with all participants. - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, - iree_hal_channel_provider_exchange_default_id( - device->channel_provider, - iree_make_byte_span((void*)&id, sizeof(id))), - "exchanging NCCL ID with other participants"); + IREE_RETURN_IF_ERROR(iree_hal_channel_provider_exchange_default_id( + device->channel_provider, + iree_make_byte_span((void*)&id, sizeof(id))), + "exchanging NCCL ID with other participants"); } else if (params.id.data_length != IREE_ARRAYSIZE(id.data)) { - IREE_TRACE_ZONE_END(z0); // User provided something but it's not what we expect. return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "NCCL ID must be %zu bytes matching the " @@ -699,7 +680,6 @@ static iree_status_t iree_hal_hip_device_create_channel( } if (iree_hal_hip_nccl_id_is_empty(&id)) { - IREE_TRACE_ZONE_END(z0); // TODO: maybe this is ok? a localhost alias or something? return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "no default NCCL ID specified (all zeros)"); @@ -711,7 +691,6 @@ static iree_status_t iree_hal_hip_device_create_channel( iree_status_t status = iree_hal_hip_nccl_channel_create( device->hip_symbols, device->nccl_symbols, &id, params.rank, params.count, device->host_allocator, out_channel); - IREE_TRACE_ZONE_END(z0); return status; } @@ -722,6 +701,7 @@ static iree_status_t iree_hal_hip_device_create_command_buffer_internal( iree_hip_device_commandbuffer_type_t type, iree_hal_command_buffer_t** out_command_buffer) { IREE_TRACE_ZONE_BEGIN(z0); + *out_command_buffer = NULL; iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); @@ -911,12 +891,13 @@ iree_hal_hip_device_query_semaphore_compatibility( return IREE_HAL_SEMAPHORE_COMPATIBILITY_HOST_ONLY; } -static iree_status_t iree_hal_hip_device_pepare_async_alloc( +static iree_status_t iree_hal_hip_device_prepare_async_alloc( iree_hal_hip_device_t* device, iree_hal_buffer_params_t params, iree_device_size_t allocation_size, iree_hal_buffer_t** IREE_RESTRICT out_buffer) { IREE_TRACE_ZONE_BEGIN(z0); IREE_TRACE_ZONE_APPEND_VALUE_I64(z0, (int64_t)allocation_size); + *out_buffer = NULL; iree_hal_buffer_params_canonicalize(¶ms); @@ -944,8 +925,10 @@ typedef enum iree_hal_hip_device_semaphore_buffer_operation_type_e { IREE_HAL_HIP_DEVICE_SEMAPHORE_OPERATION_MAX = IREE_HAL_HIP_DEVICE_SEMAPHORE_OPERATION_ASYNC_DEALLOC, } iree_hal_hip_device_semaphore_buffer_operation_type_t; + typedef struct iree_hal_hip_device_semaphore_buffer_operation_callback_data_t { - iree_atomic_ref_count_t wait_semaphore_count; + iree_allocator_t host_allocator; + iree_atomic_int64_t wait_semaphore_count; iree_hal_hip_device_t* device; iree_hal_queue_affinity_t queue_affinity; iree_hal_semaphore_list_t wait_semaphore_list; @@ -986,9 +969,10 @@ static iree_status_t iree_hal_hip_device_make_buffer_callback_data( (void**)&callback_data)); uint8_t* callback_ptr = (uint8_t*)callback_data + sizeof(*callback_data); - iree_atomic_ref_count_init_value(&callback_data->wait_semaphore_count, - wait_semaphore_list.count); + iree_atomic_store(&callback_data->wait_semaphore_count, + wait_semaphore_list.count, iree_memory_order_relaxed); + callback_data->host_allocator = host_allocator; callback_data->device = device; callback_data->queue_affinity = queue_affinity; @@ -1006,6 +990,9 @@ static iree_status_t iree_hal_hip_device_make_buffer_callback_data( callback_data->wait_semaphore_list.payload_values, wait_semaphore_list.payload_values, wait_semaphore_list.count * sizeof(*wait_semaphore_list.payload_values)); + for (iree_host_size_t i = 0; i < wait_semaphore_list.count; ++i) { + iree_hal_resource_retain(wait_semaphore_list.semaphores[i]); + } callback_ptr += wait_semaphore_list_size; // Copy signal list for later access. @@ -1023,6 +1010,9 @@ static iree_status_t iree_hal_hip_device_make_buffer_callback_data( signal_semaphore_list.payload_values, signal_semaphore_list.count * sizeof(*signal_semaphore_list.payload_values)); + for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) { + iree_hal_resource_retain(signal_semaphore_list.semaphores[i]); + } callback_ptr += signal_semaphore_list_size; callback_data->buffer = buffer; @@ -1036,6 +1026,23 @@ static iree_status_t iree_hal_hip_device_make_buffer_callback_data( return iree_ok_status(); } +void iree_hal_hip_device_destroy_buffer_callback_data( + iree_hal_hip_device_semaphore_buffer_operation_callback_data_t* data) { + if (!data) { + return; + } + iree_slim_mutex_deinitialize(&data->status_mutex); + for (iree_host_size_t i = 0; i < data->wait_semaphore_list.count; ++i) { + iree_hal_resource_release(data->wait_semaphore_list.semaphores[i]); + } + for (iree_host_size_t i = 0; i < data->signal_semaphore_list.count; ++i) { + iree_hal_resource_release(data->signal_semaphore_list.semaphores[i]); + } + iree_hal_buffer_release(data->buffer); + + iree_allocator_free(data->host_allocator, data); +} + static iree_status_t iree_hal_hip_device_stream_signal_semaphores_and_add_cleanup( iree_hal_hip_device_t* device, iree_hal_hip_cleanup_thread_t* thread, @@ -1043,26 +1050,14 @@ iree_hal_hip_device_stream_signal_semaphores_and_add_cleanup( iree_host_size_t device_ordinal, iree_hal_hip_cleanup_callback_t callback, void* user_data) { IREE_TRACE_ZONE_BEGIN(z0); - iree_status_t status = iree_ok_status(); + iree_status_t status = iree_ok_status(); for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) { - iree_hal_hip_event_t* event = NULL; - iree_status_t status = - iree_hal_hip_semaphore_create_event_and_record_if_necessary( - signal_semaphore_list.semaphores[i], - signal_semaphore_list.payload_values[i], - device->devices[device_ordinal].hip_dispatch_stream, - device->devices[device_ordinal].device_event_pool, &event); - - if (!iree_status_is_ok(status)) { - break; - } - if (!event) { - status = - iree_make_status(IREE_STATUS_ABORTED, "the hip event is missing"); - break; - } - iree_hal_hip_event_release(event); + status = iree_hal_hip_semaphore_create_event_and_record_if_necessary( + signal_semaphore_list.semaphores[i], + signal_semaphore_list.payload_values[i], + device->devices[device_ordinal].hip_dispatch_stream, + device->devices[device_ordinal].device_event_pool); if (!iree_status_is_ok(status)) { break; } @@ -1076,7 +1071,6 @@ iree_hal_hip_device_stream_signal_semaphores_and_add_cleanup( } iree_hal_hip_event_t* event = NULL; - if (iree_status_is_ok(status)) { status = iree_hal_hip_event_pool_acquire( device->devices[device_ordinal].device_event_pool, 1, &event); @@ -1092,9 +1086,11 @@ iree_hal_hip_device_stream_signal_semaphores_and_add_cleanup( if (iree_status_is_ok(status)) { status = iree_hal_hip_cleanup_thread_add_cleanup(thread, event, callback, user_data); + } else { + iree_hal_hip_event_release(event); } - IREE_TRACE_ZONE_END(z0); + IREE_TRACE_ZONE_END(z0); return status; } @@ -1109,8 +1105,10 @@ static iree_status_t iree_hal_hip_device_make_buffer_free_callback_data( iree_hal_hip_device_t* device, iree_hal_queue_affinity_t queue_affinity, iree_hal_buffer_t* buffer, iree_allocator_t host_allocator, iree_hal_hip_device_buffer_free_callback_data_t** out_data) { - *out_data = NULL; IREE_TRACE_ZONE_BEGIN(z0); + + *out_data = NULL; + iree_hal_hip_device_buffer_free_callback_data_t* callback_data = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_allocator_malloc(host_allocator, sizeof(*callback_data), @@ -1129,8 +1127,6 @@ static iree_status_t iree_hal_hip_device_make_buffer_free_callback_data( static iree_status_t iree_hal_hip_async_free_buffer(void* user_data, iree_hal_hip_event_t* event, iree_status_t status) { - // Free the event we specifically created. - iree_hal_hip_device_buffer_free_callback_data_t* data = (iree_hal_hip_device_buffer_free_callback_data_t*)(user_data); @@ -1161,6 +1157,7 @@ static iree_status_t iree_hal_hip_async_free_buffer(void* user_data, static iree_status_t iree_hal_hip_device_complete_buffer_operation( void* user_data, iree_hal_hip_event_t* event, iree_status_t status) { IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_hip_device_semaphore_buffer_operation_callback_data_t* data = (iree_hal_hip_device_semaphore_buffer_operation_callback_data_t*) user_data; @@ -1170,18 +1167,8 @@ static iree_status_t iree_hal_hip_device_complete_buffer_operation( // Notify all of the signal semaphores that they have been incremented. for (iree_host_size_t i = 0; i < data->signal_semaphore_list.count; ++i) { - uint64_t unused_return_value = 0; - // We use query to force the semaphore to update. - iree_status_ignore(iree_hal_semaphore_query( - data->signal_semaphore_list.semaphores[i], &unused_return_value)); - } - - for (iree_host_size_t i = 0; i < data->signal_semaphore_list.count; ++i) { - iree_hal_resource_release(data->signal_semaphore_list.semaphores[i]); - } - - for (iree_host_size_t i = 0; i < data->wait_semaphore_list.count; ++i) { - iree_hal_resource_release(data->wait_semaphore_list.semaphores[i]); + iree_status_ignore(iree_hal_hip_event_semaphore_advance( + data->signal_semaphore_list.semaphores[i])); } if (data->buffer && @@ -1204,11 +1191,8 @@ static iree_status_t iree_hal_hip_device_complete_buffer_operation( } } - // Free the iree_hal_hip_device_semaphore_buffer_operation_callback_data_t - // and the buffer attached. - iree_slim_mutex_deinitialize(&data->status_mutex); - iree_hal_buffer_release(data->buffer); - iree_allocator_free(data->device->host_allocator, data); + iree_hal_hip_device_destroy_buffer_callback_data(data); + IREE_TRACE_ZONE_END(z0); return status; } @@ -1218,6 +1202,7 @@ static iree_status_t iree_hal_hip_device_stream_wait_for_semaphores( iree_hal_semaphore_list_t wait_semaphore_list, iree_host_size_t device_ordinal) { IREE_TRACE_ZONE_BEGIN(z0); + iree_status_t status = iree_ok_status(); // TODO(awoloszyn): Because of how hip works, if we only have a single // physical device in the hip_device we could avoid waiting on any of these @@ -1245,24 +1230,26 @@ static iree_status_t iree_hal_hip_device_stream_wait_for_semaphores( iree_hal_hip_event_handle(event), 0)); iree_hal_hip_event_release(event); } + IREE_TRACE_ZONE_END(z0); return status; } static iree_status_t iree_hal_hip_device_perform_buffer_operation_now( void* user_data, iree_status_t status) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_hip_device_semaphore_buffer_operation_callback_data_t* data = (iree_hal_hip_device_semaphore_buffer_operation_callback_data_t*) user_data; - IREE_ASSERT_LE(data->type, IREE_HAL_HIP_DEVICE_SEMAPHORE_OPERATION_MAX); - IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_hip_device_t* device = data->device; // If we had a semaphore failure then we should propagate it // but not run anything. if (!iree_status_is_ok(data->status)) { - status = data->status; + status = iree_status_join(data->status, status); } int device_ordinal = iree_math_count_trailing_zeros_u64(data->queue_affinity); @@ -1304,12 +1291,9 @@ static iree_status_t iree_hal_hip_device_perform_buffer_operation_now( } } IREE_TRACE_ZONE_END(z3); + const iree_hal_hip_dynamic_symbols_t* symbols = data->device->hip_symbols; if (iree_status_is_ok(status)) { - // Retain the semaphores for the cleanup thread. - for (iree_host_size_t i = 0; i < data->signal_semaphore_list.count; ++i) { - iree_hal_resource_retain(data->signal_semaphore_list.semaphores[i]); - } // Data may get deleted any time after adding it to the cleanup, // so retain the symbols here. status = iree_hal_hip_device_stream_signal_semaphores_and_add_cleanup( @@ -1320,17 +1304,12 @@ static iree_status_t iree_hal_hip_device_perform_buffer_operation_now( iree_hal_semaphore_fail(data->signal_semaphore_list.semaphores[i], iree_status_clone(data->status)); } - for (iree_host_size_t i = 0; i < data->wait_semaphore_list.count; ++i) { - iree_hal_resource_release(data->wait_semaphore_list.semaphores[i]); - } - iree_hal_buffer_release(data->buffer); - iree_slim_mutex_deinitialize(&data->status_mutex); - iree_allocator_free(device->host_allocator, data); + iree_hal_hip_device_destroy_buffer_callback_data(data); } - status = iree_status_join( - status, IREE_HIP_CALL_TO_STATUS(symbols, hipCtxPopCurrent(NULL))); + IREE_TRACE_ZONE_END(z0); - return status; + return iree_status_join( + status, IREE_HIP_CALL_TO_STATUS(symbols, hipCtxPopCurrent(NULL))); } static iree_status_t iree_hal_hip_device_semaphore_buffer_operation_callback( @@ -1343,15 +1322,15 @@ static iree_status_t iree_hal_hip_device_semaphore_buffer_operation_callback( data->status = iree_status_join(data->status, status); iree_slim_mutex_unlock(&data->status_mutex); } - if (iree_atomic_ref_count_dec(&data->wait_semaphore_count) != 1) { + if (iree_atomic_fetch_sub(&data->wait_semaphore_count, 1, + iree_memory_order_acq_rel) != 1) { return iree_ok_status(); } int device_ordinal = iree_math_count_trailing_zeros_u64(data->queue_affinity); - // Now the actual buffer_operation happens, as all semaphore have been // satisfied (by satisfied here, we specifically mean that the semaphore has - // been scheduled, not necessarily completed) + // been scheduled, not necessarily completed). return iree_hal_hip_dispatch_thread_add_dispatch( data->device->devices[device_ordinal].dispatch_thread, &iree_hal_hip_device_perform_buffer_operation_now, data); @@ -1368,9 +1347,10 @@ static iree_status_t iree_hal_hip_device_queue_alloca( iree_hal_allocator_pool_t pool, iree_hal_buffer_params_t params, iree_device_size_t allocation_size, iree_hal_buffer_t** IREE_RESTRICT out_buffer) { - *out_buffer = NULL; IREE_TRACE_ZONE_BEGIN(z0); + *out_buffer = NULL; + iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); uint64_t queue_affinity_mask = ((iree_hal_queue_affinity_t)1 << device->device_count); @@ -1386,8 +1366,8 @@ static iree_status_t iree_hal_hip_device_queue_alloca( iree_hal_hip_allocator_isa(iree_hal_device_allocator(base_device)))) { iree_hal_buffer_t* buffer = NULL; - status = iree_hal_hip_device_pepare_async_alloc(device, params, - allocation_size, &buffer); + status = iree_hal_hip_device_prepare_async_alloc(device, params, + allocation_size, &buffer); iree_hal_hip_device_semaphore_buffer_operation_callback_data_t* callback_data = NULL; @@ -1401,15 +1381,9 @@ static iree_status_t iree_hal_hip_device_queue_alloca( status = iree_hal_hip_dispatch_thread_add_dispatch( device->devices[device_ordinal].dispatch_thread, &iree_hal_hip_device_perform_buffer_operation_now, callback_data); - *out_buffer = buffer; - IREE_TRACE_ZONE_END(z0); - return status; - } - - if (iree_status_is_ok(status)) { + } else if (iree_status_is_ok(status) && wait_semaphore_list.count != 0) { for (iree_host_size_t i = 0; i < wait_semaphore_list.count && iree_status_is_ok(status); ++i) { - iree_hal_resource_retain(wait_semaphore_list.semaphores[i]); status = iree_status_join( status, iree_hal_hip_semaphore_notify_work( @@ -1420,7 +1394,7 @@ static iree_status_t iree_hal_hip_device_queue_alloca( callback_data)); } } else { - iree_allocator_free(device->host_allocator, callback_data); + iree_hal_hip_device_destroy_buffer_callback_data(callback_data); } if (iree_status_is_ok(status)) { @@ -1431,6 +1405,7 @@ static iree_status_t iree_hal_hip_device_queue_alloca( iree_hal_resource_release(&buffer->resource); } } + IREE_TRACE_ZONE_END(z0); return status; } @@ -1498,13 +1473,8 @@ static iree_status_t iree_hal_hip_device_queue_dealloca( status = iree_hal_hip_dispatch_thread_add_dispatch( device->devices[device_ordinal].dispatch_thread, &iree_hal_hip_device_perform_buffer_operation_now, callback_data); - IREE_TRACE_ZONE_END(z0); - return status; - } - - if (iree_status_is_ok(status)) { + } else if (iree_status_is_ok(status) && wait_semaphore_list.count != 0) { for (iree_host_size_t i = 0; i < wait_semaphore_list.count; ++i) { - iree_hal_resource_retain(wait_semaphore_list.semaphores[i]); status = iree_status_join( status, iree_hal_hip_semaphore_notify_work( @@ -1515,8 +1485,9 @@ static iree_status_t iree_hal_hip_device_queue_dealloca( callback_data)); } } else { - iree_allocator_free(device->host_allocator, callback_data); + iree_hal_hip_device_destroy_buffer_callback_data(callback_data); } + IREE_TRACE_ZONE_END(z0); return status; } @@ -1542,6 +1513,7 @@ static iree_status_t iree_hal_hip_device_queue_dealloca( if (iree_status_is_ok(status)) { status = iree_hal_semaphore_list_signal(signal_semaphore_list); } + IREE_TRACE_ZONE_END(z0); return status; } @@ -1567,8 +1539,8 @@ static iree_status_t iree_hal_hip_device_queue_read( base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, source_file, source_offset, target_buffer, target_offset, length, flags, options)); - IREE_TRACE_ZONE_END(z0); + IREE_TRACE_ZONE_END(z0); return loop_status; } @@ -1593,12 +1565,14 @@ static iree_status_t iree_hal_hip_device_queue_write( base_device, queue_affinity, wait_semaphore_list, signal_semaphore_list, source_buffer, source_offset, target_file, target_offset, length, flags, options)); + IREE_TRACE_ZONE_END(z0); return loop_status; } typedef struct iree_hal_hip_device_semaphore_submit_callback_data_t { - iree_atomic_ref_count_t wait_semaphore_count; + iree_allocator_t host_allocator; + iree_atomic_int64_t wait_semaphore_count; iree_hal_hip_device_t* device; iree_hal_queue_affinity_t queue_affinity; iree_hal_command_buffer_t* command_buffer; @@ -1610,9 +1584,146 @@ typedef struct iree_hal_hip_device_semaphore_submit_callback_data_t { iree_status_t status; } iree_hal_hip_device_semaphore_submit_callback_data_t; +static iree_status_t iree_hal_hip_device_make_callback_data( + iree_hal_hip_device_t* device, iree_allocator_t host_allocator, + iree_arena_block_pool_t* block_pool, + iree_hal_queue_affinity_t queue_affinity, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_hal_command_buffer_t* command_buffer, + iree_hal_buffer_binding_table_t binding_table, + iree_hal_hip_device_semaphore_submit_callback_data_t** out_data) { + IREE_TRACE_ZONE_BEGIN(z0); + + *out_data = NULL; + + // Embed captured tables in the action allocation. + iree_hal_hip_device_semaphore_submit_callback_data_t* callback_data = NULL; + + const iree_host_size_t wait_semaphore_list_size = + wait_semaphore_list.count * sizeof(*wait_semaphore_list.semaphores) + + wait_semaphore_list.count * sizeof(*wait_semaphore_list.payload_values); + const iree_host_size_t signal_semaphore_list_size = + signal_semaphore_list.count * sizeof(*signal_semaphore_list.semaphores) + + signal_semaphore_list.count * + sizeof(*signal_semaphore_list.payload_values); + + const iree_host_size_t payload_size = + binding_table.count * sizeof(*binding_table.bindings); + + const iree_host_size_t total_callback_size = + sizeof(*callback_data) + wait_semaphore_list_size + + signal_semaphore_list_size + payload_size; + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, iree_allocator_malloc(host_allocator, total_callback_size, + (void**)&callback_data)); + uint8_t* callback_ptr = (uint8_t*)callback_data + sizeof(*callback_data); + + callback_data->host_allocator = host_allocator; + callback_data->device = device; + + iree_atomic_store(&callback_data->wait_semaphore_count, + wait_semaphore_list.count, iree_memory_order_relaxed); + // Copy wait list for later access. + callback_data->wait_semaphore_list.count = wait_semaphore_list.count; + callback_data->wait_semaphore_list.semaphores = + (iree_hal_semaphore_t**)callback_ptr; + memcpy(callback_data->wait_semaphore_list.semaphores, + wait_semaphore_list.semaphores, + wait_semaphore_list.count * sizeof(*wait_semaphore_list.semaphores)); + callback_data->wait_semaphore_list.payload_values = + (uint64_t*)(callback_ptr + wait_semaphore_list.count * + sizeof(*wait_semaphore_list.semaphores)); + memcpy( + callback_data->wait_semaphore_list.payload_values, + wait_semaphore_list.payload_values, + wait_semaphore_list.count * sizeof(*wait_semaphore_list.payload_values)); + for (iree_host_size_t i = 0; i < wait_semaphore_list.count; ++i) { + iree_hal_resource_retain(wait_semaphore_list.semaphores[i]); + } + callback_ptr += wait_semaphore_list_size; + + // Copy signal list for later access. + callback_data->signal_semaphore_list.count = signal_semaphore_list.count; + callback_data->signal_semaphore_list.semaphores = + (iree_hal_semaphore_t**)callback_ptr; + memcpy( + callback_data->signal_semaphore_list.semaphores, + signal_semaphore_list.semaphores, + signal_semaphore_list.count * sizeof(*signal_semaphore_list.semaphores)); + callback_data->signal_semaphore_list.payload_values = + (uint64_t*)(callback_ptr + signal_semaphore_list.count * + sizeof(*signal_semaphore_list.semaphores)); + memcpy(callback_data->signal_semaphore_list.payload_values, + signal_semaphore_list.payload_values, + signal_semaphore_list.count * + sizeof(*signal_semaphore_list.payload_values)); + for (iree_host_size_t i = 0; i < signal_semaphore_list.count; ++i) { + iree_hal_resource_retain(signal_semaphore_list.semaphores[i]); + } + callback_ptr += signal_semaphore_list_size; + + // Copy the execution resources for later access. + callback_data->queue_affinity = queue_affinity; + callback_data->command_buffer = command_buffer; + + // Retain all command buffers and semaphores. + iree_status_t status = + iree_hal_resource_set_allocate(block_pool, &callback_data->resource_set); + if (iree_status_is_ok(status)) { + status = iree_hal_resource_set_insert(callback_data->resource_set, + wait_semaphore_list.count, + wait_semaphore_list.semaphores); + } + if (iree_status_is_ok(status)) { + status = iree_hal_resource_set_insert(callback_data->resource_set, + signal_semaphore_list.count, + signal_semaphore_list.semaphores); + } + if (iree_status_is_ok(status)) { + status = iree_hal_resource_set_insert(callback_data->resource_set, 1, + &command_buffer); + } + + callback_data->binding_table = binding_table; + iree_hal_buffer_binding_t* binding_element_ptr = + (iree_hal_buffer_binding_t*)callback_ptr; + callback_data->binding_table.bindings = binding_element_ptr; + memcpy(binding_element_ptr, binding_table.bindings, + sizeof(*binding_element_ptr) * binding_table.count); + status = iree_hal_resource_set_insert_strided( + callback_data->resource_set, binding_table.count, + callback_data->binding_table.bindings, + offsetof(iree_hal_buffer_binding_t, buffer), + sizeof(iree_hal_buffer_binding_t)); + + callback_data->status = iree_ok_status(); + iree_slim_mutex_initialize(&callback_data->status_mutex); + *out_data = callback_data; + IREE_TRACE_ZONE_END(z0); + return status; +} + +void iree_hal_hip_device_destroy_callback_data( + iree_hal_hip_device_semaphore_submit_callback_data_t* data) { + if (!data) { + return; + } + iree_slim_mutex_deinitialize(&data->status_mutex); + iree_hal_resource_set_free(data->resource_set); + for (iree_host_size_t i = 0; i < data->wait_semaphore_list.count; ++i) { + iree_hal_resource_release(data->wait_semaphore_list.semaphores[i]); + } + for (iree_host_size_t i = 0; i < data->signal_semaphore_list.count; ++i) { + iree_hal_resource_release(data->signal_semaphore_list.semaphores[i]); + } + iree_allocator_free(data->host_allocator, data); +} + static iree_status_t iree_hal_hip_device_complete_submission( void* user_data, iree_hal_hip_event_t* event, iree_status_t status) { IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_hip_device_semaphore_submit_callback_data_t* data = (iree_hal_hip_device_semaphore_submit_callback_data_t*)user_data; iree_hal_hip_device_t* device = data->device; @@ -1653,35 +1764,24 @@ static iree_status_t iree_hal_hip_device_complete_submission( // Notify all of the signal semaphores that they have been incremented. for (iree_host_size_t i = 0; i < data->signal_semaphore_list.count; ++i) { - uint64_t unused_return_value = 0; - // We use query to force the semaphore to update. - iree_status_ignore(iree_hal_semaphore_query( - data->signal_semaphore_list.semaphores[i], &unused_return_value)); - } - - for (iree_host_size_t i = 0; i < data->signal_semaphore_list.count; ++i) { - iree_hal_resource_release(data->signal_semaphore_list.semaphores[i]); + iree_status_ignore(iree_hal_hip_event_semaphore_advance( + data->signal_semaphore_list.semaphores[i])); } + iree_hal_hip_device_destroy_callback_data(data); - for (iree_host_size_t i = 0; i < data->wait_semaphore_list.count; ++i) { - iree_hal_resource_release(data->wait_semaphore_list.semaphores[i]); - } - // Free the iree_hal_hip_device_semaphore_submit_callback_data_t and - // the resource set attached. - iree_hal_resource_set_free(data->resource_set); - iree_slim_mutex_deinitialize(&data->status_mutex); - iree_allocator_free(device->host_allocator, data); IREE_TRACE_ZONE_END(z0); return status; } static iree_status_t iree_hal_hip_device_execute_now(void* user_data, iree_status_t status) { + IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_hip_device_semaphore_submit_callback_data_t* data = (iree_hal_hip_device_semaphore_submit_callback_data_t*)user_data; - IREE_TRACE_ZONE_BEGIN(z0); IREE_ASSERT_EQ(iree_math_count_ones_u64(data->queue_affinity), 1, "Cannot execute a command buffer on more than one queue"); + iree_hal_hip_device_t* device = data->device; // If we had a semaphore failure then we should propagate it @@ -1703,7 +1803,6 @@ static iree_status_t iree_hal_hip_device_execute_now(void* user_data, } // We have satisfied all of the waits. - IREE_TRACE_ZONE_BEGIN_NAMED(z1, "iree_hal_hip_device_execute_now_launch"); iree_hal_command_buffer_t* command_buffer = data->command_buffer; if (iree_status_is_ok(status)) { @@ -1768,9 +1867,6 @@ static iree_status_t iree_hal_hip_device_execute_now(void* user_data, const iree_hal_hip_dynamic_symbols_t* symbols = data->device->hip_symbols; if (iree_status_is_ok(status)) { - for (iree_host_size_t i = 0; i < data->signal_semaphore_list.count; ++i) { - iree_hal_resource_retain(data->signal_semaphore_list.semaphores[i]); - } status = iree_hal_hip_device_stream_signal_semaphores_and_add_cleanup( data->device, data->device->cleanup_thread, data->signal_semaphore_list, device_ordinal, iree_hal_hip_device_complete_submission, data); @@ -1781,31 +1877,26 @@ static iree_status_t iree_hal_hip_device_execute_now(void* user_data, iree_hal_semaphore_fail(data->signal_semaphore_list.semaphores[i], iree_status_clone(data->status)); } - for (iree_host_size_t i = 0; i < data->wait_semaphore_list.count; ++i) { - iree_hal_resource_release(data->wait_semaphore_list.semaphores[i]); - } - iree_hal_resource_set_free(data->resource_set); - iree_slim_mutex_deinitialize(&data->status_mutex); - iree_allocator_free(device->host_allocator, data); + iree_hal_hip_device_destroy_callback_data(data); } - status = iree_status_join( - status, IREE_HIP_CALL_TO_STATUS(symbols, hipCtxPopCurrent(NULL))); - IREE_TRACE_ZONE_END(z0); - return status; + return iree_status_join( + status, IREE_HIP_CALL_TO_STATUS(symbols, hipCtxPopCurrent(NULL))); } static iree_status_t iree_hal_hip_device_semaphore_submit_callback( void* user_context, iree_hal_semaphore_t* semaphore, iree_status_t status) { iree_hal_hip_device_semaphore_submit_callback_data_t* data = (iree_hal_hip_device_semaphore_submit_callback_data_t*)user_context; + if (!iree_status_is_ok(status)) { iree_slim_mutex_lock(&data->status_mutex); data->status = iree_status_join(data->status, status); iree_slim_mutex_unlock(&data->status_mutex); } - if (iree_atomic_ref_count_dec(&data->wait_semaphore_count) != 1) { + if (iree_atomic_fetch_sub(&data->wait_semaphore_count, 1, + iree_memory_order_acq_rel) != 1) { return iree_ok_status(); } @@ -1819,126 +1910,16 @@ static iree_status_t iree_hal_hip_device_semaphore_submit_callback( &iree_hal_hip_device_execute_now, data); } -static iree_status_t iree_hal_hip_device_make_callback_data( - iree_hal_hip_device_t* device, iree_allocator_t host_allocator, - iree_arena_block_pool_t* block_pool, - iree_hal_queue_affinity_t queue_affinity, - const iree_hal_semaphore_list_t wait_semaphore_list, - const iree_hal_semaphore_list_t signal_semaphore_list, - iree_hal_command_buffer_t* command_buffer, - iree_hal_buffer_binding_table_t binding_table, - iree_hal_hip_device_semaphore_submit_callback_data_t** out_data) { - *out_data = NULL; - IREE_TRACE_ZONE_BEGIN(z0); - // Embed captured tables in the action allocation. - iree_hal_hip_device_semaphore_submit_callback_data_t* callback_data = NULL; - - const iree_host_size_t wait_semaphore_list_size = - wait_semaphore_list.count * sizeof(*wait_semaphore_list.semaphores) + - wait_semaphore_list.count * sizeof(*wait_semaphore_list.payload_values); - const iree_host_size_t signal_semaphore_list_size = - signal_semaphore_list.count * sizeof(*signal_semaphore_list.semaphores) + - signal_semaphore_list.count * - sizeof(*signal_semaphore_list.payload_values); - - const iree_host_size_t payload_size = - binding_table.count * sizeof(*binding_table.bindings); - - const iree_host_size_t total_callback_size = - sizeof(*callback_data) + wait_semaphore_list_size + - signal_semaphore_list_size + payload_size; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_allocator_malloc(host_allocator, total_callback_size, - (void**)&callback_data)); - uint8_t* callback_ptr = (uint8_t*)callback_data + sizeof(*callback_data); - - callback_data->device = device; - - iree_atomic_ref_count_init_value(&callback_data->wait_semaphore_count, - wait_semaphore_list.count); - // Copy wait list for later access. - callback_data->wait_semaphore_list.count = wait_semaphore_list.count; - callback_data->wait_semaphore_list.semaphores = - (iree_hal_semaphore_t**)callback_ptr; - memcpy(callback_data->wait_semaphore_list.semaphores, - wait_semaphore_list.semaphores, - wait_semaphore_list.count * sizeof(*wait_semaphore_list.semaphores)); - callback_data->wait_semaphore_list.payload_values = - (uint64_t*)(callback_ptr + wait_semaphore_list.count * - sizeof(*wait_semaphore_list.semaphores)); - memcpy( - callback_data->wait_semaphore_list.payload_values, - wait_semaphore_list.payload_values, - wait_semaphore_list.count * sizeof(*wait_semaphore_list.payload_values)); - callback_ptr += wait_semaphore_list_size; - - // Copy signal list for later access. - callback_data->signal_semaphore_list.count = signal_semaphore_list.count; - callback_data->signal_semaphore_list.semaphores = - (iree_hal_semaphore_t**)callback_ptr; - memcpy( - callback_data->signal_semaphore_list.semaphores, - signal_semaphore_list.semaphores, - signal_semaphore_list.count * sizeof(*signal_semaphore_list.semaphores)); - callback_data->signal_semaphore_list.payload_values = - (uint64_t*)(callback_ptr + signal_semaphore_list.count * - sizeof(*signal_semaphore_list.semaphores)); - memcpy(callback_data->signal_semaphore_list.payload_values, - signal_semaphore_list.payload_values, - signal_semaphore_list.count * - sizeof(*signal_semaphore_list.payload_values)); - callback_ptr += signal_semaphore_list_size; - - // Copy the execution resources for later access. - callback_data->queue_affinity = queue_affinity; - callback_data->command_buffer = command_buffer; - - // Retain all command buffers and semaphores. - iree_status_t status = - iree_hal_resource_set_allocate(block_pool, &callback_data->resource_set); - if (iree_status_is_ok(status)) { - status = iree_hal_resource_set_insert(callback_data->resource_set, - wait_semaphore_list.count, - wait_semaphore_list.semaphores); - } - if (iree_status_is_ok(status)) { - status = iree_hal_resource_set_insert(callback_data->resource_set, - signal_semaphore_list.count, - signal_semaphore_list.semaphores); - } - if (iree_status_is_ok(status)) { - status = iree_hal_resource_set_insert(callback_data->resource_set, 1, - &command_buffer); - } - - callback_data->binding_table = binding_table; - iree_hal_buffer_binding_t* binding_element_ptr = - (iree_hal_buffer_binding_t*)callback_ptr; - callback_data->binding_table.bindings = binding_element_ptr; - memcpy(binding_element_ptr, binding_table.bindings, - sizeof(*binding_element_ptr) * binding_table.count); - status = iree_hal_resource_set_insert_strided( - callback_data->resource_set, binding_table.count, - callback_data->binding_table.bindings, - offsetof(iree_hal_buffer_binding_t, buffer), - sizeof(iree_hal_buffer_binding_t)); - - callback_data->status = iree_ok_status(); - iree_slim_mutex_initialize(&callback_data->status_mutex); - *out_data = callback_data; - IREE_TRACE_ZONE_END(z0); - return status; -} - static iree_status_t iree_hal_hip_device_queue_execute( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity, const iree_hal_semaphore_list_t wait_semaphore_list, const iree_hal_semaphore_list_t signal_semaphore_list, iree_hal_command_buffer_t* command_buffer, iree_hal_buffer_binding_table_t binding_table) { - iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_TRACE_ZONE_BEGIN(z0); + iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); + if (queue_affinity == IREE_HAL_QUEUE_AFFINITY_ANY) { queue_affinity = 0x1; } @@ -1970,7 +1951,6 @@ static iree_status_t iree_hal_hip_device_queue_execute( if (iree_status_is_ok(status)) { for (iree_host_size_t i = 0; i < wait_semaphore_list.count; ++i) { - iree_hal_resource_retain(wait_semaphore_list.semaphores[i]); status = iree_status_join( status, iree_hal_hip_semaphore_notify_work( @@ -1980,8 +1960,9 @@ static iree_status_t iree_hal_hip_device_queue_execute( &iree_hal_hip_device_semaphore_submit_callback, callback_data)); } } else { - iree_allocator_free(device->host_allocator, callback_data); + iree_hal_hip_device_destroy_callback_data(callback_data); } + IREE_TRACE_ZONE_END(z0); return status; }