Skip to content

Commit

Permalink
[serving] Fail ping if error rate exceed (#2040)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Jun 10, 2024
1 parent b806678 commit bc328ec
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,14 @@ public class InferenceRequestHandler extends HttpRequestHandler {

private RequestParser requestParser;
private int chunkReadTime;
private ConfigManager config;
private static boolean exceedErrorRate;

/** default constructor. */
public InferenceRequestHandler() {
this.requestParser = new RequestParser();
chunkReadTime = ConfigManager.getInstance().getChunkedReadTimeout();
config = ConfigManager.getInstance();
chunkReadTime = config.getChunkedReadTimeout();
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -113,13 +116,16 @@ protected void handleRequest(
boolean hasPending = (boolean) w.get("hasPending");

HttpResponseStatus status;
if (hasFailure) {
if (exceedErrorRate) {
logger.info("PING FAILED: error rate exceed");
status = HttpResponseStatus.SERVICE_UNAVAILABLE;
} else if (hasFailure) {
logger.info(
"PING FAILED: {}",
JsonUtils.GSON.toJson(w.get("data")));
status = HttpResponseStatus.INTERNAL_SERVER_ERROR;
} else if (hasPending) {
if (ConfigManager.getInstance().allowsMultiStatus()) {
if (config.allowsMultiStatus()) {
status = HttpResponseStatus.MULTI_STATUS;
} else {
status = HttpResponseStatus.OK;
Expand Down Expand Up @@ -230,7 +236,6 @@ private void predict(
}

ModelManager modelManager = ModelManager.getInstance();
ConfigManager config = ConfigManager.getInstance();
Workflow workflow = modelManager.getWorkflow(workflowName, version, true);
if (workflow == null) {
String regex = config.getModelUrlPattern();
Expand Down Expand Up @@ -384,8 +389,16 @@ void sendOutput(Output output, ChannelHandlerContext ctx) {
} else {
if (code >= 500) {
SERVER_METRIC.info("{}", RESPONSE_5_XX);
if (!exceedErrorRate && config.onServerError()) {
exceedErrorRate = true;
}
} else if (code >= 400) {
SERVER_METRIC.info("{}", RESPONSE_4_XX);
if (code == 424) {
if (!exceedErrorRate && config.onServerError()) {
exceedErrorRate = true;
}
}
} else {
SERVER_METRIC.info("{}", RESPONSE_2_XX);
}
Expand Down Expand Up @@ -465,11 +478,17 @@ void onException(Throwable t, ChannelHandlerContext ctx) {
SERVER_METRIC.info("{}", RESPONSE_5_XX);
SERVER_METRIC.info("{}", WLM_ERROR);
status = HttpResponseStatus.SERVICE_UNAVAILABLE;
if (!exceedErrorRate && config.onWlmError()) {
exceedErrorRate = true;
}
} else {
logger.warn("Unexpected error", t);
SERVER_METRIC.info("{}", RESPONSE_5_XX);
SERVER_METRIC.info("{}", SERVER_ERROR);
status = HttpResponseStatus.INTERNAL_SERVER_ERROR;
if (!exceedErrorRate && config.onServerError()) {
exceedErrorRate = true;
}
}

/*
Expand Down
52 changes: 52 additions & 0 deletions serving/src/main/java/ai/djl/serving/util/ConfigManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import java.util.Locale;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;

/** A class that hold configuration information. */
Expand Down Expand Up @@ -79,6 +80,10 @@ public final class ConfigManager {
private static final String LOAD_ON_DEVICES = "load_on_devices";
private static final String PLUGIN_FOLDER = "plugin_folder";
private static final String CHUNKED_READ_TIMEOUT = "chunked_read_timeout";
private static final String ERROR_RATE_WLM = "error_rate_wlm";
private static final String ERROR_RATE_SERVER = "error_rate_server";
private static final String ERROR_RATE_MODEL = "error_rate_model";
private static final String ERROR_RATE_ANY = "error_rate_any";

// Configuration which are not documented or enabled through environment variables
private static final String USE_NATIVE_IO = "use_native_io";
Expand All @@ -89,9 +94,11 @@ public final class ConfigManager {
private static ConfigManager instance;

private Properties prop;
private Map<String, RateLimiter> limiters;

private ConfigManager(Arguments args) {
prop = new Properties();
limiters = new ConcurrentHashMap<>();

Path file = args.getConfigFile();
if (file != null) {
Expand All @@ -118,6 +125,12 @@ private ConfigManager(Arguments args) {
prop.put(key.substring(8).toLowerCase(Locale.ROOT), env.getValue());
}
}
for (Map.Entry<Object, Object> entry : prop.entrySet()) {
String key = (String) entry.getKey();
if (key.startsWith("error_rate_")) {
limiters.put(key, RateLimiter.parse(entry.getValue().toString()));
}
}
}

/**
Expand Down Expand Up @@ -171,6 +184,45 @@ public static ConfigManager getInstance() {
return instance;
}

/**
* Return true if exceed wlm error rate limit.
*
* @return true if exceed wlm error rate limit
*/
public boolean onWlmError() {
return onError(ERROR_RATE_WLM);
}

/**
* Return true if exceed server error rate limit.
*
* @return true if exceed server error rate limit
*/
public boolean onServerError() {
return onError(ERROR_RATE_SERVER);
}

/**
* Return true if exceed model error rate limit.
*
* @return true if exceed model error rate limit
*/
public boolean onModelError() {
return onError(ERROR_RATE_MODEL);
}

private boolean onError(String key) {
RateLimiter limiter = limiters.get(key);
if (limiter != null && limiter.exceed()) {
return true;
}
limiter = limiters.get(ERROR_RATE_ANY);
if (limiter != null) {
return limiter.exceed();
}
return false;
}

/**
* Returns the models server socket connector.
*
Expand Down
77 changes: 77 additions & 0 deletions serving/src/main/java/ai/djl/serving/util/RateLimiter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.serving.util;

import java.time.Duration;
import java.util.concurrent.atomic.AtomicLong;

/** A rate limiter distributes permits at a configurable rate. */
public class RateLimiter {

private long threshold;
private long timeWindow;
private AtomicLong tokens;
private long lastRefillTimestamp;

/**
* Constructs a {@code RateLimiter} with the specified threshold.
*
* @param threshold the maximum number of tokens that can be accumulated
* @param timeWindow the limit time window
*/
public RateLimiter(long threshold, Duration timeWindow) {
this.threshold = threshold;
this.timeWindow = timeWindow.toMillis();
tokens = new AtomicLong(threshold);
lastRefillTimestamp = System.currentTimeMillis();
}

/**
* Obtains a {@code RateLimiter} from a text string.
*
* @param limit the string representation of the {@code RateLimiter}
* @return a instance of {@code RateLimiter}
*/
public static RateLimiter parse(String limit) {
String[] pair = limit.split("/", 2);
long threshold = Long.parseLong(pair[0]);
Duration duration;
if (pair.length == 2) {
duration = Duration.parse(pair[1]);
} else {
duration = Duration.ofMinutes(1);
}
return new RateLimiter(threshold, duration);
}

/**
* Check if rate limit is hit.
*
* @return true if rate limit is hit
*/
public boolean exceed() {
long now = System.currentTimeMillis();
long tokensToAdd = ((now - lastRefillTimestamp) / timeWindow) * threshold;

if (tokensToAdd > 0) {
tokens.set(Math.min(threshold, tokens.addAndGet(tokensToAdd)));
lastRefillTimestamp = now;
}

if (tokens.get() > 0) {
tokens.decrementAndGet();
return false;
}
return true;
}
}
4 changes: 4 additions & 0 deletions serving/src/test/java/ai/djl/serving/ModelServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1456,6 +1456,10 @@ private void testServiceUnavailable() throws InterruptedException {
assertEquals(resp.getCode(), HttpResponseStatus.SERVICE_UNAVAILABLE.code());
assertEquals(resp.getMessage(), "All model workers has been shutdown: mlp_2");
}

channel = connect(Connector.ConnectorType.INFERENCE);
request(channel, HttpMethod.GET, "/ping");
assertHttpCode(503);
}

private void testKServeV2HealthReady(Channel channel) throws InterruptedException {
Expand Down
1 change: 1 addition & 0 deletions serving/src/test/resources/config.properties
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ load_models=djl://ai.djl.zoo/mlp,mlp:v1:PyTorch=djl://ai.djl.zoo/mlp
private_key_file=src/test/resources/key.pem
certificate_file=src/test/resources/certs.pem
max_request_size=10485760
error_rate_any=1/PT5M

0 comments on commit bc328ec

Please sign in to comment.