Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Make it possible to attach a PayloadProcessor to process model predictions #84

Merged
merged 45 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
ce5f99d
FAI-948 - resetting head
tteofili Feb 17, 2023
91cd53b
Merge branch 'main' of https://github.com/kserve/modelmesh into FAI-948
tteofili Feb 21, 2023
e50d2b6
FAI-948 - refactored payload processors to sit in MMApi and process r…
tteofili Feb 22, 2023
915061d
FAI-948 - async processor refactoring
tteofili Feb 22, 2023
cec75eb
FAI-948 - moving the payload processor instance into MMApi
tteofili Feb 22, 2023
5a9d731
FAI-948 - logging a warning in case of processing issues in MMApi
tteofili Feb 22, 2023
bcfbba4
FAI-948 - added payload uuid
tteofili Feb 23, 2023
6aba77e
FAI-948 - dropped filewriter processor
tteofili Feb 23, 2023
5e94101
FAI-948 - minor javadoc improvement
tteofili Feb 23, 2023
9c6046c
FAI-948 - fixed capacity queue test improved
tteofili Feb 23, 2023
c88c4f7
FAI-948 - minor improvements to payload data processor
tteofili Feb 23, 2023
ad1e746
FAI-948 - minor javadoc improvement
tteofili Feb 23, 2023
fad3984
FAI-948 - simplified payload processor initialization, added remote p…
tteofili Feb 27, 2023
f45edd3
FAI-948 - ignored useless imports
tteofili Feb 27, 2023
8128593
FAI-948 - improved matching payload processor, remote processor doesn…
tteofili Feb 28, 2023
d26ca03
FAI-948 - added base64 encoding for remote processor, minor improvements
tteofili Feb 28, 2023
bc3ea0e
FAI-948 - refactored remote processor
tteofili Mar 1, 2023
981d82a
FAI-948 - removed unused imports
tteofili Mar 1, 2023
32eee9d
FAI-948 - null check for payload data in remote processor
tteofili Mar 1, 2023
a47ec00
FAI-948 - base64 encoding only payload bytebuf, minor tweaks
tteofili Mar 1, 2023
6cbd71b
FAI-948 - request payload needs to be processed before releasing, gua…
tteofili Mar 1, 2023
219b4a0
FAI-948 - async processor should be aware of refCnt
tteofili Mar 1, 2023
5b98f99
FAI-948 - added example config, retain payload data
tteofili Mar 2, 2023
eacbd03
FAI-948 - debugging line dropped
tteofili Mar 2, 2023
5b1eeb6
Merge branch 'main' of https://github.com/kserve/modelmesh into FAI-948
tteofili Mar 2, 2023
6e4cd9e
FAI-948 - keep/restore bytebuf indexes positions instead of retainDup…
tteofili Mar 2, 2023
564942b
FAI-948 - move away from uuid to id, unified payload processing call,…
tteofili Mar 3, 2023
55e14f2
FAI-948 - implemented PR comments for payloads async processing
tteofili Mar 3, 2023
8454161
FAI-948 - minor improvements
tteofili Mar 3, 2023
1ab932d
FAI-948 - only retain if payload is enqueued
tteofili Mar 6, 2023
9b405eb
FAI-948 - remaining comments addressed
tteofili Mar 6, 2023
494e581
FAI-948 - minor adjustments
tteofili Mar 6, 2023
51dba80
FAI-948 - minor adjustments to README
tteofili Mar 7, 2023
9dac439
FAI-948 - more testing, minor improvements
tteofili Mar 7, 2023
8a1f612
FAI-948 - integrated Nick's patch
tteofili Mar 9, 2023
523aa55
FAI-948 - making PP closable
tteofili Mar 9, 2023
3394cd7
FAI-948 - default queue capacity to 64
tteofili Mar 9, 2023
67c6008
FAI-948 - configurable no threads, capacity, minor fix
tteofili Mar 9, 2023
829e5d7
FAI-948 - minor fix
tteofili Mar 9, 2023
5922f4f
FAI-948 - javadoc improvements
tteofili Mar 10, 2023
586240f
FAI-948 - addressing Nick's comments
tteofili Mar 10, 2023
ba0cf0e
FAI-948 - addressing Nick's comments
tteofili Mar 10, 2023
54be2bf
FAI-948 - moving back to netty base64 encoding
tteofili Mar 13, 2023
80d76c6
FAI-948 - minor fix
tteofili Mar 13, 2023
04a3d2f
FAI-948 - minor fixes
tteofili Mar 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config/base/kustomization.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ vars:

