Skip to content

Commit

Permalink
feat: Add vModelId to PayloadProcessor Payload
Browse files Browse the repository at this point in the history
Motivation

Currently the payloads passed to PayloadProcessors only contain the modelId, which in the case of vModels will be a "resolved" modelId corresponding to a particular model revision (in particular this will be true when used in KServe modelmesh-serving). It would be useful to include the vModelId too.

Modifications

Add a vModelId field to the Payload class and correspondingly update built-in PayloadProcessor implementations where applicable.

It may be null if the request was directed at a concrete modelId rather than a vModelId.

Result

Both modelId and vModelId are available to PayloadProcessors

Signed-off-by: Nick Hill <nickhill@us.ibm.com>
  • Loading branch information
njhill committed Nov 16, 2023
1 parent 8b36883 commit 2e4d1d6
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 31 deletions.
8 changes: 4 additions & 4 deletions src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ public void onHalfClose() {
} finally {
if (payloadProcessor != null) {
processPayload(reqMessage.readerIndex(reqReaderIndex),
requestId, resolvedModelId, methodName, headers, null, true);
requestId, resolvedModelId, vModelId, methodName, headers, null, true);
} else {
releaseReqMessage();
}
Expand Down Expand Up @@ -803,7 +803,7 @@ public void onHalfClose() {
data = response.data.readerIndex(respReaderIndex);
metadata = response.metadata;
}
processPayload(data, requestId, resolvedModelId, methodName, metadata, status, releaseResponse);
processPayload(data, requestId, resolvedModelId, vModelId, methodName, metadata, status, releaseResponse);
} else if (releaseResponse && response != null) {
response.release();
}
Expand All @@ -829,15 +829,15 @@ public void onHalfClose() {
* @param status null for requests, non-null for responses
* @param takeOwnership whether the processor should take ownership
*/
private void processPayload(ByteBuf data, String payloadId, String modelId, String methodName,
private void processPayload(ByteBuf data, String payloadId, String vModelId, String modelId, String methodName,
Metadata metadata, io.grpc.Status status, boolean takeOwnership) {
Payload payload = null;
try {
assert payloadProcessor != null;
if (!takeOwnership) {
ReferenceCountUtil.retain(data);
}
payload = new Payload(payloadId, modelId, methodName, metadata, data, status);
payload = new Payload(payloadId, modelId, vModelId, methodName, metadata, data, status);
if (payloadProcessor.process(payload)) {
data = null; // ownership transferred
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.ibm.watson.modelmesh.payload;

import java.io.IOException;
import java.util.Objects;

/**
* A {@link PayloadProcessor} that processes {@link Payload}s only if they match with given model ID or method name.
Expand All @@ -29,10 +30,13 @@ public class MatchingPayloadProcessor implements PayloadProcessor {

private final String modelId;

MatchingPayloadProcessor(PayloadProcessor delegate, String methodName, String modelId) {
private final String vModelId;

MatchingPayloadProcessor(PayloadProcessor delegate, String methodName, String modelId, String vModelId) {
this.delegate = delegate;
this.methodName = methodName;
this.modelId = modelId;
this.vModelId = vModelId;
}

@Override
Expand All @@ -42,40 +46,49 @@ public String getName() {

@Override
public boolean process(Payload payload) {
boolean processed = false;
boolean methodMatches = true;
if (this.methodName != null) {
methodMatches = payload.getMethod() != null && this.methodName.equals(payload.getMethod());
}
boolean methodMatches = this.methodName == null || Objects.equals(this.methodName, payload.getMethod());
if (methodMatches) {
boolean modelIdMatches = true;
if (this.modelId != null) {
modelIdMatches = this.modelId.equals(payload.getModelId());
}
boolean modelIdMatches = this.modelId == null || this.modelId.equals(payload.getModelId());
if (modelIdMatches) {
processed = delegate.process(payload);
boolean vModelIdMatches = this.vModelId == null || this.vModelId.equals(payload.getVModelId());
if (vModelIdMatches) {
return delegate.process(payload);
}
}
}
return processed;
return false;
}

public static MatchingPayloadProcessor from(String modelId, String method, PayloadProcessor processor) {
return from(modelId, null, method, processor);
}

public static MatchingPayloadProcessor from(String modelId, String vModelId,
String method, PayloadProcessor processor) {
if (modelId != null) {
if (modelId.length() > 0) {
if (!modelId.isEmpty()) {
modelId = modelId.replaceFirst("/", "");
if (modelId.length() == 0 || modelId.equals("*")) {
if (modelId.isEmpty() || modelId.equals("*")) {
modelId = null;
}
} else {
modelId = null;
}
}
if (method != null) {
if (method.length() == 0 || method.equals("*")) {
method = null;
if (vModelId != null) {
if (!vModelId.isEmpty()) {
vModelId = vModelId.replaceFirst("/", "");
if (vModelId.isEmpty() || vModelId.equals("*")) {
vModelId = null;
}
} else {
vModelId = null;
}
}
return new MatchingPayloadProcessor(processor, method, modelId);
if (method != null && (method.isEmpty() || method.equals("*"))) {
method = null;
}
return new MatchingPayloadProcessor(processor, method, modelId, vModelId);
}

@Override
Expand Down
20 changes: 20 additions & 0 deletions src/main/java/com/ibm/watson/modelmesh/payload/Payload.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ public enum Kind {

private final String modelId;

private final String vModelId;

private final String method;

private final Metadata metadata;
Expand All @@ -48,10 +50,17 @@ public enum Kind {
// null for requests, non-null for responses
private final Status status;


public Payload(@Nonnull String id, @Nonnull String modelId, @Nullable String method, @Nullable Metadata metadata,
@Nullable ByteBuf data, @Nullable Status status) {
this(id, modelId, null, method, metadata, data, status);
}

public Payload(@Nonnull String id, @Nonnull String modelId, @Nullable String vModelId, @Nullable String method,
@Nullable Metadata metadata, @Nullable ByteBuf data, @Nullable Status status) {
this.id = id;
this.modelId = modelId;
this.vModelId = vModelId;
this.method = method;
this.metadata = metadata;
this.data = data;
Expand All @@ -68,6 +77,16 @@ public String getModelId() {
return modelId;
}

@CheckForNull
public String getVModelId() {
return vModelId;
}

@Nonnull
public String getVModelIdOrModelId() {
return vModelId != null ? vModelId : modelId;
}

@CheckForNull
public String getMethod() {
return method;
Expand Down Expand Up @@ -101,6 +120,7 @@ public void release() {
public String toString() {
return "Payload{" +
"id='" + id + '\'' +
", vModelId=" + (vModelId != null ? ('\'' + vModelId + '\'') : "null") +
", modelId='" + modelId + '\'' +
", method='" + method + '\'' +
", status=" + (status == null ? "request" : String.valueOf(status)) +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,10 @@ public boolean process(Payload payload) {
private static PayloadContent prepareContentBody(Payload payload) {
String id = payload.getId();
String modelId = payload.getModelId();
String vModelId = payload.getVModelId();
String kind = payload.getKind().toString().toLowerCase();
ByteBuf byteBuf = payload.getData();
String data;
if (byteBuf != null) {
data = encodeBinaryToString(byteBuf);
} else {
data = "";
}
String data = byteBuf != null ? encodeBinaryToString(byteBuf) : "";
Metadata metadata = payload.getMetadata();
Map<String, String> metadataMap = new HashMap<>();
if (metadata != null) {
Expand All @@ -79,7 +75,7 @@ private static PayloadContent prepareContentBody(Payload payload) {
}
}
String status = payload.getStatus() != null ? payload.getStatus().getCode().toString() : "";
return new PayloadContent(id, modelId, data, kind, status, metadataMap);
return new PayloadContent(id, modelId, vModelId, data, kind, status, metadataMap);
}

private static String encodeBinaryToString(ByteBuf byteBuf) {
Expand Down Expand Up @@ -116,15 +112,17 @@ private static class PayloadContent {

private final String id;
private final String modelid;
private final String vModelId;
private final String data;
private final String kind;
private final String status;
private final Map<String, String> metadata;

private PayloadContent(String id, String modelid, String data, String kind, String status,
Map<String, String> metadata) {
private PayloadContent(String id, String modelid, String vModelId, String data, String kind,
String status, Map<String, String> metadata) {
this.id = id;
this.modelid = modelid;
this.vModelId = vModelId;
this.data = data;
this.kind = kind;
this.status = status;
Expand All @@ -143,6 +141,10 @@ public String getModelid() {
return modelid;
}

public String getvModelId() {
return vModelId;
}

public String getData() {
return data;
}
Expand All @@ -160,6 +162,7 @@ public String toString() {
return "PayloadContent{" +
"id='" + id + '\'' +
", modelid='" + modelid + '\'' +
", vModelId=" + (vModelId != null ? ('\'' + vModelId + '\'') : "null") +
", data='" + data + '\'' +
", kind='" + kind + '\'' +
", status='" + status + '\'' +
Expand Down

0 comments on commit 2e4d1d6

Please sign in to comment.