From 9000ebf77328ea396bf6ec4b57c49ec32e3fa1b9 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sat, 29 Oct 2022 17:02:21 -0700 Subject: [PATCH] fix: Fix failure expiry race condition Motivation A bug was encountered whereby model load failure records would not be cleaned up after expiration while there were continuous (failing) requests for that model, meaning no reload attempt occurred and those requests failed indefinitely. Modifications - Add unit test to reproduce the problem, which required - Making expiry time configurable - Modifying the behaviour of the dummy model so that triggered loading failures only occur on the first attempt - Add some log statements - Fix the main bug: When a failed cache entry without any record in the registry is encountered during a load attempt, remove it rather than treating as a load failure - Fix some other logic related to failure expiry - Use a shorter expiry time (2/3 of configured time) for recently-used models Result Failures expire as intended in all situations Signed-off-by: Nick Hill --- .../com/ibm/watson/modelmesh/ModelMesh.java | 108 +++++++++----- .../watson/modelmesh/ModelMeshEnvVars.java | 2 + .../modelmesh/ModelMeshFailureExpiryTest.java | 140 ++++++++++++++++++ .../modelmesh/SidecarModelMeshTest.java | 2 +- .../ibm/watson/modelmesh/example/Model.java | 12 +- 5 files changed, 223 insertions(+), 41 deletions(-) create mode 100644 src/test/java/com/ibm/watson/modelmesh/ModelMeshFailureExpiryTest.java diff --git a/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java b/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java index 6ada12928..1af0e83b2 100644 --- a/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java +++ b/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java @@ -208,7 +208,10 @@ public abstract class ModelMesh extends ThriftService // used in routing decisions, gets set to Math.max(3000L, loadTimeoutMs/3) protected /*final*/ long defaultAssumeLoadedAfterMs; - public static final long LOAD_FAILURE_EXPIRY_MS = 600_000L; // 10mins for now + // time after which loading failure records expire (allowing for re-attempts) + public final long LOAD_FAILURE_EXPIRY_MS = getLongParameter(LOAD_FAILURE_EXPIRY_ENV_VAR, 600_000L); // default 10mins + // shorter expiry time for "in use" models (receiving recent requests) + public final long IN_USE_LOAD_FAILURE_EXPIRY_MS = (LOAD_FAILURE_EXPIRY_MS * 2) / 3; public static final int MAX_LOAD_FAILURES = 3; // if unable to invoke in this many places, don't continue to load public static final int MAX_LOAD_LOCATIONS = 5; @@ -4508,14 +4511,15 @@ private void checkLoadLocationCount(ModelRecord mr, Collection explicitE } // check if model load failures have breached the maximum allowed limit - private static void checkLoadFailureCount(ModelRecord mr, ModelLoadException loadFailureSeen) + private void checkLoadFailureCount(ModelRecord mr, ModelLoadException loadFailureSeen) throws ModelLoadException { Map failedInInstances = mr.getLoadFailedInstanceIds(); if (!failedInInstances.isEmpty()) { int count = 0; + final long expiryCutoffTime = currentTimeMillis() - IN_USE_LOAD_FAILURE_EXPIRY_MS; for (Long failTime : failedInInstances.values()) { - if (failTime > LOAD_FAILURE_EXPIRY_MS) { - count++; + if (failTime > expiryCutoffTime) { + count++; // not yet expired } if (count >= MAX_LOAD_FAILURES) { if (loadFailureSeen != null) { @@ -4971,6 +4975,7 @@ protected CacheEntry loadLocal(String modelId, ModelRecord[] mrh, long lastUs break; // success } + logger.info("Encountered existing cache entry while loading model " + modelId); synchronized (existCe) { // A cache entry was already there - most likely that another thread // in this instance is also loading this model (in this same method). @@ -4986,6 +4991,7 @@ protected CacheEntry loadLocal(String modelId, ModelRecord[] mrh, long lastUs if (latestMr == null) { mrh[0] = null; existCe.remove(); + logger.info("Existing cache entry for model " + modelId + " now gone"); return INSTANCES_CHANGED; // ModelNotFoundException will be thrown } @@ -4995,6 +5001,7 @@ protected CacheEntry loadLocal(String modelId, ModelRecord[] mrh, long lastUs || !Objects.equals(latestMr.getLoadFailedInstanceIds(), mr.getLoadFailedInstanceIds())) { // model registrations changed, re-start main loop + logger.info("Registrations changed for " + modelId + ", will reevaluate"); return INSTANCES_CHANGED; } mr = latestMr; @@ -5007,6 +5014,7 @@ protected CacheEntry loadLocal(String modelId, ModelRecord[] mrh, long lastUs // Odd situation, similar to janitor logic for when a local // cache entry is found without corresponding model record entry, // we just "recycle" the already-loading/loaded one + logger.info("Recycling existing entry for " + modelId + "(state=" + ceStateString(stateNow) + ")"); ce = existCe; break; } @@ -5017,15 +5025,20 @@ protected CacheEntry loadLocal(String modelId, ModelRecord[] mrh, long lastUs if (existCe.isFailed()) { assert stateNow == CacheEntry.FAILED; latestMr = handleUnexpectedFailedCacheEntry(existCe, mr); - if (latestMr != mr) { - mrh[0] = mr = latestMr; - return INSTANCES_CHANGED; + mrh[0] = latestMr; + if (!existCe.isRemoved()) { + if (latestMr != mr) { + return INSTANCES_CHANGED; + } + logger.info("Unexpected failed cache entry for model " + modelId + + ", treating as load failure"); + return existCe; } - mrh[0] = mr = latestMr; - return existCe; + // else continue to loop now that the entry has been removed + } else { + // We'll continue to loop in this case for now + assert stateNow == CacheEntry.NEW; } - // We'll continue to loop in this case for now - assert stateNow == CacheEntry.NEW; } existCe = null; @@ -5202,33 +5215,40 @@ private ModelRecord handleUnexpectedFailedCacheEntry(CacheEntry ce, ModelReco if (failure == null) { return mr; // safeguard timeout case (didn't see load fail but timed out waiting for it) } - if (failure instanceof ModelLoadException - && ((ModelLoadException) failure).getTimeout() == KVSTORE_LOAD_FAILURE) { - long failureAge = currentTimeMillis() - ce.loadCompleteTimestamp; - if (failureAge > 30_000 && failureAge > 30_000 - + ThreadLocalRandom.current().nextLong(30_000)) { // Randomize to avoid thunder - ModelRecord newMr = registry.get(ce.modelId); - if (newMr == null ? mr != null : (mr == null || newMr.getVersion() == mr.getVersion())) { - // First replace the entry with a later-expiring one to block concurrent attempts - CacheEntry replacement = new CacheEntry<>(ce); - if (runtimeCache.replaceQuietly(ce.modelId, ce, replacement)) { - ce.remove(); - ce = replacement; + if (!(failure instanceof ModelLoadException) + || ((ModelLoadException) failure).getTimeout() != KVSTORE_LOAD_FAILURE) { + // We assume that this is an expired entry yet to be cleaned up + if (ce.remove()) { + logger.info("Removed kv-store failure cache entry for model " + ce.modelId); + } + return mr; + } + + long failureAge = currentTimeMillis() - ce.loadCompleteTimestamp; + if (failureAge > 30_000 && failureAge > 30_000 + + ThreadLocalRandom.current().nextLong(30_000)) { // Randomize to avoid thunder + + ModelRecord newMr = registry.get(ce.modelId); + if (newMr == null ? mr == null : (mr != null && newMr.getVersion() == mr.getVersion())) { + // First replace the entry with a later-expiring one to block concurrent attempts + CacheEntry replacement = new CacheEntry<>(ce); + if (runtimeCache.replaceQuietly(ce.modelId, ce, replacement)) { + ce.remove(); + ce = replacement; + try { // this might throw if there are still KV store issues - try { - newMr = registry.getStrong(ce.modelId); - if (ce.remove()) { - logger.info("Removed kv-store failure cache entry for model " + ce.modelId); - } - } catch (Exception e) { - // Cannot verify / still KV store problems - logger.warn("Failed to retrieve model record after kv-store failure entry expiry" - + " for model " + ce.modelId); + newMr = registry.getStrong(ce.modelId); + if (ce.remove()) { + logger.info("Removed kv-store failure cache entry for model " + ce.modelId); } + } catch (Exception e) { + // Cannot verify / still KV store problems + logger.warn("Failed to retrieve model record after kv-store failure entry expiry" + + " for model " + ce.modelId); } } - return newMr; } + return newMr; } if (mr != null) { // Allow failure to propagate (e.g. to invokeLocalModel() after load attempt) @@ -5907,10 +5927,22 @@ public void run() { if (ce == null) { ce = runtimeCache.getQuietly(modelId); } - long lastUsed = -1L; + long lastUsed = -2L; boolean remLoaded = loaded && (ce == null || ce.isFailed()); - boolean remFailed = failedTime != null - && ((ce != null && !ce.isFailed()) || now - failedTime > LOAD_FAILURE_EXPIRY_MS); + boolean remFailed = false; + if (failedTime != null) { + if (ce != null && !ce.isFailed()) { + remFailed = true; + } else { + lastUsed = ce != null ? runtimeCache.getLastUsedTime(modelId) : -1L; + // Use shorter expiry age if model was used in last 3 minutes + final long expiryAge = (lastUsed > 0 && (now - lastUsed) < 180_000L) + ? IN_USE_LOAD_FAILURE_EXPIRY_MS : LOAD_FAILURE_EXPIRY_MS; + if (now - failedTime > expiryAge) { + remFailed = true; + } + } + } if (remLoaded || remFailed) { if (shuttingDown) { return; @@ -5924,7 +5956,9 @@ public void run() { mr.removeLoadFailure(instanceId); } if (ce != null) { - lastUsed = runtimeCache.getLastUsedTime(modelId); + if (lastUsed == -2) { + lastUsed = runtimeCache.getLastUsedTime(modelId); + } if (lastUsed > 0L) { mr.updateLastUsed(lastUsed); } diff --git a/src/main/java/com/ibm/watson/modelmesh/ModelMeshEnvVars.java b/src/main/java/com/ibm/watson/modelmesh/ModelMeshEnvVars.java index 4025ddf40..6351f5f64 100644 --- a/src/main/java/com/ibm/watson/modelmesh/ModelMeshEnvVars.java +++ b/src/main/java/com/ibm/watson/modelmesh/ModelMeshEnvVars.java @@ -43,6 +43,8 @@ private ModelMeshEnvVars() {} public static final String MAX_INFLIGHT_PER_COPY_ENV_VAR = "MM_MAX_INFLIGHT_PER_MODEL_COPY"; public static final String CONC_SCALEUP_BANDWIDTH_PCT_ENV_VAR = "MM_CONC_SCALEUP_BANDWIDTH_PCT"; + public static final String LOAD_FAILURE_EXPIRY_ENV_VAR = "MM_LOAD_FAILURE_EXPIRY_TIME_MS"; + public static final String MMESH_METRICS_ENV_VAR = "MM_METRICS"; public static final String LOG_EACH_INVOKE_ENV_VAR = "MM_LOG_EACH_INVOKE"; diff --git a/src/test/java/com/ibm/watson/modelmesh/ModelMeshFailureExpiryTest.java b/src/test/java/com/ibm/watson/modelmesh/ModelMeshFailureExpiryTest.java new file mode 100644 index 000000000..7d88c9f81 --- /dev/null +++ b/src/test/java/com/ibm/watson/modelmesh/ModelMeshFailureExpiryTest.java @@ -0,0 +1,140 @@ +/* + * Copyright 2022 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy + * of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.ibm.watson.modelmesh; + +import com.ibm.watson.modelmesh.api.ModelInfo; +import com.ibm.watson.modelmesh.api.ModelMeshGrpc; +import com.ibm.watson.modelmesh.api.ModelMeshGrpc.ModelMeshBlockingStub; +import com.ibm.watson.modelmesh.api.ModelStatusInfo; +import com.ibm.watson.modelmesh.api.ModelStatusInfo.ModelStatus; +import com.ibm.watson.modelmesh.api.RegisterModelRequest; +import com.ibm.watson.modelmesh.api.UnregisterModelRequest; +import com.ibm.watson.modelmesh.example.api.ExamplePredictorGrpc; +import com.ibm.watson.modelmesh.example.api.ExamplePredictorGrpc.ExamplePredictorBlockingStub; +import com.ibm.watson.modelmesh.example.api.Predictor.PredictRequest; +import com.ibm.watson.modelmesh.example.api.Predictor.PredictResponse; +import io.grpc.ManagedChannel; +import io.grpc.StatusRuntimeException; +import io.grpc.netty.NettyChannelBuilder; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; + +import static com.ibm.watson.modelmesh.ModelMeshEnvVars.LOAD_FAILURE_EXPIRY_ENV_VAR; +import static org.junit.jupiter.api.Assertions.*; + +/** + * Test load failure expiry while model is receiving continuous stream of inference requests + */ +public class ModelMeshFailureExpiryTest extends SingleInstanceModelMeshTest { + + @Override + @BeforeEach + public void initialize() throws Exception { + System.setProperty("tas.janitor_freq_secs", "1"); // set janitor to run every second + System.setProperty(LOAD_FAILURE_EXPIRY_ENV_VAR, "2000"); // expire failures in 2 seconds + try { + super.initialize(); + } finally { + System.clearProperty("tas.janitor_freq_secs"); + System.clearProperty(LOAD_FAILURE_EXPIRY_ENV_VAR); + } + } + + @Test + public void failureExpiryTest() throws Exception { + ExecutorService es = Executors.newSingleThreadExecutor(); + + ManagedChannel channel = NettyChannelBuilder.forAddress("localhost", 8088) + .usePlaintext().build(); + String modelId = "myFailExpireModel"; + ModelMeshBlockingStub manageModels = ModelMeshGrpc.newBlockingStub(channel); + + final AtomicBoolean stop = new AtomicBoolean(); + try { + ExamplePredictorBlockingStub useModels = ExamplePredictorGrpc.newBlockingStub(channel); + + String errorMessage = "load should fail with this error"; + + // add a model which should fail to load + ModelStatusInfo statusInfo = manageModels.registerModel(RegisterModelRequest.newBuilder() + .setModelId(modelId).setModelInfo(ModelInfo.newBuilder().setType("ExampleType") + .setPath("FAIL_" + errorMessage).build()) // special prefix to trigger load in dummy runtime + .setLoadNow(true).setSync(true).build()); + + System.out.println("registerModel returned: " + statusInfo); + assertEquals(ModelStatus.LOADING_FAILED, statusInfo.getStatus()); + + + final PredictRequest req = PredictRequest.newBuilder().setText("predict me!").build(); + + es.execute(() -> { + while (!stop.get()) try { + forModel(useModels, modelId).predict(req); + } catch (StatusRuntimeException sre) { + // ignore + } finally { + try { + Thread.sleep(400L); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }); + + // call predict on the model + try { + forModel(useModels, modelId).predict(req); + fail("failed model should fail predict"); + } catch (StatusRuntimeException sre) { + assertTrue(sre.getStatus().getDescription().endsWith(errorMessage)); + } + + Thread.sleep(1000); + + // not expired yet + try { + forModel(useModels, modelId).predict(req); + fail("failed model should fail predict"); + } catch (StatusRuntimeException sre) { + assertTrue(sre.getStatus().getDescription().endsWith(errorMessage)); + } + + Thread.sleep(3000); + + // failure should now be expired and it should load successfully second time + // per logic in dummy Model class + + PredictResponse response = forModel(useModels, modelId).predict(req); + System.out.println("predict returned: " + response.getResultsList()); + assertEquals(1.0, response.getResults(0).getConfidence(), 0); + assertEquals("classification for predict me! by model " + modelId, + response.getResults(0).getCategory()); + + } finally { + stop.set(true); + manageModels.unregisterModel(UnregisterModelRequest.newBuilder() + .setModelId(modelId).build()); + channel.shutdown(); + es.shutdown(); + } + } + +} diff --git a/src/test/java/com/ibm/watson/modelmesh/SidecarModelMeshTest.java b/src/test/java/com/ibm/watson/modelmesh/SidecarModelMeshTest.java index 524954558..d290eb1e8 100644 --- a/src/test/java/com/ibm/watson/modelmesh/SidecarModelMeshTest.java +++ b/src/test/java/com/ibm/watson/modelmesh/SidecarModelMeshTest.java @@ -124,7 +124,7 @@ public void grpcTest() throws Exception { public void loadFailureTest() throws Exception { ManagedChannel channel = NettyChannelBuilder.forAddress("localhost", 8088) .usePlaintext().build(); - String modelId = "myModel"; + String modelId = "myModel_" + getClass().getSimpleName(); ModelMeshBlockingStub manageModels = ModelMeshGrpc.newBlockingStub(channel); try { String errorMessage = "load should fail with this error"; diff --git a/src/test/java/com/ibm/watson/modelmesh/example/Model.java b/src/test/java/com/ibm/watson/modelmesh/example/Model.java index 11917bc3c..8a5cc9339 100644 --- a/src/test/java/com/ibm/watson/modelmesh/example/Model.java +++ b/src/test/java/com/ibm/watson/modelmesh/example/Model.java @@ -16,6 +16,8 @@ package com.ibm.watson.modelmesh.example; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -23,6 +25,9 @@ public class Model { + // Keeps track of how many load attempts there have been per model id + final static Map loadCounts = new ConcurrentHashMap<>(); + final String id; final String location; @@ -39,7 +44,8 @@ public Model(String id, String location) { static final Pattern LOC_PATT = Pattern.compile("SIZE_(\\d+)_(\\d+)_(\\d+)"); public String load() { - System.out.println("Loading model " + id + "..."); + int loadCount = loadCounts.merge(id, 1, (i1, i2) -> i1 + i2); // increment + System.out.println("Loading model " + id + "(count = " + loadCount + ")..."); Matcher m = location != null ? LOC_PATT.matcher(location) : null; @@ -61,8 +67,8 @@ public String load() { sleep(loadingTime); // simulate time to load - if (location != null && location.startsWith("FAIL_")) { - return location.substring(5); // simulate loading failure + if (loadCount <= 1 && location != null && location.startsWith("FAIL_")) { + return location.substring(5); // simulate loading failure for first load } return null;