From 80226e65d26990545787b88f7f84bfe9fe4afd74 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Thu, 12 Sep 2024 03:20:53 -0700 Subject: [PATCH] Attempt to fix test. Signed-off-by: JackAKirk --- .../device/urDeviceGetGlobalTimestamps.cpp | 48 +++++++++++++++++-- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/test/conformance/device/urDeviceGetGlobalTimestamps.cpp b/test/conformance/device/urDeviceGetGlobalTimestamps.cpp index 2b874ceaf2..e0a83467f6 100644 --- a/test/conformance/device/urDeviceGetGlobalTimestamps.cpp +++ b/test/conformance/device/urDeviceGetGlobalTimestamps.cpp @@ -27,6 +27,7 @@ T absolute_difference(T a, T b) { return std::max(a, b) - std::min(a, b); } +/* void create_context_if_cuda(uur::raii::Context context, const ur_device_handle_t *devs) { void *vendorValue = nullptr; @@ -38,13 +39,22 @@ void create_context_if_cuda(uur::raii::Context context, ASSERT_SUCCESS(urContextCreate(1, devs, nullptr, context.ptr())); } } +create_context_if_cuda(context, devices.data()); +*/ using urDeviceGetGlobalTimestampTest = uur::urAllDevicesTest; TEST_F(urDeviceGetGlobalTimestampTest, Success) { uur::raii::Context context = nullptr; - create_context_if_cuda(context, devices.data()); + void *vendorValue = nullptr; + size_t *pPropSizeRet = nullptr; + urDeviceGetInfo(devices[0], UR_DEVICE_INFO_VENDOR_ID, sizeof(uint32_t), + vendorValue, pPropSizeRet); + // cuda backend defines global timestamp only over a context. + if (*reinterpret_cast(vendorValue) == 4318u) { + ASSERT_SUCCESS(urContextCreate(1, devices.data(), nullptr, context.ptr())); + } for (auto device : devices) { uint64_t device_time = 0; @@ -60,7 +70,14 @@ TEST_F(urDeviceGetGlobalTimestampTest, Success) { TEST_F(urDeviceGetGlobalTimestampTest, SuccessHostTimer) { uur::raii::Context context = nullptr; - create_context_if_cuda(context, devices.data()); + void *vendorValue = nullptr; + size_t *pPropSizeRet = nullptr; + urDeviceGetInfo(devices[0], UR_DEVICE_INFO_VENDOR_ID, sizeof(uint32_t), + vendorValue, pPropSizeRet); + // cuda backend defines global timestamp only over a context. + if (*reinterpret_cast(vendorValue) == 4318u) { + ASSERT_SUCCESS(urContextCreate(1, devices.data(), nullptr, context.ptr())); + } for (auto device : devices) { uint64_t host_time = 0; @@ -73,7 +90,14 @@ TEST_F(urDeviceGetGlobalTimestampTest, SuccessHostTimer) { TEST_F(urDeviceGetGlobalTimestampTest, SuccessNoTimers) { uur::raii::Context context = nullptr; - create_context_if_cuda(context, devices.data()); + void *vendorValue = nullptr; + size_t *pPropSizeRet = nullptr; + urDeviceGetInfo(devices[0], UR_DEVICE_INFO_VENDOR_ID, sizeof(uint32_t), + vendorValue, pPropSizeRet); + // cuda backend defines global timestamp only over a context. + if (*reinterpret_cast(vendorValue) == 4318u) { + ASSERT_SUCCESS(urContextCreate(1, devices.data(), nullptr, context.ptr())); + } for (auto device : devices) { ASSERT_SUCCESS(urDeviceGetGlobalTimestamps(device, nullptr, nullptr)); @@ -83,7 +107,14 @@ TEST_F(urDeviceGetGlobalTimestampTest, SuccessNoTimers) { TEST_F(urDeviceGetGlobalTimestampTest, SuccessSynchronizedTime) { uur::raii::Context context = nullptr; - create_context_if_cuda(context, devices.data()); + void *vendorValue = nullptr; + size_t *pPropSizeRet = nullptr; + urDeviceGetInfo(devices[0], UR_DEVICE_INFO_VENDOR_ID, sizeof(uint32_t), + vendorValue, pPropSizeRet); + // cuda backend defines global timestamp only over a context. + if (*reinterpret_cast(vendorValue) == 4318u) { + ASSERT_SUCCESS(urContextCreate(1, devices.data(), nullptr, context.ptr())); + } for (auto device : devices) { // get the timer resolution of the device @@ -139,7 +170,14 @@ TEST_F(urDeviceGetGlobalTimestampTest, SuccessSynchronizedTime) { TEST_F(urDeviceGetGlobalTimestampTest, InvalidNullHandleDevice) { uur::raii::Context context = nullptr; - create_context_if_cuda(context, devices.data()); + void *vendorValue = nullptr; + size_t *pPropSizeRet = nullptr; + urDeviceGetInfo(devices[0], UR_DEVICE_INFO_VENDOR_ID, sizeof(uint32_t), + vendorValue, pPropSizeRet); + // cuda backend defines global timestamp only over a context. + if (*reinterpret_cast(vendorValue) == 4318u) { + ASSERT_SUCCESS(urContextCreate(1, devices.data(), nullptr, context.ptr())); + } uint64_t device_time = 0; uint64_t host_time = 0;