#patchesStrategicMerge:
# - patches/etcd.yaml
# - patches/logger.yaml
# - patches/tls.yaml
# - patches/uds.yaml
# - patches/max_msg_size.yaml
Expand Down
29 changes: 29 additions & 0 deletions config/base/patches/logger.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2023 IBM Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Use this patch to change the max size in bytes allowed
# per proxied gRPC message, for headers and data
#
apiVersion: apps/v1
kind: Deployment
metadata:
name: model-mesh
spec:
template:
spec:
containers:
- name: mm
env:
- name: MM_PAYLOAD_PROCESSORS
value: "logger://*"
tteofili marked this conversation as resolved.
Show resolved Hide resolved
44 changes: 43 additions & 1 deletion src/main/java/com/ibm/watson/modelmesh/ModelMesh.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@
import com.ibm.watson.modelmesh.TypeConstraintManager.ProhibitedTypeSet;
import com.ibm.watson.modelmesh.clhm.ConcurrentLinkedHashMap;
import com.ibm.watson.modelmesh.clhm.ConcurrentLinkedHashMap.EvictionListenerWithTime;
import com.ibm.watson.modelmesh.payload.AsyncPayloadProcessor;
import com.ibm.watson.modelmesh.payload.CompositePayloadProcessor;
import com.ibm.watson.modelmesh.payload.LoggingPayloadProcessor;
import com.ibm.watson.modelmesh.payload.MatchingPayloadProcessor;
import com.ibm.watson.modelmesh.payload.PayloadProcessor;
import com.ibm.watson.modelmesh.payload.RemotePayloadProcessor;
import com.ibm.watson.modelmesh.thrift.ApplierException;
import com.ibm.watson.modelmesh.thrift.BaseModelMeshService;
import com.ibm.watson.modelmesh.thrift.InternalException;
Expand Down Expand Up @@ -101,6 +107,7 @@
import java.lang.management.MemoryUsage;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URI;
import java.nio.channels.ClosedByInterruptException;
import java.text.ParseException;
import java.text.SimpleDateFormat;
Expand Down Expand Up @@ -421,6 +428,40 @@ public abstract class ModelMesh extends ThriftService
}
}

private PayloadProcessor initPayloadProcessor() {
String payloadProcessorsDefinitions = getStringParameter(MM_PAYLOAD_PROCESSORS, null);
logger.info("Parsing PayloadProcessor definition '{}'", payloadProcessorsDefinitions);
if (payloadProcessorsDefinitions != null && payloadProcessorsDefinitions.length() > 0) {
List<PayloadProcessor> payloadProcessors = new ArrayList<>();
for (String processorDefinition : payloadProcessorsDefinitions.split(" ")) {
try {
URI uri = URI.create(processorDefinition);
String processorName = uri.getScheme();
PayloadProcessor processor = null;
String modelId = uri.getQuery();
String method = uri.getFragment();
if ("http".equals(processorName)) {
processor = new RemotePayloadProcessor(uri);
} else if ("logger".equals(processorName)) {
processor = new LoggingPayloadProcessor();
}
if (processor != null) {
MatchingPayloadProcessor p = MatchingPayloadProcessor.from(modelId, method, processor);
payloadProcessors.add(p);
logger.info("Added PayloadProcessor {}", p.getName());
}
} catch (IllegalArgumentException iae) {
logger.error("Unable to parse PayloadProcessor URI definition {}", processorDefinition);
}
}
return new AsyncPayloadProcessor(new CompositePayloadProcessor(payloadProcessors), 1, MINUTES,
Executors.newScheduledThreadPool(getIntParameter(MM_PAYLOAD_PROCESSORS_THREADS, 2)),
getIntParameter(MM_PAYLOAD_PROCESSORS_CAPACITY, 64));
} else {
return null;
}
}

/* ---------------------------------- initialization --------------------------------------------------------- */

@Override
Expand Down Expand Up @@ -854,10 +895,11 @@ protected final TProcessor initialize() throws Exception {
}

LogRequestHeaders logHeaders = LogRequestHeaders.getConfiguredLogRequestHeaders();
PayloadProcessor payloadProcessor = initPayloadProcessor();

