diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt index f56ad3ea..9c98bcaa 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt @@ -130,11 +130,16 @@ internal class ContextImpl internal constructor(private val syscalls: Syscalls) } } - override suspend fun runBlock(serde: Serde, block: suspend () -> T): T { + override suspend fun runBlock( + serde: Serde, + name: String, + block: suspend () -> T + ): T { val exitResult = suspendCancellableCoroutine { cont: CancellableContinuation> -> syscalls.enterSideEffectBlock( + name, object : EnterSideEffectSyscallCallback { override fun onSuccess(t: ByteString?) { val deferred: CompletableDeferred = CompletableDeferred() diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt index f5d38328..a5d50eba 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt @@ -106,6 +106,9 @@ sealed interface Context { * suspension point) without re-executing the closure. Use this feature if you want to perform * non-deterministic operations. * + * You can name this closure using the `name` parameter. This name will be available in the + * observability tools. + * *

The closure should tolerate retries, that is Restate might re-execute the closure multiple * times until it records a result. * @@ -138,11 +141,12 @@ sealed interface Context { * To propagate failures to the run call-site, make sure to wrap them in [TerminalException]. * * @param serde the type tag of the return value, used to serialize/deserialize it. + * @param name the name of the side effect. * @param block closure to execute. * @param T type of the return value. * @return value of the runBlock operation. */ - suspend fun runBlock(serde: Serde, block: suspend () -> T): T + suspend fun runBlock(serde: Serde, name: String = "", block: suspend () -> T): T /** * Create an [Awakeable], addressable through [Awakeable.id]. @@ -221,8 +225,11 @@ sealed interface Context { * @param T type of the return value. * @return value of the runBlock operation. */ -suspend inline fun Context.runBlock(noinline block: suspend () -> T): T { - return this.runBlock(KtSerdes.json(), block) +suspend inline fun Context.runBlock( + name: String = "", + noinline block: suspend () -> T +): T { + return this.runBlock(KtSerdes.json(), name, block) } /** diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SideEffectTest.kt b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SideEffectTest.kt index a3d6f70b..a3afc544 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SideEffectTest.kt +++ b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SideEffectTest.kt @@ -26,6 +26,12 @@ class SideEffectTest : SideEffectTestSuite() { "Hello $result" } + override fun namedSideEffect(name: String, sideEffectOutput: String): TestInvocationBuilder = + testDefinitionForService("SideEffect") { ctx, _: Unit -> + val result = ctx.runBlock(name) { sideEffectOutput } + "Hello $result" + } + override fun consecutiveSideEffect(sideEffectOutput: String): TestInvocationBuilder = testDefinitionForService("ConsecutiveSideEffect") { ctx, _: Unit -> val firstResult = ctx.runBlock { sideEffectOutput } @@ -54,4 +60,9 @@ class SideEffectTest : SideEffectTestSuite() { ctx.runBlock { ctx.send(GREETER_SERVICE_TARGET, KtSerdes.json(), "something") } throw IllegalStateException("This point should not be reached") } + + override fun failingSideEffect(name: String, reason: String): TestInvocationBuilder = + testDefinitionForService("FailingSideEffect") { ctx, _: Unit -> + ctx.runBlock(name) { throw IllegalStateException(reason) } + } } diff --git a/sdk-api/src/main/java/dev/restate/sdk/Context.java b/sdk-api/src/main/java/dev/restate/sdk/Context.java index 171b5d3e..dd01b59b 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/Context.java +++ b/sdk-api/src/main/java/dev/restate/sdk/Context.java @@ -100,6 +100,9 @@ default void sleep(Duration duration) { * suspension point) without re-executing the closure. Use this feature if you want to perform * non-deterministic operations. * + *

You can name this closure using the {@code name} parameter. This name will be available in + * the observability tools. + * *

The closure should tolerate retries, that is Restate might re-execute the closure multiple * times until it records a result. * @@ -133,16 +136,18 @@ default void sleep(Duration duration) { * To propagate run failures to the call-site, make sure to wrap them in {@link * TerminalException}. * + * @param name name of the side effect. * @param serde the type tag of the return value, used to serialize/deserialize it. * @param action closure to execute. * @param type of the return value. * @return value of the run operation. */ - T run(Serde serde, ThrowingSupplier action) throws TerminalException; + T run(String name, Serde serde, ThrowingSupplier action) throws TerminalException; - /** Like {@link #run(Serde, ThrowingSupplier)}, but without returning a value. */ - default void run(ThrowingRunnable runnable) throws TerminalException { + /** Like {@link #run(String, Serde, ThrowingSupplier)}, but without returning a value. */ + default void run(String name, ThrowingRunnable runnable) throws TerminalException { run( + name, CoreSerdes.VOID, () -> { runnable.run(); @@ -150,6 +155,16 @@ default void run(ThrowingRunnable runnable) throws TerminalException { }); } + /** Like {@link #run(String, Serde, ThrowingSupplier)}, but without a name. */ + default T run(Serde serde, ThrowingSupplier action) throws TerminalException { + return run(null, serde, action); + } + + /** Like {@link #run(String, ThrowingRunnable)}, but without a name. */ + default void run(ThrowingRunnable runnable) throws TerminalException { + run(null, runnable); + } + /** * Create an {@link Awakeable}, addressable through {@link Awakeable#id()}. * diff --git a/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java b/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java index 1704f8f6..89d1d084 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java +++ b/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java @@ -110,9 +110,10 @@ public void send(Target target, Serde inputSerde, T parameter, Duration d } @Override - public T run(Serde serde, ThrowingSupplier action) { + public T run(String name, Serde serde, ThrowingSupplier action) { CompletableFuture> enterFut = new CompletableFuture<>(); syscalls.enterSideEffectBlock( + name, new EnterSideEffectSyscallCallback() { @Override public void onNotExecuted() { diff --git a/sdk-api/src/test/java/dev/restate/sdk/SideEffectTest.java b/sdk-api/src/test/java/dev/restate/sdk/SideEffectTest.java index 38c30210..a34c7e6e 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/SideEffectTest.java +++ b/sdk-api/src/test/java/dev/restate/sdk/SideEffectTest.java @@ -18,6 +18,7 @@ public class SideEffectTest extends SideEffectTestSuite { + @Override protected TestInvocationBuilder sideEffect(String sideEffectOutput) { return testDefinitionForService( "SideEffect", @@ -29,6 +30,19 @@ protected TestInvocationBuilder sideEffect(String sideEffectOutput) { }); } + @Override + protected TestInvocationBuilder namedSideEffect(String name, String sideEffectOutput) { + return testDefinitionForService( + "SideEffect", + CoreSerdes.VOID, + CoreSerdes.JSON_STRING, + (ctx, unused) -> { + String result = ctx.run(name, CoreSerdes.JSON_STRING, () -> sideEffectOutput); + return "Hello " + result; + }); + } + + @Override protected TestInvocationBuilder consecutiveSideEffect(String sideEffectOutput) { return testDefinitionForService( "ConsecutiveSideEffect", @@ -42,6 +56,7 @@ protected TestInvocationBuilder consecutiveSideEffect(String sideEffectOutput) { }); } + @Override protected TestInvocationBuilder checkContextSwitching() { return testDefinitionForService( "CheckContextSwitching", @@ -65,6 +80,7 @@ protected TestInvocationBuilder checkContextSwitching() { }); } + @Override protected TestInvocationBuilder sideEffectGuard() { return testDefinitionForService( "SideEffectGuard", @@ -75,4 +91,20 @@ protected TestInvocationBuilder sideEffectGuard() { throw new IllegalStateException("This point should not be reached"); }); } + + @Override + protected TestInvocationBuilder failingSideEffect(String name, String reason) { + return testDefinitionForService( + "FailingSideEffect", + CoreSerdes.VOID, + CoreSerdes.JSON_STRING, + (ctx, unused) -> { + ctx.run( + name, + () -> { + throw new IllegalStateException(reason); + }); + return null; + }); + } } diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java index a4d5a627..495481b6 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java +++ b/sdk-common/src/main/java/dev/restate/sdk/common/syscalls/Syscalls.java @@ -68,7 +68,7 @@ void send( @Nullable Duration delay, SyscallCallback requestCallback); - void enterSideEffectBlock(EnterSideEffectSyscallCallback callback); + void enterSideEffectBlock(@Nullable String name, EnterSideEffectSyscallCallback callback); void exitSideEffectBlock(ByteString toWrite, ExitSideEffectSyscallCallback callback); diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/Entries.java b/sdk-core/src/main/java/dev/restate/sdk/core/Entries.java index 3e0a193c..2e7f7706 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/Entries.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/Entries.java @@ -26,6 +26,8 @@ final class Entries { private Entries() {} abstract static class JournalEntry { + abstract String getName(E expected); + void checkEntryHeader(E expected, MessageLite actual) throws ProtocolException {} abstract void trace(E expected, Span span); @@ -57,6 +59,11 @@ static final class OutputEntry extends JournalEntry { private OutputEntry() {} + @Override + String getName(OutputEntryMessage expected) { + return expected.getName(); + } + @Override public void trace(OutputEntryMessage expected, Span span) { span.addEvent("Output"); @@ -81,6 +88,11 @@ public boolean hasResult(GetStateEntryMessage actual) { return actual.getResultCase() != GetStateEntryMessage.ResultCase.RESULT_NOT_SET; } + @Override + String getName(GetStateEntryMessage expected) { + return expected.getName(); + } + @Override void checkEntryHeader(GetStateEntryMessage expected, MessageLite actual) throws ProtocolException { @@ -163,6 +175,11 @@ public boolean hasResult(GetStateKeysEntryMessage actual) { return actual.getResultCase() != GetStateKeysEntryMessage.ResultCase.RESULT_NOT_SET; } + @Override + String getName(GetStateKeysEntryMessage expected) { + return expected.getName(); + } + @Override void checkEntryHeader(GetStateKeysEntryMessage expected, MessageLite actual) throws ProtocolException { @@ -232,6 +249,11 @@ public void trace(ClearStateEntryMessage expected, Span span) { "ClearState", Attributes.of(Tracing.RESTATE_STATE_KEY, expected.getKey().toString())); } + @Override + String getName(ClearStateEntryMessage expected) { + return expected.getName(); + } + @Override void checkEntryHeader(ClearStateEntryMessage expected, MessageLite actual) throws ProtocolException { @@ -256,6 +278,11 @@ public void trace(ClearAllStateEntryMessage expected, Span span) { span.addEvent("ClearAllState"); } + @Override + String getName(ClearAllStateEntryMessage expected) { + return expected.getName(); + } + @Override void checkEntryHeader(ClearAllStateEntryMessage expected, MessageLite actual) throws ProtocolException { @@ -281,6 +308,11 @@ public void trace(SetStateEntryMessage expected, Span span) { "SetState", Attributes.of(Tracing.RESTATE_STATE_KEY, expected.getKey().toString())); } + @Override + String getName(SetStateEntryMessage expected) { + return expected.getName(); + } + @Override void checkEntryHeader(SetStateEntryMessage expected, MessageLite actual) throws ProtocolException { @@ -305,6 +337,11 @@ static final class SleepEntry extends CompletableJournalEntry syscalls.enterSideEffectBlock(callback)); + public void enterSideEffectBlock(String name, EnterSideEffectSyscallCallback callback) { + syscallsExecutor.execute(() -> syscalls.enterSideEffectBlock(name, callback)); } @Override diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java b/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java index 59dc3dd9..30ffa44e 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/InvocationStateMachine.java @@ -21,7 +21,6 @@ import io.opentelemetry.api.trace.Span; import java.util.*; import java.util.concurrent.Flow; -import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.stream.Collectors; import org.apache.logging.log4j.LogManager; @@ -43,12 +42,15 @@ class InvocationStateMachine implements InvocationFlow.InvocationProcessor { // Obtained after WAITING_START private ByteString id; + private String debugId; private String key; private int entriesToReplay; private UserStateStore userStateStore; - // Index tracking progress in the journal - private int currentJournalIndex; + // Those values track the progress in the journal + private int currentJournalEntryIndex = -1; + private String currentJournalEntryName = null; + private MessageType currentJournalEntryType = null; // Buffering of messages and completions private final IncomingEntriesStateMachine incomingEntriesStateMachine; @@ -176,6 +178,7 @@ void onStartMessage(MessageLite msg) { // Unpack the StartMessage Protocol.StartMessage startMessage = (Protocol.StartMessage) msg; this.id = startMessage.getId(); + this.debugId = startMessage.getDebugId(); InvocationId invocationId = new InvocationIdImpl(startMessage.getDebugId()); this.key = startMessage.getKey(); this.entriesToReplay = startMessage.getKnownEntries(); @@ -212,8 +215,9 @@ void onStartMessage(MessageLite msg) { this.inputSubscription.request(Long.MAX_VALUE); // Now wait input entry + this.nextJournalEntry(null, MessageType.InputEntryMessage); this.readEntry( - (i, inputMsg) -> { + inputMsg -> { if (!(inputMsg instanceof Protocol.InputEntryMessage)) { throw ProtocolException.unexpectedMessage(Protocol.InputEntryMessage.class, inputMsg); } @@ -251,17 +255,13 @@ void suspend(Collection suspensionIndexes) { void fail(Throwable cause) { LOG.warn("Invocation failed", cause); - Protocol.ErrorMessage msg; - if (cause instanceof ProtocolException) { - msg = ((ProtocolException) cause).toErrorMessage(); - } else { - msg = - Protocol.ErrorMessage.newBuilder() - .setCode(TerminalException.INTERNAL_SERVER_ERROR_CODE) - .setMessage(cause.toString()) - .build(); - } - this.closeWithMessage(msg, cause); + this.closeWithMessage( + Util.toErrorMessage( + cause, + this.currentJournalEntryIndex, + this.currentJournalEntryName, + this.currentJournalEntryType), + cause); } private void closeWithMessage(MessageLite closeMessage, Throwable cause) { @@ -295,12 +295,15 @@ void processCompletableJournalEntry( Entries.CompletableJournalEntry journalEntry, SyscallCallback> callback) { checkInsideSideEffectGuard(); + this.nextJournalEntry( + journalEntry.getName(expectedEntryMessage), MessageType.fromMessage(expectedEntryMessage)); + if (this.invocationState == InvocationState.CLOSED) { callback.onCancel(AbortedExecutionException.INSTANCE); } else if (this.invocationState == InvocationState.REPLAYING) { // Retrieve the entry this.readEntry( - (entryIndex, actualEntryMessage) -> { + actualEntryMessage -> { journalEntry.checkEntryHeader(expectedEntryMessage, actualEntryMessage); if (journalEntry.hasResult((E) actualEntryMessage)) { @@ -308,17 +311,19 @@ void processCompletableJournalEntry( journalEntry.updateUserStateStoreWithEntry( (E) actualEntryMessage, this.userStateStore); Result readyResultInternal = journalEntry.parseEntryResult((E) actualEntryMessage); - callback.onSuccess(DeferredResults.completedSingle(entryIndex, readyResultInternal)); + callback.onSuccess( + DeferredResults.completedSingle( + this.currentJournalEntryIndex, readyResultInternal)); } else { // Entry is not completed yet this.readyResultStateMachine.offerCompletionParser( - entryIndex, + this.currentJournalEntryIndex, completionMessage -> { journalEntry.updateUserStateStorageWithCompletion( (E) actualEntryMessage, completionMessage, this.userStateStore); return journalEntry.parseCompletionResult(completionMessage); }); - callback.onSuccess(DeferredResults.single(entryIndex)); + callback.onSuccess(DeferredResults.single(this.currentJournalEntryIndex)); } }, callback::onCancel); @@ -331,9 +336,6 @@ void processCompletableJournalEntry( journalEntry.trace(entryToWrite, span); } - // Retrieve the index - int entryIndex = this.currentJournalIndex; - // Write out the input entry this.writeEntry(entryToWrite); @@ -341,11 +343,11 @@ void processCompletableJournalEntry( // Complete it with the result, as we already have it callback.onSuccess( DeferredResults.completedSingle( - entryIndex, journalEntry.parseEntryResult(entryToWrite))); + this.currentJournalEntryIndex, journalEntry.parseEntryResult(entryToWrite))); } else { // Register the completion parser this.readyResultStateMachine.offerCompletionParser( - entryIndex, + this.currentJournalEntryIndex, completionMessage -> { journalEntry.updateUserStateStorageWithCompletion( entryToWrite, completionMessage, this.userStateStore); @@ -353,7 +355,7 @@ void processCompletableJournalEntry( }); // Call the onSuccess - callback.onSuccess(DeferredResults.single(entryIndex)); + callback.onSuccess(DeferredResults.single(this.currentJournalEntryIndex)); } } else { throw new IllegalStateException( @@ -367,12 +369,15 @@ void processJournalEntry( Entries.JournalEntry journalEntry, SyscallCallback callback) { checkInsideSideEffectGuard(); + this.nextJournalEntry( + journalEntry.getName(expectedEntryMessage), MessageType.fromMessage(expectedEntryMessage)); + if (this.invocationState == InvocationState.CLOSED) { callback.onCancel(AbortedExecutionException.INSTANCE); } else if (this.invocationState == InvocationState.REPLAYING) { // Retrieve the entry this.readEntry( - (entryIndex, actualEntryMessage) -> { + actualEntryMessage -> { journalEntry.checkEntryHeader(expectedEntryMessage, actualEntryMessage); journalEntry.updateUserStateStoreWithEntry((E) actualEntryMessage, this.userStateStore); callback.onSuccess(null); @@ -397,14 +402,16 @@ void processJournalEntry( } } - void enterSideEffectBlock(EnterSideEffectSyscallCallback callback) { + void enterSideEffectBlock(String name, EnterSideEffectSyscallCallback callback) { checkInsideSideEffectGuard(); + this.nextJournalEntry(name, MessageType.SideEffectEntryMessage); + if (this.invocationState == InvocationState.CLOSED) { callback.onCancel(AbortedExecutionException.INSTANCE); } else if (this.invocationState == InvocationState.REPLAYING) { // Retrieve the entry this.readEntry( - (entryIndex, msg) -> { + msg -> { Util.assertEntryClass(Java.SideEffectEntryMessage.class, msg); // We have a result already, complete the callback @@ -437,16 +444,22 @@ void exitSideEffectBlock( span.addEvent("Exit SideEffect"); } + // For side effects, let's write out the name too, if available + if (this.currentJournalEntryName != null) { + sideEffectEntry = sideEffectEntry.toBuilder().setName(this.currentJournalEntryName).build(); + } + // Write new entry - this.sideEffectAckStateMachine.registerExecutedSideEffect(this.currentJournalIndex); + this.sideEffectAckStateMachine.registerExecutedSideEffect(this.currentJournalEntryIndex); this.writeEntry(sideEffectEntry); // Wait for entry to be acked + Java.SideEffectEntryMessage finalSideEffectEntry = sideEffectEntry; this.sideEffectAckStateMachine.waitLastSideEffectAck( new SideEffectAckStateMachine.SideEffectAckCallback() { @Override public void onLastSideEffectAck() { - completeSideEffectCallbackWithEntry(sideEffectEntry, callback); + completeSideEffectCallbackWithEntry(finalSideEffectEntry, callback); } @Override @@ -570,10 +583,12 @@ private void resolveCombinatorDeferred( // Calling .await() on a combinator deferred within a side effect is not allowed // as resolving it creates or read a journal entry. checkInsideSideEffectGuard(); + this.nextJournalEntry(null, MessageType.CombinatorAwaitableEntryMessage); + if (Objects.equals(this.invocationState, InvocationState.REPLAYING)) { // Retrieve the CombinatorAwaitableEntryMessage this.readEntry( - (entryIndex, actualMsg) -> { + actualMsg -> { Util.assertEntryClass(Java.CombinatorAwaitableEntryMessage.class, actualMsg); if (!rootDeferred.tryResolve( @@ -688,16 +703,14 @@ private void transitionState(InvocationState newInvocationState) { // Cannot move out of the closed state return; } - LOG.debug("Transitioning {} to {}", this, newInvocationState); + LOG.debug("Transitioning state machine to {}", newInvocationState); this.invocationState = newInvocationState; this.loggingContextSetter.set( RestateEndpoint.LoggingContextSetter.INVOCATION_STATUS_KEY, newInvocationState.toString()); } - private void incrementCurrentIndex() { - this.currentJournalIndex++; - - if (currentJournalIndex >= entriesToReplay + private void tryTransitionProcessing() { + if (currentJournalEntryIndex == entriesToReplay - 1 && this.invocationState == InvocationState.REPLAYING) { if (!this.incomingEntriesStateMachine.isEmpty()) { throw new IllegalStateException("Entries queue should be empty at this point"); @@ -706,19 +719,25 @@ private void incrementCurrentIndex() { } } + private void nextJournalEntry(String entryName, MessageType entryType) { + this.currentJournalEntryIndex++; + this.currentJournalEntryName = entryName; + this.currentJournalEntryType = entryType; + } + private void checkInsideSideEffectGuard() { if (this.insideSideEffect) { throw ProtocolException.invalidSideEffectCall(); } } - void readEntry(BiConsumer msgCallback, Consumer errorCallback) { + void readEntry(Consumer msgCallback, Consumer errorCallback) { this.incomingEntriesStateMachine.read( new IncomingEntriesStateMachine.OnEntryCallback() { @Override public void onEntry(MessageLite msg) { - incrementCurrentIndex(); - msgCallback.accept(currentJournalIndex - 1, msg); + tryTransitionProcessing(); + msgCallback.accept(msg); } @Override @@ -737,11 +756,10 @@ public void onError(Throwable e) { private void writeEntry(MessageLite message) { LOG.trace("Writing to output message {} {}", message.getClass(), message); Objects.requireNonNull(this.outputSubscriber).onNext(message); - this.incrementCurrentIndex(); } @Override public String toString() { - return "InvocationStateMachine{id=" + id + '}'; + return "InvocationStateMachine[" + debugId + ']'; } } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java b/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java index 99cc9e63..4ee1163a 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java @@ -9,7 +9,6 @@ package dev.restate.sdk.core; import com.google.protobuf.MessageLite; -import dev.restate.generated.sdk.java.Java; import dev.restate.generated.service.protocol.Protocol; public class MessageHeader { @@ -55,19 +54,7 @@ public static MessageHeader parse(long encoded) throws ProtocolException { } public static MessageHeader fromMessage(MessageLite msg) { - if (msg instanceof Protocol.SuspensionMessage) { - return new MessageHeader(MessageType.SuspensionMessage, 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.ErrorMessage) { - return new MessageHeader(MessageType.ErrorMessage, 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.EndMessage) { - return new MessageHeader(MessageType.EndMessage, 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.EntryAckMessage) { - return new MessageHeader(MessageType.EntryAckMessage, 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.InputEntryMessage) { - return new MessageHeader(MessageType.InputEntryMessage, 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.OutputEntryMessage) { - return new MessageHeader(MessageType.OutputEntryMessage, 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.GetStateEntryMessage) { + if (msg instanceof Protocol.GetStateEntryMessage) { return new MessageHeader( MessageType.GetStateEntryMessage, ((Protocol.GetStateEntryMessage) msg).getResultCase() @@ -75,12 +62,6 @@ public static MessageHeader fromMessage(MessageLite msg) { ? DONE_FLAG : 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.SetStateEntryMessage) { - return new MessageHeader(MessageType.SetStateEntryMessage, 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.ClearStateEntryMessage) { - return new MessageHeader(MessageType.ClearStateEntryMessage, 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.ClearAllStateEntryMessage) { - return new MessageHeader(MessageType.ClearAllStateEntryMessage, 0, msg.getSerializedSize()); } else if (msg instanceof Protocol.GetStateKeysEntryMessage) { return new MessageHeader( MessageType.GetStateKeysEntryMessage, @@ -105,9 +86,6 @@ public static MessageHeader fromMessage(MessageLite msg) { ? DONE_FLAG : 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.BackgroundInvokeEntryMessage) { - return new MessageHeader( - MessageType.BackgroundInvokeEntryMessage, 0, msg.getSerializedSize()); } else if (msg instanceof Protocol.AwakeableEntryMessage) { return new MessageHeader( MessageType.AwakeableEntryMessage, @@ -116,19 +94,9 @@ public static MessageHeader fromMessage(MessageLite msg) { ? DONE_FLAG : 0, msg.getSerializedSize()); - } else if (msg instanceof Protocol.CompleteAwakeableEntryMessage) { - return new MessageHeader( - MessageType.CompleteAwakeableEntryMessage, 0, msg.getSerializedSize()); - } else if (msg instanceof Java.CombinatorAwaitableEntryMessage) { - return new MessageHeader( - MessageType.CombinatorAwaitableEntryMessage, 0, msg.getSerializedSize()); - } else if (msg instanceof Java.SideEffectEntryMessage) { - return new MessageHeader( - MessageType.SideEffectEntryMessage, REQUIRES_ACK_FLAG, msg.getSerializedSize()); - } else if (msg instanceof Protocol.CompletionMessage) { - throw new IllegalArgumentException("SDK should never send a CompletionMessage"); } - throw new IllegalStateException(); + // Messages with no flags + return new MessageHeader(MessageType.fromMessage(msg), 0, msg.getSerializedSize()); } public static void checkProtocolVersion(MessageHeader header) { diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/MessageType.java b/sdk-core/src/main/java/dev/restate/sdk/core/MessageType.java index 2b79cd6b..4f013a05 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/MessageType.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/MessageType.java @@ -201,4 +201,47 @@ public static MessageType decode(short value) throws ProtocolException { } throw ProtocolException.unknownMessageType(value); } + + public static MessageType fromMessage(MessageLite msg) { + if (msg instanceof Protocol.SuspensionMessage) { + return MessageType.SuspensionMessage; + } else if (msg instanceof Protocol.ErrorMessage) { + return MessageType.ErrorMessage; + } else if (msg instanceof Protocol.EndMessage) { + return MessageType.EndMessage; + } else if (msg instanceof Protocol.EntryAckMessage) { + return MessageType.EntryAckMessage; + } else if (msg instanceof Protocol.InputEntryMessage) { + return MessageType.InputEntryMessage; + } else if (msg instanceof Protocol.OutputEntryMessage) { + return MessageType.OutputEntryMessage; + } else if (msg instanceof Protocol.GetStateEntryMessage) { + return MessageType.GetStateEntryMessage; + } else if (msg instanceof Protocol.SetStateEntryMessage) { + return MessageType.SetStateEntryMessage; + } else if (msg instanceof Protocol.ClearStateEntryMessage) { + return MessageType.ClearStateEntryMessage; + } else if (msg instanceof Protocol.ClearAllStateEntryMessage) { + return MessageType.ClearAllStateEntryMessage; + } else if (msg instanceof Protocol.GetStateKeysEntryMessage) { + return MessageType.GetStateKeysEntryMessage; + } else if (msg instanceof Protocol.SleepEntryMessage) { + return MessageType.SleepEntryMessage; + } else if (msg instanceof Protocol.InvokeEntryMessage) { + return MessageType.InvokeEntryMessage; + } else if (msg instanceof Protocol.BackgroundInvokeEntryMessage) { + return MessageType.BackgroundInvokeEntryMessage; + } else if (msg instanceof Protocol.AwakeableEntryMessage) { + return MessageType.AwakeableEntryMessage; + } else if (msg instanceof Protocol.CompleteAwakeableEntryMessage) { + return MessageType.CompleteAwakeableEntryMessage; + } else if (msg instanceof Java.CombinatorAwaitableEntryMessage) { + return MessageType.CombinatorAwaitableEntryMessage; + } else if (msg instanceof Java.SideEffectEntryMessage) { + return MessageType.SideEffectEntryMessage; + } else if (msg instanceof Protocol.CompletionMessage) { + throw new IllegalArgumentException("SDK should never send a CompletionMessage"); + } + throw new IllegalStateException(); + } } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java b/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java index 25d5089f..588eeba8 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/ProtocolException.java @@ -11,8 +11,6 @@ import com.google.protobuf.MessageLite; import dev.restate.generated.service.protocol.Protocol; import dev.restate.sdk.common.TerminalException; -import java.io.PrintWriter; -import java.io.StringWriter; public class ProtocolException extends RuntimeException { @@ -42,20 +40,6 @@ public int getCode() { return code; } - public Protocol.ErrorMessage toErrorMessage() { - // Convert stacktrace to string - StringWriter sw = new StringWriter(); - PrintWriter pw = new PrintWriter(sw); - pw.println("Stacktrace:"); - this.printStackTrace(pw); - - return Protocol.ErrorMessage.newBuilder() - .setCode(code) - .setMessage(this.toString()) - .setDescription(sw.toString()) - .build(); - } - static ProtocolException unexpectedMessage( Class expected, MessageLite actual) { return new ProtocolException( diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java b/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java index fe261225..80036c2e 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java @@ -219,11 +219,11 @@ public void send( } @Override - public void enterSideEffectBlock(EnterSideEffectSyscallCallback callback) { + public void enterSideEffectBlock(String name, EnterSideEffectSyscallCallback callback) { wrapAndPropagateExceptions( () -> { LOG.trace("enterSideEffectBlock"); - this.stateMachine.enterSideEffectBlock(callback); + this.stateMachine.enterSideEffectBlock(name, callback); }, callback); } diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/Util.java b/sdk-core/src/main/java/dev/restate/sdk/core/Util.java index 0acebcf1..2d668398 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/Util.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/Util.java @@ -13,9 +13,12 @@ import dev.restate.generated.service.protocol.Protocol; import dev.restate.sdk.common.AbortedExecutionException; import dev.restate.sdk.common.TerminalException; +import java.io.PrintWriter; +import java.io.StringWriter; import java.util.Objects; import java.util.Optional; import java.util.function.Predicate; +import org.jspecify.annotations.Nullable; public final class Util { private Util() {} @@ -76,6 +79,41 @@ static Protocol.Failure toProtocolFailure(Throwable throwable) { return toProtocolFailure(TerminalException.INTERNAL_SERVER_ERROR_CODE, throwable.toString()); } + static Protocol.ErrorMessage toErrorMessage( + Throwable throwable, + int currentJournalIndex, + @Nullable String currentJournalEntryName, + @Nullable MessageType currentJournalEntryType) { + Protocol.ErrorMessage.Builder msg = + Protocol.ErrorMessage.newBuilder().setMessage(throwable.toString()); + + if (throwable instanceof ProtocolException) { + msg.setCode(((ProtocolException) throwable).getCode()); + } else { + msg.setCode(TerminalException.INTERNAL_SERVER_ERROR_CODE); + } + + // Convert stacktrace to string + StringWriter sw = new StringWriter(); + PrintWriter pw = new PrintWriter(sw); + pw.println("Stacktrace:"); + throwable.printStackTrace(pw); + msg.setDescription(sw.toString()); + + // Add journal entry info + if (currentJournalIndex >= 0) { + msg.setRelatedEntryIndex(currentJournalIndex); + } + if (currentJournalEntryName != null) { + msg.setRelatedEntryName(currentJournalEntryName); + } + if (currentJournalEntryType != null) { + msg.setRelatedEntryType(currentJournalEntryType.encode()); + } + + return msg.build(); + } + static TerminalException toRestateException(Protocol.Failure failure) { return new TerminalException(failure.getCode(), failure.getMessage()); } diff --git a/sdk-core/src/main/sdk-proto/dev/restate/sdk/java.proto b/sdk-core/src/main/sdk-proto/dev/restate/sdk/java.proto index 0e31615e..71e2a0fe 100644 --- a/sdk-core/src/main/sdk-proto/dev/restate/sdk/java.proto +++ b/sdk-core/src/main/sdk-proto/dev/restate/sdk/java.proto @@ -18,6 +18,9 @@ option java_package = "dev.restate.generated.sdk.java"; // Type: 0xFC00 + 0 message CombinatorAwaitableEntryMessage { repeated uint32 entry_index = 1; + + // Entry name + string name = 12; } // Type: 0xFC00 + 1 @@ -27,4 +30,7 @@ message SideEffectEntryMessage { bytes value = 14; dev.restate.service.protocol.Failure failure = 15; }; + + // Entry name + string name = 12; } diff --git a/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java b/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java index 48e58f3f..e475db22 100644 --- a/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java +++ b/sdk-core/src/test/java/dev/restate/sdk/core/SideEffectTestSuite.java @@ -8,26 +8,33 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.core; -import static dev.restate.sdk.core.AssertUtils.containsOnlyExactErrorMessage; +import static dev.restate.sdk.core.AssertUtils.*; import static dev.restate.sdk.core.ProtoUtils.*; import static dev.restate.sdk.core.TestDefinitions.TestInvocationBuilder; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.InstanceOfAssertFactories.STRING; import static org.assertj.core.api.InstanceOfAssertFactories.type; import dev.restate.generated.sdk.java.Java; +import dev.restate.generated.service.protocol.Protocol; import dev.restate.sdk.common.CoreSerdes; +import dev.restate.sdk.common.TerminalException; import java.util.stream.Stream; public abstract class SideEffectTestSuite implements TestDefinitions.TestSuite { protected abstract TestInvocationBuilder sideEffect(String sideEffectOutput); + protected abstract TestInvocationBuilder namedSideEffect(String name, String sideEffectOutput); + protected abstract TestInvocationBuilder consecutiveSideEffect(String sideEffectOutput); protected abstract TestInvocationBuilder checkContextSwitching(); protected abstract TestInvocationBuilder sideEffectGuard(); + protected abstract TestInvocationBuilder failingSideEffect(String name, String reason); + @Override public Stream definitions() { return Stream.of( @@ -46,6 +53,13 @@ public Stream definitions() { outputMessage("Hello Francesco"), END_MESSAGE) .named("Without optimization and with acks returns"), + this.namedSideEffect("get-my-name", "Francesco") + .withInput(startMessage(1), inputMessage("Till")) + .expectingOutput( + Java.SideEffectEntryMessage.newBuilder() + .setName("get-my-name") + .setValue(CoreSerdes.JSON_STRING.serializeToByteString("Francesco")), + suspensionMessage(1)), this.consecutiveSideEffect("Francesco") .withInput(startMessage(1), inputMessage("Till")) .expectingOutput( @@ -74,6 +88,25 @@ public Stream definitions() { outputMessage("Hello FRANCESCO"), END_MESSAGE) .named("With optimization and ack on first and second side effect will resume"), + this.failingSideEffect("my-side-effect", "some failure") + .withInput(startMessage(1), inputMessage()) + .onlyUnbuffered() + .assertingOutput( + containsOnly( + errorMessage( + errorMessage -> + assertThat(errorMessage) + .returns( + TerminalException.INTERNAL_SERVER_ERROR_CODE, + Protocol.ErrorMessage::getCode) + .returns(1, Protocol.ErrorMessage::getRelatedEntryIndex) + .returns( + (int) MessageType.SideEffectEntryMessage.encode(), + Protocol.ErrorMessage::getRelatedEntryType) + .returns( + "my-side-effect", Protocol.ErrorMessage::getRelatedEntryName) + .extracting(Protocol.ErrorMessage::getMessage, STRING) + .contains("some failure")))), // --- Other tests this.checkContextSwitching() diff --git a/sdk-workflow-api/src/main/java/dev/restate/sdk/workflow/impl/WorkflowContextImpl.java b/sdk-workflow-api/src/main/java/dev/restate/sdk/workflow/impl/WorkflowContextImpl.java index 6d67b694..7ea9523b 100644 --- a/sdk-workflow-api/src/main/java/dev/restate/sdk/workflow/impl/WorkflowContextImpl.java +++ b/sdk-workflow-api/src/main/java/dev/restate/sdk/workflow/impl/WorkflowContextImpl.java @@ -209,8 +209,9 @@ public void send(Target target, Serde inputSerde, T parameter, Duration d } @Override - public T run(Serde serde, ThrowingSupplier action) throws TerminalException { - return ctx.run(serde, action); + public T run(String name, Serde serde, ThrowingSupplier action) + throws TerminalException { + return ctx.run(name, serde, action); } @Override