Skip to content

Commit

Permalink
chat: implement basic UI backend selection
Browse files Browse the repository at this point in the history
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
  • Loading branch information
cebtenzzre committed May 6, 2024
1 parent 81140d4 commit c67f868
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
4 changes: 4 additions & 0 deletions gpt4all-backend/llmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include <string_view>
#include <vector>

using namespace std::string_literals;

#define LLMODEL_MAX_PROMPT_BATCH 128

class Dlhandle;
Expand Down Expand Up @@ -51,6 +53,8 @@ class LLModel {
GPUDevice(const char *backend, int index, int type, size_t heapSize, std::string name, std::string vendor):
backend(backend), index(index), type(type), heapSize(heapSize), name(std::move(name)),
vendor(std::move(vendor)) {}

std::string uiName() const { return backend == "cuda"s ? "CUDA: " + name : name; }
};

class Implementation {
Expand Down
16 changes: 10 additions & 6 deletions gpt4all-chat/chatllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,19 +302,23 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
QElapsedTimer modelLoadTimer;
modelLoadTimer.start();

auto requestedDevice = MySettings::globalInstance()->device();
auto n_ctx = MySettings::globalInstance()->modelContextLength(modelInfo);
m_ctx.n_ctx = n_ctx;
auto ngl = MySettings::globalInstance()->modelGpuLayers(modelInfo);

std::string buildVariant = "auto";
#if defined(Q_OS_MAC) && defined(__arm__)
std::string backend = "auto";
#if !defined(Q_OS_MAC)
if (requestedDevice.startsWith("CUDA: "))
backend = "cuda";
#elif defined(__arm__)
if (m_forceMetal)
buildVariant = "metal";
backend = "metal";
#endif
QString constructError;
m_llModelInfo.model = nullptr;
try {
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), backend, n_ctx);
} catch (const LLModel::MissingImplementationError &e) {
modelLoadProps.insert("error", "missing_model_impl");
constructError = e.what();
Expand Down Expand Up @@ -346,19 +350,19 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)

// Pick the best match for the device
QString actualDevice = m_llModelInfo.model->implementation().buildVariant() == "metal" ? "Metal" : "CPU";
const QString requestedDevice = MySettings::globalInstance()->device();
if (requestedDevice == "CPU") {
emit reportFallbackReason(""); // fallback not applicable
} else {
const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString(), n_ctx, ngl);
std::vector<LLModel::GPUDevice> availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory);
LLModel::GPUDevice *device = nullptr;

// NB: relies on the fact that Kompute devices are listed first
if (!availableDevices.empty() && requestedDevice == "Auto" && availableDevices.front().type == 2 /*a discrete gpu*/) {
device = &availableDevices.front();
} else {
for (LLModel::GPUDevice &d : availableDevices) {
if (QString::fromStdString(d.name) == requestedDevice) {
if (QString::fromStdString(d.uiName()) == requestedDevice) {
device = &d;
break;
}
Expand Down
2 changes: 1 addition & 1 deletion gpt4all-chat/mysettings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ MySettings::MySettings()
std::vector<LLModel::GPUDevice> devices = LLModel::Implementation::availableGPUDevices();
QVector<QString> deviceList{ "Auto" };
for (LLModel::GPUDevice &d : devices)
deviceList << QString::fromStdString(d.name);
deviceList << QString::fromStdString(d.uiName());
deviceList << "CPU";
setDeviceList(deviceList);
}
Expand Down

0 comments on commit c67f868

Please sign in to comment.