From 64b4d7ae8ba4c9ceaf1b3afbc319e6b53d781d28 Mon Sep 17 00:00:00 2001 From: Francesco Guardiani Date: Mon, 5 Feb 2024 18:17:29 +0100 Subject: [PATCH] Split RestateContext interface in KeyedContext/UnkeyedContext (#213) * [java] Split RestateContext interface in KeyedContext/UnkeyedContext * [kotlin] Split RestateContext interface in KeyedContext/UnkeyedContext * Make sure the protoc code generator injects the correct interface * Update tests and examples --- .../dev/restate/sdk/examples/Counter.java | 16 +-- .../sdk/examples/VanillaGrpcCounter.java | 10 +- .../dev/restate/sdk/examples/CounterKt.kt | 12 +- .../java/dev/restate/sdk/gen/RestateGen.java | 8 ++ .../src/main/resources/javaStub.mustache | 42 +++--- .../src/main/resources/ktStub.mustache | 24 ++-- .../{RestateContextImpl.kt => ContextImpl.kt} | 3 +- .../main/kotlin/dev/restate/sdk/kotlin/api.kt | 120 +++++++++++------- .../dev/restate/sdk/kotlin/AwaitableTest.kt | 16 +-- .../dev/restate/sdk/kotlin/AwakeableIdTest.kt | 2 +- .../dev/restate/sdk/kotlin/EagerStateTest.kt | 8 +- .../dev/restate/sdk/kotlin/RandomTest.kt | 4 +- .../restate/sdk/kotlin/RestateCodegenTest.kt | 15 +-- .../dev/restate/sdk/kotlin/SideEffectTest.kt | 8 +- .../dev/restate/sdk/kotlin/SleepTest.kt | 4 +- .../sdk/kotlin/StateMachineFailuresTest.kt | 4 +- .../dev/restate/sdk/kotlin/StateTest.kt | 4 +- .../restate/sdk/kotlin/UserFailuresTest.kt | 4 +- .../main/java/dev/restate/sdk/Awakeable.java | 2 +- ...stateContextImpl.java => ContextImpl.java} | 4 +- .../dev/restate/sdk/GrpcChannelAdapter.java | 8 +- .../java/dev/restate/sdk/KeyedContext.java | 66 ++++++++++ .../java/dev/restate/sdk/RestateRandom.java | 4 +- .../java/dev/restate/sdk/RestateService.java | 13 +- ...estateContext.java => UnkeyedContext.java} | 90 +++++-------- .../java/dev/restate/sdk/AwaitableTest.java | 16 +-- .../java/dev/restate/sdk/AwakeableIdTest.java | 2 +- .../java/dev/restate/sdk/EagerStateTest.java | 8 +- .../restate/sdk/GrpcChannelAdapterTest.java | 4 +- .../test/java/dev/restate/sdk/RandomTest.java | 4 +- .../dev/restate/sdk/RestateCodegenTest.java | 12 +- .../java/dev/restate/sdk/SideEffectTest.java | 8 +- .../test/java/dev/restate/sdk/SleepTest.java | 4 +- .../restate/sdk/StateMachineFailuresTest.java | 4 +- .../test/java/dev/restate/sdk/StateTest.java | 8 +- .../dev/restate/sdk/UserFailuresTest.java | 4 +- .../testservices/BlockingGreeterService.java | 7 +- .../restate/sdk/http/vertx/HttpVertxTests.kt | 5 +- .../vertx/testservices/GreeterKtService.kt | 7 +- .../testservices/JavaCounterService.java | 3 +- .../testservices/KotlinCounterService.kt | 3 +- .../java/dev/restate/sdk/testing/Counter.java | 10 +- 42 files changed, 329 insertions(+), 271 deletions(-) rename sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/{RestateContextImpl.kt => ContextImpl.kt} (98%) rename sdk-api/src/main/java/dev/restate/sdk/{RestateContextImpl.java => ContextImpl.java} (98%) create mode 100644 sdk-api/src/main/java/dev/restate/sdk/KeyedContext.java rename sdk-api/src/main/java/dev/restate/sdk/{RestateContext.java => UnkeyedContext.java} (80%) diff --git a/examples/src/main/java/dev/restate/sdk/examples/Counter.java b/examples/src/main/java/dev/restate/sdk/examples/Counter.java index bffd31d3..9d4a112b 100644 --- a/examples/src/main/java/dev/restate/sdk/examples/Counter.java +++ b/examples/src/main/java/dev/restate/sdk/examples/Counter.java @@ -8,7 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.examples; -import dev.restate.sdk.RestateContext; +import dev.restate.sdk.KeyedContext; import dev.restate.sdk.common.CoreSerdes; import dev.restate.sdk.common.StateKey; import dev.restate.sdk.examples.generated.*; @@ -23,30 +23,28 @@ public class Counter extends CounterRestate.CounterRestateImplBase { private static final StateKey TOTAL = StateKey.of("total", CoreSerdes.JSON_LONG); @Override - public void reset(RestateContext ctx, CounterRequest request) { - restateContext().clear(TOTAL); + public void reset(KeyedContext ctx, CounterRequest request) { + ctx.clear(TOTAL); } @Override - public void add(RestateContext ctx, CounterAddRequest request) { + public void add(KeyedContext ctx, CounterAddRequest request) { long currentValue = ctx.get(TOTAL).orElse(0L); long newValue = currentValue + request.getValue(); ctx.set(TOTAL, newValue); } @Override - public GetResponse get(RestateContext context, CounterRequest request) { - long currentValue = restateContext().get(TOTAL).orElse(0L); + public GetResponse get(KeyedContext ctx, CounterRequest request) { + long currentValue = ctx.get(TOTAL).orElse(0L); return GetResponse.newBuilder().setValue(currentValue).build(); } @Override - public CounterUpdateResult getAndAdd(RestateContext context, CounterAddRequest request) { + public CounterUpdateResult getAndAdd(KeyedContext ctx, CounterAddRequest request) { LOG.info("Invoked get and add with " + request.getValue()); - RestateContext ctx = restateContext(); - long currentValue = ctx.get(TOTAL).orElse(0L); long newValue = currentValue + request.getValue(); ctx.set(TOTAL, newValue); diff --git a/examples/src/main/java/dev/restate/sdk/examples/VanillaGrpcCounter.java b/examples/src/main/java/dev/restate/sdk/examples/VanillaGrpcCounter.java index f429dbb9..fc8611ab 100644 --- a/examples/src/main/java/dev/restate/sdk/examples/VanillaGrpcCounter.java +++ b/examples/src/main/java/dev/restate/sdk/examples/VanillaGrpcCounter.java @@ -9,7 +9,7 @@ package dev.restate.sdk.examples; import com.google.protobuf.Empty; -import dev.restate.sdk.RestateContext; +import dev.restate.sdk.KeyedContext; import dev.restate.sdk.RestateService; import dev.restate.sdk.common.CoreSerdes; import dev.restate.sdk.common.StateKey; @@ -27,7 +27,7 @@ public class VanillaGrpcCounter extends CounterGrpc.CounterImplBase implements R @Override public void reset(CounterRequest request, StreamObserver responseObserver) { - restateContext().clear(TOTAL); + KeyedContext.current().clear(TOTAL); responseObserver.onNext(Empty.getDefaultInstance()); responseObserver.onCompleted(); @@ -35,7 +35,7 @@ public void reset(CounterRequest request, StreamObserver responseObserver @Override public void add(CounterAddRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); long currentValue = ctx.get(TOTAL).orElse(0L); long newValue = currentValue + request.getValue(); @@ -47,7 +47,7 @@ public void add(CounterAddRequest request, StreamObserver responseObserve @Override public void get(CounterRequest request, StreamObserver responseObserver) { - long currentValue = restateContext().get(TOTAL).orElse(0L); + long currentValue = KeyedContext.current().get(TOTAL).orElse(0L); responseObserver.onNext(GetResponse.newBuilder().setValue(currentValue).build()); responseObserver.onCompleted(); @@ -58,7 +58,7 @@ public void getAndAdd( CounterAddRequest request, StreamObserver responseObserver) { LOG.info("Invoked get and add with " + request.getValue()); - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); long currentValue = ctx.get(TOTAL).orElse(0L); long newValue = currentValue + request.getValue(); diff --git a/examples/src/main/kotlin/dev/restate/sdk/examples/CounterKt.kt b/examples/src/main/kotlin/dev/restate/sdk/examples/CounterKt.kt index d189d1c6..9bdf3b67 100644 --- a/examples/src/main/kotlin/dev/restate/sdk/examples/CounterKt.kt +++ b/examples/src/main/kotlin/dev/restate/sdk/examples/CounterKt.kt @@ -12,7 +12,7 @@ import dev.restate.sdk.common.CoreSerdes import dev.restate.sdk.common.StateKey import dev.restate.sdk.examples.generated.* import dev.restate.sdk.http.vertx.RestateHttpEndpointBuilder -import dev.restate.sdk.kotlin.RestateContext +import dev.restate.sdk.kotlin.KeyedContext import org.apache.logging.log4j.LogManager class CounterKt : CounterRestateKt.CounterRestateKtImplBase() { @@ -21,20 +21,20 @@ class CounterKt : CounterRestateKt.CounterRestateKtImplBase() { private val TOTAL = StateKey.of("total", CoreSerdes.JSON_LONG) - override suspend fun reset(context: RestateContext, request: CounterRequest) { + override suspend fun reset(context: KeyedContext, request: CounterRequest) { context.clear(TOTAL) } - override suspend fun add(context: RestateContext, request: CounterAddRequest) { + override suspend fun add(context: KeyedContext, request: CounterAddRequest) { updateCounter(context, request.value) } - override suspend fun get(context: RestateContext, request: CounterRequest): GetResponse { + override suspend fun get(context: KeyedContext, request: CounterRequest): GetResponse { return getResponse { value = context.get(TOTAL) ?: 0L } } override suspend fun getAndAdd( - context: RestateContext, + context: KeyedContext, request: CounterAddRequest ): CounterUpdateResult { LOG.info("Invoked get and add with " + request.value) @@ -45,7 +45,7 @@ class CounterKt : CounterRestateKt.CounterRestateKtImplBase() { } } - private suspend fun updateCounter(context: RestateContext, add: Long): Pair { + private suspend fun updateCounter(context: KeyedContext, add: Long): Pair { val currentValue = context.get(TOTAL) ?: 0L val newValue = currentValue + add diff --git a/protoc-gen-restate/src/main/java/dev/restate/sdk/gen/RestateGen.java b/protoc-gen-restate/src/main/java/dev/restate/sdk/gen/RestateGen.java index 6457ded8..b76d2919 100644 --- a/protoc-gen-restate/src/main/java/dev/restate/sdk/gen/RestateGen.java +++ b/protoc-gen-restate/src/main/java/dev/restate/sdk/gen/RestateGen.java @@ -15,6 +15,7 @@ import com.salesforce.jprotoc.ProtoTypeMap; import com.salesforce.jprotoc.ProtocPlugin; import dev.restate.generated.ext.Ext; +import dev.restate.generated.ext.ServiceType; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -120,6 +121,12 @@ private ServiceContext buildServiceContext( serviceContext.serviceName = serviceProto.getName(); serviceContext.deprecated = serviceProto.getOptions().getDeprecated(); + // Resolve context type + serviceContext.contextType = + serviceProto.getOptions().getExtension(Ext.serviceType) == ServiceType.UNKEYED + ? "UnkeyedContext" + : "KeyedContext"; + // Resolve javadoc DescriptorProtos.SourceCodeInfo.Location serviceLocation = locations.stream() @@ -215,6 +222,7 @@ private static class ServiceContext { public String packageName; public String className; public String serviceName; + public String contextType; public String apidoc; public boolean deprecated; public final List methods = new ArrayList<>(); diff --git a/protoc-gen-restate/src/main/resources/javaStub.mustache b/protoc-gen-restate/src/main/resources/javaStub.mustache index d72ea327..f5ca127e 100644 --- a/protoc-gen-restate/src/main/resources/javaStub.mustache +++ b/protoc-gen-restate/src/main/resources/javaStub.mustache @@ -2,7 +2,8 @@ package {{packageName}}; {{/packageName}} -import dev.restate.sdk.RestateContext; +import dev.restate.sdk.UnkeyedContext; +import dev.restate.sdk.KeyedContext; import dev.restate.sdk.Awaitable; import dev.restate.sdk.common.syscalls.Syscalls; import java.time.Duration; @@ -15,24 +16,17 @@ public class {{className}} { private {{className}}() {} /** - * Create a new client. - */ - public static {{serviceName}}RestateClient newClient() { - return newClient(RestateContext.fromSyscalls(Syscalls.current())); - } - - /** - * Create a new client from the given {@link RestateContext}. + * Create a new client from the given {@link KeyedContext}. */ - public static {{serviceName}}RestateClient newClient(RestateContext ctx) { + public static {{serviceName}}RestateClient newClient(UnkeyedContext ctx) { return new {{serviceName}}RestateClient(ctx); } {{{apidoc}}} public static final class {{serviceName}}RestateClient { - private final RestateContext ctx; + private final UnkeyedContext ctx; - {{serviceName}}RestateClient(RestateContext ctx) { + {{serviceName}}RestateClient(UnkeyedContext ctx) { this.ctx = ctx; } @@ -57,9 +51,9 @@ public class {{className}} { } public static final class {{serviceName}}RestateOneWayClient { - private final RestateContext ctx; + private final UnkeyedContext ctx; - {{serviceName}}RestateOneWayClient(RestateContext ctx) { + {{serviceName}}RestateOneWayClient(UnkeyedContext ctx) { this.ctx = ctx; } @@ -74,10 +68,10 @@ public class {{className}} { } public static final class {{serviceName}}RestateDelayedClient { - private final RestateContext ctx; + private final UnkeyedContext ctx; private final Duration delay; - {{serviceName}}RestateDelayedClient(RestateContext ctx, Duration delay) { + {{serviceName}}RestateDelayedClient(UnkeyedContext ctx, Duration delay) { this.ctx = ctx; this.delay = delay; } @@ -100,7 +94,7 @@ public class {{className}} { @java.lang.Deprecated {{/deprecated}} {{{apidoc}}} - public {{#isOutputEmpty}}void{{/isOutputEmpty}}{{^isOutputEmpty}}{{outputType}}{{/isOutputEmpty}} {{methodName}}(RestateContext context{{^isInputEmpty}}, {{inputType}} request{{/isInputEmpty}}) throws dev.restate.sdk.common.TerminalException { + public {{#isOutputEmpty}}void{{/isOutputEmpty}}{{^isOutputEmpty}}{{outputType}}{{/isOutputEmpty}} {{methodName}}({{contextType}} context{{^isInputEmpty}}, {{inputType}} request{{/isInputEmpty}}) throws dev.restate.sdk.common.TerminalException { throw new dev.restate.sdk.common.TerminalException(dev.restate.sdk.common.TerminalException.Code.UNIMPLEMENTED); } @@ -120,34 +114,34 @@ public class {{className}} { private static final class HandlerAdapter implements io.grpc.stub.ServerCalls.UnaryMethod { - private final java.util.function.BiFunction handler; + private final java.util.function.BiFunction handler; - private HandlerAdapter(java.util.function.BiFunction handler) { + private HandlerAdapter(java.util.function.BiFunction handler) { this.handler = handler; } @Override public void invoke(Req request, io.grpc.stub.StreamObserver responseObserver) { - responseObserver.onNext(handler.apply(RestateContext.fromSyscalls(Syscalls.current()), request)); + responseObserver.onNext(handler.apply(KeyedContext.fromSyscalls(Syscalls.current()), request)); responseObserver.onCompleted(); } - private static HandlerAdapter of(java.util.function.BiFunction handler) { + private static HandlerAdapter of(java.util.function.BiFunction handler) { return new HandlerAdapter<>(handler); } - private static HandlerAdapter of(java.util.function.Function handler) { + private static HandlerAdapter of(java.util.function.Function handler) { return new HandlerAdapter<>((ctx, e) -> handler.apply(ctx)); } - private static HandlerAdapter of(java.util.function.BiConsumer handler) { + private static HandlerAdapter of(java.util.function.BiConsumer handler) { return new HandlerAdapter<>((ctx, req) -> { handler.accept(ctx, req); return com.google.protobuf.Empty.getDefaultInstance(); }); } - private static HandlerAdapter of(java.util.function.Consumer handler) { + private static HandlerAdapter of(java.util.function.Consumer handler) { return new HandlerAdapter<>((ctx, req) -> { handler.accept(ctx); return com.google.protobuf.Empty.getDefaultInstance(); diff --git a/protoc-gen-restate/src/main/resources/ktStub.mustache b/protoc-gen-restate/src/main/resources/ktStub.mustache index 5ffb27e2..b475fa05 100644 --- a/protoc-gen-restate/src/main/resources/ktStub.mustache +++ b/protoc-gen-restate/src/main/resources/ktStub.mustache @@ -2,14 +2,14 @@ package {{packageName}}; {{/packageName}} -import dev.restate.sdk.kotlin.RestateContext; +import dev.restate.sdk.kotlin.UnkeyedContext; +import dev.restate.sdk.kotlin.KeyedContext; import dev.restate.sdk.kotlin.Awaitable; import dev.restate.sdk.kotlin.RestateKtService; import dev.restate.sdk.common.syscalls.Syscalls; import io.grpc.kotlin.ClientCalls.unaryRpc import io.grpc.kotlin.ServerCalls.unaryServerMethodDefinition import {{packageName}}.{{serviceName}}Grpc.getServiceDescriptor; -import dev.restate.sdk.kotlin.restateContextFromSyscalls; import kotlin.time.Duration {{#deprecated}} @@ -18,14 +18,14 @@ import kotlin.time.Duration public object {{className}} { /** - * Create a new client from the given [RestateContext]. + * Create a new client from the given [UnkeyedContext]. */ - fun newClient(ctx: RestateContext): {{serviceName}}RestateKtClient { + fun newClient(ctx: UnkeyedContext): {{serviceName}}RestateKtClient { return {{serviceName}}RestateKtClient(ctx); } {{{javadoc}}} - public class {{serviceName}}RestateKtClient(private val ctx: RestateContext) { + public class {{serviceName}}RestateKtClient(private val ctx: UnkeyedContext) { // Create a variant of this client to execute oneWay calls. public fun oneWay(): {{serviceName}}RestateKtOneWayClient { return {{serviceName}}RestateKtOneWayClient(ctx); @@ -46,7 +46,7 @@ public object {{className}} { {{/methods}} } - public class {{serviceName}}RestateKtOneWayClient(private val ctx: RestateContext) { + public class {{serviceName}}RestateKtOneWayClient(private val ctx: UnkeyedContext) { {{#methods}} {{#deprecated}}@Deprecated{{/deprecated}} {{{javadoc}}} @@ -57,7 +57,7 @@ public object {{className}} { {{/methods}} } - public class {{serviceName}}RestateKtDelayedClient(private val ctx: RestateContext, private val delay: Duration) { + public class {{serviceName}}RestateKtDelayedClient(private val ctx: UnkeyedContext, private val delay: Duration) { {{#methods}} {{#deprecated}}@Deprecated{{/deprecated}} {{{javadoc}}} @@ -78,7 +78,7 @@ public object {{className}} { @Deprecated {{/deprecated}} {{{javadoc}}} - public open suspend fun {{methodName}}(context: RestateContext{{^isInputEmpty}}, request: {{inputType}} {{/isInputEmpty}}){{^isOutputEmpty}}: {{outputType}}{{/isOutputEmpty}} { + public open suspend fun {{methodName}}(context: {{contextType}}{{^isInputEmpty}}, request: {{inputType}} {{/isInputEmpty}}){{^isOutputEmpty}}: {{outputType}}{{/isOutputEmpty}} { throw dev.restate.sdk.common.TerminalException(dev.restate.sdk.common.TerminalException.Code.UNIMPLEMENTED); } @@ -91,18 +91,18 @@ public object {{className}} { descriptor = {{packageName}}.{{serviceName}}Grpc.{{methodDescriptorGetter}}(), implementation = { {{#isInputEmpty}}{{#isOutputEmpty}} - {{methodName}}(restateContextFromSyscalls(Syscalls.current())) + {{methodName}}(KeyedContext.fromSyscalls(Syscalls.current())) return@unaryServerMethodDefinition com.google.protobuf.Empty.getDefaultInstance() {{/isOutputEmpty}}{{/isInputEmpty}} {{#isInputEmpty}}{{^isOutputEmpty}} - return@unaryServerMethodDefinition {{methodName}}(restateContextFromSyscalls(Syscalls.current())) + return@unaryServerMethodDefinition {{methodName}}(KeyedContext.fromSyscalls(Syscalls.current())) {{/isOutputEmpty}}{{/isInputEmpty}} {{^isInputEmpty}}{{#isOutputEmpty}} - {{methodName}}(restateContextFromSyscalls(Syscalls.current()), it) + {{methodName}}(KeyedContext.fromSyscalls(Syscalls.current()), it) return@unaryServerMethodDefinition com.google.protobuf.Empty.getDefaultInstance() {{/isOutputEmpty}}{{/isInputEmpty}} {{^isInputEmpty}}{{^isOutputEmpty}} - return@unaryServerMethodDefinition {{methodName}}(restateContextFromSyscalls(Syscalls.current()), it) + return@unaryServerMethodDefinition {{methodName}}(KeyedContext.fromSyscalls(Syscalls.current()), it) {{/isOutputEmpty}}{{/isInputEmpty}} } )) diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/RestateContextImpl.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt similarity index 98% rename from sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/RestateContextImpl.kt rename to sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt index 2c353df4..8ec20672 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/RestateContextImpl.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/ContextImpl.kt @@ -24,8 +24,7 @@ import kotlin.time.Duration import kotlin.time.toJavaDuration import kotlinx.coroutines.* -internal class RestateContextImpl internal constructor(private val syscalls: Syscalls) : - RestateContext { +internal class ContextImpl internal constructor(private val syscalls: Syscalls) : KeyedContext { override suspend fun get(key: StateKey): T? { val deferred: Deferred = suspendCancellableCoroutine { cont: CancellableContinuation> -> diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt index bb5e9a8e..e03812d9 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt @@ -19,9 +19,9 @@ import kotlin.random.Random import kotlin.time.Duration /** - * This interface exposes the Restate functionalities to Restate services. It can be used to access - * the service instance key-value state storage, interact with other Restate services, record side - * effects, execute timers and synchronize with external systems. + * This interface exposes the Restate functionalities to Restate services. It can be used to + * interact with other Restate services, record side effects, execute timers and synchronize with + * external systems. * * To use it within your Restate service, implement [RestateKtService] and get an instance with * [RestateKtService.restateContext]. @@ -32,33 +32,7 @@ import kotlin.time.Duration * NOTE: This interface MUST NOT be accessed concurrently since it can lead to different orderings * of user actions, corrupting the execution of the invocation. */ -sealed interface RestateContext { - - /** - * Gets the state stored under key, deserializing the raw value using the registered - * [dev.restate.sdk.core.serde.Serde] in the interceptor. - * - * @param key identifying the state to get and its type. - * @return the value containing the stored state deserialized. - * @throws RuntimeException when the state cannot be deserialized. - */ - suspend fun get(key: StateKey): T? - - /** - * Sets the given value under the given key, serializing the value using the registered - * [dev.restate.sdk.core.serde.Serde] in the interceptor. - * - * @param key identifying the value to store and its type. - * @param value to store under the given key. - */ - suspend fun set(key: StateKey, value: T) - - /** - * Clears the state stored under key. - * - * @param key identifying the state to clear. - */ - suspend fun clear(key: StateKey<*>) +sealed interface UnkeyedContext { /** * Causes the current execution of the function invocation to sleep for the given duration. @@ -213,6 +187,73 @@ sealed interface RestateContext { * @return the [Random] instance. */ fun random(): RestateRandom + + companion object { + + /** + * Create a [UnkeyedContext]. This will look up the thread-local/async-context storage for the + * underlying context implementation, so make sure to call it always from the same context where + * the service is executed. + */ + fun current(): UnkeyedContext { + return fromSyscalls(Syscalls.current()) + } + + /** Build a context from the underlying [Syscalls] object. */ + fun fromSyscalls(syscalls: Syscalls): UnkeyedContext { + return ContextImpl(syscalls) + } + } +} + +/** + * This interface extends [UnkeyedContext] adding access to the service instance key-value state + * storage. + */ +sealed interface KeyedContext : UnkeyedContext { + + /** + * Gets the state stored under key, deserializing the raw value using the registered + * [dev.restate.sdk.core.serde.Serde] in the interceptor. + * + * @param key identifying the state to get and its type. + * @return the value containing the stored state deserialized. + * @throws RuntimeException when the state cannot be deserialized. + */ + suspend fun get(key: StateKey): T? + + /** + * Sets the given value under the given key, serializing the value using the registered + * [dev.restate.sdk.core.serde.Serde] in the interceptor. + * + * @param key identifying the value to store and its type. + * @param value to store under the given key. + */ + suspend fun set(key: StateKey, value: T) + + /** + * Clears the state stored under key. + * + * @param key identifying the state to clear. + */ + suspend fun clear(key: StateKey<*>) + + companion object { + + /** + * Create a [KeyedContext]. This will look up the thread-local/async-context storage for the + * underlying context implementation, so make sure to call it always from the same context where + * the service is executed. + */ + fun current(): KeyedContext { + return fromSyscalls(Syscalls.current()) + } + + /** Build a context from the underlying [Syscalls] object. */ + fun fromSyscalls(syscalls: Syscalls): KeyedContext { + return ContextImpl(syscalls) + } + } } class RestateRandom(seed: Long, private val syscalls: Syscalls) : Random() { @@ -357,7 +398,7 @@ sealed interface AwakeableHandle { } /** - * Marker interface for Restate services implemented using the [RestateContext] interface. + * Marker interface for Restate services. * * ## Error handling * @@ -367,19 +408,4 @@ sealed interface AwakeableHandle { * * When throwing any other type of exception, the failure is considered "non-terminal" and the * runtime will retry it, according to its configuration */ -interface RestateKtService : NonBlockingService { - /** @return an instance of the [RestateContext]. */ - fun restateContext(): RestateContext { - return RestateContextImpl(Syscalls.SYSCALLS_KEY.get()) - } -} - -/** - * Build a RestateContext from the [Syscalls] object. - * - * This method is used by code-generation, you should not use it directly but rather use - * [RestateKtService.restateContext]. - */ -fun restateContextFromSyscalls(syscalls: Syscalls): RestateContext { - return RestateContextImpl(syscalls) -} +interface RestateKtService : NonBlockingService diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/AwaitableTest.kt b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/AwaitableTest.kt index 0eb51d5b..760e4f94 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/AwaitableTest.kt +++ b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/AwaitableTest.kt @@ -21,7 +21,7 @@ class AwaitableTest : DeferredTestSuite() { private class ReverseAwaitOrder : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - val ctx = restateContext() + val ctx = KeyedContext.current() val a1 = ctx.callAsync(GreeterGrpcKt.greetMethod, greetingRequest { name = "Francesco" }) val a2 = ctx.callAsync(GreeterGrpcKt.greetMethod, greetingRequest { name = "Till" }) val a2Res = a2.await().getMessage() @@ -38,7 +38,7 @@ class AwaitableTest : DeferredTestSuite() { private class AwaitTwiceTheSameAwaitable : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - val ctx = restateContext() + val ctx = KeyedContext.current() val a = ctx.callAsync(GreeterGrpcKt.greetMethod, greetingRequest { name = "Francesco" }) return greetingResponse { message = a.await().getMessage() + "-" + a.await().getMessage() } } @@ -51,7 +51,7 @@ class AwaitableTest : DeferredTestSuite() { private class AwaitAll : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - val ctx = restateContext() + val ctx = KeyedContext.current() val a1 = ctx.callAsync(GreeterGrpcKt.greetMethod, greetingRequest { name = "Francesco" }) val a2 = ctx.callAsync(GreeterGrpcKt.greetMethod, greetingRequest { name = "Till" }) @@ -71,7 +71,7 @@ class AwaitableTest : DeferredTestSuite() { private class AwaitAny : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - val ctx = restateContext() + val ctx = KeyedContext.current() val a1 = ctx.callAsync(GreeterGrpcKt.greetMethod, greetingRequest { name = "Francesco" }) val a2 = ctx.callAsync(GreeterGrpcKt.greetMethod, greetingRequest { name = "Till" }) return Awaitable.any(a1, a2).await() as GreetingResponse @@ -81,7 +81,7 @@ class AwaitableTest : DeferredTestSuite() { private class AwaitSelect : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - val ctx = restateContext() + val ctx = KeyedContext.current() val a1 = ctx.callAsync(GreeterGrpcKt.greetMethod, greetingRequest { name = "Francesco" }) val a2 = ctx.callAsync(GreeterGrpcKt.greetMethod, greetingRequest { name = "Till" }) return select { @@ -98,7 +98,7 @@ class AwaitableTest : DeferredTestSuite() { private class CombineAnyWithAll : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - val ctx = restateContext() + val ctx = KeyedContext.current() val a1 = ctx.awakeable(CoreSerdes.JSON_STRING) val a2 = ctx.awakeable(CoreSerdes.JSON_STRING) val a3 = ctx.awakeable(CoreSerdes.JSON_STRING) @@ -121,7 +121,7 @@ class AwaitableTest : DeferredTestSuite() { private class AwaitAnyIndex : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - val ctx = restateContext() + val ctx = KeyedContext.current() val a1 = ctx.awakeable(CoreSerdes.JSON_STRING) val a2 = ctx.awakeable(CoreSerdes.JSON_STRING) val a3 = ctx.awakeable(CoreSerdes.JSON_STRING) @@ -140,7 +140,7 @@ class AwaitableTest : DeferredTestSuite() { private class AwaitOnAlreadyResolvedAwaitables : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - val ctx = restateContext() + val ctx = KeyedContext.current() val a1 = ctx.awakeable(CoreSerdes.JSON_STRING) val a2 = ctx.awakeable(CoreSerdes.JSON_STRING) val a12 = Awaitable.all(a1, a2) diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/AwakeableIdTest.kt b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/AwakeableIdTest.kt index 25c6fea0..16659768 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/AwakeableIdTest.kt +++ b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/AwakeableIdTest.kt @@ -22,7 +22,7 @@ class AwakeableIdTest : AwakeableIdTestSuite() { GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - val id: String = restateContext().awakeable(CoreSerdes.JSON_STRING).id + val id: String = KeyedContext.current().awakeable(CoreSerdes.JSON_STRING).id return greetingResponse { message = id } } } diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/EagerStateTest.kt b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/EagerStateTest.kt index 64c80116..b2d9454c 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/EagerStateTest.kt +++ b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/EagerStateTest.kt @@ -23,7 +23,7 @@ class EagerStateTest : EagerStateTestSuite() { private class GetEmpty : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - val ctx = restateContext() + val ctx = KeyedContext.current() val stateIsEmpty = ctx.get(StateKey.of("STATE", CoreSerdes.JSON_STRING)) == null return greetingResponse { message = stateIsEmpty.toString() } } @@ -37,7 +37,7 @@ class EagerStateTest : EagerStateTestSuite() { GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { return greetingResponse { - message = restateContext().get(StateKey.of("STATE", CoreSerdes.JSON_STRING))!! + message = KeyedContext.current().get(StateKey.of("STATE", CoreSerdes.JSON_STRING))!! } } } @@ -49,7 +49,7 @@ class EagerStateTest : EagerStateTestSuite() { private class GetAppendAndGet : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - val ctx = restateContext() + val ctx = KeyedContext.current() val oldState = ctx.get(StateKey.of("STATE", CoreSerdes.JSON_STRING))!! ctx.set(StateKey.of("STATE", CoreSerdes.JSON_STRING), oldState + request.getName()) val newState = ctx.get(StateKey.of("STATE", CoreSerdes.JSON_STRING))!! @@ -64,7 +64,7 @@ class EagerStateTest : EagerStateTestSuite() { private class GetClearAndGet : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - val ctx = restateContext() + val ctx = KeyedContext.current() val oldState = ctx.get(StateKey.of("STATE", CoreSerdes.JSON_STRING))!! ctx.clear(StateKey.of("STATE", CoreSerdes.JSON_STRING)) assertThat(ctx.get(StateKey.of("STATE", CoreSerdes.JSON_STRING))).isNull() diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/RandomTest.kt b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/RandomTest.kt index 106c8c62..44220a92 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/RandomTest.kt +++ b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/RandomTest.kt @@ -22,7 +22,7 @@ class RandomTest : RandomTestSuite() { GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - val number = restateContext().random().nextInt() + val number = KeyedContext.current().random().nextInt() return greetingResponse { message = number.toString() } } } @@ -34,7 +34,7 @@ class RandomTest : RandomTestSuite() { private class RandomInsideSideEffect : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - val ctx = restateContext() + val ctx = KeyedContext.current() ctx.sideEffect { ctx.random().nextInt() } throw IllegalStateException("This should not unreachable") } diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/RestateCodegenTest.kt b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/RestateCodegenTest.kt index cdaec2bc..0d2b8793 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/RestateCodegenTest.kt +++ b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/RestateCodegenTest.kt @@ -17,10 +17,7 @@ import kotlin.time.Duration.Companion.seconds class RestateCodegenTest : RestateCodegenTestSuite() { private class GreeterWithRestateClientAndServerCodegen : GreeterRestateKtImplBase() { - override suspend fun greet( - context: RestateContext, - request: GreetingRequest - ): GreetingResponse { + override suspend fun greet(context: KeyedContext, request: GreetingRequest): GreetingResponse { val client = GreeterRestateKt.newClient(context) client.delayed(1.seconds).greet(request) client.oneWay().greet(request) @@ -33,27 +30,27 @@ class RestateCodegenTest : RestateCodegenTestSuite() { } private class Codegen : CodegenRestateKtImplBase() { - override suspend fun emptyInput(context: RestateContext): MyMessage { + override suspend fun emptyInput(context: UnkeyedContext): MyMessage { val client = CodegenRestateKt.newClient(context) return client.emptyInput().await() } - override suspend fun emptyOutput(context: RestateContext, request: MyMessage) { + override suspend fun emptyOutput(context: UnkeyedContext, request: MyMessage) { val client = CodegenRestateKt.newClient(context) client.emptyOutput(request).await() } - override suspend fun emptyInputOutput(context: RestateContext) { + override suspend fun emptyInputOutput(context: UnkeyedContext) { val client = CodegenRestateKt.newClient(context) client.emptyInputOutput().await() } - override suspend fun oneWay(context: RestateContext, request: MyMessage): MyMessage { + override suspend fun oneWay(context: UnkeyedContext, request: MyMessage): MyMessage { val client = CodegenRestateKt.newClient(context) return client._oneWay(request).await() } - override suspend fun delayed(context: RestateContext, request: MyMessage): MyMessage { + override suspend fun delayed(context: UnkeyedContext, request: MyMessage): MyMessage { val client = CodegenRestateKt.newClient(context) return client._delayed(request).await() } diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SideEffectTest.kt b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SideEffectTest.kt index 67bf2f9a..ecc1cacf 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SideEffectTest.kt +++ b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SideEffectTest.kt @@ -20,7 +20,7 @@ class SideEffectTest : SideEffectTestSuite() { private class SideEffect(private val sideEffectOutput: String) : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - val ctx: RestateContext = restateContext() + val ctx = KeyedContext.current() val result = ctx.sideEffect(CoreSerdes.JSON_STRING) { sideEffectOutput } return greetingResponse { message = "Hello $result" } } @@ -33,7 +33,7 @@ class SideEffectTest : SideEffectTestSuite() { private class ConsecutiveSideEffect(private val sideEffectOutput: String) : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - val ctx: RestateContext = restateContext() + val ctx = KeyedContext.current() val firstResult = ctx.sideEffect(CoreSerdes.JSON_STRING) { sideEffectOutput } val secondResult = ctx.sideEffect(CoreSerdes.JSON_STRING) { firstResult.uppercase(Locale.getDefault()) } @@ -52,7 +52,7 @@ class SideEffectTest : SideEffectTestSuite() { override suspend fun greet(request: GreetingRequest): GreetingResponse { val sideEffectThread = - restateContext().sideEffect(CoreSerdes.JSON_STRING) { Thread.currentThread().name } + KeyedContext.current().sideEffect(CoreSerdes.JSON_STRING) { Thread.currentThread().name } check(sideEffectThread.contains("CheckContextSwitchingTestCoroutine")) { "Side effect thread is not running within the same coroutine context of the handler method: $sideEffectThread" } @@ -67,7 +67,7 @@ class SideEffectTest : SideEffectTestSuite() { private class SideEffectGuard : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - val ctx = restateContext() + val ctx = KeyedContext.current() ctx.sideEffect { ctx.oneWayCall(GreeterGrpcKt.greetMethod, greetingRequest { name = "something" }) } diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SleepTest.kt b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SleepTest.kt index cedf9929..bb029ec6 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SleepTest.kt +++ b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/SleepTest.kt @@ -21,7 +21,7 @@ class SleepTest : SleepTestSuite() { private class SleepGreeter : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - val ctx = restateContext() + val ctx = KeyedContext.current() ctx.sleep(1000.milliseconds) return greetingResponse { message = "Hello" } } @@ -34,7 +34,7 @@ class SleepTest : SleepTestSuite() { private class ManySleeps : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - val ctx = restateContext() + val ctx = KeyedContext.current() val awaitables = mutableListOf>() for (i in 0..9) { awaitables.add(ctx.timer(1000.milliseconds)) diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/StateMachineFailuresTest.kt b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/StateMachineFailuresTest.kt index e5b7d36b..5286c3ad 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/StateMachineFailuresTest.kt +++ b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/StateMachineFailuresTest.kt @@ -27,7 +27,7 @@ class StateMachineFailuresTest : StateMachineFailuresTestSuite() { GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { try { - restateContext().get(STATE) + KeyedContext.current().get(STATE) } catch (e: Throwable) { // A user should never catch Throwable!!! if (e !is CancellationException && e !is TerminalException) { @@ -57,7 +57,7 @@ class StateMachineFailuresTest : StateMachineFailuresTestSuite() { private class SideEffectFailure(private val serde: Serde) : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - restateContext().sideEffect(serde) { 0 } + KeyedContext.current().sideEffect(serde) { 0 } return greetingResponse { message = "Francesco" } } } diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/StateTest.kt b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/StateTest.kt index 0bc85648..667868d5 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/StateTest.kt +++ b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/StateTest.kt @@ -23,7 +23,7 @@ class StateTest : StateTestSuite() { GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { val state: String = - restateContext().get(StateKey.of("STATE", CoreSerdes.JSON_STRING)) ?: "Unknown" + KeyedContext.current().get(StateKey.of("STATE", CoreSerdes.JSON_STRING)) ?: "Unknown" return greetingResponse { message = "Hello $state" } } } @@ -35,7 +35,7 @@ class StateTest : StateTestSuite() { private class GetAndSetState : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - val ctx = restateContext() + val ctx = KeyedContext.current() val state = ctx.get(StateKey.of("STATE", CoreSerdes.JSON_STRING))!! ctx.set(StateKey.of("STATE", CoreSerdes.JSON_STRING), request.getName()) diff --git a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/UserFailuresTest.kt b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/UserFailuresTest.kt index b3c44055..129f6243 100644 --- a/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/UserFailuresTest.kt +++ b/sdk-api-kotlin/src/test/kotlin/dev/restate/sdk/kotlin/UserFailuresTest.kt @@ -35,7 +35,7 @@ class UserFailuresTest : UserFailuresTestSuite() { ) : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { try { - restateContext().sideEffect { throw IllegalStateException("Whatever") } + KeyedContext.current().sideEffect { throw IllegalStateException("Whatever") } } catch (e: Throwable) { if (e !is CancellationException && e !is TerminalException) { nonTerminalExceptionsSeen.addAndGet(1) @@ -74,7 +74,7 @@ class UserFailuresTest : UserFailuresTestSuite() { private val message: String ) : GreeterGrpcKt.GreeterCoroutineImplBase(Dispatchers.Unconfined), RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { - restateContext().sideEffect { throw TerminalException(code, message) } + KeyedContext.current().sideEffect { throw TerminalException(code, message) } throw IllegalStateException("Not expected to reach this point") } } diff --git a/sdk-api/src/main/java/dev/restate/sdk/Awakeable.java b/sdk-api/src/main/java/dev/restate/sdk/Awakeable.java index 24fdba1d..ef8f380a 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/Awakeable.java +++ b/sdk-api/src/main/java/dev/restate/sdk/Awakeable.java @@ -24,7 +24,7 @@ * *

For example, you can send a Kafka record including the {@link Awakeable#id()}, and then let * another service consume from Kafka the responses of given external system interaction by using - * {@link RestateContext#awakeableHandle(String)}. + * {@link KeyedContext#awakeableHandle(String)}. */ @NotThreadSafe public final class Awakeable extends Awaitable.MappedAwaitable { diff --git a/sdk-api/src/main/java/dev/restate/sdk/RestateContextImpl.java b/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java similarity index 98% rename from sdk-api/src/main/java/dev/restate/sdk/RestateContextImpl.java rename to sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java index 59f74ace..e90b9d9f 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/RestateContextImpl.java +++ b/sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java @@ -23,11 +23,11 @@ import javax.annotation.Nonnull; import javax.annotation.Nullable; -class RestateContextImpl implements RestateContext { +class ContextImpl implements KeyedContext { private final Syscalls syscalls; - RestateContextImpl(Syscalls syscalls) { + ContextImpl(Syscalls syscalls) { this.syscalls = syscalls; } diff --git a/sdk-api/src/main/java/dev/restate/sdk/GrpcChannelAdapter.java b/sdk-api/src/main/java/dev/restate/sdk/GrpcChannelAdapter.java index 9544c018..c05b0fa7 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/GrpcChannelAdapter.java +++ b/sdk-api/src/main/java/dev/restate/sdk/GrpcChannelAdapter.java @@ -18,10 +18,10 @@ *

Keep in mind that this channel should be used only with generated blocking stubs. */ public class GrpcChannelAdapter extends Channel { - private final RestateContext restateContext; + private final UnkeyedContext ctx; - GrpcChannelAdapter(RestateContext restateContext) { - this.restateContext = restateContext; + GrpcChannelAdapter(UnkeyedContext ctx) { + this.ctx = ctx; } @Override @@ -56,7 +56,7 @@ public void halfClose() { @Override public void sendMessage(RequestT message) { - this.awaitable = restateContext.call(methodDescriptor, message); + this.awaitable = ctx.call(methodDescriptor, message); } }; } diff --git a/sdk-api/src/main/java/dev/restate/sdk/KeyedContext.java b/sdk-api/src/main/java/dev/restate/sdk/KeyedContext.java new file mode 100644 index 00000000..84ddad0c --- /dev/null +++ b/sdk-api/src/main/java/dev/restate/sdk/KeyedContext.java @@ -0,0 +1,66 @@ +// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH +// +// This file is part of the Restate Java SDK, +// which is released under the MIT license. +// +// You can find a copy of the license in file LICENSE in the root +// directory of this repository or package, or at +// https://github.com/restatedev/sdk-java/blob/main/LICENSE +package dev.restate.sdk; + +import dev.restate.sdk.common.*; +import dev.restate.sdk.common.syscalls.Syscalls; +import java.util.Optional; +import javax.annotation.Nonnull; +import javax.annotation.concurrent.NotThreadSafe; + +/** + * This interface extends {@link UnkeyedContext} adding access to the service instance key-value + * state storage + * + * @see UnkeyedContext + */ +@NotThreadSafe +public interface KeyedContext extends UnkeyedContext { + + /** + * Gets the state stored under key, deserializing the raw value using the {@link Serde} in the + * {@link StateKey}. + * + * @param key identifying the state to get and its type. + * @return an {@link Optional} containing the stored state deserialized or an empty {@link + * Optional} if not set yet. + * @throws RuntimeException when the state cannot be deserialized. + */ + Optional get(StateKey key); + + /** + * Clears the state stored under key. + * + * @param key identifying the state to clear. + */ + void clear(StateKey key); + + /** + * Sets the given value under the given key, serializing the value using the {@link Serde} in the + * {@link StateKey}. + * + * @param key identifying the value to store and its type. + * @param value to store under the given key. MUST NOT be null. + */ + void set(StateKey key, @Nonnull T value); + + /** + * Create a {@link KeyedContext}. This will look up the thread-local/async-context storage for the + * underlying context implementation, so make sure to call it always from the same context where + * the service is executed. + */ + static KeyedContext current() { + return fromSyscalls(Syscalls.current()); + } + + /** Build a RestateContext from the underlying {@link Syscalls} object. */ + static KeyedContext fromSyscalls(Syscalls syscalls) { + return new ContextImpl(syscalls); + } +} diff --git a/sdk-api/src/main/java/dev/restate/sdk/RestateRandom.java b/sdk-api/src/main/java/dev/restate/sdk/RestateRandom.java index 7cda1d15..ebc92d96 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/RestateRandom.java +++ b/sdk-api/src/main/java/dev/restate/sdk/RestateRandom.java @@ -21,9 +21,9 @@ * *

This instance is useful to generate identifiers, idempotency keys, and for uniform sampling * from a set of options. If a cryptographically secure value is needed, please generate that - * externally using {@link RestateContext#sideEffect(Serde, ThrowingSupplier)}. + * externally using {@link KeyedContext#sideEffect(Serde, ThrowingSupplier)}. * - *

You MUST NOT use this object inside a {@link RestateContext#sideEffect(Serde, + *

You MUST NOT use this object inside a {@link KeyedContext#sideEffect(Serde, * ThrowingSupplier)}. */ public class RestateRandom extends Random { diff --git a/sdk-api/src/main/java/dev/restate/sdk/RestateService.java b/sdk-api/src/main/java/dev/restate/sdk/RestateService.java index bcd14b46..546b8358 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/RestateService.java +++ b/sdk-api/src/main/java/dev/restate/sdk/RestateService.java @@ -10,10 +10,9 @@ import dev.restate.sdk.common.BlockingService; import dev.restate.sdk.common.TerminalException; -import dev.restate.sdk.common.syscalls.Syscalls; /** - * Marker interface for Restate services implemented using the {@link RestateContext} interface. + * Marker interface for Restate services. * *

* @@ -28,12 +27,4 @@ * runtime will retry it, according to its configuration * */ -public interface RestateService extends BlockingService { - - /** - * @return an instance of the {@link RestateContext}. - */ - default RestateContext restateContext() { - return RestateContext.fromSyscalls(Syscalls.current()); - } -} +public interface RestateService extends BlockingService {} diff --git a/sdk-api/src/main/java/dev/restate/sdk/RestateContext.java b/sdk-api/src/main/java/dev/restate/sdk/UnkeyedContext.java similarity index 80% rename from sdk-api/src/main/java/dev/restate/sdk/RestateContext.java rename to sdk-api/src/main/java/dev/restate/sdk/UnkeyedContext.java index af97234d..eda7d9e1 100644 --- a/sdk-api/src/main/java/dev/restate/sdk/RestateContext.java +++ b/sdk-api/src/main/java/dev/restate/sdk/UnkeyedContext.java @@ -15,17 +15,12 @@ import io.grpc.Channel; import io.grpc.MethodDescriptor; import java.time.Duration; -import java.util.Optional; -import javax.annotation.Nonnull; import javax.annotation.concurrent.NotThreadSafe; /** - * This interface exposes the Restate functionalities to Restate services. It can be used to access - * the service instance key-value state storage, interact with other Restate services, record side - * effects, execute timers and synchronize with external systems. - * - *

To use it within your Restate service, implement {@link RestateService} and get an instance - * with {@link RestateService#restateContext()}. + * This interface exposes the Restate functionalities to Restate services. It can be used to + * interact with other Restate services, record side effects, execute timers and synchronize with + * external systems. * *

All methods of this interface, and related interfaces, throws either {@link TerminalException} * or {@link AbortedExecutionException}, where the former can be caught and acted upon, while the @@ -35,51 +30,7 @@ * orderings of user actions, corrupting the execution of the invocation. */ @NotThreadSafe -public interface RestateContext { - - /** - * Gets the state stored under key, deserializing the raw value using the {@link Serde} in the - * {@link StateKey}. - * - * @param key identifying the state to get and its type. - * @return an {@link Optional} containing the stored state deserialized or an empty {@link - * Optional} if not set yet. - * @throws RuntimeException when the state cannot be deserialized. - */ - Optional get(StateKey key); - - /** - * Clears the state stored under key. - * - * @param key identifying the state to clear. - */ - void clear(StateKey key); - - /** - * Sets the given value under the given key, serializing the value using the {@link Serde} in the - * {@link StateKey}. - * - * @param key identifying the value to store and its type. - * @param value to store under the given key. MUST NOT be null. - */ - void set(StateKey key, @Nonnull T value); - - /** - * Causes the current execution of the function invocation to sleep for the given duration. - * - * @param duration for which to sleep. - */ - default void sleep(Duration duration) { - timer(duration).await(); - } - - /** - * Causes the start of a timer for the given duration. You can await on the timer end by invoking - * {@link Awaitable#await()}. - * - * @param duration for which to sleep. - */ - Awaitable timer(Duration duration); +public interface UnkeyedContext { /** * Invoke another Restate service method. @@ -128,6 +79,23 @@ default Channel grpcChannel() { */ void delayedCall(MethodDescriptor methodDescriptor, T parameter, Duration delay); + /** + * Causes the current execution of the function invocation to sleep for the given duration. + * + * @param duration for which to sleep. + */ + default void sleep(Duration duration) { + timer(duration).await(); + } + + /** + * Causes the start of a timer for the given duration. You can await on the timer end by invoking + * {@link Awaitable#await()}. + * + * @param duration for which to sleep. + */ + Awaitable timer(Duration duration); + /** * Execute a non-deterministic closure, recording the result value in the journal. The result * value will be re-played in case of re-invocation (e.g. because of failure recovery or @@ -213,12 +181,16 @@ default void sideEffect(ThrowingRunnable runnable) throws TerminalException { RestateRandom random(); /** - * Build a RestateContext from the underlying {@link Syscalls} object. - * - *

This method is used by code-generation, you should not use it directly but rather use {@link - * RestateService#restateContext()}. + * Create a {@link KeyedContext}. This will look up the thread-local/async-context storage for the + * underlying context implementation, so make sure to call it always from the same context where + * the service is executed. */ - static RestateContext fromSyscalls(Syscalls syscalls) { - return new RestateContextImpl(syscalls); + static UnkeyedContext current() { + return fromSyscalls(Syscalls.current()); + } + + /** Build a RestateContext from the underlying {@link Syscalls} object. */ + static UnkeyedContext fromSyscalls(Syscalls syscalls) { + return new ContextImpl(syscalls); } } diff --git a/sdk-api/src/test/java/dev/restate/sdk/AwaitableTest.java b/sdk-api/src/test/java/dev/restate/sdk/AwaitableTest.java index 6a0420e3..91f6142d 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/AwaitableTest.java +++ b/sdk-api/src/test/java/dev/restate/sdk/AwaitableTest.java @@ -28,7 +28,7 @@ private static class ReverseAwaitOrder extends GreeterGrpc.GreeterImplBase implements RestateService { @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); Awaitable a1 = ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Francesco")); @@ -54,7 +54,7 @@ private static class AwaitTwiceTheSameAwaitable extends GreeterGrpc.GreeterImplB implements RestateService { @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); Awaitable a = ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Francesco")); @@ -73,7 +73,7 @@ protected BindableService awaitTwiceTheSameAwaitable() { private static class AwaitAll extends GreeterGrpc.GreeterImplBase implements RestateService { @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); Awaitable a1 = ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Francesco")); @@ -96,7 +96,7 @@ protected BindableService awaitAll() { private static class AwaitAny extends GreeterGrpc.GreeterImplBase implements RestateService { @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); Awaitable a1 = ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Francesco")); @@ -119,7 +119,7 @@ private static class CombineAnyWithAll extends GreeterGrpc.GreeterImplBase implements RestateService { @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); Awaitable a1 = ctx.awakeable(CoreSerdes.JSON_STRING); Awaitable a2 = ctx.awakeable(CoreSerdes.JSON_STRING); @@ -144,7 +144,7 @@ protected BindableService combineAnyWithAll() { private static class AwaitAnyIndex extends GreeterGrpc.GreeterImplBase implements RestateService { @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); Awaitable a1 = ctx.awakeable(CoreSerdes.JSON_STRING); Awaitable a2 = ctx.awakeable(CoreSerdes.JSON_STRING); @@ -167,7 +167,7 @@ private static class AwaitOnAlreadyResolvedAwaitables extends GreeterGrpc.Greete implements RestateService { @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); Awaitable a1 = ctx.awakeable(CoreSerdes.JSON_STRING); Awaitable a2 = ctx.awakeable(CoreSerdes.JSON_STRING); @@ -193,7 +193,7 @@ private static class AwaitWithTimeout extends GreeterGrpc.GreeterImplBase implements RestateService { @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); Awaitable call = ctx.call(GreeterGrpc.getGreetMethod(), greetingRequest("Francesco")); diff --git a/sdk-api/src/test/java/dev/restate/sdk/AwakeableIdTest.java b/sdk-api/src/test/java/dev/restate/sdk/AwakeableIdTest.java index ed6e93e9..4cf9895d 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/AwakeableIdTest.java +++ b/sdk-api/src/test/java/dev/restate/sdk/AwakeableIdTest.java @@ -25,7 +25,7 @@ private static class ReturnAwakeableId extends GreeterGrpc.GreeterImplBase @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - String id = restateContext().awakeable(CoreSerdes.JSON_STRING).id(); + String id = KeyedContext.current().awakeable(CoreSerdes.JSON_STRING).id(); responseObserver.onNext(greetingResponse(id)); responseObserver.onCompleted(); } diff --git a/sdk-api/src/test/java/dev/restate/sdk/EagerStateTest.java b/sdk-api/src/test/java/dev/restate/sdk/EagerStateTest.java index 7d15f986..0dd20743 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/EagerStateTest.java +++ b/sdk-api/src/test/java/dev/restate/sdk/EagerStateTest.java @@ -24,7 +24,7 @@ public class EagerStateTest extends EagerStateTestSuite { private static class GetEmpty extends GreeterGrpc.GreeterImplBase implements RestateService { @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); boolean stateIsEmpty = ctx.get(StateKey.of("STATE", CoreSerdes.JSON_STRING)).isEmpty(); @@ -42,7 +42,7 @@ protected BindableService getEmpty() { private static class Get extends GreeterGrpc.GreeterImplBase implements RestateService { @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); String state = ctx.get(StateKey.of("STATE", CoreSerdes.JSON_STRING)).get(); @@ -60,7 +60,7 @@ private static class GetAppendAndGet extends GreeterGrpc.GreeterImplBase implements RestateService { @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); String oldState = ctx.get(StateKey.of("STATE", CoreSerdes.JSON_STRING)).get(); ctx.set(StateKey.of("STATE", CoreSerdes.JSON_STRING), oldState + request.getName()); @@ -81,7 +81,7 @@ private static class GetClearAndGet extends GreeterGrpc.GreeterImplBase implements RestateService { @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); String oldState = ctx.get(StateKey.of("STATE", CoreSerdes.JSON_STRING)).get(); diff --git a/sdk-api/src/test/java/dev/restate/sdk/GrpcChannelAdapterTest.java b/sdk-api/src/test/java/dev/restate/sdk/GrpcChannelAdapterTest.java index 7453feb6..ee10d486 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/GrpcChannelAdapterTest.java +++ b/sdk-api/src/test/java/dev/restate/sdk/GrpcChannelAdapterTest.java @@ -27,7 +27,7 @@ private static class InvokeUsingGeneratedClient extends GreeterGrpc.GreeterImplB implements RestateService { @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); GreeterGrpc.GreeterBlockingStub client = GreeterGrpc.newBlockingStub(ctx.grpcChannel()); String response = client.greet(GreetingRequest.newBuilder().setName("Francesco").build()).getMessage(); @@ -41,7 +41,7 @@ private static class InvokeUsingGeneratedFutureClient extends GreeterGrpc.Greete implements RestateService { @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); GreeterGrpc.GreeterFutureStub client = GreeterGrpc.newFutureStub(ctx.grpcChannel()); String response; try { diff --git a/sdk-api/src/test/java/dev/restate/sdk/RandomTest.java b/sdk-api/src/test/java/dev/restate/sdk/RandomTest.java index fc240f6a..9b7aa302 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/RandomTest.java +++ b/sdk-api/src/test/java/dev/restate/sdk/RandomTest.java @@ -20,7 +20,7 @@ public class RandomTest extends RandomTestSuite { private static class RandomShouldBeDeterministic extends GreeterRestate.GreeterRestateImplBase { @Override - public GreetingResponse greet(RestateContext context, GreetingRequest request) + public GreetingResponse greet(KeyedContext context, GreetingRequest request) throws TerminalException { return GreetingResponse.newBuilder() .setMessage(Integer.toString(context.random().nextInt())) @@ -35,7 +35,7 @@ protected BindableService randomShouldBeDeterministic() { private static class RandomInsideSideEffect extends GreeterRestate.GreeterRestateImplBase { @Override - public GreetingResponse greet(RestateContext context, GreetingRequest request) + public GreetingResponse greet(KeyedContext context, GreetingRequest request) throws TerminalException { context.sideEffect(() -> context.random().nextInt()); throw new IllegalStateException("This should not unreachable"); diff --git a/sdk-api/src/test/java/dev/restate/sdk/RestateCodegenTest.java b/sdk-api/src/test/java/dev/restate/sdk/RestateCodegenTest.java index 9e2845cb..3e2c4c45 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/RestateCodegenTest.java +++ b/sdk-api/src/test/java/dev/restate/sdk/RestateCodegenTest.java @@ -19,7 +19,7 @@ private static class GreeterWithRestateClientAndServerCodegen extends GreeterRestate.GreeterRestateImplBase { @Override - public GreetingResponse greet(RestateContext context, GreetingRequest request) { + public GreetingResponse greet(KeyedContext context, GreetingRequest request) { GreeterRestate.GreeterRestateClient client = GreeterRestate.newClient(context); client.delayed(Duration.ofSeconds(1)).greet(request); client.oneWay().greet(request); @@ -35,31 +35,31 @@ protected BindableService greeterWithRestateClientAndServerCodegen() { private static class Codegen extends CodegenRestate.CodegenRestateImplBase { @Override - public MyMessage emptyInput(RestateContext context) { + public MyMessage emptyInput(UnkeyedContext context) { CodegenRestate.CodegenRestateClient client = CodegenRestate.newClient(context); return client.emptyInput().await(); } @Override - public void emptyOutput(RestateContext context, MyMessage request) { + public void emptyOutput(UnkeyedContext context, MyMessage request) { CodegenRestate.CodegenRestateClient client = CodegenRestate.newClient(context); client.emptyOutput(request).await(); } @Override - public void emptyInputOutput(RestateContext context) { + public void emptyInputOutput(UnkeyedContext context) { CodegenRestate.CodegenRestateClient client = CodegenRestate.newClient(context); client.emptyInputOutput().await(); } @Override - public MyMessage oneWay(RestateContext context, MyMessage request) { + public MyMessage oneWay(UnkeyedContext context, MyMessage request) { CodegenRestate.CodegenRestateClient client = CodegenRestate.newClient(context); return client._oneWay(request).await(); } @Override - public MyMessage delayed(RestateContext context, MyMessage request) { + public MyMessage delayed(UnkeyedContext context, MyMessage request) { CodegenRestate.CodegenRestateClient client = CodegenRestate.newClient(context); return client._delayed(request).await(); } diff --git a/sdk-api/src/test/java/dev/restate/sdk/SideEffectTest.java b/sdk-api/src/test/java/dev/restate/sdk/SideEffectTest.java index 9cc523b0..25fd3b1a 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/SideEffectTest.java +++ b/sdk-api/src/test/java/dev/restate/sdk/SideEffectTest.java @@ -31,7 +31,7 @@ private static class SideEffect extends GreeterGrpc.GreeterImplBase implements R @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); String result = ctx.sideEffect(CoreSerdes.JSON_STRING, () -> this.sideEffectOutput); @@ -56,7 +56,7 @@ private static class ConsecutiveSideEffect extends GreeterGrpc.GreeterImplBase @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); String firstResult = ctx.sideEffect(CoreSerdes.JSON_STRING, () -> this.sideEffectOutput); String secondResult = ctx.sideEffect(CoreSerdes.JSON_STRING, firstResult::toUpperCase); @@ -80,7 +80,7 @@ public void greet(GreetingRequest request, StreamObserver resp String currentThread = Thread.currentThread().getName(); String sideEffectThread = - restateContext() + KeyedContext.current() .sideEffect(CoreSerdes.JSON_STRING, () -> Thread.currentThread().getName()); if (!Objects.equals(currentThread, sideEffectThread)) { @@ -106,7 +106,7 @@ private static class SideEffectGuard extends GreeterGrpc.GreeterImplBase @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); ctx.sideEffect( () -> ctx.oneWayCall(GreeterGrpc.getGreetMethod(), greetingRequest("something"))); diff --git a/sdk-api/src/test/java/dev/restate/sdk/SleepTest.java b/sdk-api/src/test/java/dev/restate/sdk/SleepTest.java index ccf105de..8bb0deaf 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/SleepTest.java +++ b/sdk-api/src/test/java/dev/restate/sdk/SleepTest.java @@ -24,7 +24,7 @@ private static class SleepGreeter extends GreeterGrpc.GreeterImplBase implements @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); ctx.sleep(Duration.ofSeconds(1)); @@ -42,7 +42,7 @@ private static class ManySleeps extends GreeterGrpc.GreeterImplBase implements R @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); List> collectedAwaitables = new ArrayList<>(); for (int i = 0; i < 10; i++) { diff --git a/sdk-api/src/test/java/dev/restate/sdk/StateMachineFailuresTest.java b/sdk-api/src/test/java/dev/restate/sdk/StateMachineFailuresTest.java index 66645cc1..cee08068 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/StateMachineFailuresTest.java +++ b/sdk-api/src/test/java/dev/restate/sdk/StateMachineFailuresTest.java @@ -43,7 +43,7 @@ private GetState(AtomicInteger nonTerminalExceptionsSeen) { @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { try { - restateContext().get(STATE); + KeyedContext.current().get(STATE); } catch (Throwable e) { // A user should never catch Throwable!!! if (AbortedExecutionException.INSTANCE.equals(e)) { @@ -75,7 +75,7 @@ private SideEffectFailure(Serde serde) { @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - restateContext().sideEffect(serde, () -> 0); + KeyedContext.current().sideEffect(serde, () -> 0); responseObserver.onNext(greetingResponse("Francesco")); responseObserver.onCompleted(); diff --git a/sdk-api/src/test/java/dev/restate/sdk/StateTest.java b/sdk-api/src/test/java/dev/restate/sdk/StateTest.java index 50a8d89e..1ac511fe 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/StateTest.java +++ b/sdk-api/src/test/java/dev/restate/sdk/StateTest.java @@ -26,7 +26,9 @@ private static class GetState extends GreeterGrpc.GreeterImplBase implements Res @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { String state = - restateContext().get(StateKey.of("STATE", CoreSerdes.JSON_STRING)).orElse("Unknown"); + KeyedContext.current() + .get(StateKey.of("STATE", CoreSerdes.JSON_STRING)) + .orElse("Unknown"); responseObserver.onNext(GreetingResponse.newBuilder().setMessage("Hello " + state).build()); responseObserver.onCompleted(); @@ -42,7 +44,7 @@ private static class GetAndSetState extends GreeterGrpc.GreeterImplBase implements RestateService { @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); String state = ctx.get(StateKey.of("STATE", CoreSerdes.JSON_STRING)).get(); @@ -61,7 +63,7 @@ protected BindableService getAndSetState() { private static class SetNullState extends GreeterGrpc.GreeterImplBase implements RestateService { @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - restateContext() + KeyedContext.current() .set( StateKey.of( "STATE", diff --git a/sdk-api/src/test/java/dev/restate/sdk/UserFailuresTest.java b/sdk-api/src/test/java/dev/restate/sdk/UserFailuresTest.java index 87751926..eb90757f 100644 --- a/sdk-api/src/test/java/dev/restate/sdk/UserFailuresTest.java +++ b/sdk-api/src/test/java/dev/restate/sdk/UserFailuresTest.java @@ -51,7 +51,7 @@ private SideEffectThrowIllegalStateException(AtomicInteger nonTerminalExceptions @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { try { - restateContext() + KeyedContext.current() .sideEffect( () -> { throw new IllegalStateException("Whatever"); @@ -110,7 +110,7 @@ private SideEffectThrowTerminalException(TerminalException.Code code, String mes @Override public void greet(GreetingRequest request, StreamObserver responseObserver) { - restateContext() + KeyedContext.current() .sideEffect( () -> { throw new TerminalException(code, message); diff --git a/sdk-http-vertx/src/test/java/dev/restate/sdk/http/vertx/testservices/BlockingGreeterService.java b/sdk-http-vertx/src/test/java/dev/restate/sdk/http/vertx/testservices/BlockingGreeterService.java index 1aaede39..61a73da4 100644 --- a/sdk-http-vertx/src/test/java/dev/restate/sdk/http/vertx/testservices/BlockingGreeterService.java +++ b/sdk-http-vertx/src/test/java/dev/restate/sdk/http/vertx/testservices/BlockingGreeterService.java @@ -8,6 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.http.vertx.testservices; +import dev.restate.sdk.KeyedContext; import dev.restate.sdk.RestateService; import dev.restate.sdk.common.CoreSerdes; import dev.restate.sdk.common.StateKey; @@ -30,10 +31,10 @@ public void greet(GreetingRequest request, StreamObserver resp LOG.info("Greet invoked!"); - var count = restateContext().get(COUNTER).orElse(0L) + 1; - restateContext().set(COUNTER, count); + var count = KeyedContext.current().get(COUNTER).orElse(0L) + 1; + KeyedContext.current().set(COUNTER, count); - restateContext().sleep(Duration.ofSeconds(1)); + KeyedContext.current().sleep(Duration.ofSeconds(1)); responseObserver.onNext( GreetingResponse.newBuilder() diff --git a/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/HttpVertxTests.kt b/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/HttpVertxTests.kt index d3d9418e..062bcb49 100644 --- a/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/HttpVertxTests.kt +++ b/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/HttpVertxTests.kt @@ -17,6 +17,7 @@ import dev.restate.sdk.core.TestDefinitions.* import dev.restate.sdk.core.testservices.GreeterGrpc import dev.restate.sdk.core.testservices.GreetingRequest import dev.restate.sdk.core.testservices.GreetingResponse +import dev.restate.sdk.kotlin.KeyedContext import dev.restate.sdk.kotlin.KotlinCoroutinesTests import dev.restate.sdk.kotlin.RestateKtService import io.grpc.stub.StreamObserver @@ -51,7 +52,7 @@ class HttpVertxTests : dev.restate.sdk.core.TestRunner() { RestateKtService { override suspend fun greet(request: GreetingRequest): GreetingResponse { check(Vertx.currentContext().isEventLoopContext) - restateContext().sideEffect { check(Vertx.currentContext().isEventLoopContext) } + KeyedContext.current().sideEffect { check(Vertx.currentContext().isEventLoopContext) } check(Vertx.currentContext().isEventLoopContext) return GreetingResponse.getDefaultInstance() } @@ -65,7 +66,7 @@ class HttpVertxTests : dev.restate.sdk.core.TestRunner() { ) { val id = Thread.currentThread().id check(Vertx.currentContext() == null) - restateContext().sideEffect { + dev.restate.sdk.KeyedContext.current().sideEffect { check(Thread.currentThread().id == id) check(Vertx.currentContext() == null) } diff --git a/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/testservices/GreeterKtService.kt b/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/testservices/GreeterKtService.kt index 3296d88b..fee9a8ae 100644 --- a/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/testservices/GreeterKtService.kt +++ b/sdk-http-vertx/src/test/kotlin/dev/restate/sdk/http/vertx/testservices/GreeterKtService.kt @@ -12,6 +12,7 @@ import dev.restate.sdk.core.testservices.GreeterGrpcKt import dev.restate.sdk.core.testservices.GreetingRequest import dev.restate.sdk.core.testservices.GreetingResponse import dev.restate.sdk.core.testservices.greetingResponse +import dev.restate.sdk.kotlin.KeyedContext import dev.restate.sdk.kotlin.RestateKtService import kotlin.coroutines.CoroutineContext import kotlin.time.Duration.Companion.seconds @@ -25,10 +26,10 @@ class GreeterKtService(coroutineContext: CoroutineContext) : override suspend fun greet(request: GreetingRequest): GreetingResponse { LOG.info("Greet invoked!") - val count = (restateContext().get(BlockingGreeterService.COUNTER) ?: 0) + 1 - restateContext().set(BlockingGreeterService.COUNTER, count) + val count = (KeyedContext.current().get(BlockingGreeterService.COUNTER) ?: 0) + 1 + KeyedContext.current().set(BlockingGreeterService.COUNTER, count) - restateContext().sleep(1.seconds) + KeyedContext.current().sleep(1.seconds) return greetingResponse { message = "Hello ${request.name}. Count: $count" } } diff --git a/sdk-lambda/src/test/java/dev/restate/sdk/lambda/testservices/JavaCounterService.java b/sdk-lambda/src/test/java/dev/restate/sdk/lambda/testservices/JavaCounterService.java index 215fce59..6ad49672 100644 --- a/sdk-lambda/src/test/java/dev/restate/sdk/lambda/testservices/JavaCounterService.java +++ b/sdk-lambda/src/test/java/dev/restate/sdk/lambda/testservices/JavaCounterService.java @@ -8,6 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.lambda.testservices; +import dev.restate.sdk.KeyedContext; import dev.restate.sdk.RestateService; import dev.restate.sdk.common.Serde; import dev.restate.sdk.common.StateKey; @@ -26,7 +27,7 @@ public class JavaCounterService extends JavaCounterGrpc.JavaCounterImplBase @Override public void get(CounterRequest request, StreamObserver responseObserver) { - restateContext().get(COUNTER); + KeyedContext.current().get(COUNTER); throw new IllegalStateException("We shouldn't reach this point"); } diff --git a/sdk-lambda/src/test/kotlin/dev/restate/sdk/lambda/testservices/KotlinCounterService.kt b/sdk-lambda/src/test/kotlin/dev/restate/sdk/lambda/testservices/KotlinCounterService.kt index cfea46c4..3508752f 100644 --- a/sdk-lambda/src/test/kotlin/dev/restate/sdk/lambda/testservices/KotlinCounterService.kt +++ b/sdk-lambda/src/test/kotlin/dev/restate/sdk/lambda/testservices/KotlinCounterService.kt @@ -8,6 +8,7 @@ // https://github.com/restatedev/sdk-java/blob/main/LICENSE package dev.restate.sdk.lambda.testservices +import dev.restate.sdk.kotlin.KeyedContext import dev.restate.sdk.kotlin.RestateKtService import kotlinx.coroutines.Dispatchers @@ -16,7 +17,7 @@ class KotlinCounterService : RestateKtService { override suspend fun get(request: CounterRequest): GetResponse { - (restateContext().get(JavaCounterService.COUNTER) ?: 0) + 1 + (KeyedContext.current().get(JavaCounterService.COUNTER) ?: 0) + 1 throw IllegalStateException("We shouldn't reach this point") } diff --git a/sdk-testing/src/test/java/dev/restate/sdk/testing/Counter.java b/sdk-testing/src/test/java/dev/restate/sdk/testing/Counter.java index ab869161..3f186ea9 100644 --- a/sdk-testing/src/test/java/dev/restate/sdk/testing/Counter.java +++ b/sdk-testing/src/test/java/dev/restate/sdk/testing/Counter.java @@ -9,7 +9,7 @@ package dev.restate.sdk.testing; import com.google.protobuf.Empty; -import dev.restate.sdk.RestateContext; +import dev.restate.sdk.KeyedContext; import dev.restate.sdk.RestateService; import dev.restate.sdk.common.CoreSerdes; import dev.restate.sdk.common.StateKey; @@ -26,7 +26,7 @@ public class Counter extends CounterGrpc.CounterImplBase implements RestateServi @Override public void reset(CounterRequest request, StreamObserver responseObserver) { - restateContext().clear(TOTAL); + KeyedContext.current().clear(TOTAL); responseObserver.onNext(Empty.getDefaultInstance()); responseObserver.onCompleted(); @@ -34,7 +34,7 @@ public void reset(CounterRequest request, StreamObserver responseObserver @Override public void add(CounterAddRequest request, StreamObserver responseObserver) { - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); long currentValue = ctx.get(TOTAL).orElse(0L); long newValue = currentValue + request.getValue(); @@ -46,7 +46,7 @@ public void add(CounterAddRequest request, StreamObserver responseObserve @Override public void get(CounterRequest request, StreamObserver responseObserver) { - long currentValue = restateContext().get(TOTAL).orElse(0L); + long currentValue = KeyedContext.current().get(TOTAL).orElse(0L); responseObserver.onNext(GetResponse.newBuilder().setValue(currentValue).build()); responseObserver.onCompleted(); @@ -57,7 +57,7 @@ public void getAndAdd( CounterAddRequest request, StreamObserver responseObserver) { LOG.info("Invoked get and add with " + request.getValue()); - RestateContext ctx = restateContext(); + KeyedContext ctx = KeyedContext.current(); long currentValue = ctx.get(TOTAL).orElse(0L); long newValue = currentValue + request.getValue();