Skip to content

Commit

Permalink
Modified the MessageHeader#fromMessage to be safer.
Browse files Browse the repository at this point in the history
  • Loading branch information
slinkydeveloper committed Jun 5, 2023
1 parent 7ca190b commit 2bb068b
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public class MessageHeader {
private final short flags;
private final int length;

private MessageHeader(MessageType type, short flags, int length) {
public MessageHeader(MessageType type, short flags, int length) {
this.type = type;
this.flags = flags;
this.length = length;
Expand Down Expand Up @@ -53,13 +53,7 @@ public static MessageHeader parse(long encoded) throws ProtocolException {
}

public static MessageHeader fromMessage(MessageLite msg) {
if (msg instanceof Protocol.StartMessage) {
// We set PARTIAL_STATE here only for tests, in prod code this branch should be never hit.
return new MessageHeader(
MessageType.StartMessage, PARTIAL_STATE_FLAG, msg.getSerializedSize());
} else if (msg instanceof Protocol.CompletionMessage) {
return new MessageHeader(MessageType.CompletionMessage, (short) 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.SuspensionMessage) {
if (msg instanceof Protocol.SuspensionMessage) {
return new MessageHeader(MessageType.SuspensionMessage, (short) 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.PollInputStreamEntryMessage) {
return new MessageHeader(
Expand Down Expand Up @@ -114,6 +108,10 @@ public static MessageHeader fromMessage(MessageLite msg) {
} else if (msg instanceof Java.SideEffectEntryMessage) {
return new MessageHeader(
MessageType.SideEffectEntryMessage, REQUIRES_ACK_FLAG, msg.getSerializedSize());
} else if (msg instanceof Protocol.StartMessage) {
throw new IllegalArgumentException("SDK should never send a StartMessage");
} else if (msg instanceof Protocol.CompletionMessage) {
throw new IllegalArgumentException("SDK should never send a CompletionMessage");
}
throw new IllegalStateException();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dev.restate.sdk.core.impl;

import static dev.restate.sdk.core.impl.CoreTestRunner.TestCaseBuilder.testInvocation;
import static dev.restate.sdk.core.impl.ProtoUtils.headerFromMessage;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.params.provider.Arguments.arguments;

Expand Down Expand Up @@ -169,7 +170,7 @@ WithInputBuilder withInput(short flags, MessageLiteOrBuilder msgOrBuilder) {
return new WithInputBuilder(
svc,
method,
List.of(InvocationInput.of(MessageHeader.fromMessage(msg).copyWithFlags(flags), msg)));
List.of(InvocationInput.of(headerFromMessage(msg).copyWithFlags(flags), msg)));
}

WithInputBuilder withInput(MessageLiteOrBuilder... messages) {
Expand All @@ -180,7 +181,7 @@ WithInputBuilder withInput(MessageLiteOrBuilder... messages) {
.map(
msgOrBuilder -> {
MessageLite msg = ProtoUtils.build(msgOrBuilder);
return InvocationInput.of(MessageHeader.fromMessage(msg), msg);
return InvocationInput.of(headerFromMessage(msg), msg);
})
.collect(Collectors.toList()));
}
Expand All @@ -196,8 +197,7 @@ static class WithInputBuilder extends TestInvocationBuilder {

WithInputBuilder withInput(short flags, MessageLiteOrBuilder msgOrBuilder) {
MessageLite msg = ProtoUtils.build(msgOrBuilder);
this.input.add(
InvocationInput.of(MessageHeader.fromMessage(msg).copyWithFlags(flags), msg));
this.input.add(InvocationInput.of(headerFromMessage(msg).copyWithFlags(flags), msg));
return this;
}

Expand All @@ -207,7 +207,7 @@ WithInputBuilder withInput(MessageLiteOrBuilder... messages) {
.map(
msgOrBuilder -> {
MessageLite msg = ProtoUtils.build(msgOrBuilder);
return InvocationInput.of(MessageHeader.fromMessage(msg), msg);
return InvocationInput.of(headerFromMessage(msg), msg);
})
.collect(Collectors.toList()));
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@

public class ProtoUtils {

/**
* Variant of {@link MessageHeader#fromMessage(MessageLite)} supporting StartMessage and
* CompletionMessage.
*/
public static MessageHeader headerFromMessage(MessageLite msg) {
if (msg instanceof Protocol.StartMessage) {
return new MessageHeader(
MessageType.StartMessage, MessageHeader.PARTIAL_STATE_FLAG, msg.getSerializedSize());
} else if (msg instanceof Protocol.CompletionMessage) {
return new MessageHeader(MessageType.CompletionMessage, (short) 0, msg.getSerializedSize());
}
return MessageHeader.fromMessage(msg);
}

public static Protocol.StartMessage.Builder startMessage(int entries) {
return Protocol.StartMessage.newBuilder()
.setInstanceKey(ByteString.copyFromUtf8("abc"))
Expand Down
1 change: 1 addition & 0 deletions sdk-lambda/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies {

testImplementation(project(":sdk-blocking"))
testImplementation(project(":sdk-kotlin"))
testImplementation(project(":sdk-core-impl", "testArchive"))
testImplementation(testingLibs.junit.jupiter)
testImplementation(testingLibs.assertj)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import dev.restate.generated.service.discovery.Discovery;
import dev.restate.generated.service.protocol.Protocol;
import dev.restate.sdk.core.impl.MessageHeader;
import dev.restate.sdk.core.impl.ProtoUtils;
import dev.restate.sdk.lambda.testservices.CounterRequest;
import dev.restate.sdk.lambda.testservices.JavaCounterGrpc;
import dev.restate.sdk.lambda.testservices.KotlinCounterGrpc;
Expand Down Expand Up @@ -109,7 +110,7 @@ private static byte[] serializeEntries(MessageLite... msgs) throws IOException {
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
for (MessageLite msg : msgs) {
ByteBuffer headerBuf = ByteBuffer.allocate(8);
headerBuf.putLong(MessageHeader.fromMessage(msg).encode());
headerBuf.putLong(ProtoUtils.headerFromMessage(msg).encode());
outputStream.write(headerBuf.array());
msg.writeTo(outputStream);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package dev.restate.sdk.testing;

import com.google.protobuf.MessageLite;
import dev.restate.generated.service.protocol.Protocol;
import dev.restate.sdk.core.impl.InvocationFlow;
import dev.restate.sdk.core.impl.MessageHeader;
import dev.restate.sdk.core.impl.MessageType;
import java.util.Queue;
import java.util.concurrent.Flow;

Expand All @@ -21,10 +23,21 @@ class PublishSubscription implements Flow.Subscription {
public void request(long l) {
while (l != 0 && !this.queue.isEmpty()) {
MessageLite msg = queue.remove();
subscriber.onNext(InvocationFlow.InvocationInput.of(MessageHeader.fromMessage(msg), msg));
MessageHeader header = headerFromMessage(msg);
subscriber.onNext(InvocationFlow.InvocationInput.of(header, msg));
}
}

@Override
public void cancel() {}

static MessageHeader headerFromMessage(MessageLite msg) {
if (msg instanceof Protocol.StartMessage) {
return new MessageHeader(
MessageType.StartMessage, MessageHeader.PARTIAL_STATE_FLAG, msg.getSerializedSize());
} else if (msg instanceof Protocol.CompletionMessage) {
return new MessageHeader(MessageType.CompletionMessage, (short) 0, msg.getSerializedSize());
}
return MessageHeader.fromMessage(msg);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class InvocationProcessor
private final String serviceName;
private final String instanceKey;
private final String functionInvocationId;
private final Collection<MessageLite> elements;
private final Collection<MessageLite> inputMessages;

private Flow.Subscriber<? super InvocationFlow.InvocationInput>
publisher; // publisher = ExceptionCatchingInvocationInputSubscriber
Expand All @@ -209,11 +209,11 @@ public InvocationProcessor(
String functionInvocationId,
String serviceName,
String instanceKey,
Collection<MessageLite> elements) {
Collection<MessageLite> inputMessages) {
this.serviceName = serviceName;
this.instanceKey = instanceKey;
this.functionInvocationId = functionInvocationId;
this.elements = elements;
this.inputMessages = inputMessages;
}

// PUBLISHER LOGIC: to send messages to the service
Expand All @@ -222,7 +222,7 @@ public InvocationProcessor(
public void subscribe(Flow.Subscriber<? super InvocationFlow.InvocationInput> publisher) {
this.publisher = publisher;
this.currentJournalIndex = 0;
this.outputSubscription = new PublishSubscription(publisher, new ArrayDeque<>(elements));
this.outputSubscription = new PublishSubscription(publisher, new ArrayDeque<>(inputMessages));

publisher.onSubscribe(this.outputSubscription);
}
Expand Down Expand Up @@ -266,7 +266,8 @@ public void onComplete() {
private void routeMessage(MessageLite t) {
if (t instanceof Protocol.CompletionMessage) {
LOG.trace("Sending completion message");
publisher.onNext(InvocationFlow.InvocationInput.of(MessageHeader.fromMessage(t), t));
publisher.onNext(
InvocationFlow.InvocationInput.of(PublishSubscription.headerFromMessage(t), t));
} else if (t instanceof Protocol.PollInputStreamEntryMessage) {
LOG.trace("Sending poll input stream message");
publisher.onNext(InvocationFlow.InvocationInput.of(MessageHeader.fromMessage(t), t));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,8 @@ internal class RestateHttpEndpointTest {
request.setChunked(true).putHeader(HttpHeaders.CONTENT_TYPE, "application/restate")

// Send start message and PollInputStreamEntry
request.write(MessageEncoder.encode(Buffer.buffer(), startMessage(1).build()))
request.write(
MessageEncoder.encode(
Buffer.buffer(), inputMessage(greetingRequest { name = "Francesco" })))
request.write(encode(startMessage(1).build()))
request.write(encode(inputMessage(greetingRequest { name = "Francesco" })))

val response = request.response().await()

Expand All @@ -97,7 +95,7 @@ internal class RestateHttpEndpointTest {
.returns(ByteString.copyFromUtf8("counter"), GetStateEntryMessage::getKey)

// Send completion
request.write(MessageEncoder.encode(Buffer.buffer(), completionMessage(1, "2")))
request.write(encode(completionMessage(1, "2")))

// Wait for Set State Entry
val setStateEntry = inputChannel.receive()
Expand Down Expand Up @@ -145,7 +143,7 @@ internal class RestateHttpEndpointTest {

// Prepare request header
request.setChunked(true).putHeader(HttpHeaders.CONTENT_TYPE, "application/restate")
request.write(MessageEncoder.encode(Buffer.buffer(), startMessage(0).build()))
request.write(encode(startMessage(0).build()))

val response = request.response().await()

Expand Down Expand Up @@ -193,4 +191,14 @@ internal class RestateHttpEndpointTest {
.containsExactlyInAnyOrder(
"dev/restate/ext.proto", "google/protobuf/descriptor.proto", "greeter.proto")
}

fun encode(msg: MessageLite): Buffer {
val buffer = Buffer.buffer(MessageEncoder.encodeLength(msg))
val header = headerFromMessage(msg)

buffer.appendLong(header.encode())
buffer.appendBytes(msg.toByteArray())

return buffer
}
}

0 comments on commit 2bb068b

Please sign in to comment.