Skip to content

Commit

Permalink
fix: Fix failure expiry race condition
Browse files Browse the repository at this point in the history
Motivation

A bug was encountered whereby model load failure records would not be cleaned up after expiration while there were continuous (failing) requests for that model, meaning no reload attempt occurred and those requests failed indefinitely.

Modifications

- Add unit test to reproduce the problem, which required
   - Making expiry time configurable
   - Modifying the behaviour of the dummy model so that triggered loading failures only occur on the first attempt
- Add some log statements
- Fix the main bug: When a failed cache entry without any record in the registry is encountered during a load attempt, remove it rather than treating as a load failure
- Fix some other logic related to failure expiry
- Use a shorter expiry time (2/3 of configured time) for recently-used models

Result

Failures expire as intended in all situations

Signed-off-by: Nick Hill <nickhill@us.ibm.com>
  • Loading branch information
njhill committed Oct 31, 2022
1 parent 29597da commit 9000ebf
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 41 deletions.
108 changes: 71 additions & 37 deletions src/main/java/com/ibm/watson/modelmesh/ModelMesh.java
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,10 @@ public abstract class ModelMesh extends ThriftService
// used in routing decisions, gets set to Math.max(3000L, loadTimeoutMs/3)
protected /*final*/ long defaultAssumeLoadedAfterMs;

public static final long LOAD_FAILURE_EXPIRY_MS = 600_000L; // 10mins for now
// time after which loading failure records expire (allowing for re-attempts)
public final long LOAD_FAILURE_EXPIRY_MS = getLongParameter(LOAD_FAILURE_EXPIRY_ENV_VAR, 600_000L); // default 10mins
// shorter expiry time for "in use" models (receiving recent requests)
public final long IN_USE_LOAD_FAILURE_EXPIRY_MS = (LOAD_FAILURE_EXPIRY_MS * 2) / 3;
public static final int MAX_LOAD_FAILURES = 3;
// if unable to invoke in this many places, don't continue to load
public static final int MAX_LOAD_LOCATIONS = 5;
Expand Down Expand Up @@ -4508,14 +4511,15 @@ private void checkLoadLocationCount(ModelRecord mr, Collection<String> explicitE
}

