Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let RemotePayloadProcessor send also gRPC call Metadata #96

Merged
merged 11 commits into from
May 26, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.grpc.Metadata;
import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.base64.Base64;
import org.slf4j.Logger;
Expand Down Expand Up @@ -58,16 +61,32 @@ private static PayloadContent prepareContentBody(Payload payload) {
ByteBuf byteBuf = payload.getData();
String data;
if (byteBuf != null) {
ByteBuf encoded = Base64.encode(byteBuf, byteBuf.readerIndex(), byteBuf.readableBytes(), false);
//TODO custom jackson serialization for this field to avoid round-tripping to string
data = encoded.toString(StandardCharsets.US_ASCII);
data = encodeBinaryToString(byteBuf);
} else {
data = "";
}
Metadata metadata = payload.getMetadata();
Map<String, String> metadataMap = new HashMap<>();
if (metadata != null) {
for (String key : metadata.keys()) {
if (key.endsWith("-bin")) {
byte[] bytes = metadata.get(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER));
metadataMap.put(key, java.util.Base64.getEncoder().encodeToString(bytes));
} else {
String value = metadata.get(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER));
metadataMap.put(key, value);
}
}
}
String status = payload.getStatus() != null ? payload.getStatus().getCode().toString() : "";
return new PayloadContent(id, modelId, data, kind, status);
return new PayloadContent(id, modelId, data, kind, status, metadataMap);
}

private static String encodeBinaryToString(ByteBuf byteBuf) {
ByteBuf encodedBinary = Base64.encode(byteBuf, byteBuf.readerIndex(), byteBuf.readableBytes(), false);
//TODO custom jackson serialization for this field to avoid round-tripping to string
return encodedBinary.toString(StandardCharsets.US_ASCII);
}

private boolean sendPayload(Payload payload) {
try {
Expand All @@ -94,18 +113,22 @@ public String getName() {
}

private static class PayloadContent {

private final String id;
private final String modelid;
private final String data;
private final String kind;
private final String status;
private final Map<String, String> metadata;

private PayloadContent(String id, String modelid, String data, String kind, String status) {
private PayloadContent(String id, String modelid, String data, String kind, String status,
Map<String, String> metadata) {
this.id = id;
this.modelid = modelid;
this.data = data;
this.kind = kind;
this.status = status;
this.metadata = metadata;
}

public String getId() {
Expand All @@ -128,6 +151,10 @@ public String getStatus() {
return status;
}

public Map<String, String> getMetadata() {
return metadata;
}

@Override
public String toString() {
return "PayloadContent{" +
Expand All @@ -136,6 +163,7 @@ public String toString() {
", data='" + data + '\'' +
", kind='" + kind + '\'' +
", status='" + status + '\'' +
", metadata='" + metadata + '\'' +
'}';
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public class SidecarModelMeshPayloadProcessingTest extends SingleInstanceModelMe

@BeforeEach
public void initialize() throws Exception {
System.setProperty(ModelMeshEnvVars.MM_PAYLOAD_PROCESSORS, "logger://*");
System.setProperty(ModelMeshEnvVars.MM_PAYLOAD_PROCESSORS, "http://localhost:8080/consumer/kserve/v2");
super.initialize();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ void testDestinationUnreachable() {
String method = "predict";
Status kind = Status.INVALID_ARGUMENT;
Metadata metadata = new Metadata();
metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar");
metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes());
ByteBuf data = Unpooled.buffer(4);
Payload payload = new Payload(id, modelId, method, metadata, data, kind);
assertFalse(remotePayloadProcessor.process(payload));
Expand Down