From f8fa189493d9c1d2d173af00dee63c5a9cac4721 Mon Sep 17 00:00:00 2001 From: Ewan Crawford Date: Thu, 13 Feb 2025 16:09:23 +0000 Subject: [PATCH] [SYCL][Graph] Batch graph updates Uses PR https://github.com/oneapi-src/unified-runtime/pull/2666 to pass a list of updates to UR in a single host call per UR command-buffer in the graph, rather than making N calls to UR for N nodes to update. --- sycl/cmake/modules/FetchUnifiedRuntime.cmake | 2 +- sycl/cmake/modules/UnifiedRuntimeTag.cmake | 2 +- .../sycl_ext_oneapi_graph.asciidoc | 5 + sycl/source/detail/graph_impl.cpp | 174 +++++++++++++----- sycl/source/detail/graph_impl.hpp | 21 ++- sycl/source/detail/scheduler/commands.cpp | 4 +- 6 files changed, 155 insertions(+), 53 deletions(-) diff --git a/sycl/cmake/modules/FetchUnifiedRuntime.cmake b/sycl/cmake/modules/FetchUnifiedRuntime.cmake index b782b017191ed..bcdbf4b136c29 100644 --- a/sycl/cmake/modules/FetchUnifiedRuntime.cmake +++ b/sycl/cmake/modules/FetchUnifiedRuntime.cmake @@ -122,7 +122,7 @@ elseif(SYCL_UR_USE_FETCH_CONTENT) CACHE PATH "Path to external '${name}' adapter source dir" FORCE) endfunction() - set(UNIFIED_RUNTIME_REPO "https://github.com/oneapi-src/unified-runtime.git") + set(UNIFIED_RUNTIME_REPO "https://github.com/Bensuo/unified-runtime.git") include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules/UnifiedRuntimeTag.cmake) set(UMF_BUILD_EXAMPLES OFF CACHE INTERNAL "EXAMPLES") diff --git a/sycl/cmake/modules/UnifiedRuntimeTag.cmake b/sycl/cmake/modules/UnifiedRuntimeTag.cmake index 62720fd524b37..b5f71618b1a87 100644 --- a/sycl/cmake/modules/UnifiedRuntimeTag.cmake +++ b/sycl/cmake/modules/UnifiedRuntimeTag.cmake @@ -4,4 +4,4 @@ # Date: Thu Feb 13 11:43:34 2025 +0000 # Merge pull request #2680 from ldorau/Set_UMF_CUDA_INCLUDE_DIR_to_not_fetch_cudart_from_gitlab # Do not fetch cudart from gitlab for UMF -set(UNIFIED_RUNTIME_TAG d03f19a88e42cb98be9604ff24b61190d1e48727) +set(UNIFIED_RUNTIME_TAG "ewan/update_list") diff --git a/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc b/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc index 649c00fb474b5..40cc961676458 100644 --- a/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc +++ b/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc @@ -1397,6 +1397,8 @@ Exceptions: * Throws with error code `invalid` if `node` is not part of the graph. +* TODO - define behavior on exception + | [source,c++] ---- @@ -1431,6 +1433,7 @@ Exceptions: * Throws with error code `invalid` if any node in `nodes` is not part of the graph. +* TODO - define behavior on exception | [source, c++] ---- @@ -1482,6 +1485,8 @@ Exceptions: * Throws synchronously with error code `invalid` if `property::graph::updatable` was not set when the executable graph was created. + +* TODO - define behavior on exception |=== Table {counter: tableNumber}. Member functions of the `command_graph` class for diff --git a/sycl/source/detail/graph_impl.cpp b/sycl/source/detail/graph_impl.cpp index c89353efc0f4d..531321eef5384 100644 --- a/sycl/source/detail/graph_impl.cpp +++ b/sycl/source/detail/graph_impl.cpp @@ -1375,20 +1375,14 @@ void exec_graph_impl::update(std::shared_ptr Node) { this->update(std::vector>{Node}); } -void exec_graph_impl::update( - const std::vector> Nodes) { - - if (!MIsUpdatable) { - throw sycl::exception(sycl::make_error_code(errc::invalid), - "update() cannot be called on a executable graph " - "which was not created with property::updatable"); - } +bool exec_graph_impl::needsScheduledUpdate( + const std::vector> &Nodes, + std::vector &UpdateRequirements) { // If there are any accessor requirements, we have to update through the // scheduler to ensure that any allocations have taken place before trying to // update. bool NeedScheduledUpdate = false; - std::vector UpdateRequirements; // At worst we may have as many requirements as there are for the entire graph // for updating. UpdateRequirements.reserve(MRequirements.size()); @@ -1431,37 +1425,44 @@ void exec_graph_impl::update( // ensure it is ordered correctly. NeedScheduledUpdate |= MExecutionEvents.size() > 0; - if (NeedScheduledUpdate) { - auto AllocaQueue = std::make_shared( - sycl::detail::getSyclObjImpl(MGraphImpl->getDevice()), - sycl::detail::getSyclObjImpl(MGraphImpl->getContext()), - sycl::async_handler{}, sycl::property_list{}); - // Don't need to care about the return event here because it is synchronous - sycl::detail::Scheduler::getInstance().addCommandGraphUpdate( - this, Nodes, AllocaQueue, UpdateRequirements, MExecutionEvents); - } else { - for (auto &Node : Nodes) { - updateImpl(Node); + return NeedScheduledUpdate; +} + +std::map, std::vector>> +exec_graph_impl::getPartitionForNodes( + const std::vector> &Nodes) { + + std::map, std::vector>> + PartitionedNodes; + for (const auto &Partition : MPartitions) { + std::vector> NodesForPartition; + + for (auto &N : Nodes) { + auto ExecNode = MIDCache.find(N->MID); + assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache"); + + if (std::find_if(Partition->MSchedule.begin(), Partition->MSchedule.end(), + [ExecNode](const std::shared_ptr &P) { + return P->MID == ExecNode->second->MID; + }) != Partition->MSchedule.end()) { + NodesForPartition.push_back(N); + } + } + if (!NodesForPartition.empty()) { + PartitionedNodes.insert({Partition, NodesForPartition}); } } - // Rebuild cached requirements for this graph with updated nodes - MRequirements.clear(); - for (auto &Node : MNodeStorage) { - if (!Node->MCommandGroup) - continue; - MRequirements.insert(MRequirements.end(), - Node->MCommandGroup->getRequirements().begin(), - Node->MCommandGroup->getRequirements().end()); - } + return PartitionedNodes; } -void exec_graph_impl::updateImpl(std::shared_ptr Node) { - // Kernel node update is the only command type supported in UR for update. - // Updating any other types of nodes, e.g. empty & barrier nodes is a no-op. - if (Node->MCGType != sycl::detail::CGType::Kernel) { - return; - } +void exec_graph_impl::populateUpdateStruct( + std::shared_ptr &Node, + std::vector &MemobjDescs, + std::vector &PtrDescs, + std::vector &ValueDescs, + sycl::detail::NDRDescT &NDRDesc, + ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc) { auto ContextImpl = sycl::detail::getSyclObjImpl(MContext); const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter(); auto DeviceImpl = sycl::detail::getSyclObjImpl(MGraphImpl->getDevice()); @@ -1472,7 +1473,7 @@ void exec_graph_impl::updateImpl(std::shared_ptr Node) { // Copy args because we may modify them std::vector NodeArgs = ExecCG.getArguments(); // Copy NDR desc since we need to modify it - auto NDRDesc = ExecCG.MNDRDesc; + NDRDesc = ExecCG.MNDRDesc; ur_program_handle_t UrProgram = nullptr; ur_kernel_handle_t UrKernel = nullptr; @@ -1535,17 +1536,12 @@ void exec_graph_impl::updateImpl(std::shared_ptr Node) { if (EnforcedLocalSize) LocalSize = RequiredWGSize; } - // Create update descriptor // Storage for individual arg descriptors - std::vector MemobjDescs; - std::vector PtrDescs; - std::vector ValueDescs; MemobjDescs.reserve(MaskedArgs.size()); PtrDescs.reserve(MaskedArgs.size()); ValueDescs.reserve(MaskedArgs.size()); - ur_exp_command_buffer_update_kernel_launch_desc_t UpdateDesc{}; UpdateDesc.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC; UpdateDesc.pNext = nullptr; @@ -1622,25 +1618,105 @@ void exec_graph_impl::updateImpl(std::shared_ptr Node) { auto ExecNode = MIDCache.find(Node->MID); assert(ExecNode != MIDCache.end() && "Node ID was not found in ID cache"); + ur_exp_command_buffer_command_handle_t Command = + MCommandMap[ExecNode->second]; + UpdateDesc.hCommand = Command; + // Update ExecNode with new values from Node, in case we ever need to // rebuild the command buffers ExecNode->second->updateFromOtherNode(Node); - ur_exp_command_buffer_command_handle_t Command = - MCommandMap[ExecNode->second]; - ur_result_t Res = Adapter->call_nocheck< - sycl::detail::UrApiKind::urCommandBufferUpdateKernelLaunchExp>( - Command, &UpdateDesc); - + // TODO if (UrProgram) { // We retained these objects by calling getOrCreateKernel() Adapter->call(UrKernel); Adapter->call(UrProgram); } +} + +void exec_graph_impl::update( + const std::vector> Nodes) { + + if (!MIsUpdatable) { + throw sycl::exception(sycl::make_error_code(errc::invalid), + "update() cannot be called on a executable graph " + "which was not created with property::updatable"); + } - if (Res != UR_RESULT_SUCCESS) { - throw sycl::exception(errc::invalid, "Error updating command_graph"); + // TODO - To get UR handles to free somehow + std::map, std::vector>> + PartitionedNodes = getPartitionForNodes(Nodes); + + for (auto It = PartitionedNodes.begin(); It != PartitionedNodes.end(); It++) { + auto CommandBuffer = It->first->MCommandBuffers[MDevice]; + + // If there are any accessor requirements, we have to update through the + // scheduler to ensure that any allocations have taken place before trying + // to update. + std::vector UpdateRequirements; + bool NeedScheduledUpdate = needsScheduledUpdate(Nodes, UpdateRequirements); + if (NeedScheduledUpdate) { + auto AllocaQueue = std::make_shared( + sycl::detail::getSyclObjImpl(MGraphImpl->getDevice()), + sycl::detail::getSyclObjImpl(MGraphImpl->getContext()), + sycl::async_handler{}, sycl::property_list{}); + // Don't need to care about the return event here because it is + // synchronous + sycl::detail::Scheduler::getInstance().addCommandGraphUpdate( + this, It->second, AllocaQueue, UpdateRequirements, MExecutionEvents); + } else { + updateImpl(CommandBuffer, It->second); + } } + + // Rebuild cached requirements for this graph with updated nodes + MRequirements.clear(); + for (auto &Node : MNodeStorage) { + if (!Node->MCommandGroup) + continue; + MRequirements.insert(MRequirements.end(), + Node->MCommandGroup->getRequirements().begin(), + Node->MCommandGroup->getRequirements().end()); + } +} + +void exec_graph_impl::updateImpl( + ur_exp_command_buffer_handle_t CommandBuffer, + std::vector> &Nodes) { + + std::vector> + MemobjDescsList(Nodes.size()); + std::vector> + PtrDescsList(Nodes.size()); + std::vector> + ValueDescsList(Nodes.size()); + std::vector NDRDescList(Nodes.size()); + std::vector UpdateDescList( + Nodes.size()); + + for (size_t i = 0; i < Nodes.size(); i++) { + + auto &N = Nodes[i]; + + // Kernel node update is the only command type supported in UR for update. + // Updating any other types of nodes, e.g. empty & barrier nodes is a no-op. + if (N->MCGType != sycl::detail::CGType::Kernel) { + return; + } + + auto &MemobjDescs = MemobjDescsList[i]; + auto &PtrDescs = PtrDescsList[i]; + auto &ValueDescs = ValueDescsList[i]; + auto &NDRDesc = NDRDescList[i]; + auto &UpdateDesc = UpdateDescList[i]; + populateUpdateStruct(N, MemobjDescs, PtrDescs, ValueDescs, NDRDesc, + UpdateDesc); + } + + auto ContextImpl = sycl::detail::getSyclObjImpl(MContext); + const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter(); + Adapter->call( + CommandBuffer, UpdateDescList.size(), UpdateDescList.data()); } modifiable_command_graph::modifiable_command_graph( diff --git a/sycl/source/detail/graph_impl.hpp b/sycl/source/detail/graph_impl.hpp index 1a1f2ef9cf55f..863d0622c17a8 100644 --- a/sycl/source/detail/graph_impl.hpp +++ b/sycl/source/detail/graph_impl.hpp @@ -1301,7 +1301,8 @@ class exec_graph_impl { void update(std::shared_ptr Node); void update(const std::vector> Nodes); - void updateImpl(std::shared_ptr NodeImpl); + void updateImpl(ur_exp_command_buffer_handle_t CommandBuffer, + std::vector> &Nodes); unsigned long long getID() const { return MID; } @@ -1371,6 +1372,24 @@ class exec_graph_impl { Stream.close(); } + // TODO comment + bool needsScheduledUpdate( + const std::vector> &Nodes, + std::vector &UpdateRequirements); + + // TODO comment + std::map, std::vector>> + getPartitionForNodes(const std::vector> &Nodes); + + // TODO comment + void populateUpdateStruct( + std::shared_ptr &Node, + std::vector &MemobjDescs, + std::vector &PtrDescs, + std::vector &ValueDescs, + sycl::detail::NDRDescT &NDRDesc, + ur_exp_command_buffer_update_kernel_launch_desc_t &UpdateDesc); + /// Execution schedule of nodes in the graph. std::list> MSchedule; /// Pointer to the modifiable graph impl associated with this executable diff --git a/sycl/source/detail/scheduler/commands.cpp b/sycl/source/detail/scheduler/commands.cpp index 005008a74ebd0..0d3e5186b3d03 100644 --- a/sycl/source/detail/scheduler/commands.cpp +++ b/sycl/source/detail/scheduler/commands.cpp @@ -3729,9 +3729,11 @@ ur_result_t UpdateCommandBufferCommand::enqueueImp() { } } } - MGraph->updateImpl(Node); } + ur_exp_command_buffer_handle_t CommandBuffer{}; // TODO + MGraph->updateImpl(CommandBuffer, MNodes); + return UR_RESULT_SUCCESS; }