// check if model load failures have breached the maximum allowed limit
private static void checkLoadFailureCount(ModelRecord mr, ModelLoadException loadFailureSeen)
private void checkLoadFailureCount(ModelRecord mr, ModelLoadException loadFailureSeen)
throws ModelLoadException {
Map<String, Long> failedInInstances = mr.getLoadFailedInstanceIds();
if (!failedInInstances.isEmpty()) {
int count = 0;
final long expiryCutoffTime = currentTimeMillis() - IN_USE_LOAD_FAILURE_EXPIRY_MS;
for (Long failTime : failedInInstances.values()) {
if (failTime > LOAD_FAILURE_EXPIRY_MS) {
count++;
if (failTime > expiryCutoffTime) {
count++; // not yet expired
}
if (count >= MAX_LOAD_FAILURES) {
if (loadFailureSeen != null) {
Expand Down Expand Up @@ -4971,6 +4975,7 @@ protected CacheEntry<?> loadLocal(String modelId, ModelRecord[] mrh, long lastUs
break; // success
}

logger.info("Encountered existing cache entry while loading model " + modelId);
synchronized (existCe) {
// A cache entry was already there - most likely that another thread
// in this instance is also loading this model (in this same method).
Expand All @@ -4986,6 +4991,7 @@ protected CacheEntry<?> loadLocal(String modelId, ModelRecord[] mrh, long lastUs
if (latestMr == null) {
mrh[0] = null;
existCe.remove();
logger.info("Existing cache entry for model " + modelId + " now gone");
return INSTANCES_CHANGED; // ModelNotFoundException will be thrown
}

Expand All @@ -4995,6 +5001,7 @@ protected CacheEntry<?> loadLocal(String modelId, ModelRecord[] mrh, long lastUs
|| !Objects.equals(latestMr.getLoadFailedInstanceIds(),
mr.getLoadFailedInstanceIds())) {
// model registrations changed, re-start main loop
logger.info("Registrations changed for " + modelId + ", will reevaluate");
return INSTANCES_CHANGED;
}
mr = latestMr;
Expand All @@ -5007,6 +5014,7 @@ protected CacheEntry<?> loadLocal(String modelId, ModelRecord[] mrh, long lastUs
// Odd situation, similar to janitor logic for when a local
// cache entry is found without corresponding model record entry,
// we just "recycle" the already-loading/loaded one
logger.info("Recycling existing entry for " + modelId + "(state=" + ceStateString(stateNow) + ")");
ce = existCe;
break;
}
Expand All @@ -5017,15 +5025,20 @@ protected CacheEntry<?> loadLocal(String modelId, ModelRecord[] mrh, long lastUs
if (existCe.isFailed()) {
assert stateNow == CacheEntry.FAILED;
latestMr = handleUnexpectedFailedCacheEntry(existCe, mr);
if (latestMr != mr) {
mrh[0] = mr = latestMr;
return INSTANCES_CHANGED;
mrh[0] = latestMr;
if (!existCe.isRemoved()) {
if (latestMr != mr) {
return INSTANCES_CHANGED;
}
logger.info("Unexpected failed cache entry for model " + modelId
+ ", treating as load failure");
return existCe;
}
mrh[0] = mr = latestMr;
return existCe;
// else continue to loop now that the entry has been removed
} else {
// We'll continue to loop in this case for now
assert stateNow == CacheEntry.NEW;
}
// We'll continue to loop in this case for now
assert stateNow == CacheEntry.NEW;
}

existCe = null;
Expand Down Expand Up @@ -5202,33 +5215,40 @@ private ModelRecord handleUnexpectedFailedCacheEntry(CacheEntry<?> ce, ModelReco
if (failure == null) {
return mr; // safeguard timeout case (didn't see load fail but timed out waiting for it)
}
if (failure instanceof ModelLoadException
&& ((ModelLoadException) failure).getTimeout() == KVSTORE_LOAD_FAILURE) {
long failureAge = currentTimeMillis() - ce.loadCompleteTimestamp;
if (failureAge > 30_000 && failureAge > 30_000
+ ThreadLocalRandom.current().nextLong(30_000)) { // Randomize to avoid thunder
ModelRecord newMr = registry.get(ce.modelId);
if (newMr == null ? mr != null : (mr == null || newMr.getVersion() == mr.getVersion())) {
// First replace the entry with a later-expiring one to block concurrent attempts
CacheEntry<?> replacement = new CacheEntry<>(ce);
if (runtimeCache.replaceQuietly(ce.modelId, ce, replacement)) {
ce.remove();
ce = replacement;
if (!(failure instanceof ModelLoadException)
|| ((ModelLoadException) failure).getTimeout() != KVSTORE_LOAD_FAILURE) {
// We assume that this is an expired entry yet to be cleaned up
if (ce.remove()) {
logger.info("Removed kv-store failure cache entry for model " + ce.modelId);
}
return mr;
}

long failureAge = currentTimeMillis() - ce.loadCompleteTimestamp;
if (failureAge > 30_000 && failureAge > 30_000
+ ThreadLocalRandom.current().nextLong(30_000)) { // Randomize to avoid thunder

ModelRecord newMr = registry.get(ce.modelId);
if (newMr == null ? mr == null : (mr != null && newMr.getVersion() == mr.getVersion())) {
// First replace the entry with a later-expiring one to block concurrent attempts
CacheEntry<?> replacement = new CacheEntry<>(ce);
if (runtimeCache.replaceQuietly(ce.modelId, ce, replacement)) {
ce.remove();
ce = replacement;
try {
// this might throw if there are still KV store issues
try {
newMr = registry.getStrong(ce.modelId);
if (ce.remove()) {
logger.info("Removed kv-store failure cache entry for model " + ce.modelId);
}
} catch (Exception e) {
// Cannot verify / still KV store problems
logger.warn("Failed to retrieve model record after kv-store failure entry expiry"
+ " for model " + ce.modelId);
newMr = registry.getStrong(ce.modelId);
if (ce.remove()) {
logger.info("Removed kv-store failure cache entry for model " + ce.modelId);
}
} catch (Exception e) {
// Cannot verify / still KV store problems
logger.warn("Failed to retrieve model record after kv-store failure entry expiry"
+ " for model " + ce.modelId);
}
}
return newMr;
}
return newMr;
}
if (mr != null) {
// Allow failure to propagate (e.g. to invokeLocalModel() after load attempt)
Expand Down Expand Up @@ -5907,10 +5927,22 @@ public void run() {
if (ce == null) {
ce = runtimeCache.getQuietly(modelId);
}
long lastUsed = -1L;
long lastUsed = -2L;
boolean remLoaded = loaded && (ce == null || ce.isFailed());
boolean remFailed = failedTime != null
&& ((ce != null && !ce.isFailed()) || now - failedTime > LOAD_FAILURE_EXPIRY_MS);
boolean remFailed = false;
if (failedTime != null) {
if (ce != null && !ce.isFailed()) {
remFailed = true;
} else {
lastUsed = ce != null ? runtimeCache.getLastUsedTime(modelId) : -1L;
// Use shorter expiry age if model was used in last 3 minutes
final long expiryAge = (lastUsed > 0 && (now - lastUsed) < 180_000L)
? IN_USE_LOAD_FAILURE_EXPIRY_MS : LOAD_FAILURE_EXPIRY_MS;
if (now - failedTime > expiryAge) {
remFailed = true;
}
}
}
if (remLoaded || remFailed) {
if (shuttingDown) {
return;
Expand All @@ -5924,7 +5956,9 @@ public void run() {
mr.removeLoadFailure(instanceId);
}
if (ce != null) {
lastUsed = runtimeCache.getLastUsedTime(modelId);
if (lastUsed == -2) {
lastUsed = runtimeCache.getLastUsedTime(modelId);
}
if (lastUsed > 0L) {
mr.updateLastUsed(lastUsed);
}
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/com/ibm/watson/modelmesh/ModelMeshEnvVars.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ private ModelMeshEnvVars() {}
public static final String MAX_INFLIGHT_PER_COPY_ENV_VAR = "MM_MAX_INFLIGHT_PER_MODEL_COPY";
public static final String CONC_SCALEUP_BANDWIDTH_PCT_ENV_VAR = "MM_CONC_SCALEUP_BANDWIDTH_PCT";

public static final String LOAD_FAILURE_EXPIRY_ENV_VAR = "MM_LOAD_FAILURE_EXPIRY_TIME_MS";

public static final String MMESH_METRICS_ENV_VAR = "MM_METRICS";

public static final String LOG_EACH_INVOKE_ENV_VAR = "MM_LOG_EACH_INVOKE";
Expand Down
140 changes: 140 additions & 0 deletions src/test/java/com/ibm/watson/modelmesh/ModelMeshFailureExpiryTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
* Copyright 2022 IBM Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.ibm.watson.modelmesh;

import com.ibm.watson.modelmesh.api.ModelInfo;
import com.ibm.watson.modelmesh.api.ModelMeshGrpc;
import com.ibm.watson.modelmesh.api.ModelMeshGrpc.ModelMeshBlockingStub;
import com.ibm.watson.modelmesh.api.ModelStatusInfo;
import com.ibm.watson.modelmesh.api.ModelStatusInfo.ModelStatus;
import com.ibm.watson.modelmesh.api.RegisterModelRequest;
import com.ibm.watson.modelmesh.api.UnregisterModelRequest;
import com.ibm.watson.modelmesh.example.api.ExamplePredictorGrpc;
import com.ibm.watson.modelmesh.example.api.ExamplePredictorGrpc.ExamplePredictorBlockingStub;
import com.ibm.watson.modelmesh.example.api.Predictor.PredictRequest;
import com.ibm.watson.modelmesh.example.api.Predictor.PredictResponse;
import io.grpc.ManagedChannel;
import io.grpc.StatusRuntimeException;
import io.grpc.netty.NettyChannelBuilder;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;

import static com.ibm.watson.modelmesh.ModelMeshEnvVars.LOAD_FAILURE_EXPIRY_ENV_VAR;
import static org.junit.jupiter.api.Assertions.*;

/**
* Test load failure expiry while model is receiving continuous stream of inference requests
*/
public class ModelMeshFailureExpiryTest extends SingleInstanceModelMeshTest {

@Override
@BeforeEach
public void initialize() throws Exception {
System.setProperty("tas.janitor_freq_secs", "1"); // set janitor to run every second
System.setProperty(LOAD_FAILURE_EXPIRY_ENV_VAR, "2000"); // expire failures in 2 seconds
try {
super.initialize();
} finally {
System.clearProperty("tas.janitor_freq_secs");
System.clearProperty(LOAD_FAILURE_EXPIRY_ENV_VAR);
}
}

@Test
public void failureExpiryTest() throws Exception {
ExecutorService es = Executors.newSingleThreadExecutor();

ManagedChannel channel = NettyChannelBuilder.forAddress("localhost", 8088)
.usePlaintext().build();
String modelId = "myFailExpireModel";
ModelMeshBlockingStub manageModels = ModelMeshGrpc.newBlockingStub(channel);

final AtomicBoolean stop = new AtomicBoolean();
try {
ExamplePredictorBlockingStub useModels = ExamplePredictorGrpc.newBlockingStub(channel);

String errorMessage = "load should fail with this error";

// add a model which should fail to load
ModelStatusInfo statusInfo = manageModels.registerModel(RegisterModelRequest.newBuilder()
.setModelId(modelId).setModelInfo(ModelInfo.newBuilder().setType("ExampleType")
.setPath("FAIL_" + errorMessage).build()) // special prefix to trigger load in dummy runtime
.setLoadNow(true).setSync(true).build());

System.out.println("registerModel returned: " + statusInfo);
assertEquals(ModelStatus.LOADING_FAILED, statusInfo.getStatus());


final PredictRequest req = PredictRequest.newBuilder().setText("predict me!").build();

es.execute(() -> {
while (!stop.get()) try {
forModel(useModels, modelId).predict(req);
} catch (StatusRuntimeException sre) {
// ignore
} finally {
try {
Thread.sleep(400L);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
});

// call predict on the model
try {
forModel(useModels, modelId).predict(req);
fail("failed model should fail predict");
} catch (StatusRuntimeException sre) {
assertTrue(sre.getStatus().getDescription().endsWith(errorMessage));
}

Thread.sleep(1000);

// not expired yet
try {
forModel(useModels, modelId).predict(req);
fail("failed model should fail predict");
} catch (StatusRuntimeException sre) {
assertTrue(sre.getStatus().getDescription().endsWith(errorMessage));
}

Thread.sleep(3000);

// failure should now be expired and it should load successfully second time
// per logic in dummy Model class

PredictResponse response = forModel(useModels, modelId).predict(req);
System.out.println("predict returned: " + response.getResultsList());
assertEquals(1.0, response.getResults(0).getConfidence(), 0);
assertEquals("classification for predict me! by model " + modelId,
response.getResults(0).getCategory());

} finally {
stop.set(true);
manageModels.unregisterModel(UnregisterModelRequest.newBuilder()
.setModelId(modelId).build());
channel.shutdown();
es.shutdown();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public void grpcTest() throws Exception {
public void loadFailureTest() throws Exception {
ManagedChannel channel = NettyChannelBuilder.forAddress("localhost", 8088)
.usePlaintext().build();
String modelId = "myModel";
String modelId = "myModel_" + getClass().getSimpleName();
ModelMeshBlockingStub manageModels = ModelMeshGrpc.newBlockingStub(channel);
try {
String errorMessage = "load should fail with this error";
Expand Down
12 changes: 9 additions & 3 deletions src/test/java/com/ibm/watson/modelmesh/example/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@

package com.ibm.watson.modelmesh.example;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static com.ibm.watson.modelmesh.example.ExampleModelRuntime.FAST_MODE;

public class Model {

// Keeps track of how many load attempts there have been per model id
final static Map<String, Integer> loadCounts = new ConcurrentHashMap<>();

final String id;

final String location;
Expand All @@ -39,7 +44,8 @@ public Model(String id, String location) {
static final Pattern LOC_PATT = Pattern.compile("SIZE_(\\d+)_(\\d+)_(\\d+)");

public String load() {
System.out.println("Loading model " + id + "...");
int loadCount = loadCounts.merge(id, 1, (i1, i2) -> i1 + i2); // increment
System.out.println("Loading model " + id + "(count = " + loadCount + ")...");

Matcher m = location != null ? LOC_PATT.matcher(location) : null;

Expand All @@ -61,8 +67,8 @@ public String load() {

sleep(loadingTime); // simulate time to load

if (location != null && location.startsWith("FAIL_")) {
return location.substring(5); // simulate loading failure
if (loadCount <= 1 && location != null && location.startsWith("FAIL_")) {
return location.substring(5); // simulate loading failure for first load
}

return null;
Expand Down

0 comments on commit 9000ebf

Please sign in to comment.