Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a /health endpoint to the server #4860

Merged
merged 12 commits into from
Jan 10, 2024
199 changes: 113 additions & 86 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <mutex>
#include <chrono>
#include <condition_variable>
#include <atomic>

#ifndef SERVER_VERBOSE
#define SERVER_VERBOSE 1
Expand Down Expand Up @@ -146,6 +147,12 @@ static std::vector<uint8_t> base64_decode(const std::string & encoded_string)
// parallel
//

enum ServerState {
LOADING_MODEL, // Server is starting up, model not fully loaded yet
READY, // Server is ready and model is loaded
ERROR // An error occurred, load_model failed
};

enum task_type {
COMPLETION_TASK,
CANCEL_TASK
Expand Down Expand Up @@ -2453,7 +2460,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
}


static std::string random_string()
{
static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
Expand Down Expand Up @@ -2790,15 +2796,117 @@ int main(int argc, char **argv)
{"system_info", llama_print_system_info()},
});

// load the model
if (!llama.load_model(params))
httplib::Server svr;

std::atomic<ServerState> server_state{LOADING_MODEL};

svr.set_default_headers({{"Server", "llama.cpp"},
{"Access-Control-Allow-Origin", "*"},
{"Access-Control-Allow-Headers", "content-type"}});

svr.Get("/health", [&](const httplib::Request&, httplib::Response& res) {
ServerState current_state = server_state.load();
switch(current_state) {
case READY:
res.set_content(R"({"status": "ok"})", "application/json");
res.status = 200; // HTTP OK
break;
case LOADING_MODEL:
res.set_content(R"({"status": "loading model"})", "application/json");
res.status = 503; // HTTP Service Unavailable
break;
case ERROR:
res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json");
res.status = 500; // HTTP Internal Server Error
break;
}
});

svr.set_logger(log_server_request);

svr.set_exception_handler([](const httplib::Request &, httplib::Response &res, std::exception_ptr ep)
{
const char fmt[] = "500 Internal Server Error\n%s";
char buf[BUFSIZ];
try
{
std::rethrow_exception(std::move(ep));
}
catch (std::exception &e)
{
snprintf(buf, sizeof(buf), fmt, e.what());
}
catch (...)
{
snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
}
res.set_content(buf, "text/plain; charset=utf-8");
res.status = 500;
});

svr.set_error_handler([](const httplib::Request &, httplib::Response &res)
{
if (res.status == 401)
{
res.set_content("Unauthorized", "text/plain; charset=utf-8");
}
if (res.status == 400)
{
res.set_content("Invalid request", "text/plain; charset=utf-8");
}
else if (res.status == 404)
{
res.set_content("File Not Found", "text/plain; charset=utf-8");
res.status = 404;
}
});

// set timeouts and change hostname and port
svr.set_read_timeout (sparams.read_timeout);
svr.set_write_timeout(sparams.write_timeout);

if (!svr.bind_to_port(sparams.hostname, sparams.port))
{
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port);
return 1;
}

llama.initialize();
// Set the base directory for serving static files
svr.set_base_dir(sparams.public_path);

httplib::Server svr;
// to make it ctrl+clickable:
LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);

std::unordered_map<std::string, std::string> log_data;
log_data["hostname"] = sparams.hostname;
log_data["port"] = std::to_string(sparams.port);

if (!sparams.api_key.empty()) {
log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4);
}

LOG_INFO("HTTP server listening", log_data);
// run the HTTP server in a thread - see comment below
std::thread t([&]()
{
if (!svr.listen_after_bind())
{
server_state.store(ERROR);
return 1;
}

return 0;
});

// load the model
if (!llama.load_model(params))
{
server_state.store(ERROR);
return 1;
} else {
llama.initialize();
server_state.store(READY);
}

// Middleware for API key validation
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
Expand Down Expand Up @@ -2826,10 +2934,6 @@ int main(int argc, char **argv)
return false;
};

svr.set_default_headers({{"Server", "llama.cpp"},
{"Access-Control-Allow-Origin", "*"},
{"Access-Control-Allow-Headers", "content-type"}});

// this is only called if no index.html is found in the public --path
svr.Get("/", [](const httplib::Request &, httplib::Response &res)
{
Expand Down Expand Up @@ -2937,8 +3041,6 @@ int main(int argc, char **argv)
}
});



svr.Get("/v1/models", [&params](const httplib::Request&, httplib::Response& res)
{
std::time_t t = std::time(0);
Expand Down Expand Up @@ -3157,81 +3259,6 @@ int main(int argc, char **argv)
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
});

svr.set_logger(log_server_request);

svr.set_exception_handler([](const httplib::Request &, httplib::Response &res, std::exception_ptr ep)
{
const char fmt[] = "500 Internal Server Error\n%s";
char buf[BUFSIZ];
try
{
std::rethrow_exception(std::move(ep));
}
catch (std::exception &e)
{
snprintf(buf, sizeof(buf), fmt, e.what());
}
catch (...)
{
snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
}
res.set_content(buf, "text/plain; charset=utf-8");
res.status = 500;
});

svr.set_error_handler([](const httplib::Request &, httplib::Response &res)
{
if (res.status == 401)
{
res.set_content("Unauthorized", "text/plain; charset=utf-8");
}
if (res.status == 400)
{
res.set_content("Invalid request", "text/plain; charset=utf-8");
}
else if (res.status == 404)
{
res.set_content("File Not Found", "text/plain; charset=utf-8");
res.status = 404;
}
});

// set timeouts and change hostname and port
svr.set_read_timeout (sparams.read_timeout);
svr.set_write_timeout(sparams.write_timeout);

if (!svr.bind_to_port(sparams.hostname, sparams.port))
{
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port);
return 1;
}

// Set the base directory for serving static files
svr.set_base_dir(sparams.public_path);

// to make it ctrl+clickable:
LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);

std::unordered_map<std::string, std::string> log_data;
log_data["hostname"] = sparams.hostname;
log_data["port"] = std::to_string(sparams.port);

if (!sparams.api_key.empty()) {
log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4);
}

LOG_INFO("HTTP server listening", log_data);
// run the HTTP server in a thread - see comment below
std::thread t([&]()
{
if (!svr.listen_after_bind())
{
return 1;
}

return 0;
});

// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
// "Bus error: 10" - this is on macOS, it does not crash on Linux
//std::thread t2([&]()
Expand Down
Loading