Skip to content

Commit

Permalink
Eager state support (#87) (#87)
Browse files Browse the repository at this point in the history
* Support MessageHeader.PARTIAL_STATE_FLAG
* Propagate the MessageHeader to the state machine
* Implement LocalStateStorage and wire it up
* Add the EagerStateTest
* Modified the `MessageHeader#fromMessage` to be safer.
  • Loading branch information
slinkydeveloper authored Jun 5, 2023
1 parent 5c9e8df commit 5ab3461
Show file tree
Hide file tree
Showing 18 changed files with 572 additions and 117 deletions.
52 changes: 52 additions & 0 deletions sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/Entries.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dev.restate.sdk.core.impl;

import com.google.protobuf.ByteString;
import com.google.protobuf.Empty;
import com.google.protobuf.MessageLite;
import dev.restate.generated.service.protocol.Protocol;
import dev.restate.generated.service.protocol.Protocol.*;
Expand All @@ -17,6 +18,8 @@ abstract static class JournalEntry<E extends MessageLite> {
void checkEntryHeader(E expected, MessageLite actual) throws ProtocolException {}

abstract void trace(E expected, Span span);

void updateLocalStateStorage(E expected, LocalStateStorage localStateStorage) {}
}

abstract static class CompletableJournalEntry<E extends MessageLite, R> extends JournalEntry<E> {
Expand All @@ -28,6 +31,13 @@ ReadyResultInternal<R> parseCompletionResult(CompletionMessage actual) {
throw ProtocolException.completionDoesNotMatch(
this.getClass().getName(), actual.getResultCase());
}

E tryCompleteWithLocalStateStorage(E expected, LocalStateStorage localStateStorage) {
return expected;
}

void updateLocalStateStorageWithCompletion(
E expected, CompletionMessage actual, LocalStateStorage localStateStorage) {}
}

static final class PollInputEntry<R extends MessageLite>
Expand Down Expand Up @@ -123,6 +133,36 @@ public ReadyResultInternal<R> parseCompletionResult(CompletionMessage actual) {
}
return super.parseCompletionResult(actual);
}

@Override
void updateLocalStateStorage(
GetStateEntryMessage expected, LocalStateStorage localStateStorage) {
localStateStorage.set(expected.getKey(), expected.getValue());
}

@Override
GetStateEntryMessage tryCompleteWithLocalStateStorage(
GetStateEntryMessage expected, LocalStateStorage localStateStorage) {
LocalStateStorage.State value = localStateStorage.get(expected.getKey());
if (value instanceof LocalStateStorage.Value) {
return expected.toBuilder().setValue(((LocalStateStorage.Value) value).getValue()).build();
} else if (value instanceof LocalStateStorage.Empty) {
return expected.toBuilder().setEmpty(Empty.getDefaultInstance()).build();
}
return expected;
}

@Override
void updateLocalStateStorageWithCompletion(
GetStateEntryMessage expected,
CompletionMessage actual,
LocalStateStorage localStateStorage) {
if (actual.hasEmpty()) {
localStateStorage.clear(expected.getKey());
} else {
localStateStorage.set(expected.getKey(), actual.getValue());
}
}
}

static final class ClearStateEntry extends JournalEntry<ClearStateEntryMessage> {
Expand All @@ -142,6 +182,12 @@ void checkEntryHeader(ClearStateEntryMessage expected, MessageLite actual)
throws ProtocolException {
Util.assertEntryEquals(expected, actual);
}

@Override
void updateLocalStateStorage(
ClearStateEntryMessage expected, LocalStateStorage localStateStorage) {
localStateStorage.clear(expected.getKey());
}
}

static final class SetStateEntry extends JournalEntry<SetStateEntryMessage> {
Expand All @@ -161,6 +207,12 @@ void checkEntryHeader(SetStateEntryMessage expected, MessageLite actual)
throws ProtocolException {
Util.assertEntryEquals(expected, actual);
}

@Override
void updateLocalStateStorage(
SetStateEntryMessage expected, LocalStateStorage localStateStorage) {
localStateStorage.set(expected.getKey(), expected.getValue());
}
}

static final class SleepEntry extends CompletableJournalEntry<SleepEntryMessage, Void> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package dev.restate.sdk.core.impl;

import com.google.protobuf.MessageLite;
import java.util.concurrent.Flow;

class ExceptionCatchingInvocationInputSubscriber
Expand All @@ -24,9 +23,9 @@ public void onSubscribe(Flow.Subscription subscription) {
}

