Skip to content

Commit

Permalink
Update SDK to use restatedev/service-protocol#58
Browse files Browse the repository at this point in the history
This commit updates the Java SDK to allow all completable journal entries
to have a failure variant. As part of this change, we also added tests to
ensure the correct behavior.

This fixes restatedev#187.
  • Loading branch information
tillrohrmann committed Dec 18, 2023
1 parent 54f0ce3 commit f1dd357
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 19 deletions.
36 changes: 29 additions & 7 deletions sdk-core/src/main/java/dev/restate/sdk/core/Entries.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,26 @@ public void trace(PollInputStreamEntryMessage expected, Span span) {

@Override
public boolean hasResult(PollInputStreamEntryMessage actual) {
return true;
return actual.getResultCase() != PollInputStreamEntryMessage.ResultCase.RESULT_NOT_SET;
}

@Override
public ReadyResultInternal<R> parseEntryResult(PollInputStreamEntryMessage actual) {
return valueParser.apply(actual.getValue());
if (actual.getResultCase() == PollInputStreamEntryMessage.ResultCase.VALUE) {
return valueParser.apply(actual.getValue());
} else if (actual.getResultCase() == PollInputStreamEntryMessage.ResultCase.FAILURE) {
return ReadyResults.failure(Util.toRestateException(actual.getFailure()));
} else {
throw new IllegalStateException("PollInputEntry has not been completed.");
}
}

@Override
public ReadyResultInternal<R> parseCompletionResult(CompletionMessage actual) {
if (actual.getResultCase() == CompletionMessage.ResultCase.VALUE) {
return valueParser.apply(actual.getValue());
} else if (actual.getResultCase() == CompletionMessage.ResultCase.FAILURE) {
return ReadyResults.failure(Util.toRestateException(actual.getFailure()));
}
return super.parseCompletionResult(actual);
}
Expand Down Expand Up @@ -126,17 +134,23 @@ void checkEntryHeader(GetStateEntryMessage expected, MessageLite actual)
public ReadyResultInternal<ByteString> parseEntryResult(GetStateEntryMessage actual) {
if (actual.getResultCase() == GetStateEntryMessage.ResultCase.VALUE) {
return ReadyResults.success(actual.getValue());
} else if (actual.getResultCase() == GetStateEntryMessage.ResultCase.FAILURE) {
return ReadyResults.failure(Util.toRestateException(actual.getFailure()));
} else if (actual.getResultCase() == GetStateEntryMessage.ResultCase.EMPTY) {
return ReadyResults.empty();
} else {
throw new IllegalStateException("GetStateEntry has not been completed.");
}
return ReadyResults.empty();
}

@Override
public ReadyResultInternal<ByteString> parseCompletionResult(CompletionMessage actual) {
if (actual.getResultCase() == CompletionMessage.ResultCase.VALUE) {
return ReadyResults.success(actual.getValue());
}
if (actual.getResultCase() == CompletionMessage.ResultCase.EMPTY) {
} else if (actual.getResultCase() == CompletionMessage.ResultCase.EMPTY) {
return ReadyResults.empty();
} else if (actual.getResultCase() == CompletionMessage.ResultCase.FAILURE) {
return ReadyResults.failure(Util.toRestateException(actual.getFailure()));
}
return super.parseCompletionResult(actual);
}
Expand Down Expand Up @@ -239,18 +253,26 @@ void trace(SleepEntryMessage expected, Span span) {

@Override
public boolean hasResult(SleepEntryMessage actual) {
return actual.hasResult();
return actual.getResultCase() != Protocol.SleepEntryMessage.ResultCase.RESULT_NOT_SET;
}

@Override
public ReadyResultInternal<Void> parseEntryResult(SleepEntryMessage actual) {
return ReadyResults.empty();
if (actual.getResultCase() == SleepEntryMessage.ResultCase.FAILURE) {
return ReadyResults.failure(Util.toRestateException(actual.getFailure()));
} else if (actual.getResultCase() == SleepEntryMessage.ResultCase.EMPTY) {
return ReadyResults.empty();
} else {
throw new IllegalStateException("SleepEntry has not been completed.");
}
}

@Override
public ReadyResultInternal<Void> parseCompletionResult(CompletionMessage actual) {
if (actual.getResultCase() == CompletionMessage.ResultCase.EMPTY) {
return ReadyResults.empty();
} else if (actual.getResultCase() == CompletionMessage.ResultCase.FAILURE) {
return ReadyResults.failure(Util.toRestateException(actual.getFailure()));
}
return super.parseCompletionResult(actual);
}
Expand Down
13 changes: 11 additions & 2 deletions sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,13 @@ public static MessageHeader fromMessage(MessageLite msg) {
} else if (msg instanceof Protocol.EntryAckMessage) {
return new MessageHeader(MessageType.EntryAckMessage, 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.PollInputStreamEntryMessage) {
return new MessageHeader(MessageType.PollInputStreamEntryMessage, 0, msg.getSerializedSize());
return new MessageHeader(
MessageType.PollInputStreamEntryMessage,
((Protocol.PollInputStreamEntryMessage) msg).getResultCase()
!= Protocol.PollInputStreamEntryMessage.ResultCase.RESULT_NOT_SET
? DONE_FLAG
: 0,
msg.getSerializedSize());
} else if (msg instanceof Protocol.OutputStreamEntryMessage) {
return new MessageHeader(MessageType.OutputStreamEntryMessage, 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.GetStateEntryMessage) {
Expand All @@ -81,7 +87,10 @@ public static MessageHeader fromMessage(MessageLite msg) {
} else if (msg instanceof Protocol.SleepEntryMessage) {
return new MessageHeader(
MessageType.SleepEntryMessage,
((Protocol.SleepEntryMessage) msg).hasResult() ? DONE_FLAG : 0,
((Protocol.SleepEntryMessage) msg).getResultCase()
!= Protocol.SleepEntryMessage.ResultCase.RESULT_NOT_SET
? DONE_FLAG
: 0,
msg.getSerializedSize());
} else if (msg instanceof Protocol.InvokeEntryMessage) {
return new MessageHeader(
Expand Down
21 changes: 16 additions & 5 deletions sdk-core/src/main/java/dev/restate/sdk/core/RestateServerCall.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import com.google.protobuf.MessageLite;
import dev.restate.sdk.common.TerminalException;
import dev.restate.sdk.common.syscalls.ReadyResult;
import dev.restate.sdk.common.syscalls.SyscallCallback;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
Expand Down Expand Up @@ -161,11 +162,21 @@ private void pollInput() {
() -> {
Objects.requireNonNull(listener);

// PollInput can only be result
MessageLite message = deferredValue.toReadyResult().getResult();

LOG.trace("Read input message:\n{}", message);
listener.invoke(message);
final ReadyResult<MessageLite> pollInputReadyResult =
deferredValue.toReadyResult();

if (pollInputReadyResult.isSuccess()) {
final MessageLite message = pollInputReadyResult.getResult();
LOG.trace("Read input message:\n{}", message);
listener.invoke(message);
} else {
final TerminalException failure = pollInputReadyResult.getFailure();
this.close(
Status.UNKNOWN
.withDescription(failure.getMessage())
.withCause(failure),
new Metadata());
}
},
this::onError)),
this::onError));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import static dev.restate.sdk.core.TestDefinitions.TestDefinition;
import static dev.restate.sdk.core.TestDefinitions.testInvocation;

import dev.restate.sdk.common.TerminalException;
import dev.restate.sdk.core.TestDefinitions.TestSuite;
import dev.restate.sdk.core.testservices.GreeterGrpc;
import dev.restate.sdk.core.testservices.GreetingRequest;
Expand All @@ -30,7 +31,12 @@ public Stream<TestDefinition> definitions() {
.withInput(
startMessage(1), inputMessage(GreetingRequest.newBuilder().setName("Francesco")))
.expectingOutput(
outputMessage(
GreetingResponse.newBuilder().setMessage("Hello Francesco").build())));
outputMessage(GreetingResponse.newBuilder().setMessage("Hello Francesco").build())),
testInvocation(this::noSyscallsGreeter, GreeterGrpc.getGreetMethod())
.withInput(
startMessage(1),
inputMessage(new TerminalException(TerminalException.Code.CANCELLED)))
.expectingOutput(
outputMessage(new TerminalException(TerminalException.Code.CANCELLED))));
}
}
10 changes: 10 additions & 0 deletions sdk-core/src/test/java/dev/restate/sdk/core/ProtoUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ public static Protocol.PollInputStreamEntryMessage inputMessage(MessageLiteOrBui
.build();
}

