Skip to content

Commit

Permalink
fix: Fix failure expiry race condition
Browse files Browse the repository at this point in the history
description TODO

Signed-off-by: Nick Hill <nickhill@us.ibm.com>
  • Loading branch information
njhill committed Oct 31, 2022
1 parent 29597da commit 88b5494
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 42 deletions.
108 changes: 71 additions & 37 deletions src/main/java/com/ibm/watson/modelmesh/ModelMesh.java
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,10 @@ public abstract class ModelMesh extends ThriftService
// used in routing decisions, gets set to Math.max(3000L, loadTimeoutMs/3)
protected /*final*/ long defaultAssumeLoadedAfterMs;

public static final long LOAD_FAILURE_EXPIRY_MS = 600_000L; // 10mins for now
// time after which loading failure records expire (allowing for re-attempts)
public final long LOAD_FAILURE_EXPIRY_MS = getLongParameter(LOAD_FAILURE_EXPIRY_ENV_VAR, 600_000L); // default 10mins
// shorter expiry time for "in use" models (receiving recent requests)
public final long IN_USE_LOAD_FAILURE_EXPIRY_MS = (LOAD_FAILURE_EXPIRY_MS * 2) / 3;
public static final int MAX_LOAD_FAILURES = 3;
// if unable to invoke in this many places, don't continue to load
public static final int MAX_LOAD_LOCATIONS = 5;
Expand Down Expand Up @@ -4508,14 +4511,15 @@ private void checkLoadLocationCount(ModelRecord mr, Collection<String> explicitE
}