@Override
public void onNext(MessageLite messageLite) {
public void onNext(InvocationFlow.InvocationInput invocationInput) {
try {
invocationInputSubscriber.onNext(messageLite);
invocationInputSubscriber.onNext(invocationInput);
} catch (Throwable throwable) {
invocationInputSubscriber.onError(throwable);
throw throwable;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,41 @@

public interface InvocationFlow {

interface InvocationInputPublisher extends Flow.Publisher<MessageLite> {}
interface InvocationInput {
MessageHeader header();

MessageLite message();

static InvocationInput of(MessageHeader header, MessageLite message) {
return new InvocationInput() {
@Override
public MessageHeader header() {
return header;
}

@Override
public MessageLite message() {
return message;
}

@Override
public String toString() {
return header.toString() + " " + message.toString();
}
};
}
}

interface InvocationInputPublisher extends Flow.Publisher<InvocationInput> {}

interface InvocationOutputPublisher extends Flow.Publisher<MessageLite> {}

interface InvocationInputSubscriber extends Flow.Subscriber<MessageLite> {}
interface InvocationInputSubscriber extends Flow.Subscriber<InvocationInput> {}

interface InvocationOutputSubscriber extends Flow.Subscriber<MessageLite> {}

interface InvocationProcessor
extends Flow.Processor<MessageLite, MessageLite>,
extends Flow.Processor<InvocationInput, MessageLite>,
InvocationInputSubscriber,
InvocationOutputPublisher {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ private enum State {
private ByteString instanceKey;
private ByteString invocationId;
private int entriesToReplay;
private LocalStateStorage localStateStorage;

// Index tracking progress in the journal
private int currentJournalIndex;
Expand Down Expand Up @@ -109,10 +110,12 @@ public void onSubscribe(Flow.Subscription subscription) {
}

@Override
public void onNext(MessageLite msg) {
public void onNext(InvocationFlow.InvocationInput invocationInput) {
MessageHeader header = invocationInput.header();
MessageLite msg = invocationInput.message();
LOG.trace("Received input message {} {}", msg.getClass(), msg);
if (this.state == State.WAITING_START) {
this.onStart(msg);
this.onStart(header, msg);
} else if (msg instanceof Protocol.CompletionMessage) {
// We check the instance rather than the state, because the user code might still be
// replaying, but the network layer is already past it and is receiving completions from the
Expand Down Expand Up @@ -151,7 +154,7 @@ void start(Consumer<InvocationId> afterStartCallback) {
this.inputSubscription.request(1);
}

void onStart(MessageLite msg) {
void onStart(MessageHeader header, MessageLite msg) {
if (!(msg instanceof Protocol.StartMessage)) {
this.fail(ProtocolException.unexpectedMessage(Protocol.StartMessage.class, msg));
return;
Expand All @@ -163,6 +166,16 @@ void onStart(MessageLite msg) {
this.invocationId = startMessage.getInvocationId();
this.entriesToReplay = startMessage.getKnownEntries();

// Set up the state cache
this.localStateStorage =
new LocalStateStorage(
header.hasFlag(MessageHeader.PARTIAL_STATE_FLAG),
startMessage.getStateMapList().stream()
.collect(
Collectors.toMap(
Protocol.StartMessage.StateEntry::getKey,
Protocol.StartMessage.StateEntry::getValue)));

if (this.span.isRecording()) {
span.addEvent(
"Start", Attributes.of(Tracing.RESTATE_INVOCATION_ID, this.invocationId.toStringUtf8()));
Expand Down Expand Up @@ -236,41 +249,64 @@ <E extends MessageLite, T> void processCompletableJournalEntry(
journalEntry.checkEntryHeader(expectedEntryMessage, actualEntryMessage);

if (journalEntry.hasResult((E) actualEntryMessage)) {
journalEntry.updateLocalStateStorage((E) actualEntryMessage, this.localStateStorage);
ReadyResultInternal<T> readyResultInternal =
journalEntry.parseEntryResult((E) actualEntryMessage);
callback.onSuccess(DeferredResults.completedSingle(entryIndex, readyResultInternal));
} else {
this.readyResultPublisher.offerCompletionParser(
entryIndex, journalEntry::parseCompletionResult);
entryIndex,
completionMessage -> {
journalEntry.updateLocalStateStorageWithCompletion(
(E) actualEntryMessage, completionMessage, this.localStateStorage);
return journalEntry.parseCompletionResult(completionMessage);
});
callback.onSuccess(DeferredResults.single(entryIndex));
}
},
callback::onCancel);
} else if (this.state == State.PROCESSING) {
// Try complete with local storage
E entryToWrite =
journalEntry.tryCompleteWithLocalStateStorage(expectedEntryMessage, localStateStorage);

if (span.isRecording()) {
journalEntry.trace(expectedEntryMessage, span);
journalEntry.trace(entryToWrite, span);
}

// Retrieve the index
int entryIndex = this.currentJournalIndex;

// Write out the input entry
this.writeEntry(expectedEntryMessage);

// Register the completion parser
this.readyResultPublisher.offerCompletionParser(
entryIndex, journalEntry::parseCompletionResult);
this.writeEntry(entryToWrite);

// Call the onSuccess
callback.onSuccess(DeferredResults.single(entryIndex));
if (journalEntry.hasResult(entryToWrite)) {
// Complete it with the result, as we already have it
callback.onSuccess(
DeferredResults.completedSingle(
entryIndex, journalEntry.parseEntryResult(entryToWrite)));
} else {
// Register the completion parser
this.readyResultPublisher.offerCompletionParser(
entryIndex,
completionMessage -> {
journalEntry.updateLocalStateStorageWithCompletion(
entryToWrite, completionMessage, this.localStateStorage);
return journalEntry.parseCompletionResult(completionMessage);
});

// Call the onSuccess
callback.onSuccess(DeferredResults.single(entryIndex));
}
} else {
throw new IllegalStateException(
"This method was invoked when the state machine is not ready to process user code. This is probably an SDK bug");
}
}

<T extends MessageLite> void processJournalEntryWithoutWaitingAck(
T expectedEntryMessage, JournalEntry<T> journalEntry, SyscallCallback<Void> callback) {
@SuppressWarnings("unchecked")
<E extends MessageLite> void processJournalEntry(
E expectedEntryMessage, JournalEntry<E> journalEntry, SyscallCallback<Void> callback) {
checkInsideSideEffectGuard();
if (this.state == State.CLOSED) {
callback.onCancel(SuspendedException.INSTANCE);
Expand All @@ -279,7 +315,7 @@ <T extends MessageLite> void processJournalEntryWithoutWaitingAck(
this.readEntry(
(entryIndex, actualEntryMessage) -> {
journalEntry.checkEntryHeader(expectedEntryMessage, actualEntryMessage);

journalEntry.updateLocalStateStorage((E) actualEntryMessage, this.localStateStorage);
callback.onSuccess(null);
},
callback::onCancel);
Expand All @@ -291,6 +327,9 @@ <T extends MessageLite> void processJournalEntryWithoutWaitingAck(
// Write new entry
this.writeEntry(expectedEntryMessage);

// Update local storage
journalEntry.updateLocalStateStorage(expectedEntryMessage, this.localStateStorage);

// Invoke the ok callback
callback.onSuccess(null);
} else {
Expand All @@ -299,7 +338,7 @@ <T extends MessageLite> void processJournalEntryWithoutWaitingAck(
}
}

void enterSideEffectJournalEntry(
void enterSideEffectBlock(
Consumer<Span> traceFn,
Consumer<Java.SideEffectEntryMessage> entryCallback,
Runnable noEntryCallback,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package dev.restate.sdk.core.impl;

import com.google.protobuf.ByteString;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;

final class LocalStateStorage {

interface State {}

static final class Unknown implements State {
private static final Unknown INSTANCE = new Unknown();

private Unknown() {}
}

static final class Empty implements State {
private static final Empty INSTANCE = new Empty();

private Empty() {}
}

static final class Value implements State {
private final ByteString value;

private Value(ByteString value) {
this.value = value;
}

public ByteString getValue() {
return value;
}
}

private final boolean isPartial;
private final HashMap<ByteString, State> map;

LocalStateStorage(boolean isPartial, Map<ByteString, ByteString> map) {
this.isPartial = isPartial;
this.map =
new HashMap<>(
map.entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> new Value(e.getValue()))));
}

public State get(ByteString key) {
return this.map.getOrDefault(key, isPartial ? Unknown.INSTANCE : Empty.INSTANCE);
}

public void set(ByteString key, ByteString value) {
this.map.put(key, new Value(value));
}

public void clear(ByteString key) {
this.map.put(key, Empty.INSTANCE);
}
}
Loading

0 comments on commit 5ab3461

Please sign in to comment.