Skip to content

Commit

Permalink
vulkan: select only one device for single gpu with multiple drivers
Browse files Browse the repository at this point in the history
  • Loading branch information
Adriankhl committed Jun 6, 2024
1 parent b864b50 commit cc612eb
Showing 1 changed file with 72 additions and 3 deletions.
75 changes: 72 additions & 3 deletions ggml-vulkan.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include "ggml-vulkan.h"

#ifdef GGML_VULKAN_RUN_TESTS
#include <chrono>
#endif
Expand All @@ -9,12 +8,13 @@
#include <algorithm>
#include <cmath>
#include <iostream>
#include <limits>
#include <tuple>
#include <vector>
#include <sstream>
#include <utility>
#include <memory>
#include <limits>
#include <map>

#include "ggml.h"
#include "ggml-backend-impl.h"
Expand Down Expand Up @@ -1691,7 +1691,76 @@ void ggml_vk_instance_init() {
vk::PhysicalDeviceProperties props = devices[i].getProperties();

if (props.deviceType == vk::PhysicalDeviceType::eDiscreteGpu) {
vk_instance.device_indices.push_back(i);
// Check if there are two physical devices corresponding to the same GPU
auto old_device = std::find_if(
vk_instance.device_indices.begin(),
vk_instance.device_indices.end(),
[&devices, &props](const size_t k){ return devices[k].getProperties().deviceID == props.deviceID; }
);
if (old_device == vk_instance.device_indices.end()) {
vk_instance.device_indices.push_back(i);
} else {
// There can be two physical devices corresponding to the same GPU if there are 2 different drivers
// This can cause error when splitting layers aross the devices, need to keep only 1
#ifdef GGML_VULKAN_DEBUG
std::cerr << "Device " << i << " and device " << *old_device << " have the same device id" << std::endl;
#endif

vk::PhysicalDeviceProperties2 old_prop;
vk::PhysicalDeviceDriverProperties old_driver;
old_prop.pNext = &old_driver;
devices[*old_device].getProperties2(&old_prop);

vk::PhysicalDeviceProperties2 new_prop;
vk::PhysicalDeviceDriverProperties new_driver;
new_prop.pNext = &new_driver;
devices[i].getProperties2(&new_prop);

std::map<vk::DriverId, int> driver_priorities {};
int old_priority = std::numeric_limits<int>::max();
int new_priority = std::numeric_limits<int>::max();

// Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id
// Smaller number -> higher priority
switch (old_prop.properties.vendorID) {
case VK_VENDOR_ID_AMD:
driver_priorities[static_cast<vk::DriverId>(VkDriverId::VK_DRIVER_ID_MESA_RADV)] = 1;
driver_priorities[static_cast<vk::DriverId>(VkDriverId::VK_DRIVER_ID_AMD_OPEN_SOURCE)] = 2;
driver_priorities[static_cast<vk::DriverId>(VkDriverId::VK_DRIVER_ID_AMD_PROPRIETARY)] = 3;
break;
case VK_VENDOR_ID_INTEL:
driver_priorities[static_cast<vk::DriverId>(VkDriverId::VK_DRIVER_ID_INTEL_OPEN_SOURCE_MESA)] = 1;
driver_priorities[static_cast<vk::DriverId>(VkDriverId::VK_DRIVER_ID_INTEL_PROPRIETARY_WINDOWS)] = 2;
break;
case VK_VENDOR_ID_NVIDIA:
driver_priorities[static_cast<vk::DriverId>(VkDriverId::VK_DRIVER_ID_NVIDIA_PROPRIETARY)] = 1;
driver_priorities[static_cast<vk::DriverId>(VkDriverId::VK_DRIVER_ID_MESA_NVK)] = 2;
break;
}

if (driver_priorities.count(old_driver.driverID)) {
old_priority = driver_priorities[old_driver.driverID];
}
if (driver_priorities.count(new_driver.driverID)) {
new_priority = driver_priorities[new_driver.driverID];
}

if (new_priority < old_priority) {
auto r = std::remove(vk_instance.device_indices.begin(), vk_instance.device_indices.end(), *old_device);
vk_instance.device_indices.erase(r, vk_instance.device_indices.end());
vk_instance.device_indices.push_back(i);

#ifdef GGML_VULKAN_DEBUG
std::cerr << "Prioritize device " << i << " driver " << new_driver.driverName << " over device " << *old_device << " driver " << old_driver.driverName << std::endl;
#endif
}
#ifdef GGML_VULKAN_DEBUG
else {
std::cerr << "Prioritize device " << *old_device << " driver " << old_driver.driverName << " over device " << i << " driver " << new_driver.driverName << std::endl;

}
#endif
}
}
}

Expand Down

0 comments on commit cc612eb

Please sign in to comment.