From 60b567fc4466d3233fe916c4c179ae18005772d4 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sat, 23 Nov 2024 20:35:28 +0800 Subject: [PATCH 01/10] Enable rocm and vulkan in CI workflow of PJRT plugin Signed-off-by: PragmaTwice --- .github/workflows/pkgci_test_pjrt.yml | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/pkgci_test_pjrt.yml b/.github/workflows/pkgci_test_pjrt.yml index 8237648b8540..c908f8b6b379 100644 --- a/.github/workflows/pkgci_test_pjrt.yml +++ b/.github/workflows/pkgci_test_pjrt.yml @@ -27,11 +27,10 @@ jobs: # TODO: cuda runner is not available yet, refer to #18814 # - runner: some-cuda-available-runner # pjrt_platform: cuda - # TODO: enable these AMD runners - # - runner: nodai-amdgpu-w7900-x86-64 - # pjrt_platform: rocm - # - runner: nodai-amdgpu-w7900-x86-64 - # pjrt_platform: vulkan + - runner: nodai-amdgpu-w7900-x86-64 + pjrt_platform: rocm + - runner: nodai-amdgpu-w7900-x86-64 + pjrt_platform: vulkan name: Build and test runs-on: ${{ matrix.runner }} env: From 6e577c3bf7b6305b541319624ab72cbe3ab5eadd Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sat, 23 Nov 2024 21:22:10 +0800 Subject: [PATCH 02/10] Add the missing cmake package to pyproject.toml Signed-off-by: PragmaTwice --- integrations/pjrt/python_packages/iree_cpu_plugin/pyproject.toml | 1 + .../pjrt/python_packages/iree_cuda_plugin/pyproject.toml | 1 + .../pjrt/python_packages/iree_rocm_plugin/pyproject.toml | 1 + .../pjrt/python_packages/iree_vulkan_plugin/pyproject.toml | 1 + 4 files changed, 4 insertions(+) diff --git a/integrations/pjrt/python_packages/iree_cpu_plugin/pyproject.toml b/integrations/pjrt/python_packages/iree_cpu_plugin/pyproject.toml index f539b95a4ede..682d01b62e5d 100644 --- a/integrations/pjrt/python_packages/iree_cpu_plugin/pyproject.toml +++ b/integrations/pjrt/python_packages/iree_cpu_plugin/pyproject.toml @@ -3,5 +3,6 @@ requires = [ "setuptools>=42", "wheel", "ninja", + "cmake", ] build-backend = "setuptools.build_meta" diff --git a/integrations/pjrt/python_packages/iree_cuda_plugin/pyproject.toml b/integrations/pjrt/python_packages/iree_cuda_plugin/pyproject.toml index f539b95a4ede..682d01b62e5d 100644 --- a/integrations/pjrt/python_packages/iree_cuda_plugin/pyproject.toml +++ b/integrations/pjrt/python_packages/iree_cuda_plugin/pyproject.toml @@ -3,5 +3,6 @@ requires = [ "setuptools>=42", "wheel", "ninja", + "cmake", ] build-backend = "setuptools.build_meta" diff --git a/integrations/pjrt/python_packages/iree_rocm_plugin/pyproject.toml b/integrations/pjrt/python_packages/iree_rocm_plugin/pyproject.toml index f539b95a4ede..682d01b62e5d 100644 --- a/integrations/pjrt/python_packages/iree_rocm_plugin/pyproject.toml +++ b/integrations/pjrt/python_packages/iree_rocm_plugin/pyproject.toml @@ -3,5 +3,6 @@ requires = [ "setuptools>=42", "wheel", "ninja", + "cmake", ] build-backend = "setuptools.build_meta" diff --git a/integrations/pjrt/python_packages/iree_vulkan_plugin/pyproject.toml b/integrations/pjrt/python_packages/iree_vulkan_plugin/pyproject.toml index f539b95a4ede..682d01b62e5d 100644 --- a/integrations/pjrt/python_packages/iree_vulkan_plugin/pyproject.toml +++ b/integrations/pjrt/python_packages/iree_vulkan_plugin/pyproject.toml @@ -3,5 +3,6 @@ requires = [ "setuptools>=42", "wheel", "ninja", + "cmake", ] build-backend = "setuptools.build_meta" From bf355b0f9fa69f2647428fe8dffe940314ce5795 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sat, 23 Nov 2024 21:50:11 +0800 Subject: [PATCH 03/10] Fix value of the cmake option IREE_EXTERNAL_HAL_DRIVERS Signed-off-by: PragmaTwice --- integrations/pjrt/python_packages/iree_rocm_plugin/setup.py | 2 +- integrations/pjrt/src/iree_pjrt/rocm/CMakeLists.txt | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py b/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py index 0cba11d4577a..3afbec55f18d 100644 --- a/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py +++ b/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py @@ -32,7 +32,7 @@ def build_default_configuration(self): print("*****************************", file=sys.stderr) self.build_configuration( os.path.join(THIS_DIR, "build", "cmake"), - extra_cmake_args=("-DIREE_EXTERNAL_HAL_DRIVERS=rocm",), + extra_cmake_args=("-DIREE_HAL_DRIVER_HIP=ON",), ) print("Target populated.", file=sys.stderr) diff --git a/integrations/pjrt/src/iree_pjrt/rocm/CMakeLists.txt b/integrations/pjrt/src/iree_pjrt/rocm/CMakeLists.txt index 8992ecc3099c..10b5d72a5c7e 100644 --- a/integrations/pjrt/src/iree_pjrt/rocm/CMakeLists.txt +++ b/integrations/pjrt/src/iree_pjrt/rocm/CMakeLists.txt @@ -13,8 +13,8 @@ iree_cc_library( "client.cc" DEPS iree_pjrt::common - iree::experimental::rocm - iree::experimental::rocm::registration + iree::hal::drivers::hip + iree::hal::drivers::hip::registration ) iree_cc_library( From 4eb40d45f30d8310f0e396e33aefd7aefda1dabd Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sat, 23 Nov 2024 22:10:45 +0800 Subject: [PATCH 04/10] Fix hal driver calls in rocm PJRT plugin Signed-off-by: PragmaTwice --- integrations/pjrt/src/iree_pjrt/rocm/client.cc | 4 ++-- integrations/pjrt/src/iree_pjrt/rocm/client.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/integrations/pjrt/src/iree_pjrt/rocm/client.cc b/integrations/pjrt/src/iree_pjrt/rocm/client.cc index 5f290d0ddab4..1c30b010c30f 100644 --- a/integrations/pjrt/src/iree_pjrt/rocm/client.cc +++ b/integrations/pjrt/src/iree_pjrt/rocm/client.cc @@ -6,7 +6,7 @@ #include "iree_pjrt/rocm/client.h" -#include "experimental/rocm/registration/driver_module.h" +#include "iree/hal/drivers/hip/registration/driver_module.h" namespace iree::pjrt::rocm { @@ -17,7 +17,7 @@ ROCMClientInstance::ROCMClientInstance(std::unique_ptr platform) // TODO: Get this when constructing the client so it is guaranteed to // match. cached_platform_name_ = "iree_rocm"; - IREE_CHECK_OK(iree_hal_rocm_driver_module_register(driver_registry_)); + IREE_CHECK_OK(iree_hal_hip_driver_module_register(driver_registry_)); } ROCMClientInstance::~ROCMClientInstance() {} diff --git a/integrations/pjrt/src/iree_pjrt/rocm/client.h b/integrations/pjrt/src/iree_pjrt/rocm/client.h index e2b78da2e880..0ebb793b6fa7 100644 --- a/integrations/pjrt/src/iree_pjrt/rocm/client.h +++ b/integrations/pjrt/src/iree_pjrt/rocm/client.h @@ -7,7 +7,7 @@ #ifndef IREE_PJRT_PLUGIN_PJRT_ROCM_CLIENT_H_ #define IREE_PJRT_PLUGIN_PJRT_ROCM_CLIENT_H_ -#include "experimental/rocm/api.h" +#include "iree/hal/drivers/hip/api.h" #include "iree_pjrt/common/api_impl.h" namespace iree::pjrt::rocm { From 45abc51c42c7babe6b5a90227d100f3ba1409de5 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sat, 23 Nov 2024 22:27:18 +0800 Subject: [PATCH 05/10] Fix cmake for rocm PJRT plugin Signed-off-by: PragmaTwice --- integrations/pjrt/src/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/pjrt/src/CMakeLists.txt b/integrations/pjrt/src/CMakeLists.txt index d0f49479b392..93d67279a79d 100644 --- a/integrations/pjrt/src/CMakeLists.txt +++ b/integrations/pjrt/src/CMakeLists.txt @@ -27,7 +27,7 @@ endif() if(IREE_HAL_DRIVER_CUDA) add_subdirectory(iree_pjrt/cuda) endif() -if("rocm" IN_LIST IREE_EXTERNAL_HAL_DRIVERS) +if(IREE_HAL_DRIVER_HIP) add_subdirectory(iree_pjrt/rocm) endif() if(IREE_HAL_DRIVER_VULKAN) From d06204ee737a9391979e5d4a3b144c6f135daef6 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sat, 23 Nov 2024 22:28:29 +0800 Subject: [PATCH 06/10] Fix backend name for vulkan PJRT plugin Signed-off-by: PragmaTwice --- integrations/pjrt/src/iree_pjrt/vulkan/client.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/pjrt/src/iree_pjrt/vulkan/client.cc b/integrations/pjrt/src/iree_pjrt/vulkan/client.cc index 853ead814b24..228cbe9eca8c 100644 --- a/integrations/pjrt/src/iree_pjrt/vulkan/client.cc +++ b/integrations/pjrt/src/iree_pjrt/vulkan/client.cc @@ -32,7 +32,7 @@ iree_status_t VulkanClientInstance::CreateDriver( } bool VulkanClientInstance::SetDefaultCompilerFlags(CompilerJob* compiler_job) { - return compiler_job->SetFlag("--iree-hal-target-backends=vulkan"); + return compiler_job->SetFlag("--iree-hal-target-backends=vulkan-spirv"); } } // namespace iree::pjrt::vulkan From 2426649130d226597b54c351a40db287801e79f1 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sat, 23 Nov 2024 22:41:15 +0800 Subject: [PATCH 07/10] Rename driver name from rocm to hip Signed-off-by: PragmaTwice --- integrations/pjrt/src/iree_pjrt/rocm/client.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/pjrt/src/iree_pjrt/rocm/client.cc b/integrations/pjrt/src/iree_pjrt/rocm/client.cc index 1c30b010c30f..e22585f3b5b8 100644 --- a/integrations/pjrt/src/iree_pjrt/rocm/client.cc +++ b/integrations/pjrt/src/iree_pjrt/rocm/client.cc @@ -23,7 +23,7 @@ ROCMClientInstance::ROCMClientInstance(std::unique_ptr platform) ROCMClientInstance::~ROCMClientInstance() {} iree_status_t ROCMClientInstance::CreateDriver(iree_hal_driver_t** out_driver) { - iree_string_view_t driver_name = iree_make_cstring_view("rocm"); + iree_string_view_t driver_name = iree_make_cstring_view("hip"); IREE_RETURN_IF_ERROR(iree_hal_driver_registry_try_create( driver_registry_, driver_name, host_allocator_, out_driver)); logger().debug("ROCM driver created"); From f378748cdfc21f3d6e5cf2d70321c50219115581 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sat, 23 Nov 2024 23:08:37 +0800 Subject: [PATCH 08/10] Add compiler flag --iree-hip-target Signed-off-by: PragmaTwice --- integrations/pjrt/src/iree_pjrt/rocm/client.cc | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/integrations/pjrt/src/iree_pjrt/rocm/client.cc b/integrations/pjrt/src/iree_pjrt/rocm/client.cc index e22585f3b5b8..de4c0626d4e4 100644 --- a/integrations/pjrt/src/iree_pjrt/rocm/client.cc +++ b/integrations/pjrt/src/iree_pjrt/rocm/client.cc @@ -31,7 +31,18 @@ iree_status_t ROCMClientInstance::CreateDriver(iree_hal_driver_t** out_driver) { } bool ROCMClientInstance::SetDefaultCompilerFlags(CompilerJob* compiler_job) { - return compiler_job->SetFlag("--iree-hal-target-backends=rocm"); + auto flags = { + "--iree-hal-target-backends=rocm", + + // TODO: gfx908 is just a placeholder here to make it work, + // we should instead detect the device target on the fly + "--iree-hip-target=gfx908", + }; + + for (auto flag : flags) { + if (!compiler_job->SetFlag(flag)) return false; + } + return true; } } // namespace iree::pjrt::rocm From f0258ad8095e56e01548f4b9f8058f8044dc51b4 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Sun, 24 Nov 2024 13:53:59 +0800 Subject: [PATCH 09/10] Detect HIP target automatically Signed-off-by: PragmaTwice --- .../pjrt/src/iree_pjrt/rocm/client.cc | 56 ++++++++++++++++--- integrations/pjrt/src/iree_pjrt/rocm/client.h | 1 + runtime/src/iree/hal/drivers/hip/hip_device.h | 5 ++ runtime/src/iree/hal/drivers/hip/hip_driver.c | 14 +++++ 4 files changed, 67 insertions(+), 9 deletions(-) diff --git a/integrations/pjrt/src/iree_pjrt/rocm/client.cc b/integrations/pjrt/src/iree_pjrt/rocm/client.cc index de4c0626d4e4..f3565f3fb93e 100644 --- a/integrations/pjrt/src/iree_pjrt/rocm/client.cc +++ b/integrations/pjrt/src/iree_pjrt/rocm/client.cc @@ -6,6 +6,8 @@ #include "iree_pjrt/rocm/client.h" +#include "iree/hal/drivers/hip/api.h" +#include "iree/hal/drivers/hip/hip_device.h" #include "iree/hal/drivers/hip/registration/driver_module.h" namespace iree::pjrt::rocm { @@ -17,30 +19,66 @@ ROCMClientInstance::ROCMClientInstance(std::unique_ptr platform) // TODO: Get this when constructing the client so it is guaranteed to // match. cached_platform_name_ = "iree_rocm"; - IREE_CHECK_OK(iree_hal_hip_driver_module_register(driver_registry_)); } ROCMClientInstance::~ROCMClientInstance() {} iree_status_t ROCMClientInstance::CreateDriver(iree_hal_driver_t** out_driver) { iree_string_view_t driver_name = iree_make_cstring_view("hip"); - IREE_RETURN_IF_ERROR(iree_hal_driver_registry_try_create( - driver_registry_, driver_name, host_allocator_, out_driver)); + + // Device params. + iree_hal_hip_device_params_t default_params; + iree_hal_hip_device_params_initialize(&default_params); + + // Driver params. + iree_hal_hip_driver_options_t driver_options; + iree_hal_hip_driver_options_initialize(&driver_options); + + IREE_RETURN_IF_ERROR(iree_hal_hip_driver_create(driver_name, &driver_options, + &default_params, + host_allocator_, out_driver)); logger().debug("ROCM driver created"); + + // retrieve the target name of current available device + iree_host_size_t device_info_count; + iree_hal_device_info_t* device_infos; + IREE_RETURN_IF_ERROR(iree_hal_driver_query_available_devices( + *out_driver, host_allocator_, &device_info_count, &device_infos)); + + // TODO: here we just use the target name of the first available device, + // but ideally we should find the device which will run the program + if (device_info_count > 0) { + hipDeviceProp_tR0000 props; + IREE_RETURN_IF_ERROR(iree_hal_hip_get_device_properties( + *out_driver, device_infos->device_id, &props)); + + // `gcnArchName` comes back like gfx90a:sramecc+:xnack- for a fully + // specified target. However the IREE target-chip flag only expects the + // prefix. refer to + // https://github.com/iree-org/iree-turbine/blob/965247e/iree/turbine/runtime/device.py#L495 + std::string_view target = props.gcnArchName; + if (auto pos = target.find(':'); pos != target.npos) { + target = target.substr(0, pos); + } + + hip_target_ = target; + logger().debug("HIP target detected: " + hip_target_); + } + return iree_ok_status(); } bool ROCMClientInstance::SetDefaultCompilerFlags(CompilerJob* compiler_job) { - auto flags = { + std::vector flags = { "--iree-hal-target-backends=rocm", - - // TODO: gfx908 is just a placeholder here to make it work, - // we should instead detect the device target on the fly - "--iree-hip-target=gfx908", }; + if (!hip_target_.empty()) { + flags.push_back("--iree-hip-target=" + hip_target_); + } + for (auto flag : flags) { - if (!compiler_job->SetFlag(flag)) return false; + if (!compiler_job->SetFlag(flag.c_str())) return false; } return true; } diff --git a/integrations/pjrt/src/iree_pjrt/rocm/client.h b/integrations/pjrt/src/iree_pjrt/rocm/client.h index 0ebb793b6fa7..b4ed002ca80e 100644 --- a/integrations/pjrt/src/iree_pjrt/rocm/client.h +++ b/integrations/pjrt/src/iree_pjrt/rocm/client.h @@ -20,6 +20,7 @@ class ROCMClientInstance final : public ClientInstance { bool SetDefaultCompilerFlags(CompilerJob* compiler_job) override; private: + std::string hip_target_; }; } // namespace iree::pjrt::rocm diff --git a/runtime/src/iree/hal/drivers/hip/hip_device.h b/runtime/src/iree/hal/drivers/hip/hip_device.h index 044f4d53f844..a93dcf62cf3c 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_device.h +++ b/runtime/src/iree/hal/drivers/hip/hip_device.h @@ -42,6 +42,11 @@ iree_status_t iree_hal_hip_device_create_stream_command_buffer( // contexts and the context may be in use on other threads. hipCtx_t iree_hal_hip_device_context(iree_hal_device_t* device); +// Retrieve device properties for the given |device_id| to |out_props| +iree_status_t iree_hal_hip_get_device_properties( + iree_hal_driver_t* driver, iree_hal_device_id_t device_id, + hipDeviceProp_tR0000* out_props); + // Returns the dynamic symbol table from the |device| if it is a HIP device // and otherwise returns NULL. // diff --git a/runtime/src/iree/hal/drivers/hip/hip_driver.c b/runtime/src/iree/hal/drivers/hip/hip_driver.c index 4600d48b086d..bdeba68dd0c4 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_driver.c +++ b/runtime/src/iree/hal/drivers/hip/hip_driver.c @@ -242,6 +242,20 @@ static iree_status_t iree_hal_hip_driver_query_available_devices( return status; } +iree_status_t iree_hal_hip_get_device_properties( + iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id, + hipDeviceProp_tR0000* out_props) { + iree_hal_hip_driver_t* driver = iree_hal_hip_driver_cast(base_driver); + + hipDevice_t device = IREE_DEVICE_ID_TO_HIPDEVICE(device_id); + + IREE_HIP_RETURN_IF_ERROR(&driver->hip_symbols, + hipGetDeviceProperties(out_props, device), + "hipGetDeviceProperties"); + + return iree_ok_status(); +} + static iree_status_t iree_hal_hip_driver_dump_device_info( iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id, iree_string_builder_t* builder) { From dd972e6c7d48dc6c7e72fab7cd076152696903c0 Mon Sep 17 00:00:00 2001 From: PragmaTwice Date: Tue, 26 Nov 2024 21:43:52 +0800 Subject: [PATCH 10/10] Disable fail-fast Signed-off-by: PragmaTwice --- .github/workflows/pkgci_test_pjrt.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/pkgci_test_pjrt.yml b/.github/workflows/pkgci_test_pjrt.yml index c908f8b6b379..5adbb2b16184 100644 --- a/.github/workflows/pkgci_test_pjrt.yml +++ b/.github/workflows/pkgci_test_pjrt.yml @@ -20,6 +20,7 @@ on: jobs: build_and_test: strategy: + fail-fast: false matrix: include: - runner: ubuntu-20.04