Skip to content

Commit

Permalink
Ensure that remote exceptions to be checked are unwrapped; add test
Browse files Browse the repository at this point in the history
Signed-off-by: Nick Hill <nickhill@us.ibm.com>
  • Loading branch information
njhill committed Oct 3, 2022
1 parent 4cfdf2e commit 1ade6d8
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 23 deletions.
47 changes: 26 additions & 21 deletions src/main/java/com/ibm/watson/modelmesh/ModelMesh.java
Original file line number Diff line number Diff line change
Expand Up @@ -2538,7 +2538,7 @@ static <T extends Throwable> T noStack(T exception) {
static final ApplierException QUEUE_BREACH_EXCEPTION = noStack(
new ApplierException("Model queue overload", null, RESOURCE_EXHAUSTED));

static boolean isExhausted(Exception e) {
static boolean isExhausted(Throwable e) {
return e instanceof ApplierException && RESOURCE_EXHAUSTED.equals(((ApplierException) e).getGrpcStatusCode());
}

Expand Down Expand Up @@ -3551,15 +3551,18 @@ protected Object invokeModel(final String modelId, final Method method, final Me
Object result = invokeRemote(runtimeClient, method, remoteMeth, modelId, args);
return method == null && externalReq ? updateWithModelCopyInfo(result, mr) : result;
} catch (Exception e) {
boolean callFailed = processRemoteInvocationException(e, modelId); // this may throw
final Throwable t = e instanceof InvocationTargetException ? e.getCause() : e;
final boolean callFailed = processRemoteInvocationException(t, modelId); // this may throw
if (callFailed) {
if (e instanceof ModelLoadException) {
loadFailureSeen = (ModelLoadException) e;
if (t instanceof ModelLoadException) {
loadFailureSeen = (ModelLoadException) t;
updateLocalModelRecordAfterRemoteLoadFailure(mr, loadFailureSeen);
} else if (e instanceof InternalException) {
internalFailureSeen = (InternalException) e;
} else if (isExhausted(e) && ++resExaustedCount >= MAX_RES_EXHAUSTED) {
throw e;
} else if (t instanceof InternalException) {
internalFailureSeen = (InternalException) t;
} else if (isExhausted(t) && ++resExaustedCount >= MAX_RES_EXHAUSTED) {
Throwables.throwIfInstanceOf(t, Error.class);
Throwables.throwIfInstanceOf(t, Exception.class);
throw new IllegalStateException(t); // should not happen
}
continue;
}
Expand Down Expand Up @@ -3717,16 +3720,19 @@ else if (mr.getInstanceIds().containsKey(instanceId)) {
Object result = invokeRemote(cacheMissClient, method, remoteMeth, modelId, args);
return method == null && externalReq ? updateWithModelCopyInfo(result, mr) : result;
} catch (Exception e) {
boolean callFailed = processRemoteInvocationException(e, modelId); // this may throw
final Throwable t = e instanceof InvocationTargetException ? e.getCause() : e;
final boolean callFailed = processRemoteInvocationException(t, modelId); // this may throw
//TODO handle "stale" case here
if (callFailed) {
if (e instanceof ModelLoadException) {
loadFailureSeen = (ModelLoadException) e;
if (t instanceof ModelLoadException) {
loadFailureSeen = (ModelLoadException) t;
updateLocalModelRecordAfterRemoteLoadFailure(mr, loadFailureSeen);
} else if (e instanceof InternalException) {
internalFailureSeen = (InternalException) e;
} else if (isExhausted(e) && ++resExaustedCount >= MAX_RES_EXHAUSTED) {
throw e;
} else if (t instanceof InternalException) {
internalFailureSeen = (InternalException) t;
} else if (isExhausted(t) && ++resExaustedCount >= MAX_RES_EXHAUSTED) {
Throwables.throwIfInstanceOf(t, Error.class);
Throwables.throwIfInstanceOf(t, Exception.class);
throw new IllegalStateException(t); // should not happen
}
// continue inner loop
if (++n >= MAX_ITERATIONS) {
Expand Down Expand Up @@ -4113,17 +4119,16 @@ static Map<String, String> ensureContextMapIsMutable(Map<String, String> context
}

/**
* @param e
* @param t
* @return true if remote call failed, false if call wasn't made (due to unavailability or
* indication that local attempt should be made)
* @throws TException
*/
protected boolean processRemoteInvocationException(Exception e, String modelId) throws TException {
if (e instanceof IllegalAccessException || e instanceof RuntimeException) {
protected boolean processRemoteInvocationException(Throwable t, String modelId) throws TException {
if (t instanceof IllegalAccessException || t instanceof RuntimeException) {
throw newInternalException(
"Unexpected exception while attempting remote invocation for model " + modelId, e);
"Unexpected exception while attempting remote invocation for model " + modelId, t);
} else {
Throwable t = e instanceof InvocationTargetException ? e.getCause() : e;
if (t.getCause() instanceof ServiceUnavailableException) {
return false;
} else if (t instanceof ModelNotHereException) {
Expand Down Expand Up @@ -4155,7 +4160,7 @@ protected boolean processRemoteInvocationException(Exception e, String modelId)
}
Throwables.throwIfInstanceOf(t, Error.class);
Throwables.throwIfInstanceOf(t, TException.class); // other app-defined exceptions or ModelNotFoundException
throw new IllegalStateException(e); // should not happen
throw new IllegalStateException(t); // should not happen
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ protected PodCloser[] getPodClosers() {
return podClosers;
}

// Can be overridden
protected Map<String, String> extraEnvVars(String replicaId) {
return extraEnvVars();
}

@BeforeAll
public void initialize() throws Exception {
//shared infrastructure
Expand All @@ -104,8 +109,10 @@ public void initialize() throws Exception {
String replicaSetId = "RS1";
podClosers = new PodCloser[replicaCount()];
for (int i = 0; i < podClosers.length; i++) {
podClosers[i] = startModelMeshPod(kvStoreString, replicaSetId, 9000 + i * 4,
extraEnvVars, extraRtEnvVars, extraJvmArgs, extraLlArgs,
int port = 9000 + i * 4;
String replicaId = Integer.toString(port);
podClosers[i] = startModelMeshPod(kvStoreString, replicaSetId, port,
extraEnvVars(replicaId), extraRtEnvVars, extraJvmArgs, extraLlArgs,
useDifferentInternalPortForInference(), inheritIo());
}
System.out.println("started");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Copyright 2022 IBM Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy
* of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/

package com.ibm.watson.modelmesh;

import com.google.common.collect.ImmutableMap;
import com.ibm.watson.modelmesh.api.GetStatusRequest;
import com.ibm.watson.modelmesh.api.ModelInfo;
import com.ibm.watson.modelmesh.api.ModelMeshGrpc;
import com.ibm.watson.modelmesh.api.ModelMeshGrpc.ModelMeshBlockingStub;
import com.ibm.watson.modelmesh.api.ModelStatusInfo;
import com.ibm.watson.modelmesh.api.ModelStatusInfo.ModelStatus;
import com.ibm.watson.modelmesh.api.RegisterModelRequest;
import com.ibm.watson.modelmesh.api.UnregisterModelRequest;
import com.ibm.watson.modelmesh.example.api.ExamplePredictorGrpc;
import com.ibm.watson.modelmesh.example.api.ExamplePredictorGrpc.ExamplePredictorBlockingStub;
import com.ibm.watson.modelmesh.example.api.Predictor.PredictRequest;
import com.ibm.watson.modelmesh.example.api.Predictor.PredictResponse;
import io.grpc.ManagedChannel;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.netty.NettyChannelBuilder;
import org.junit.jupiter.api.Test;

import java.util.Map;

import static org.junit.jupiter.api.Assertions.*;

/**
* Model-mesh unit tests
*/
public class ModelMeshErrorPropagationTest extends AbstractModelMeshClusterTest {

@Override
protected int replicaCount() {
return 2;
}

static String CONSTRAINTS = "{\n" +
" \"my-type-1\": {\n" +
" \"required\": [\"my-label-1\"]\n" +
" }\n" +
"}";

@Override
protected Map<String, String> extraEnvVars(String replicaId) {
return !"9000".equals(replicaId) ? ImmutableMap.of("MM_TYPE_CONSTRAINTS", CONSTRAINTS)
: ImmutableMap.of("MM_TYPE_CONSTRAINTS", CONSTRAINTS, "MM_LABELS", "my-label-1");
}

@Test
public void errorPropagationTest() throws Exception {

ManagedChannel channel = NettyChannelBuilder.forAddress("localhost", 9000).usePlaintext().build();
ManagedChannel channel2 = NettyChannelBuilder.forAddress("localhost", 9004).usePlaintext().build();
try {
ModelMeshBlockingStub manageModels = ModelMeshGrpc.newBlockingStub(channel);

ExamplePredictorBlockingStub useModels = ExamplePredictorGrpc.newBlockingStub(channel);
ExamplePredictorBlockingStub useModels2 = ExamplePredictorGrpc.newBlockingStub(channel2);

// Add a model - with the type constraints it can only be loaded in one of the two instances
String modelId = "myModel22";
ModelStatusInfo statusInfo = manageModels.registerModel(RegisterModelRequest.newBuilder()
.setModelId(modelId).setModelInfo(ModelInfo.newBuilder().setType("my-type-1").build())
.setLoadNow(true).build());

System.out.println("registerModel returned: " + statusInfo.getStatus());

// call predict on the model to ensure that it's loaded and working
PredictRequest req = PredictRequest.newBuilder().setText("predict me!").build();
PredictResponse response = forModel(useModels, modelId).predict(req);
assertEquals(1.0, response.getResults(0).getConfidence(), 0);
assertEquals("classification for predict me! by model myModel22",
response.getResults(0).getCategory());

// verify that there's only one copy loaded
ModelStatusInfo status = manageModels.getModelStatus(GetStatusRequest.newBuilder()
.setModelId(modelId).build());
assertEquals(ModelStatus.LOADED, status.getStatus());
assertEquals(0, status.getErrorsCount());
assertEquals(1, status.getModelCopyInfosCount());

// Send a poison request to make the runtime return a specific error
req = PredictRequest.newBuilder()
.setText("test:error:code=UNAVAILABLE:message=Fake prediction error message").build();

// Ensure client recieves consistent error whichever instance the external request is sent to
try {
forModel(useModels, modelId).predict(req);
fail("predict call should have failed");
} catch (StatusRuntimeException sre) {
assertExpectedException(sre);
}

try {
forModel(useModels2, modelId).predict(req);
fail("predict call should have failed");
} catch (StatusRuntimeException sre) {
assertExpectedException(sre);
}

// verify that there's still only one copy loaded
status = manageModels.getModelStatus(GetStatusRequest.newBuilder()
.setModelId(modelId).build());
assertEquals(ModelStatus.LOADED, status.getStatus());
assertEquals(0, status.getErrorsCount());
assertEquals(1, status.getModelCopyInfosCount());

// delete
manageModels.unregisterModel(UnregisterModelRequest.newBuilder()
.setModelId(modelId).build());
} finally {
channel.shutdown();
channel2.shutdown();
}
}

static void assertExpectedException(StatusRuntimeException sre) {
Status status = Status.fromThrowable(sre);
assertEquals(Status.Code.INTERNAL, status.getCode());
assertEquals("ModelRuntime UNAVAILABLE: mmesh.ExamplePredictor/predict: UNAVAILABLE: Fake prediction error message",
status.getDescription());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.LongAdder;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static java.lang.Integer.parseInt;

Expand Down Expand Up @@ -393,6 +395,8 @@ public void multiPredict(PredictRequest request, StreamObserver<MultiPredictResp
response.onCompleted();
}

private static final Pattern ERR_REQ_PATT = Pattern.compile("test:error:code=(\\w+):message=(.+)");

private PredictResponse doPredict(String modelId,
PredictRequest request, StreamObserver<?> response) {

Expand Down Expand Up @@ -421,6 +425,15 @@ private PredictResponse doPredict(String modelId,

performSpecialActions(stringToClassify);

Matcher m = ERR_REQ_PATT.matcher(stringToClassify);
if (m.matches()) {
// Decode request intended to return a specific error
response.onError(io.grpc.Status.fromCode(
io.grpc.Status.Code.valueOf(m.group(1)))
.withDescription(m.group(2)).asException());
return null;
}

// perform the inferencing
String classification = model.classify(stringToClassify);

Expand Down

0 comments on commit 1ade6d8

Please sign in to comment.