diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index 97b6c7c18c555..687250754e133 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -10,6 +10,8 @@ #include #include +using namespace std::string_literals; + #define LLMODEL_MAX_PROMPT_BATCH 128 class Dlhandle; @@ -51,6 +53,8 @@ class LLModel { GPUDevice(const char *backend, int index, int type, size_t heapSize, std::string name, std::string vendor): backend(backend), index(index), type(type), heapSize(heapSize), name(std::move(name)), vendor(std::move(vendor)) {} + + std::string uiName() const { return backend == "cuda"s ? "CUDA: " + name : name; } }; class Implementation { diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index a0ab7fb664528..733652189231a 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -302,19 +302,23 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) QElapsedTimer modelLoadTimer; modelLoadTimer.start(); + auto requestedDevice = MySettings::globalInstance()->device(); auto n_ctx = MySettings::globalInstance()->modelContextLength(modelInfo); m_ctx.n_ctx = n_ctx; auto ngl = MySettings::globalInstance()->modelGpuLayers(modelInfo); - std::string buildVariant = "auto"; -#if defined(Q_OS_MAC) && defined(__arm__) + std::string backend = "auto"; +#if !defined(Q_OS_MAC) + if (requestedDevice.startsWith("CUDA: ")) + backend = "cuda"; +#elif defined(__arm__) if (m_forceMetal) - buildVariant = "metal"; + backend = "metal"; #endif QString constructError; m_llModelInfo.model = nullptr; try { - m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx); + m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), backend, n_ctx); } catch (const LLModel::MissingImplementationError &e) { modelLoadProps.insert("error", "missing_model_impl"); constructError = e.what(); @@ -346,7 +350,6 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) // Pick the best match for the device QString actualDevice = m_llModelInfo.model->implementation().buildVariant() == "metal" ? "Metal" : "CPU"; - const QString requestedDevice = MySettings::globalInstance()->device(); if (requestedDevice == "CPU") { emit reportFallbackReason(""); // fallback not applicable } else { @@ -354,11 +357,12 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) std::vector availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory); LLModel::GPUDevice *device = nullptr; + // NB: relies on the fact that Kompute devices are listed first if (!availableDevices.empty() && requestedDevice == "Auto" && availableDevices.front().type == 2 /*a discrete gpu*/) { device = &availableDevices.front(); } else { for (LLModel::GPUDevice &d : availableDevices) { - if (QString::fromStdString(d.name) == requestedDevice) { + if (QString::fromStdString(d.uiName()) == requestedDevice) { device = &d; break; } diff --git a/gpt4all-chat/mysettings.cpp b/gpt4all-chat/mysettings.cpp index 3feaea21b1f67..73a27f6cfa852 100644 --- a/gpt4all-chat/mysettings.cpp +++ b/gpt4all-chat/mysettings.cpp @@ -68,7 +68,7 @@ MySettings::MySettings() std::vector devices = LLModel::Implementation::availableGPUDevices(); QVector deviceList{ "Auto" }; for (LLModel::GPUDevice &d : devices) - deviceList << QString::fromStdString(d.name); + deviceList << QString::fromStdString(d.uiName()); deviceList << "CPU"; setDeviceList(deviceList); }