public static Protocol.PollInputStreamEntryMessage inputMessage(Throwable error) {
return Protocol.PollInputStreamEntryMessage.newBuilder()
.setFailure(Util.toProtocolFailure(error))
.build();
}

public static Protocol.OutputStreamEntryMessage outputMessage(MessageLiteOrBuilder value) {
return Protocol.OutputStreamEntryMessage.newBuilder()
.setValue(build(value).toByteString())
Expand All @@ -121,6 +127,10 @@ public static Protocol.GetStateEntryMessage.Builder getStateMessage(String key)
return Protocol.GetStateEntryMessage.newBuilder().setKey(ByteString.copyFromUtf8(key));
}

public static Protocol.GetStateEntryMessage.Builder getStateMessage(String key, Throwable error) {
return getStateMessage(key).setFailure(Util.toProtocolFailure(error));
}

public static Protocol.GetStateEntryMessage getStateEmptyMessage(String key) {
return Protocol.GetStateEntryMessage.newBuilder()
.setKey(ByteString.copyFromUtf8(key))
Expand Down
31 changes: 28 additions & 3 deletions sdk-core/src/test/java/dev/restate/sdk/core/SleepTestSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import com.google.protobuf.Empty;
import com.google.protobuf.MessageLiteOrBuilder;
import dev.restate.generated.service.protocol.Protocol;
import dev.restate.sdk.common.TerminalException;
import dev.restate.sdk.core.testservices.GreeterGrpc;
import dev.restate.sdk.core.testservices.GreetingRequest;
import dev.restate.sdk.core.testservices.GreetingResponse;
Expand Down Expand Up @@ -55,7 +56,7 @@ public Stream<TestDefinitions.TestDefinition> definitions() {
inputMessage(GreetingRequest.newBuilder().setName("Till")),
Protocol.SleepEntryMessage.newBuilder()
.setWakeUpTime(Instant.now().toEpochMilli())
.setResult(Empty.getDefaultInstance())
.setEmpty(Empty.getDefaultInstance())
.build())
.expectingOutput(
outputMessage(GreetingResponse.newBuilder().setMessage("Hello").build()))
Expand All @@ -81,13 +82,37 @@ public Stream<TestDefinitions.TestDefinition> definitions() {
(i % 3 == 0)
? Protocol.SleepEntryMessage.newBuilder()
.setWakeUpTime(Instant.now().toEpochMilli())
.setResult(Empty.getDefaultInstance())
.setEmpty(Empty.getDefaultInstance())
.build()
: Protocol.SleepEntryMessage.newBuilder()
.setWakeUpTime(Instant.now().toEpochMilli())
.build()))
.toArray(MessageLiteOrBuilder[]::new))
.expectingOutput(suspensionMessage(1, 2, 4, 5, 7, 8, 10))
.named("Sleep 1000 ms sleep completed"));
.named("Sleep 1000 ms sleep completed"),
testInvocation(this::sleepGreeter, GreeterGrpc.getGreetMethod())
.withInput(
startMessage(2),
inputMessage(GreetingRequest.newBuilder().setName("Till")),
Protocol.SleepEntryMessage.newBuilder()
.setWakeUpTime(Instant.now().toEpochMilli())
.setFailure(
Util.toProtocolFailure(TerminalException.Code.CANCELLED, "canceled"))
.build())
.expectingOutput(outputMessage(TerminalException.Code.CANCELLED, "canceled"))
.named("Failed sleep"),
testInvocation(this::sleepGreeter, GreeterGrpc.getGreetMethod())
.withInput(
startMessage(1),
inputMessage(GreetingRequest.newBuilder().setName("Till")),
completionMessage(
1, new TerminalException(TerminalException.Code.CANCELLED, "canceled")))
.assertingOutput(
messageLites -> {
assertThat(messageLites.get(0)).isInstanceOf(Protocol.SleepEntryMessage.class);
assertThat(messageLites.get(1))
.isEqualTo(outputMessage(TerminalException.Code.CANCELLED, "canceled"));
})
.named("Failing sleep"));
}
}
23 changes: 23 additions & 0 deletions sdk-core/src/test/java/dev/restate/sdk/core/StateTestSuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@