// check if model load failures have breached the maximum allowed limit
private static void checkLoadFailureCount(ModelRecord mr, ModelLoadException loadFailureSeen)
private void checkLoadFailureCount(ModelRecord mr, ModelLoadException loadFailureSeen)
throws ModelLoadException {
Map<String, Long> failedInInstances = mr.getLoadFailedInstanceIds();
if (!failedInInstances.isEmpty()) {
int count = 0;
final long expiryCutoffTime = currentTimeMillis() - IN_USE_LOAD_FAILURE_EXPIRY_MS;
for (Long failTime : failedInInstances.values()) {
if (failTime > LOAD_FAILURE_EXPIRY_MS) {
count++;
if (failTime > expiryCutoffTime) {
count++; // not yet expired
}
if (count >= MAX_LOAD_FAILURES) {
if (loadFailureSeen != null) {
Expand Down Expand Up @@ -4971,6 +4975,7 @@ protected CacheEntry<?> loadLocal(String modelId, ModelRecord[] mrh, long lastUs
break; // success
}

logger.info("Encountered existing cache entry while loading model " + modelId);
synchronized (existCe) {
// A cache entry was already there - most likely that another thread
// in this instance is also loading this model (in this same method).
Expand All @@ -4986,6 +4991,7 @@ protected CacheEntry<?> loadLocal(String modelId, ModelRecord[] mrh, long lastUs
if (latestMr == null) {
mrh[0] = null;
existCe.remove();
logger.info("Existing cache entry for model " + modelId + " now gone");
return INSTANCES_CHANGED; // ModelNotFoundException will be thrown
}

Expand All @@ -4995,6 +5001,7 @@ protected CacheEntry<?> loadLocal(String modelId, ModelRecord[] mrh, long lastUs
|| !Objects.equals(latestMr.getLoadFailedInstanceIds(),
mr.getLoadFailedInstanceIds())) {
// model registrations changed, re-start main loop
logger.info("Registrations changed for " + modelId + ", will reevaluate");
return INSTANCES_CHANGED;
}
mr = latestMr;
Expand All @@ -5007,6 +5014,7 @@ protected CacheEntry<?> loadLocal(String modelId, ModelRecord[] mrh, long lastUs
// Odd situation, similar to janitor logic for when a local
// cache entry is found without corresponding model record entry,
// we just "recycle" the already-loading/loaded one
logger.info("Recycling existing entry for " + modelId + "(state=" + ceStateString(stateNow) + ")");
ce = existCe;
break;
}
Expand All @@ -5017,15 +5025,20 @@ protected CacheEntry<?> loadLocal(String modelId, ModelRecord[] mrh, long lastUs
if (existCe.isFailed()) {
assert stateNow == CacheEntry.FAILED;
latestMr = handleUnexpectedFailedCacheEntry(existCe, mr);
if (latestMr != mr) {
mrh[0] = mr = latestMr;
return INSTANCES_CHANGED;
mrh[0] = latestMr;
if (!existCe.isRemoved()) {
if (latestMr != mr) {
return INSTANCES_CHANGED;
}
logger.info("Unexpected failed cache entry for model " + modelId
+ ", treating as load failure");
return existCe;
}
mrh[0] = mr = latestMr;
return existCe;
// else continue to loop now that the entry has been removed
} else {
// We'll continue to loop in this case for now
assert stateNow == CacheEntry.NEW;
}
// We'll continue to loop in this case for now
assert stateNow == CacheEntry.NEW;
}

existCe = null;
Expand Down Expand Up @@ -5202,33 +5215,40 @@ private ModelRecord handleUnexpectedFailedCacheEntry(CacheEntry<?> ce, ModelReco
if (failure == null) {
return mr; // safeguard timeout case (didn't see load fail but timed out waiting for it)
}
if (failure instanceof ModelLoadException
&& ((ModelLoadException) failure).getTimeout() == KVSTORE_LOAD_FAILURE) {
long failureAge = currentTimeMillis() - ce.loadCompleteTimestamp;
if (failureAge > 30_000 && failureAge > 30_000
+ ThreadLocalRandom.current().nextLong(30_000)) { // Randomize to avoid thunder
ModelRecord newMr = registry.get(ce.modelId);
if (newMr == null ? mr != null : (mr == null || newMr.getVersion() == mr.getVersion())) {
// First replace the entry with a later-expiring one to block concurrent attempts
CacheEntry<?> replacement = new CacheEntry<>(ce);
if (runtimeCache.replaceQuietly(ce.modelId, ce, replacement)) {
ce.remove();
ce = replacement;
if (!(failure instanceof ModelLoadException)
|| ((ModelLoadException) failure).getTimeout() != KVSTORE_LOAD_FAILURE) {
// We assume that this is an expired entry yet to be cleaned up
if (ce.remove()) {
logger.info("Removed kv-store failure cache entry for model " + ce.modelId);
}
return mr;
}

long failureAge = currentTimeMillis() - ce.loadCompleteTimestamp;
if (failureAge > 30_000 && failureAge > 30_000
+ ThreadLocalRandom.current().nextLong(30_000)) { // Randomize to avoid thunder

ModelRecord newMr = registry.get(ce.modelId);
if (newMr == null ? mr == null : (mr != null && newMr.getVersion() == mr.getVersion())) {
// First replace the entry with a later-expiring one to block concurrent attempts
CacheEntry<?> replacement = new CacheEntry<>(ce);
if (runtimeCache.replaceQuietly(ce.modelId, ce, replacement)) {
ce.remove();
ce = replacement;
try {
// this might throw if there are still KV store issues
try {
newMr = registry.getStrong(ce.modelId);
if (ce.remove()) {
logger.info("Removed kv-store failure cache entry for model " + ce.modelId);
}
} catch (Exception e) {
// Cannot verify / still KV store problems
logger.warn("Failed to retrieve model record after kv-store failure entry expiry"
+ " for model " + ce.modelId);
newMr = registry.getStrong(ce.modelId);
if (ce.remove()) {
logger.info("Removed kv-store failure cache entry for model " + ce.modelId);
}
} catch (Exception e) {
// Cannot verify / still KV store problems
logger.warn("Failed to retrieve model record after kv-store failure entry expiry"
+ " for model " + ce.modelId);
}
}
return newMr;
}
return newMr;
}
if (mr != null) {
// Allow failure to propagate (e.g. to invokeLocalModel() after load attempt)
Expand Down Expand Up @@ -5907,10 +5927,22 @@ public void run() {
if (ce == null) {
ce = runtimeCache.getQuietly(modelId);
}
long lastUsed = -1L;
long lastUsed = -2L;
boolean remLoaded = loaded && (ce == null || ce.isFailed());
boolean remFailed = failedTime != null
&& ((ce != null && !ce.isFailed()) || now - failedTime > LOAD_FAILURE_EXPIRY_MS);
boolean remFailed = false;
if (failedTime != null) {
if (ce != null && !ce.isFailed()) {
remFailed = true;
} else {
lastUsed = ce != null ? runtimeCache.getLastUsedTime(modelId) : -1L;
// Use shorter expiry age if model was used in last 3 minutes
final long expiryAge = (lastUsed > 0 && (now - lastUsed) < 180_000L)
? IN_USE_LOAD_FAILURE_EXPIRY_MS : LOAD_FAILURE_EXPIRY_MS;
if (now - failedTime > expiryAge) {
remFailed = true;
}
}
}
if (remLoaded || remFailed) {
if (shuttingDown) {
return;
Expand All @@ -5924,7 +5956,9 @@ public void run() {
mr.removeLoadFailure(instanceId);
}
if (ce != null) {
lastUsed = runtimeCache.getLastUsedTime(modelId);
if (lastUsed == -2) {
lastUsed = runtimeCache.getLastUsedTime(modelId);
}
if (lastUsed > 0L) {
mr.updateLastUsed(lastUsed);
}
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/com/ibm/watson/modelmesh/ModelMeshEnvVars.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ private ModelMeshEnvVars() {}
public static final String MAX_INFLIGHT_PER_COPY_ENV_VAR = "MM_MAX_INFLIGHT_PER_MODEL_COPY";
public static final String CONC_SCALEUP_BANDWIDTH_PCT_ENV_VAR = "MM_CONC_SCALEUP_BANDWIDTH_PCT";

public static final String LOAD_FAILURE_EXPIRY_ENV_VAR = "MM_LOAD_FAILURE_EXPIRY_TIME_MS";

public static final String MMESH_METRICS_ENV_VAR = "MM_METRICS";

public static final String LOG_EACH_INVOKE_ENV_VAR = "MM_LOG_EACH_INVOKE";
Expand Down
140 changes: 140 additions & 0 deletions src/test/java/com/ibm/watson/modelmesh/ModelMeshFailureExpiryTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
* Copyright 2021 IBM Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.ibm.watson.modelmesh;

import com.ibm.watson.modelmesh.api.ModelInfo;
import com.ibm.watson.modelmesh.api.ModelMeshGrpc;
import com.ibm.watson.modelmesh.api.ModelMeshGrpc.ModelMeshBlockingStub;
import com.ibm.watson.modelmesh.api.ModelStatusInfo;
import com.ibm.watson.modelmesh.api.ModelStatusInfo.ModelStatus;
import com.ibm.watson.modelmesh.api.RegisterModelRequest;
import com.ibm.watson.modelmesh.api.UnregisterModelRequest;
import com.ibm.watson.modelmesh.example.api.ExamplePredictorGrpc;
import com.ibm.watson.modelmesh.example.api.ExamplePredictorGrpc.ExamplePredictorBlockingStub;
import com.ibm.watson.modelmesh.example.api.Predictor.PredictRequest;
import com.ibm.watson.modelmesh.example.api.Predictor.PredictResponse;
import io.grpc.ManagedChannel;
import io.grpc.StatusRuntimeException;
import io.grpc.netty.NettyChannelBuilder;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;

import static com.ibm.watson.modelmesh.ModelMeshEnvVars.LOAD_FAILURE_EXPIRY_ENV_VAR;
import static org.junit.jupiter.api.Assertions.*;

/**
* Test load failure expiry while model is receiving continuous stream of inference requests
*/
public class ModelMeshFailureExpiryTest extends SingleInstanceModelMeshTest {

@Override
@BeforeEach
public void initialize() throws Exception {
System.setProperty("tas.janitor_freq_secs", "1"); // set janitor to run every second
System.setProperty(LOAD_FAILURE_EXPIRY_ENV_VAR, "2000"); // expire failures in 2 seconds
try {
super.initialize();
} finally {
System.clearProperty("tas.janitor_freq_secs");
System.clearProperty(LOAD_FAILURE_EXPIRY_ENV_VAR);
}
}

@Test
public void failureExpiryTest() throws Exception {
ExecutorService es = Executors.newSingleThreadExecutor();

ManagedChannel channel = NettyChannelBuilder.forAddress("localhost", 8088)
.usePlaintext().build();
String modelId = "myFailExpireModel";
ModelMeshBlockingStub manageModels = ModelMeshGrpc.newBlockingStub(channel);

final AtomicBoolean stop = new AtomicBoolean();
try {
ExamplePredictorBlockingStub useModels = ExamplePredictorGrpc.newBlockingStub(channel);

String errorMessage = "load should fail with this error";

// add a model which should fail to load
ModelStatusInfo statusInfo = manageModels.registerModel(RegisterModelRequest.newBuilder()
.setModelId(modelId).setModelInfo(ModelInfo.newBuilder().setType("ExampleType")
.setPath("FAIL_" + errorMessage).build()) // special prefix to trigger load in dummy runtime
.setLoadNow(true).setSync(true).build());

System.out.println("registerModel returned: " + statusInfo);
assertEquals(ModelStatus.LOADING_FAILED, statusInfo.getStatus());


final PredictRequest req = PredictRequest.newBuilder().setText("predict me!").build();

es.execute(() -> {
while (!stop.get()) try {
forModel(useModels, modelId).predict(req);
} catch (StatusRuntimeException sre) {
// ignore
} finally {
try {
Thread.sleep(400L);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
});

// call predict on the model
try {
forModel(useModels, modelId).predict(req);
fail("failed model should fail predict");
} catch (StatusRuntimeException sre) {
assertTrue(sre.getStatus().getDescription().endsWith(errorMessage));
}

Thread.sleep(1000);

// not expired yet
try {
forModel(useModels, modelId).predict(req);
fail("failed model should fail predict");
} catch (StatusRuntimeException sre) {
assertTrue(sre.getStatus().getDescription().endsWith(errorMessage));
}

Thread.sleep(3000);

// failure should now be expired and it should load successfully second time
// per logic in dummy Model class

PredictResponse response = forModel(useModels, modelId).predict(req);
System.out.println("predict returned: " + response.getResultsList());
assertEquals(1.0, response.getResults(0).getConfidence(), 0);
assertEquals("classification for predict me! by model " + modelId,
response.getResults(0).getCategory());

} finally {
stop.set(true);
manageModels.unregisterModel(UnregisterModelRequest.newBuilder()
.setModelId(modelId).build());
channel.shutdown();
es.shutdown();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public void grpcTest() throws Exception {
public void loadFailureTest() throws Exception {
ManagedChannel channel = NettyChannelBuilder.forAddress("localhost", 8088)
.usePlaintext().build();
String modelId = "myModel";
String modelId = "myModel_" + getClass().getSimpleName();
ModelMeshBlockingStub manageModels = ModelMeshGrpc.newBlockingStub(channel);
try {
String errorMessage = "load should fail with this error";
Expand All @@ -137,7 +137,7 @@ public void loadFailureTest() throws Exception {

System.out.println("registerModel returned: " + statusInfo);

verifyFailureResponse(statusInfo, errorMessage);
verifyFailureResponse(statusInfo, errorMessage);

Thread.sleep(200);
// Check that the response is consistent
Expand Down
11 changes: 8 additions & 3 deletions src/test/java/com/ibm/watson/modelmesh/example/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@

package com.ibm.watson.modelmesh.example;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static com.ibm.watson.modelmesh.example.ExampleModelRuntime.FAST_MODE;

public class Model {

final static Map<String, Integer> loadCounts = new ConcurrentHashMap<>();

final String id;

final String location;
Expand All @@ -39,7 +43,8 @@ public Model(String id, String location) {
static final Pattern LOC_PATT = Pattern.compile("SIZE_(\\d+)_(\\d+)_(\\d+)");

public String load() {
System.out.println("Loading model " + id + "...");
int loadCount = loadCounts.merge(id, 1, (i1, i2) -> i1 + i2); // increment
System.out.println("Loading model " + id + "(count = " + loadCount + ")...");

Matcher m = location != null ? LOC_PATT.matcher(location) : null;

Expand All @@ -61,8 +66,8 @@ public String load() {

sleep(loadingTime); // simulate time to load

if (location != null && location.startsWith("FAIL_")) {
return location.substring(5); // simulate loading failure
if (loadCount <= 1 && location != null && location.startsWith("FAIL_")) {
return location.substring(5); // simulate loading failure for first load
}

return null;
Expand Down

0 comments on commit 88b5494

Please sign in to comment.