grpcServer = new ModelMeshApi((SidecarModelMesh) this, vModelManager, GRPC_PORT, keyCertFile, privateKeyFile,
privateKeyPassphrase, clientAuth, caCertFiles, maxGrpcMessageSize, maxGrpcHeadersSize,
maxGrpcConnectionAge, maxGrpcConnectionAgeGrace, logHeaders);
maxGrpcConnectionAge, maxGrpcConnectionAgeGrace, logHeaders, payloadProcessor);
}

if (grpcServer != null) {
Expand Down
123 changes: 92 additions & 31 deletions src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
import com.ibm.watson.modelmesh.api.UnregisterModelRequest;
import com.ibm.watson.modelmesh.api.UnregisterModelResponse;
import com.ibm.watson.modelmesh.api.VModelStatusInfo;
import com.ibm.watson.modelmesh.payload.Payload;
import com.ibm.watson.modelmesh.payload.PayloadProcessor;
import com.ibm.watson.modelmesh.thrift.ApplierException;
import com.ibm.watson.modelmesh.thrift.InvalidInputException;
import com.ibm.watson.modelmesh.thrift.InvalidStateException;
Expand Down Expand Up @@ -156,6 +158,10 @@ public final class ModelMeshApi extends ModelMeshGrpc.ModelMeshImplBase
// null if header logging is not enabled.
protected final LogRequestHeaders logHeaders;

private final PayloadProcessor payloadProcessor;

private final ThreadLocal<long[]> localIdCounter = ThreadLocal.withInitial(() -> new long[1]);

/**
* Create <b>and start</b> the server.
*
Expand All @@ -171,16 +177,18 @@ public final class ModelMeshApi extends ModelMeshGrpc.ModelMeshImplBase
* @param maxConnectionAge in seconds
* @param maxConnectionAgeGrace in seconds, custom grace time for graceful connection termination
* @param logHeaders
* @param payloadProcessor a processor of payloads
* @throws IOException
*/
public ModelMeshApi(SidecarModelMesh delegate, VModelManager vmm, int port, File keyCert, File privateKey,
String privateKeyPassphrase, ClientAuth clientAuth, File[] trustCerts,
int maxMessageSize, int maxHeadersSize, long maxConnectionAge, long maxConnectionAgeGrace,
LogRequestHeaders logHeaders) throws IOException {
LogRequestHeaders logHeaders, PayloadProcessor payloadProcessor) throws IOException {
tteofili marked this conversation as resolved.
Show resolved Hide resolved

this.delegate = delegate;
this.vmm = vmm;
this.logHeaders = logHeaders;
this.payloadProcessor = payloadProcessor;

this.multiParallelism = getMultiParallelism();

Expand Down Expand Up @@ -293,6 +301,13 @@ public void shutdown(long timeout, TimeUnit unit) throws InterruptedException {
if (!done) {
server.shutdownNow();
}
if (payloadProcessor != null) {
try {
payloadProcessor.close();
} catch (IOException e) {
logger.warn("Error closing PayloadProcessor {}: {}", payloadProcessor, e.getMessage());
}
}
tteofili marked this conversation as resolved.
Show resolved Hide resolved
threads.shutdownNow();
shutdownEventLoops();
}
Expand Down Expand Up @@ -686,49 +701,57 @@ public void onHalfClose() {
call.close(INTERNAL.withDescription("Half-closed without a request"), emptyMeta());
return;
}
final int reqSize = reqMessage.readableBytes();
int reqReaderIndex = reqMessage.readerIndex();
int reqSize = reqMessage.readableBytes();
int respSize = -1;
int respReaderIndex = 0;

io.grpc.Status status = INTERNAL;
String modelId = null;
String requestId = null;
ModelResponse response = null;
try (InterruptingListener cancelListener = newInterruptingListener()) {
if (logHeaders != null) {
logHeaders.addToMDC(headers); // MDC cleared in finally block
}
ModelResponse response = null;
if (payloadProcessor != null) {
requestId = Thread.currentThread().getId() + "-" + ++localIdCounter.get()[0];
}
try {
try {
String balancedMetaVal = headers.get(BALANCED_META_KEY);
Iterator<String> midIt = modelIds.iterator();
// guaranteed at least one
String modelId = validateModelId(midIt.next(), isVModel);
if (!midIt.hasNext()) {
// single model case (most common)
response = callModel(modelId, isVModel, methodName,
balancedMetaVal, headers, reqMessage).retain();
} else {
// multi-model case (specialized)
boolean allRequired = "all".equalsIgnoreCase(headers.get(REQUIRED_KEY));
List<String> idList = new ArrayList<>();
idList.add(modelId);
while (midIt.hasNext()) {
idList.add(validateModelId(midIt.next(), isVModel));
}
response = applyParallelMultiModel(idList, isVModel, methodName,
balancedMetaVal, headers, reqMessage, allRequired);
String balancedMetaVal = headers.get(BALANCED_META_KEY);
Iterator<String> midIt = modelIds.iterator();
// guaranteed at least one
modelId = validateModelId(midIt.next(), isVModel);
if (!midIt.hasNext()) {
// single model case (most common)
response = callModel(modelId, isVModel, methodName,
balancedMetaVal, headers, reqMessage).retain();
} else {
// multi-model case (specialized)
boolean allRequired = "all".equalsIgnoreCase(headers.get(REQUIRED_KEY));
List<String> idList = new ArrayList<>();
idList.add(modelId);
while (midIt.hasNext()) {
idList.add(validateModelId(midIt.next(), isVModel));
}
} finally {
releaseReqMessage();
response = applyParallelMultiModel(idList, isVModel, methodName,
balancedMetaVal, headers, reqMessage, allRequired);
}

respSize = response.data.readableBytes();
call.sendHeaders(response.metadata);
call.sendMessage(response.data);
response = null;
} finally {
if (response != null) {
response.release();
if (payloadProcessor != null) {
processPayload(reqMessage.readerIndex(reqReaderIndex),
requestId, modelId, methodName, headers, null, true);
} else {
releaseReqMessage();
}
reqMessage = null; // ownership released or transferred
}

respReaderIndex = response.data.readerIndex();
respSize = response.data.readableBytes();
call.sendHeaders(response.metadata);
call.sendMessage(response.data);
// response is released via ReleaseAfterResponse.releaseAll()
status = OK;
} catch (Exception e) {
status = toStatus(e);
Expand All @@ -745,6 +768,15 @@ public void onHalfClose() {
evictMethodDescriptor(methodName);
}
} finally {
if (payloadProcessor != null) {
ByteBuf data = null;
Metadata metadata = null;
if (response != null) {
data = response.data.readerIndex(respReaderIndex);
metadata = response.metadata;
}
processPayload(data, requestId, modelId, methodName, metadata, status, false);
}
ReleaseAfterResponse.releaseAll();
clearThreadLocals();
//TODO(maybe) additional trailer info in exception case?
Expand All @@ -757,6 +789,35 @@ public void onHalfClose() {
}
}

/**
* Invoke PayloadProcessor on the request/response data
* @param data the binary data
* @param payloadId the id of the request
* @param modelId the id of the model
* @param methodName the name of the invoked method
* @param metadata the method name metadata
* @param status null for requests, non-null for responses
* @param takeOwnership whether the processor should take ownership
*/
private void processPayload(ByteBuf data, String payloadId, String modelId, String methodName,
Metadata metadata, io.grpc.Status status, boolean takeOwnership) {
Payload payload = null;
try {
assert payloadProcessor != null;
if (!takeOwnership) {
data.retain();
}
payload = new Payload(payloadId, modelId, methodName, metadata, data, status);
if (payloadProcessor.process(payload)) {
data = null; // ownership transferred
}
} catch (Throwable t) {
logger.warn("Error while processing payload: {}", payload, t);
} finally {
ReferenceCountUtil.release(data);
}
}

@Override
public void onComplete() {
releaseReqMessage();
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/com/ibm/watson/modelmesh/ModelMeshEnvVars.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ public final class ModelMeshEnvVars {

private ModelMeshEnvVars() {}

public static final String MM_PAYLOAD_PROCESSORS = "MM_PAYLOAD_PROCESSORS";
public static final String MM_PAYLOAD_PROCESSORS_THREADS = "MM_PAYLOAD_PROCESSORS_THREADS";
public static final String MM_PAYLOAD_PROCESSORS_CAPACITY = "MM_PAYLOAD_PROCESSORS_CAPACITY";

// This must not be changed after model-mesh is already deployed to a particular env
public static final String KV_STORE_PREFIX = "MM_KVSTORE_PREFIX";

Expand Down
Loading