diff --git a/.github/workflows/pkgci_test_pjrt.yml b/.github/workflows/pkgci_test_pjrt.yml index a32b1c4503da..e6ee719827f2 100644 --- a/.github/workflows/pkgci_test_pjrt.yml +++ b/.github/workflows/pkgci_test_pjrt.yml @@ -20,6 +20,7 @@ on: jobs: build_and_test: strategy: + fail-fast: false matrix: include: - runner: ubuntu-20.04 @@ -27,11 +28,10 @@ jobs: # TODO: cuda runner is not available yet, refer to #18814 # - runner: some-cuda-available-runner # pjrt_platform: cuda - # TODO: enable these AMD runners - # - runner: nodai-amdgpu-w7900-x86-64 - # pjrt_platform: rocm - # - runner: nodai-amdgpu-w7900-x86-64 - # pjrt_platform: vulkan + - runner: nodai-amdgpu-w7900-x86-64 + pjrt_platform: rocm + - runner: nodai-amdgpu-w7900-x86-64 + pjrt_platform: vulkan name: Build and test runs-on: ${{ matrix.runner }} env: diff --git a/integrations/pjrt/python_packages/iree_cpu_plugin/pyproject.toml b/integrations/pjrt/python_packages/iree_cpu_plugin/pyproject.toml index f539b95a4ede..682d01b62e5d 100644 --- a/integrations/pjrt/python_packages/iree_cpu_plugin/pyproject.toml +++ b/integrations/pjrt/python_packages/iree_cpu_plugin/pyproject.toml @@ -3,5 +3,6 @@ requires = [ "setuptools>=42", "wheel", "ninja", + "cmake", ] build-backend = "setuptools.build_meta" diff --git a/integrations/pjrt/python_packages/iree_cuda_plugin/pyproject.toml b/integrations/pjrt/python_packages/iree_cuda_plugin/pyproject.toml index f539b95a4ede..682d01b62e5d 100644 --- a/integrations/pjrt/python_packages/iree_cuda_plugin/pyproject.toml +++ b/integrations/pjrt/python_packages/iree_cuda_plugin/pyproject.toml @@ -3,5 +3,6 @@ requires = [ "setuptools>=42", "wheel", "ninja", + "cmake", ] build-backend = "setuptools.build_meta" diff --git a/integrations/pjrt/python_packages/iree_rocm_plugin/pyproject.toml b/integrations/pjrt/python_packages/iree_rocm_plugin/pyproject.toml index f539b95a4ede..682d01b62e5d 100644 --- a/integrations/pjrt/python_packages/iree_rocm_plugin/pyproject.toml +++ b/integrations/pjrt/python_packages/iree_rocm_plugin/pyproject.toml @@ -3,5 +3,6 @@ requires = [ "setuptools>=42", "wheel", "ninja", + "cmake", ] build-backend = "setuptools.build_meta" diff --git a/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py b/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py index 0cba11d4577a..3afbec55f18d 100644 --- a/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py +++ b/integrations/pjrt/python_packages/iree_rocm_plugin/setup.py @@ -32,7 +32,7 @@ def build_default_configuration(self): print("*****************************", file=sys.stderr) self.build_configuration( os.path.join(THIS_DIR, "build", "cmake"), - extra_cmake_args=("-DIREE_EXTERNAL_HAL_DRIVERS=rocm",), + extra_cmake_args=("-DIREE_HAL_DRIVER_HIP=ON",), ) print("Target populated.", file=sys.stderr) diff --git a/integrations/pjrt/python_packages/iree_vulkan_plugin/pyproject.toml b/integrations/pjrt/python_packages/iree_vulkan_plugin/pyproject.toml index f539b95a4ede..682d01b62e5d 100644 --- a/integrations/pjrt/python_packages/iree_vulkan_plugin/pyproject.toml +++ b/integrations/pjrt/python_packages/iree_vulkan_plugin/pyproject.toml @@ -3,5 +3,6 @@ requires = [ "setuptools>=42", "wheel", "ninja", + "cmake", ] build-backend = "setuptools.build_meta" diff --git a/integrations/pjrt/src/CMakeLists.txt b/integrations/pjrt/src/CMakeLists.txt index d0f49479b392..93d67279a79d 100644 --- a/integrations/pjrt/src/CMakeLists.txt +++ b/integrations/pjrt/src/CMakeLists.txt @@ -27,7 +27,7 @@ endif() if(IREE_HAL_DRIVER_CUDA) add_subdirectory(iree_pjrt/cuda) endif() -if("rocm" IN_LIST IREE_EXTERNAL_HAL_DRIVERS) +if(IREE_HAL_DRIVER_HIP) add_subdirectory(iree_pjrt/rocm) endif() if(IREE_HAL_DRIVER_VULKAN) diff --git a/integrations/pjrt/src/iree_pjrt/rocm/CMakeLists.txt b/integrations/pjrt/src/iree_pjrt/rocm/CMakeLists.txt index 8992ecc3099c..10b5d72a5c7e 100644 --- a/integrations/pjrt/src/iree_pjrt/rocm/CMakeLists.txt +++ b/integrations/pjrt/src/iree_pjrt/rocm/CMakeLists.txt @@ -13,8 +13,8 @@ iree_cc_library( "client.cc" DEPS iree_pjrt::common - iree::experimental::rocm - iree::experimental::rocm::registration + iree::hal::drivers::hip + iree::hal::drivers::hip::registration ) iree_cc_library( diff --git a/integrations/pjrt/src/iree_pjrt/rocm/client.cc b/integrations/pjrt/src/iree_pjrt/rocm/client.cc index 5f290d0ddab4..f3565f3fb93e 100644 --- a/integrations/pjrt/src/iree_pjrt/rocm/client.cc +++ b/integrations/pjrt/src/iree_pjrt/rocm/client.cc @@ -6,7 +6,9 @@ #include "iree_pjrt/rocm/client.h" -#include "experimental/rocm/registration/driver_module.h" +#include "iree/hal/drivers/hip/api.h" +#include "iree/hal/drivers/hip/hip_device.h" +#include "iree/hal/drivers/hip/registration/driver_module.h" namespace iree::pjrt::rocm { @@ -17,21 +19,68 @@ ROCMClientInstance::ROCMClientInstance(std::unique_ptr platform) // TODO: Get this when constructing the client so it is guaranteed to // match. cached_platform_name_ = "iree_rocm"; - IREE_CHECK_OK(iree_hal_rocm_driver_module_register(driver_registry_)); } ROCMClientInstance::~ROCMClientInstance() {} iree_status_t ROCMClientInstance::CreateDriver(iree_hal_driver_t** out_driver) { - iree_string_view_t driver_name = iree_make_cstring_view("rocm"); - IREE_RETURN_IF_ERROR(iree_hal_driver_registry_try_create( - driver_registry_, driver_name, host_allocator_, out_driver)); + iree_string_view_t driver_name = iree_make_cstring_view("hip"); + + // Device params. + iree_hal_hip_device_params_t default_params; + iree_hal_hip_device_params_initialize(&default_params); + + // Driver params. + iree_hal_hip_driver_options_t driver_options; + iree_hal_hip_driver_options_initialize(&driver_options); + + IREE_RETURN_IF_ERROR(iree_hal_hip_driver_create(driver_name, &driver_options, + &default_params, + host_allocator_, out_driver)); logger().debug("ROCM driver created"); + + // retrieve the target name of current available device + iree_host_size_t device_info_count; + iree_hal_device_info_t* device_infos; + IREE_RETURN_IF_ERROR(iree_hal_driver_query_available_devices( + *out_driver, host_allocator_, &device_info_count, &device_infos)); + + // TODO: here we just use the target name of the first available device, + // but ideally we should find the device which will run the program + if (device_info_count > 0) { + hipDeviceProp_tR0000 props; + IREE_RETURN_IF_ERROR(iree_hal_hip_get_device_properties( + *out_driver, device_infos->device_id, &props)); + + // `gcnArchName` comes back like gfx90a:sramecc+:xnack- for a fully + // specified target. However the IREE target-chip flag only expects the + // prefix. refer to + // https://github.com/iree-org/iree-turbine/blob/965247e/iree/turbine/runtime/device.py#L495 + std::string_view target = props.gcnArchName; + if (auto pos = target.find(':'); pos != target.npos) { + target = target.substr(0, pos); + } + + hip_target_ = target; + logger().debug("HIP target detected: " + hip_target_); + } + return iree_ok_status(); } bool ROCMClientInstance::SetDefaultCompilerFlags(CompilerJob* compiler_job) { - return compiler_job->SetFlag("--iree-hal-target-backends=rocm"); + std::vector flags = { + "--iree-hal-target-backends=rocm", + }; + + if (!hip_target_.empty()) { + flags.push_back("--iree-hip-target=" + hip_target_); + } + + for (auto flag : flags) { + if (!compiler_job->SetFlag(flag.c_str())) return false; + } + return true; } } // namespace iree::pjrt::rocm diff --git a/integrations/pjrt/src/iree_pjrt/rocm/client.h b/integrations/pjrt/src/iree_pjrt/rocm/client.h index e2b78da2e880..b4ed002ca80e 100644 --- a/integrations/pjrt/src/iree_pjrt/rocm/client.h +++ b/integrations/pjrt/src/iree_pjrt/rocm/client.h @@ -7,7 +7,7 @@ #ifndef IREE_PJRT_PLUGIN_PJRT_ROCM_CLIENT_H_ #define IREE_PJRT_PLUGIN_PJRT_ROCM_CLIENT_H_ -#include "experimental/rocm/api.h" +#include "iree/hal/drivers/hip/api.h" #include "iree_pjrt/common/api_impl.h" namespace iree::pjrt::rocm { @@ -20,6 +20,7 @@ class ROCMClientInstance final : public ClientInstance { bool SetDefaultCompilerFlags(CompilerJob* compiler_job) override; private: + std::string hip_target_; }; } // namespace iree::pjrt::rocm diff --git a/integrations/pjrt/src/iree_pjrt/vulkan/client.cc b/integrations/pjrt/src/iree_pjrt/vulkan/client.cc index 853ead814b24..228cbe9eca8c 100644 --- a/integrations/pjrt/src/iree_pjrt/vulkan/client.cc +++ b/integrations/pjrt/src/iree_pjrt/vulkan/client.cc @@ -32,7 +32,7 @@ iree_status_t VulkanClientInstance::CreateDriver( } bool VulkanClientInstance::SetDefaultCompilerFlags(CompilerJob* compiler_job) { - return compiler_job->SetFlag("--iree-hal-target-backends=vulkan"); + return compiler_job->SetFlag("--iree-hal-target-backends=vulkan-spirv"); } } // namespace iree::pjrt::vulkan diff --git a/runtime/src/iree/hal/drivers/hip/hip_device.h b/runtime/src/iree/hal/drivers/hip/hip_device.h index 044f4d53f844..a93dcf62cf3c 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_device.h +++ b/runtime/src/iree/hal/drivers/hip/hip_device.h @@ -42,6 +42,11 @@ iree_status_t iree_hal_hip_device_create_stream_command_buffer( // contexts and the context may be in use on other threads. hipCtx_t iree_hal_hip_device_context(iree_hal_device_t* device); +// Retrieve device properties for the given |device_id| to |out_props| +iree_status_t iree_hal_hip_get_device_properties( + iree_hal_driver_t* driver, iree_hal_device_id_t device_id, + hipDeviceProp_tR0000* out_props); + // Returns the dynamic symbol table from the |device| if it is a HIP device // and otherwise returns NULL. // diff --git a/runtime/src/iree/hal/drivers/hip/hip_driver.c b/runtime/src/iree/hal/drivers/hip/hip_driver.c index 4600d48b086d..bdeba68dd0c4 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_driver.c +++ b/runtime/src/iree/hal/drivers/hip/hip_driver.c @@ -242,6 +242,20 @@ static iree_status_t iree_hal_hip_driver_query_available_devices( return status; } +iree_status_t iree_hal_hip_get_device_properties( + iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id, + hipDeviceProp_tR0000* out_props) { + iree_hal_hip_driver_t* driver = iree_hal_hip_driver_cast(base_driver); + + hipDevice_t device = IREE_DEVICE_ID_TO_HIPDEVICE(device_id); + + IREE_HIP_RETURN_IF_ERROR(&driver->hip_symbols, + hipGetDeviceProperties(out_props, device), + "hipGetDeviceProperties"); + + return iree_ok_status(); +} + static iree_status_t iree_hal_hip_driver_dump_device_info( iree_hal_driver_t* base_driver, iree_hal_device_id_t device_id, iree_string_builder_t* builder) {