diff --git a/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java b/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java index 6ada1292..14dcf320 100644 --- a/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java +++ b/src/main/java/com/ibm/watson/modelmesh/ModelMesh.java @@ -208,7 +208,10 @@ public abstract class ModelMesh extends ThriftService // used in routing decisions, gets set to Math.max(3000L, loadTimeoutMs/3) protected /*final*/ long defaultAssumeLoadedAfterMs; - public static final long LOAD_FAILURE_EXPIRY_MS = 600_000L; // 10mins for now + // time after which loading failure records expire (allowing for re-attempts) + public final long LOAD_FAILURE_EXPIRY_MS = getLongParameter(LOAD_FAILURE_EXPIRY_ENV_VAR, 600_000L); // default 10mins + // shorter expiry time for "in use" models (receiving recent requests) + public final long IN_USE_LOAD_FAILURE_EXPIRY_MS = (LOAD_FAILURE_EXPIRY_MS * 2) / 3; public static final int MAX_LOAD_FAILURES = 3; // if unable to invoke in this many places, don't continue to load public static final int MAX_LOAD_LOCATIONS = 5; @@ -265,6 +268,10 @@ public abstract class ModelMesh extends ThriftService // time before which we don't wait for migrated models to load elsewhere during pre-shutdown protected static final long CUTOFF_AGE_MS = 60 * 60_000L; // 1 hour + // when expiring failure records, use the shorter age if recent requests for the model + // have been seen within this time + protected static final long SHORT_EXPIRY_RECENT_USE_TIME_MS = 3 * 60_000L; // 3mins + // max combined number of cache-hit/miss retries per request - mainly just a safeguard protected static final int MAX_ITERATIONS = 8; @@ -4508,14 +4515,15 @@ private void checkLoadLocationCount(ModelRecord mr, Collection explicitE } // check if model load failures have breached the maximum allowed limit - private static void checkLoadFailureCount(ModelRecord mr, ModelLoadException loadFailureSeen) + private void checkLoadFailureCount(ModelRecord mr, ModelLoadException loadFailureSeen) throws ModelLoadException { Map failedInInstances = mr.getLoadFailedInstanceIds(); if (!failedInInstances.isEmpty()) { int count = 0; + final long expiryCutoffTime = currentTimeMillis() - IN_USE_LOAD_FAILURE_EXPIRY_MS; for (Long failTime : failedInInstances.values()) { - if (failTime > LOAD_FAILURE_EXPIRY_MS) { - count++; + if (failTime > expiryCutoffTime) { + count++; // not yet expired } if (count >= MAX_LOAD_FAILURES) { if (loadFailureSeen != null) { @@ -4971,6 +4979,7 @@ protected CacheEntry loadLocal(String modelId, ModelRecord[] mrh, long lastUs break; // success } + logger.info("Encountered existing cache entry while loading model " + modelId); synchronized (existCe) { // A cache entry was already there - most likely that another thread // in this instance is also loading this model (in this same method). @@ -4986,6 +4995,7 @@ protected CacheEntry loadLocal(String modelId, ModelRecord[] mrh, long lastUs if (latestMr == null) { mrh[0] = null; existCe.remove(); + logger.info("Existing cache entry for model " + modelId + " now gone"); return INSTANCES_CHANGED; // ModelNotFoundException will be thrown } @@ -4995,6 +5005,7 @@ protected CacheEntry loadLocal(String modelId, ModelRecord[] mrh, long lastUs || !Objects.equals(latestMr.getLoadFailedInstanceIds(), mr.getLoadFailedInstanceIds())) { // model registrations changed, re-start main loop + logger.info("Registrations changed for " + modelId + ", will reevaluate"); return INSTANCES_CHANGED; } mr = latestMr; @@ -5007,6 +5018,7 @@ protected CacheEntry loadLocal(String modelId, ModelRecord[] mrh, long lastUs // Odd situation, similar to janitor logic for when a local // cache entry is found without corresponding model record entry, // we just "recycle" the already-loading/loaded one + logger.info("Recycling existing entry for " + modelId + "(state=" + ceStateString(stateNow) + ")"); ce = existCe; break; } @@ -5017,15 +5029,20 @@ protected CacheEntry loadLocal(String modelId, ModelRecord[] mrh, long lastUs if (existCe.isFailed()) { assert stateNow == CacheEntry.FAILED; latestMr = handleUnexpectedFailedCacheEntry(existCe, mr); - if (latestMr != mr) { - mrh[0] = mr = latestMr; - return INSTANCES_CHANGED; + mrh[0] = latestMr; + if (!existCe.isRemoved()) { + if (latestMr != mr) { + return INSTANCES_CHANGED; + } + logger.info("Unexpected failed cache entry for model " + modelId + + ", treating as load failure"); + return existCe; } - mrh[0] = mr = latestMr; - return existCe; + // else continue to loop now that the entry has been removed + } else { + // We'll continue to loop in this case for now + assert stateNow == CacheEntry.NEW; } - // We'll continue to loop in this case for now - assert stateNow == CacheEntry.NEW; } existCe = null; @@ -5202,33 +5219,40 @@ private ModelRecord handleUnexpectedFailedCacheEntry(CacheEntry ce, ModelReco if (failure == null) { return mr; // safeguard timeout case (didn't see load fail but timed out waiting for it) } - if (failure instanceof ModelLoadException - && ((ModelLoadException) failure).getTimeout() == KVSTORE_LOAD_FAILURE) { - long failureAge = currentTimeMillis() - ce.loadCompleteTimestamp; - if (failureAge > 30_000 && failureAge > 30_000 - + ThreadLocalRandom.current().nextLong(30_000)) { // Randomize to avoid thunder - ModelRecord newMr = registry.get(ce.modelId); - if (newMr == null ? mr != null : (mr == null || newMr.getVersion() == mr.getVersion())) { - // First replace the entry with a later-expiring one to block concurrent attempts - CacheEntry replacement = new CacheEntry<>(ce); - if (runtimeCache.replaceQuietly(ce.modelId, ce, replacement)) { - ce.remove(); - ce = replacement; + if (!(failure instanceof ModelLoadException) + || ((ModelLoadException) failure).getTimeout() != KVSTORE_LOAD_FAILURE) { + // We assume that this is an expired entry yet to be cleaned up + if (ce.remove()) { + logger.info("Removed kv-store failure cache entry for model " + ce.modelId); + } + return mr; + } + + long failureAge = currentTimeMillis() - ce.loadCompleteTimestamp; + if (failureAge > 30_000 && failureAge > 30_000 + + ThreadLocalRandom.current().nextLong(30_000)) { // Randomize to avoid thunder + + ModelRecord newMr = registry.get(ce.modelId); + if (newMr == null ? mr == null : (mr != null && newMr.getVersion() == mr.getVersion())) { + // First replace the entry with a later-expiring one to block concurrent attempts + CacheEntry replacement = new CacheEntry<>(ce); + if (runtimeCache.replaceQuietly(ce.modelId, ce, replacement)) { + ce.remove(); + ce = replacement; + try { // this might throw if there are still KV store issues - try { - newMr = registry.getStrong(ce.modelId); - if (ce.remove()) { - logger.info("Removed kv-store failure cache entry for model " + ce.modelId); - } - } catch (Exception e) { - // Cannot verify / still KV store problems - logger.warn("Failed to retrieve model record after kv-store failure entry expiry" - + " for model " + ce.modelId); + newMr = registry.getStrong(ce.modelId); + if (ce.remove()) { + logger.info("Removed kv-store failure cache entry for model " + ce.modelId); } + } catch (Exception e) { + // Cannot verify / still KV store problems + logger.warn("Failed to retrieve model record after kv-store failure entry expiry" + + " for model " + ce.modelId); } } - return newMr; } + return newMr; } if (mr != null) { // Allow failure to propagate (e.g. to invokeLocalModel() after load attempt) @@ -5907,10 +5931,22 @@ public void run() { if (ce == null) { ce = runtimeCache.getQuietly(modelId); } - long lastUsed = -1L; + long lastUsed = -2L; boolean remLoaded = loaded && (ce == null || ce.isFailed()); - boolean remFailed = failedTime != null - && ((ce != null && !ce.isFailed()) || now - failedTime > LOAD_FAILURE_EXPIRY_MS); + boolean remFailed = false; + if (failedTime != null) { + if (ce != null && !ce.isFailed()) { + remFailed = true; + } else { + lastUsed = ce != null ? runtimeCache.getLastUsedTime(modelId) : -1L; + // Use shorter expiry age if model was used in last 3 minutes + final long expiryAge = (lastUsed > 0 && (now - lastUsed) < SHORT_EXPIRY_RECENT_USE_TIME_MS) + ? IN_USE_LOAD_FAILURE_EXPIRY_MS : LOAD_FAILURE_EXPIRY_MS; + if (now - failedTime > expiryAge) { + remFailed = true; + } + } + } if (remLoaded || remFailed) { if (shuttingDown) { return; @@ -5924,7 +5960,9 @@ public void run() { mr.removeLoadFailure(instanceId); } if (ce != null) { - lastUsed = runtimeCache.getLastUsedTime(modelId); + if (lastUsed == -2) { + lastUsed = runtimeCache.getLastUsedTime(modelId); + } if (lastUsed > 0L) { mr.updateLastUsed(lastUsed); } diff --git a/src/main/java/com/ibm/watson/modelmesh/ModelMeshEnvVars.java b/src/main/java/com/ibm/watson/modelmesh/ModelMeshEnvVars.java index 4025ddf4..6351f5f6 100644 --- a/src/main/java/com/ibm/watson/modelmesh/ModelMeshEnvVars.java +++ b/src/main/java/com/ibm/watson/modelmesh/ModelMeshEnvVars.java @@ -43,6 +43,8 @@ private ModelMeshEnvVars() {} public static final String MAX_INFLIGHT_PER_COPY_ENV_VAR = "MM_MAX_INFLIGHT_PER_MODEL_COPY"; public static final String CONC_SCALEUP_BANDWIDTH_PCT_ENV_VAR = "MM_CONC_SCALEUP_BANDWIDTH_PCT"; + public static final String LOAD_FAILURE_EXPIRY_ENV_VAR = "MM_LOAD_FAILURE_EXPIRY_TIME_MS"; + public static final String MMESH_METRICS_ENV_VAR = "MM_METRICS"; public static final String LOG_EACH_INVOKE_ENV_VAR = "MM_LOG_EACH_INVOKE"; diff --git a/src/test/java/com/ibm/watson/modelmesh/ModelMeshFailureExpiryTest.java b/src/test/java/com/ibm/watson/modelmesh/ModelMeshFailureExpiryTest.java new file mode 100644 index 00000000..7d88c9f8 --- /dev/null +++ b/src/test/java/com/ibm/watson/modelmesh/ModelMeshFailureExpiryTest.java @@ -0,0 +1,140 @@ +/* + * Copyright 2022 IBM Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy + * of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package com.ibm.watson.modelmesh; + +import com.ibm.watson.modelmesh.api.ModelInfo; +import com.ibm.watson.modelmesh.api.ModelMeshGrpc; +import com.ibm.watson.modelmesh.api.ModelMeshGrpc.ModelMeshBlockingStub; +import com.ibm.watson.modelmesh.api.ModelStatusInfo; +import com.ibm.watson.modelmesh.api.ModelStatusInfo.ModelStatus; +import com.ibm.watson.modelmesh.api.RegisterModelRequest; +import com.ibm.watson.modelmesh.api.UnregisterModelRequest; +import com.ibm.watson.modelmesh.example.api.ExamplePredictorGrpc; +import com.ibm.watson.modelmesh.example.api.ExamplePredictorGrpc.ExamplePredictorBlockingStub; +import com.ibm.watson.modelmesh.example.api.Predictor.PredictRequest; +import com.ibm.watson.modelmesh.example.api.Predictor.PredictResponse; +import io.grpc.ManagedChannel; +import io.grpc.StatusRuntimeException; +import io.grpc.netty.NettyChannelBuilder; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; + +import static com.ibm.watson.modelmesh.ModelMeshEnvVars.LOAD_FAILURE_EXPIRY_ENV_VAR; +import static org.junit.jupiter.api.Assertions.*; + +/** + * Test load failure expiry while model is receiving continuous stream of inference requests + */ +public class ModelMeshFailureExpiryTest extends SingleInstanceModelMeshTest { + + @Override + @BeforeEach + public void initialize() throws Exception { + System.setProperty("tas.janitor_freq_secs", "1"); // set janitor to run every second + System.setProperty(LOAD_FAILURE_EXPIRY_ENV_VAR, "2000"); // expire failures in 2 seconds + try { + super.initialize(); + } finally { + System.clearProperty("tas.janitor_freq_secs"); + System.clearProperty(LOAD_FAILURE_EXPIRY_ENV_VAR); + } + } + + @Test + public void failureExpiryTest() throws Exception { + ExecutorService es = Executors.newSingleThreadExecutor(); + + ManagedChannel channel = NettyChannelBuilder.forAddress("localhost", 8088) + .usePlaintext().build(); + String modelId = "myFailExpireModel"; + ModelMeshBlockingStub manageModels = ModelMeshGrpc.newBlockingStub(channel); + + final AtomicBoolean stop = new AtomicBoolean(); + try { + ExamplePredictorBlockingStub useModels = ExamplePredictorGrpc.newBlockingStub(channel); + + String errorMessage = "load should fail with this error"; + + // add a model which should fail to load + ModelStatusInfo statusInfo = manageModels.registerModel(RegisterModelRequest.newBuilder() + .setModelId(modelId).setModelInfo(ModelInfo.newBuilder().setType("ExampleType") + .setPath("FAIL_" + errorMessage).build()) // special prefix to trigger load in dummy runtime + .setLoadNow(true).setSync(true).build()); + + System.out.println("registerModel returned: " + statusInfo); + assertEquals(ModelStatus.LOADING_FAILED, statusInfo.getStatus()); + + + final PredictRequest req = PredictRequest.newBuilder().setText("predict me!").build(); + + es.execute(() -> { + while (!stop.get()) try { + forModel(useModels, modelId).predict(req); + } catch (StatusRuntimeException sre) { + // ignore + } finally { + try { + Thread.sleep(400L); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }); + + // call predict on the model + try { + forModel(useModels, modelId).predict(req); + fail("failed model should fail predict"); + } catch (StatusRuntimeException sre) { + assertTrue(sre.getStatus().getDescription().endsWith(errorMessage)); + } + + Thread.sleep(1000); + + // not expired yet + try { + forModel(useModels, modelId).predict(req); + fail("failed model should fail predict"); + } catch (StatusRuntimeException sre) { + assertTrue(sre.getStatus().getDescription().endsWith(errorMessage)); + } + + Thread.sleep(3000); + + // failure should now be expired and it should load successfully second time + // per logic in dummy Model class + + PredictResponse response = forModel(useModels, modelId).predict(req); + System.out.println("predict returned: " + response.getResultsList()); + assertEquals(1.0, response.getResults(0).getConfidence(), 0); + assertEquals("classification for predict me! by model " + modelId, + response.getResults(0).getCategory()); + + } finally { + stop.set(true); + manageModels.unregisterModel(UnregisterModelRequest.newBuilder() + .setModelId(modelId).build()); + channel.shutdown(); + es.shutdown(); + } + } + +} diff --git a/src/test/java/com/ibm/watson/modelmesh/SidecarModelMeshTest.java b/src/test/java/com/ibm/watson/modelmesh/SidecarModelMeshTest.java index 52495455..d290eb1e 100644 --- a/src/test/java/com/ibm/watson/modelmesh/SidecarModelMeshTest.java +++ b/src/test/java/com/ibm/watson/modelmesh/SidecarModelMeshTest.java @@ -124,7 +124,7 @@ public void grpcTest() throws Exception { public void loadFailureTest() throws Exception { ManagedChannel channel = NettyChannelBuilder.forAddress("localhost", 8088) .usePlaintext().build(); - String modelId = "myModel"; + String modelId = "myModel_" + getClass().getSimpleName(); ModelMeshBlockingStub manageModels = ModelMeshGrpc.newBlockingStub(channel); try { String errorMessage = "load should fail with this error"; diff --git a/src/test/java/com/ibm/watson/modelmesh/example/Model.java b/src/test/java/com/ibm/watson/modelmesh/example/Model.java index 11917bc3..8a5cc933 100644 --- a/src/test/java/com/ibm/watson/modelmesh/example/Model.java +++ b/src/test/java/com/ibm/watson/modelmesh/example/Model.java @@ -16,6 +16,8 @@ package com.ibm.watson.modelmesh.example; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -23,6 +25,9 @@ public class Model { + // Keeps track of how many load attempts there have been per model id + final static Map loadCounts = new ConcurrentHashMap<>(); + final String id; final String location; @@ -39,7 +44,8 @@ public Model(String id, String location) { static final Pattern LOC_PATT = Pattern.compile("SIZE_(\\d+)_(\\d+)_(\\d+)"); public String load() { - System.out.println("Loading model " + id + "..."); + int loadCount = loadCounts.merge(id, 1, (i1, i2) -> i1 + i2); // increment + System.out.println("Loading model " + id + "(count = " + loadCount + ")..."); Matcher m = location != null ? LOC_PATT.matcher(location) : null; @@ -61,8 +67,8 @@ public String load() { sleep(loadingTime); // simulate time to load - if (location != null && location.startsWith("FAIL_")) { - return location.substring(5); // simulate loading failure + if (loadCount <= 1 && location != null && location.startsWith("FAIL_")) { + return location.substring(5); // simulate loading failure for first load } return null;