Skip to content

Commit

Permalink
Add ctx.stateKeys() (#219)
Browse files Browse the repository at this point in the history
  • Loading branch information
slinkydeveloper authored Feb 9, 2024
1 parent dd79547 commit 8a74cae
Show file tree
Hide file tree
Showing 16 changed files with 232 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,25 @@ internal class ContextImpl internal constructor(private val syscalls: Syscalls)
return key.serde().deserializeWrappingException(syscalls, readyResult.value!!)!!
}

override suspend fun stateKeys(): Collection<String> {
val deferred: Deferred<Collection<String>> =
suspendCancellableCoroutine { cont: CancellableContinuation<Deferred<Collection<String>>> ->
syscalls.getKeys(completingContinuation(cont))
}

if (!deferred.isCompleted) {
suspendCancellableCoroutine { cont: CancellableContinuation<Unit> ->
syscalls.resolveDeferred(deferred, completingUnitContinuation(cont))
}
}

val readyResult = deferred.toResult()!!
if (!readyResult.isSuccess) {
throw readyResult.failure!!
}
return readyResult.value!!
}

override suspend fun <T : Any> set(key: StateKey<T>, value: T) {
val serializedValue = key.serde().serializeWrappingException(syscalls, value)!!
return suspendCancellableCoroutine { cont: CancellableContinuation<Unit> ->
Expand Down
7 changes: 7 additions & 0 deletions sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,13 @@ sealed interface KeyedContext : UnkeyedContext {
*/
suspend fun <T : Any> get(key: StateKey<T>): T?

/**
* Gets all the known state keys for this service instance.
*
* @return the immutable collection of known state keys.
*/
suspend fun stateKeys(): Collection<String>

/**
* Sets the given value under the given key, serializing the value using the registered
* [dev.restate.sdk.core.serde.Serde] in the interceptor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,14 @@ class EagerStateTest : EagerStateTestSuite() {
override fun getClearAllAndGet(): BindableService {
return GetClearAllAndGet()
}

private class ListKeys : GreeterRestateKt.GreeterRestateKtImplBase() {
override suspend fun greet(context: KeyedContext, request: GreetingRequest): GreetingResponse {
return greetingResponse { message = context.stateKeys().joinToString(separator = ",") }
}
}

override fun listKeys(): BindableService {
return ListKeys()
}
}
12 changes: 12 additions & 0 deletions sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import dev.restate.sdk.common.syscalls.Syscalls;
import io.grpc.MethodDescriptor;
import java.time.Duration;
import java.util.Collection;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
Expand All @@ -43,6 +44,17 @@ public <T> Optional<T> get(StateKey<T> key) {
.map(bs -> Util.deserializeWrappingException(syscalls, key.serde(), bs));
}

@Override
public Collection<String> stateKeys() {
Deferred<Collection<String>> deferred = Util.blockOnSyscall(syscalls::getKeys);

if (!deferred.isCompleted()) {
Util.<Void>blockOnSyscall(cb -> syscalls.resolveDeferred(deferred, cb));
}

return Util.unwrapResult(deferred.toResult());
}

@Override
public void clear(StateKey<?> key) {
Util.<Void>blockOnSyscall(cb -> syscalls.clear(key.name(), cb));
Expand Down
8 changes: 8 additions & 0 deletions sdk-api/src/main/java/dev/restate/sdk/KeyedContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import dev.restate.sdk.common.*;
import dev.restate.sdk.common.syscalls.Syscalls;
import java.util.Collection;
import java.util.Optional;
import javax.annotation.Nonnull;
import javax.annotation.concurrent.NotThreadSafe;
Expand All @@ -34,6 +35,13 @@ public interface KeyedContext extends UnkeyedContext {
*/
<T> Optional<T> get(StateKey<T> key);

/**
* Gets all the known state keys for this service instance.
*
* @return the immutable collection of known state keys.
*/
Collection<String> stateKeys();

/**
* Clears the state stored under key.
*
Expand Down
17 changes: 17 additions & 0 deletions sdk-api/src/test/java/dev/restate/sdk/EagerStateTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@

import dev.restate.sdk.common.CoreSerdes;
import dev.restate.sdk.common.StateKey;
import dev.restate.sdk.common.TerminalException;
import dev.restate.sdk.core.EagerStateTestSuite;
import dev.restate.sdk.core.testservices.GreeterGrpc;
import dev.restate.sdk.core.testservices.GreeterRestate;
import dev.restate.sdk.core.testservices.GreetingRequest;
import dev.restate.sdk.core.testservices.GreetingResponse;
import io.grpc.BindableService;
Expand Down Expand Up @@ -119,4 +121,19 @@ public void greet(GreetingRequest request, StreamObserver<GreetingResponse> resp
protected BindableService getClearAllAndGet() {
return new GetClearAllAndGet();
}

private static class ListKeys extends GreeterRestate.GreeterRestateImplBase {
@Override
public GreetingResponse greet(KeyedContext context, GreetingRequest request)
throws TerminalException {
return GreetingResponse.newBuilder()
.setMessage(String.join(",", context.stateKeys()))
.build();
}
}

@Override
protected BindableService listKeys() {
return new ListKeys();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import io.grpc.Context;
import io.grpc.MethodDescriptor;
import java.time.Duration;
import java.util.*;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -61,6 +62,8 @@ static Syscalls current() {

void get(String name, SyscallCallback<Deferred<ByteString>> callback);

void getKeys(SyscallCallback<Deferred<Collection<String>>> callback);

void clear(String name, SyscallCallback<Void> callback);

void clearAll(SyscallCallback<Void> callback);
Expand Down
75 changes: 75 additions & 0 deletions sdk-core/src/main/java/dev/restate/sdk/core/Entries.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@

import com.google.protobuf.ByteString;
import com.google.protobuf.Empty;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.MessageLite;
import dev.restate.generated.service.protocol.Protocol;
import dev.restate.generated.service.protocol.Protocol.*;
import dev.restate.sdk.common.syscalls.Result;
import io.opentelemetry.api.common.Attributes;
import io.opentelemetry.api.trace.Span;
import java.util.Collection;
import java.util.function.Function;
import java.util.stream.Collectors;

final class Entries {
static final String AWAKEABLE_IDENTIFIER_PREFIX = "prom_1";
Expand Down Expand Up @@ -183,6 +186,78 @@ void updateUserStateStorageWithCompletion(
}
}

static final class GetStateKeysEntry
extends CompletableJournalEntry<GetStateKeysEntryMessage, Collection<String>> {

static final GetStateKeysEntry INSTANCE = new GetStateKeysEntry();

private GetStateKeysEntry() {}

@Override
void trace(GetStateKeysEntryMessage expected, Span span) {
span.addEvent("GetStateKeys");
}

@Override
public boolean hasResult(GetStateKeysEntryMessage actual) {
return actual.getResultCase() != GetStateKeysEntryMessage.ResultCase.RESULT_NOT_SET;
}

@Override
void checkEntryHeader(GetStateKeysEntryMessage expected, MessageLite actual)
throws ProtocolException {
if (!(actual instanceof GetStateKeysEntryMessage)) {
throw ProtocolException.entryDoesNotMatch(expected, actual);
}
}

@Override
public Result<Collection<String>> parseEntryResult(GetStateKeysEntryMessage actual) {
if (actual.getResultCase() == GetStateKeysEntryMessage.ResultCase.VALUE) {
return Result.success(
actual.getValue().getKeysList().stream()
.map(ByteString::toStringUtf8)
.collect(Collectors.toUnmodifiableList()));
} else if (actual.getResultCase() == GetStateKeysEntryMessage.ResultCase.FAILURE) {
return Result.failure(Util.toRestateException(actual.getFailure()));
} else {
throw new IllegalStateException("GetStateKeysEntryMessage has not been completed.");
}
}

@Override
public Result<Collection<String>> parseCompletionResult(CompletionMessage actual) {
if (actual.getResultCase() == CompletionMessage.ResultCase.VALUE) {
GetStateKeysEntryMessage.StateKeys stateKeys;
try {
stateKeys = GetStateKeysEntryMessage.StateKeys.parseFrom(actual.getValue());
} catch (InvalidProtocolBufferException e) {
throw new ProtocolException(
"Cannot parse get state keys completion", e, ProtocolException.PROTOCOL_VIOLATION);
}
return Result.success(
stateKeys.getKeysList().stream()
.map(ByteString::toStringUtf8)
.collect(Collectors.toUnmodifiableList()));
} else if (actual.getResultCase() == CompletionMessage.ResultCase.FAILURE) {
return Result.failure(Util.toRestateException(actual.getFailure()));
}
return super.parseCompletionResult(actual);
}

@Override
GetStateKeysEntryMessage tryCompleteWithUserStateStorage(
GetStateKeysEntryMessage expected, UserStateStore userStateStore) {
if (userStateStore.isComplete()) {
return expected.toBuilder()
.setValue(
GetStateKeysEntryMessage.StateKeys.newBuilder().addAllKeys(userStateStore.keys()))
.build();
}
return expected;
}
}

static final class ClearStateEntry extends JournalEntry<ClearStateEntryMessage> {

static final ClearStateEntry INSTANCE = new ClearStateEntry();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import dev.restate.sdk.common.syscalls.SyscallCallback;
import io.grpc.MethodDescriptor;
import java.time.Duration;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.Executor;

Expand Down Expand Up @@ -50,6 +51,11 @@ public void get(String name, SyscallCallback<Deferred<ByteString>> callback) {
syscallsExecutor.execute(() -> syscalls.get(name, callback));
}

@Override
public void getKeys(SyscallCallback<Deferred<Collection<String>>> callback) {
syscallsExecutor.execute(() -> syscalls.getKeys(callback));
}

@Override
public void clear(String name, SyscallCallback<Void> callback) {
syscallsExecutor.execute(() -> syscalls.clear(name, callback));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ public static MessageHeader fromMessage(MessageLite msg) {
return new MessageHeader(MessageType.ClearStateEntryMessage, 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.ClearAllStateEntryMessage) {
return new MessageHeader(MessageType.ClearAllStateEntryMessage, 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.GetStateKeysEntryMessage) {
return new MessageHeader(
MessageType.GetStateKeysEntryMessage,
((Protocol.GetStateKeysEntryMessage) msg).getResultCase()
!= Protocol.GetStateKeysEntryMessage.ResultCase.RESULT_NOT_SET
? DONE_FLAG
: 0,
msg.getSerializedSize());
} else if (msg instanceof Protocol.SleepEntryMessage) {
return new MessageHeader(
MessageType.SleepEntryMessage,
Expand Down
8 changes: 8 additions & 0 deletions sdk-core/src/main/java/dev/restate/sdk/core/MessageType.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public enum MessageType {
SetStateEntryMessage,
ClearStateEntryMessage,
ClearAllStateEntryMessage,
GetStateKeysEntryMessage,

// Syscalls
SleepEntryMessage,
Expand All @@ -54,6 +55,7 @@ public enum MessageType {
public static final short SET_STATE_ENTRY_MESSAGE_TYPE = 0x0801;
public static final short CLEAR_STATE_ENTRY_MESSAGE_TYPE = 0x0802;
public static final short CLEAR_ALL_STATE_ENTRY_MESSAGE_TYPE = 0x0803;
public static final short GET_STATE_KEYS_ENTRY_MESSAGE_TYPE = 0x0804;
public static final short SLEEP_ENTRY_MESSAGE_TYPE = 0x0C00;
public static final short INVOKE_ENTRY_MESSAGE_TYPE = 0x0C01;
public static final short BACKGROUND_INVOKE_ENTRY_MESSAGE_TYPE = 0x0C02;
Expand Down Expand Up @@ -88,6 +90,8 @@ public Parser<? extends MessageLite> messageParser() {
return Protocol.ClearStateEntryMessage.parser();
case ClearAllStateEntryMessage:
return Protocol.ClearAllStateEntryMessage.parser();
case GetStateKeysEntryMessage:
return Protocol.GetStateKeysEntryMessage.parser();
case SleepEntryMessage:
return Protocol.SleepEntryMessage.parser();
case InvokeEntryMessage:
Expand Down Expand Up @@ -132,6 +136,8 @@ public short encode() {
return CLEAR_STATE_ENTRY_MESSAGE_TYPE;
case ClearAllStateEntryMessage:
return CLEAR_ALL_STATE_ENTRY_MESSAGE_TYPE;
case GetStateKeysEntryMessage:
return GET_STATE_KEYS_ENTRY_MESSAGE_TYPE;
case SleepEntryMessage:
return SLEEP_ENTRY_MESSAGE_TYPE;
case InvokeEntryMessage:
Expand Down Expand Up @@ -176,6 +182,8 @@ public static MessageType decode(short value) throws ProtocolException {
return ClearStateEntryMessage;
case CLEAR_ALL_STATE_ENTRY_MESSAGE_TYPE:
return ClearAllStateEntryMessage;
case GET_STATE_KEYS_ENTRY_MESSAGE_TYPE:
return GetStateKeysEntryMessage;
case SLEEP_ENTRY_MESSAGE_TYPE:
return SleepEntryMessage;
case INVOKE_ENTRY_MESSAGE_TYPE:
Expand Down
17 changes: 14 additions & 3 deletions sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
import java.nio.ByteBuffer;
import java.time.Duration;
import java.time.Instant;
import java.util.AbstractMap;
import java.util.Base64;
import java.util.Map;
import java.util.*;
import java.util.function.Function;
import javax.annotation.Nullable;
import org.apache.logging.log4j.LogManager;
Expand Down Expand Up @@ -104,6 +102,19 @@ public void get(String name, SyscallCallback<Deferred<ByteString>> callback) {
callback);
}

@Override
public void getKeys(SyscallCallback<Deferred<Collection<String>>> callback) {
wrapAndPropagateExceptions(
() -> {
LOG.trace("get keys");
this.stateMachine.processCompletableJournalEntry(
Protocol.GetStateKeysEntryMessage.newBuilder().build(),
GetStateKeysEntry.INSTANCE,
callback);
},
callback);
}

@Override
public void clear(String name, SyscallCallback<Void> callback) {
wrapAndPropagateExceptions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import com.google.protobuf.ByteString;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

final class UserStateStore {
Expand Down Expand Up @@ -68,4 +69,12 @@ public void clearAll() {
this.map.clear();
this.isPartial = false;
}

public boolean isComplete() {
return !isPartial;
}

public Set<ByteString> keys() {
return this.map.keySet();
}
}
3 changes: 2 additions & 1 deletion sdk-core/src/main/java/dev/restate/sdk/core/Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ static boolean isTerminalException(Throwable throwable) {

static void assertIsEntry(MessageLite msg) {
if (!isEntry(msg)) {
throw new IllegalStateException("Expected input to be entry");
throw new IllegalStateException("Expected input to be entry: " + msg);
}
}

Expand All @@ -110,6 +110,7 @@ static boolean isEntry(MessageLite msg) {
return msg instanceof Protocol.PollInputStreamEntryMessage
|| msg instanceof Protocol.OutputStreamEntryMessage
|| msg instanceof Protocol.GetStateEntryMessage
|| msg instanceof Protocol.GetStateKeysEntryMessage
|| msg instanceof Protocol.SetStateEntryMessage
|| msg instanceof Protocol.ClearStateEntryMessage
|| msg instanceof Protocol.ClearAllStateEntryMessage
Expand Down
Loading

0 comments on commit 8a74cae

Please sign in to comment.