diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/MessageHeader.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/MessageHeader.java index 237ab493..9814fc35 100644 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/MessageHeader.java +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/MessageHeader.java @@ -14,7 +14,7 @@ public class MessageHeader { private final short flags; private final int length; - private MessageHeader(MessageType type, short flags, int length) { + public MessageHeader(MessageType type, short flags, int length) { this.type = type; this.flags = flags; this.length = length; @@ -53,13 +53,7 @@ public static MessageHeader parse(long encoded) throws ProtocolException { } public static MessageHeader fromMessage(MessageLite msg) { - if (msg instanceof Protocol.StartMessage) { - // We set PARTIAL_STATE here only for tests, in prod code this branch should be never hit. - return new MessageHeader( - MessageType.StartMessage, PARTIAL_STATE_FLAG, msg.getSerializedSize()); - } else if (msg instanceof Protocol.CompletionMessage) { - return new MessageHeader(MessageType.CompletionMessage, (short) 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.SuspensionMessage) { + if (msg instanceof Protocol.SuspensionMessage) { return new MessageHeader(MessageType.SuspensionMessage, (short) 0, msg.getSerializedSize()); } else if (msg instanceof Protocol.PollInputStreamEntryMessage) { return new MessageHeader( @@ -114,6 +108,10 @@ public static MessageHeader fromMessage(MessageLite msg) { } else if (msg instanceof Java.SideEffectEntryMessage) { return new MessageHeader( MessageType.SideEffectEntryMessage, REQUIRES_ACK_FLAG, msg.getSerializedSize()); + } else if (msg instanceof Protocol.StartMessage) { + throw new IllegalArgumentException("SDK should never send a StartMessage"); + } else if (msg instanceof Protocol.CompletionMessage) { + throw new IllegalArgumentException("SDK should never send a CompletionMessage"); } throw new IllegalStateException(); } diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/CoreTestRunner.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/CoreTestRunner.java index 78c275c9..5f311b34 100644 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/CoreTestRunner.java +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/CoreTestRunner.java @@ -1,6 +1,7 @@ package dev.restate.sdk.core.impl; import static dev.restate.sdk.core.impl.CoreTestRunner.TestCaseBuilder.testInvocation; +import static dev.restate.sdk.core.impl.ProtoUtils.headerFromMessage; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.params.provider.Arguments.arguments; @@ -169,7 +170,7 @@ WithInputBuilder withInput(short flags, MessageLiteOrBuilder msgOrBuilder) { return new WithInputBuilder( svc, method, - List.of(InvocationInput.of(MessageHeader.fromMessage(msg).copyWithFlags(flags), msg))); + List.of(InvocationInput.of(headerFromMessage(msg).copyWithFlags(flags), msg))); } WithInputBuilder withInput(MessageLiteOrBuilder... messages) { @@ -180,7 +181,7 @@ WithInputBuilder withInput(MessageLiteOrBuilder... messages) { .map( msgOrBuilder -> { MessageLite msg = ProtoUtils.build(msgOrBuilder); - return InvocationInput.of(MessageHeader.fromMessage(msg), msg); + return InvocationInput.of(headerFromMessage(msg), msg); }) .collect(Collectors.toList())); } @@ -196,8 +197,7 @@ static class WithInputBuilder extends TestInvocationBuilder { WithInputBuilder withInput(short flags, MessageLiteOrBuilder msgOrBuilder) { MessageLite msg = ProtoUtils.build(msgOrBuilder); - this.input.add( - InvocationInput.of(MessageHeader.fromMessage(msg).copyWithFlags(flags), msg)); + this.input.add(InvocationInput.of(headerFromMessage(msg).copyWithFlags(flags), msg)); return this; } @@ -207,7 +207,7 @@ WithInputBuilder withInput(MessageLiteOrBuilder... messages) { .map( msgOrBuilder -> { MessageLite msg = ProtoUtils.build(msgOrBuilder); - return InvocationInput.of(MessageHeader.fromMessage(msg), msg); + return InvocationInput.of(headerFromMessage(msg), msg); }) .collect(Collectors.toList())); return this; diff --git a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/ProtoUtils.java b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/ProtoUtils.java index e0aa5318..52eb6c2d 100644 --- a/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/ProtoUtils.java +++ b/sdk-core-impl/src/test/java/dev/restate/sdk/core/impl/ProtoUtils.java @@ -17,6 +17,20 @@ public class ProtoUtils { + /** + * Variant of {@link MessageHeader#fromMessage(MessageLite)} supporting StartMessage and + * CompletionMessage. + */ + public static MessageHeader headerFromMessage(MessageLite msg) { + if (msg instanceof Protocol.StartMessage) { + return new MessageHeader( + MessageType.StartMessage, MessageHeader.PARTIAL_STATE_FLAG, msg.getSerializedSize()); + } else if (msg instanceof Protocol.CompletionMessage) { + return new MessageHeader(MessageType.CompletionMessage, (short) 0, msg.getSerializedSize()); + } + return MessageHeader.fromMessage(msg); + } + public static Protocol.StartMessage.Builder startMessage(int entries) { return Protocol.StartMessage.newBuilder() .setInstanceKey(ByteString.copyFromUtf8("abc")) diff --git a/sdk-lambda/build.gradle.kts b/sdk-lambda/build.gradle.kts index 7f3394da..2e7eb713 100644 --- a/sdk-lambda/build.gradle.kts +++ b/sdk-lambda/build.gradle.kts @@ -29,6 +29,7 @@ dependencies { testImplementation(project(":sdk-blocking")) testImplementation(project(":sdk-kotlin")) + testImplementation(project(":sdk-core-impl", "testArchive")) testImplementation(testingLibs.junit.jupiter) testImplementation(testingLibs.assertj) diff --git a/sdk-lambda/src/test/java/dev/restate/sdk/lambda/LambdaHandlerTest.java b/sdk-lambda/src/test/java/dev/restate/sdk/lambda/LambdaHandlerTest.java index d9c29689..24f92449 100644 --- a/sdk-lambda/src/test/java/dev/restate/sdk/lambda/LambdaHandlerTest.java +++ b/sdk-lambda/src/test/java/dev/restate/sdk/lambda/LambdaHandlerTest.java @@ -15,6 +15,7 @@ import dev.restate.generated.service.discovery.Discovery; import dev.restate.generated.service.protocol.Protocol; import dev.restate.sdk.core.impl.MessageHeader; +import dev.restate.sdk.core.impl.ProtoUtils; import dev.restate.sdk.lambda.testservices.CounterRequest; import dev.restate.sdk.lambda.testservices.JavaCounterGrpc; import dev.restate.sdk.lambda.testservices.KotlinCounterGrpc; @@ -109,7 +110,7 @@ private static byte[] serializeEntries(MessageLite... msgs) throws IOException { ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); for (MessageLite msg : msgs) { ByteBuffer headerBuf = ByteBuffer.allocate(8); - headerBuf.putLong(MessageHeader.fromMessage(msg).encode()); + headerBuf.putLong(ProtoUtils.headerFromMessage(msg).encode()); outputStream.write(headerBuf.array()); msg.writeTo(outputStream); } diff --git a/sdk-testing/src/main/java/dev/restate/sdk/testing/PublishSubscription.java b/sdk-testing/src/main/java/dev/restate/sdk/testing/PublishSubscription.java index ca72a71e..b7cc3fa8 100644 --- a/sdk-testing/src/main/java/dev/restate/sdk/testing/PublishSubscription.java +++ b/sdk-testing/src/main/java/dev/restate/sdk/testing/PublishSubscription.java @@ -1,8 +1,10 @@ package dev.restate.sdk.testing; import com.google.protobuf.MessageLite; +import dev.restate.generated.service.protocol.Protocol; import dev.restate.sdk.core.impl.InvocationFlow; import dev.restate.sdk.core.impl.MessageHeader; +import dev.restate.sdk.core.impl.MessageType; import java.util.Queue; import java.util.concurrent.Flow; @@ -21,10 +23,21 @@ class PublishSubscription implements Flow.Subscription { public void request(long l) { while (l != 0 && !this.queue.isEmpty()) { MessageLite msg = queue.remove(); - subscriber.onNext(InvocationFlow.InvocationInput.of(MessageHeader.fromMessage(msg), msg)); + MessageHeader header = headerFromMessage(msg); + subscriber.onNext(InvocationFlow.InvocationInput.of(header, msg)); } } @Override public void cancel() {} + + static MessageHeader headerFromMessage(MessageLite msg) { + if (msg instanceof Protocol.StartMessage) { + return new MessageHeader( + MessageType.StartMessage, MessageHeader.PARTIAL_STATE_FLAG, msg.getSerializedSize()); + } else if (msg instanceof Protocol.CompletionMessage) { + return new MessageHeader(MessageType.CompletionMessage, (short) 0, msg.getSerializedSize()); + } + return MessageHeader.fromMessage(msg); + } } diff --git a/sdk-testing/src/main/java/dev/restate/sdk/testing/TestRestateRuntime.java b/sdk-testing/src/main/java/dev/restate/sdk/testing/TestRestateRuntime.java index a99b79da..39360361 100644 --- a/sdk-testing/src/main/java/dev/restate/sdk/testing/TestRestateRuntime.java +++ b/sdk-testing/src/main/java/dev/restate/sdk/testing/TestRestateRuntime.java @@ -191,7 +191,7 @@ class InvocationProcessor private final String serviceName; private final String instanceKey; private final String functionInvocationId; - private final Collection elements; + private final Collection inputMessages; private Flow.Subscriber publisher; // publisher = ExceptionCatchingInvocationInputSubscriber @@ -209,11 +209,11 @@ public InvocationProcessor( String functionInvocationId, String serviceName, String instanceKey, - Collection elements) { + Collection inputMessages) { this.serviceName = serviceName; this.instanceKey = instanceKey; this.functionInvocationId = functionInvocationId; - this.elements = elements; + this.inputMessages = inputMessages; } // PUBLISHER LOGIC: to send messages to the service @@ -222,7 +222,7 @@ public InvocationProcessor( public void subscribe(Flow.Subscriber publisher) { this.publisher = publisher; this.currentJournalIndex = 0; - this.outputSubscription = new PublishSubscription(publisher, new ArrayDeque<>(elements)); + this.outputSubscription = new PublishSubscription(publisher, new ArrayDeque<>(inputMessages)); publisher.onSubscribe(this.outputSubscription); } @@ -266,7 +266,8 @@ public void onComplete() { private void routeMessage(MessageLite t) { if (t instanceof Protocol.CompletionMessage) { LOG.trace("Sending completion message"); - publisher.onNext(InvocationFlow.InvocationInput.of(MessageHeader.fromMessage(t), t)); + publisher.onNext( + InvocationFlow.InvocationInput.of(PublishSubscription.headerFromMessage(t), t)); } else if (t instanceof Protocol.PollInputStreamEntryMessage) { LOG.trace("Sending poll input stream message"); publisher.onNext(InvocationFlow.InvocationInput.of(MessageHeader.fromMessage(t), t)); diff --git a/sdk-vertx/src/test/kotlin/dev/restate/sdk/vertx/RestateHttpEndpointTest.kt b/sdk-vertx/src/test/kotlin/dev/restate/sdk/vertx/RestateHttpEndpointTest.kt index c9476de4..a5704762 100644 --- a/sdk-vertx/src/test/kotlin/dev/restate/sdk/vertx/RestateHttpEndpointTest.kt +++ b/sdk-vertx/src/test/kotlin/dev/restate/sdk/vertx/RestateHttpEndpointTest.kt @@ -73,10 +73,8 @@ internal class RestateHttpEndpointTest { request.setChunked(true).putHeader(HttpHeaders.CONTENT_TYPE, "application/restate") // Send start message and PollInputStreamEntry - request.write(MessageEncoder.encode(Buffer.buffer(), startMessage(1).build())) - request.write( - MessageEncoder.encode( - Buffer.buffer(), inputMessage(greetingRequest { name = "Francesco" }))) + request.write(encode(startMessage(1).build())) + request.write(encode(inputMessage(greetingRequest { name = "Francesco" }))) val response = request.response().await() @@ -97,7 +95,7 @@ internal class RestateHttpEndpointTest { .returns(ByteString.copyFromUtf8("counter"), GetStateEntryMessage::getKey) // Send completion - request.write(MessageEncoder.encode(Buffer.buffer(), completionMessage(1, "2"))) + request.write(encode(completionMessage(1, "2"))) // Wait for Set State Entry val setStateEntry = inputChannel.receive() @@ -145,7 +143,7 @@ internal class RestateHttpEndpointTest { // Prepare request header request.setChunked(true).putHeader(HttpHeaders.CONTENT_TYPE, "application/restate") - request.write(MessageEncoder.encode(Buffer.buffer(), startMessage(0).build())) + request.write(encode(startMessage(0).build())) val response = request.response().await() @@ -193,4 +191,14 @@ internal class RestateHttpEndpointTest { .containsExactlyInAnyOrder( "dev/restate/ext.proto", "google/protobuf/descriptor.proto", "greeter.proto") } + + fun encode(msg: MessageLite): Buffer { + val buffer = Buffer.buffer(MessageEncoder.encodeLength(msg)) + val header = headerFromMessage(msg) + + buffer.appendLong(header.encode()) + buffer.appendBytes(msg.toByteArray()) + + return buffer + } }