import static dev.restate.sdk.core.ProtoUtils.*;
import static dev.restate.sdk.core.TestDefinitions.testInvocation;
import static org.assertj.core.api.Assertions.assertThat;

import com.google.protobuf.Empty;
import dev.restate.generated.service.protocol.Protocol;
import dev.restate.sdk.common.TerminalException;
import dev.restate.sdk.core.testservices.GreeterGrpc;
import dev.restate.sdk.core.testservices.GreetingRequest;
import dev.restate.sdk.core.testservices.GreetingResponse;
Expand Down Expand Up @@ -76,6 +79,26 @@ public Stream<TestDefinitions.TestDefinition> definitions() {
getStateMessage("STATE"),
outputMessage(GreetingResponse.newBuilder().setMessage("Hello Francesco")))
.named("Without GetStateEntry and completed with later CompletionFrame"),
testInvocation(this::getState, GreeterGrpc.getGreetMethod())
.withInput(
startMessage(2),
inputMessage(GreetingRequest.newBuilder().setName("Till")),
getStateMessage("STATE", new TerminalException(TerminalException.Code.CANCELLED)))
.expectingOutput(outputMessage(new TerminalException(TerminalException.Code.CANCELLED)))
.named("Failed GetStateEntry"),
testInvocation(this::getState, GreeterGrpc.getGreetMethod())
.withInput(
startMessage(1),
inputMessage(GreetingRequest.newBuilder().setName("Till")),
completionMessage(1, new TerminalException(TerminalException.Code.CANCELLED)))
.assertingOutput(
messageLites -> {
assertThat(messageLites.get(0)).isInstanceOf(Protocol.GetStateEntryMessage.class);
assertThat(messageLites.get(1))
.isEqualTo(
outputMessage(new TerminalException(TerminalException.Code.CANCELLED)));
})
.named("Failing GetStateEntry"),
testInvocation(this::getAndSetState, GreeterGrpc.getGreetMethod())
.withInput(
startMessage(3),
Expand Down

0 comments on commit f1dd357

Please sign in to comment.