Skip to content

Commit

Permalink
Split RestateContext interface in KeyedContext/UnkeyedContext (#213)
Browse files Browse the repository at this point in the history
* [java] Split RestateContext interface in KeyedContext/UnkeyedContext
* [kotlin] Split RestateContext interface in KeyedContext/UnkeyedContext
* Make sure the protoc code generator injects the correct interface
* Update tests and examples
  • Loading branch information
slinkydeveloper authored Feb 5, 2024
1 parent 008868c commit 64b4d7a
Show file tree
Hide file tree
Showing 42 changed files with 329 additions and 271 deletions.
16 changes: 7 additions & 9 deletions examples/src/main/java/dev/restate/sdk/examples/Counter.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.sdk.examples;

import dev.restate.sdk.RestateContext;
import dev.restate.sdk.KeyedContext;
import dev.restate.sdk.common.CoreSerdes;
import dev.restate.sdk.common.StateKey;
import dev.restate.sdk.examples.generated.*;
Expand All @@ -23,30 +23,28 @@ public class Counter extends CounterRestate.CounterRestateImplBase {
private static final StateKey<Long> TOTAL = StateKey.of("total", CoreSerdes.JSON_LONG);

@Override
public void reset(RestateContext ctx, CounterRequest request) {
restateContext().clear(TOTAL);
public void reset(KeyedContext ctx, CounterRequest request) {
ctx.clear(TOTAL);
}

@Override
public void add(RestateContext ctx, CounterAddRequest request) {
public void add(KeyedContext ctx, CounterAddRequest request) {
long currentValue = ctx.get(TOTAL).orElse(0L);
long newValue = currentValue + request.getValue();
ctx.set(TOTAL, newValue);
}

@Override
public GetResponse get(RestateContext context, CounterRequest request) {
long currentValue = restateContext().get(TOTAL).orElse(0L);
public GetResponse get(KeyedContext ctx, CounterRequest request) {
long currentValue = ctx.get(TOTAL).orElse(0L);

return GetResponse.newBuilder().setValue(currentValue).build();
}

@Override
public CounterUpdateResult getAndAdd(RestateContext context, CounterAddRequest request) {
public CounterUpdateResult getAndAdd(KeyedContext ctx, CounterAddRequest request) {
LOG.info("Invoked get and add with " + request.getValue());

RestateContext ctx = restateContext();

long currentValue = ctx.get(TOTAL).orElse(0L);
long newValue = currentValue + request.getValue();
ctx.set(TOTAL, newValue);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
package dev.restate.sdk.examples;

import com.google.protobuf.Empty;
import dev.restate.sdk.RestateContext;
import dev.restate.sdk.KeyedContext;
import dev.restate.sdk.RestateService;
import dev.restate.sdk.common.CoreSerdes;
import dev.restate.sdk.common.StateKey;
Expand All @@ -27,15 +27,15 @@ public class VanillaGrpcCounter extends CounterGrpc.CounterImplBase implements R

@Override
public void reset(CounterRequest request, StreamObserver<Empty> responseObserver) {
restateContext().clear(TOTAL);
KeyedContext.current().clear(TOTAL);

responseObserver.onNext(Empty.getDefaultInstance());
responseObserver.onCompleted();
}

@Override
public void add(CounterAddRequest request, StreamObserver<Empty> responseObserver) {
RestateContext ctx = restateContext();
KeyedContext ctx = KeyedContext.current();

long currentValue = ctx.get(TOTAL).orElse(0L);
long newValue = currentValue + request.getValue();
Expand All @@ -47,7 +47,7 @@ public void add(CounterAddRequest request, StreamObserver<Empty> responseObserve

@Override
public void get(CounterRequest request, StreamObserver<GetResponse> responseObserver) {
long currentValue = restateContext().get(TOTAL).orElse(0L);
long currentValue = KeyedContext.current().get(TOTAL).orElse(0L);

responseObserver.onNext(GetResponse.newBuilder().setValue(currentValue).build());
responseObserver.onCompleted();
Expand All @@ -58,7 +58,7 @@ public void getAndAdd(
CounterAddRequest request, StreamObserver<CounterUpdateResult> responseObserver) {
LOG.info("Invoked get and add with " + request.getValue());

RestateContext ctx = restateContext();
KeyedContext ctx = KeyedContext.current();

long currentValue = ctx.get(TOTAL).orElse(0L);
long newValue = currentValue + request.getValue();
Expand Down
12 changes: 6 additions & 6 deletions examples/src/main/kotlin/dev/restate/sdk/examples/CounterKt.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import dev.restate.sdk.common.CoreSerdes
import dev.restate.sdk.common.StateKey
import dev.restate.sdk.examples.generated.*
import dev.restate.sdk.http.vertx.RestateHttpEndpointBuilder
import dev.restate.sdk.kotlin.RestateContext
import dev.restate.sdk.kotlin.KeyedContext
import org.apache.logging.log4j.LogManager

class CounterKt : CounterRestateKt.CounterRestateKtImplBase() {
Expand All @@ -21,20 +21,20 @@ class CounterKt : CounterRestateKt.CounterRestateKtImplBase() {

private val TOTAL = StateKey.of("total", CoreSerdes.JSON_LONG)

override suspend fun reset(context: RestateContext, request: CounterRequest) {
override suspend fun reset(context: KeyedContext, request: CounterRequest) {
context.clear(TOTAL)
}

override suspend fun add(context: RestateContext, request: CounterAddRequest) {
override suspend fun add(context: KeyedContext, request: CounterAddRequest) {
updateCounter(context, request.value)
}

override suspend fun get(context: RestateContext, request: CounterRequest): GetResponse {
override suspend fun get(context: KeyedContext, request: CounterRequest): GetResponse {
return getResponse { value = context.get(TOTAL) ?: 0L }
}

override suspend fun getAndAdd(
context: RestateContext,
context: KeyedContext,
request: CounterAddRequest
): CounterUpdateResult {
LOG.info("Invoked get and add with " + request.value)
Expand All @@ -45,7 +45,7 @@ class CounterKt : CounterRestateKt.CounterRestateKtImplBase() {
}
}

private suspend fun updateCounter(context: RestateContext, add: Long): Pair<Long, Long> {
private suspend fun updateCounter(context: KeyedContext, add: Long): Pair<Long, Long> {
val currentValue = context.get(TOTAL) ?: 0L
val newValue = currentValue + add

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import com.salesforce.jprotoc.ProtoTypeMap;
import com.salesforce.jprotoc.ProtocPlugin;
import dev.restate.generated.ext.Ext;
import dev.restate.generated.ext.ServiceType;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -120,6 +121,12 @@ private ServiceContext buildServiceContext(
serviceContext.serviceName = serviceProto.getName();
serviceContext.deprecated = serviceProto.getOptions().getDeprecated();

// Resolve context type
serviceContext.contextType =
serviceProto.getOptions().getExtension(Ext.serviceType) == ServiceType.UNKEYED
? "UnkeyedContext"
: "KeyedContext";

// Resolve javadoc
DescriptorProtos.SourceCodeInfo.Location serviceLocation =
locations.stream()
Expand Down Expand Up @@ -215,6 +222,7 @@ private static class ServiceContext {
public String packageName;
public String className;
public String serviceName;
public String contextType;
public String apidoc;
public boolean deprecated;
public final List<MethodContext> methods = new ArrayList<>();
Expand Down
42 changes: 18 additions & 24 deletions protoc-gen-restate/src/main/resources/javaStub.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
package {{packageName}};
{{/packageName}}

import dev.restate.sdk.RestateContext;
import dev.restate.sdk.UnkeyedContext;
import dev.restate.sdk.KeyedContext;
import dev.restate.sdk.Awaitable;
import dev.restate.sdk.common.syscalls.Syscalls;
import java.time.Duration;
Expand All @@ -15,24 +16,17 @@ public class {{className}} {
private {{className}}() {}

/**
* Create a new client.
*/
public static {{serviceName}}RestateClient newClient() {
return newClient(RestateContext.fromSyscalls(Syscalls.current()));
}

/**
* Create a new client from the given {@link RestateContext}.
* Create a new client from the given {@link KeyedContext}.
*/
public static {{serviceName}}RestateClient newClient(RestateContext ctx) {
public static {{serviceName}}RestateClient newClient(UnkeyedContext ctx) {
return new {{serviceName}}RestateClient(ctx);
}

{{{apidoc}}}
public static final class {{serviceName}}RestateClient {
private final RestateContext ctx;
private final UnkeyedContext ctx;
{{serviceName}}RestateClient(RestateContext ctx) {
{{serviceName}}RestateClient(UnkeyedContext ctx) {
this.ctx = ctx;
}

Expand All @@ -57,9 +51,9 @@ public class {{className}} {
}

public static final class {{serviceName}}RestateOneWayClient {
private final RestateContext ctx;
private final UnkeyedContext ctx;
{{serviceName}}RestateOneWayClient(RestateContext ctx) {
{{serviceName}}RestateOneWayClient(UnkeyedContext ctx) {
this.ctx = ctx;
}

Expand All @@ -74,10 +68,10 @@ public class {{className}} {
}

public static final class {{serviceName}}RestateDelayedClient {
private final RestateContext ctx;
private final UnkeyedContext ctx;
private final Duration delay;
{{serviceName}}RestateDelayedClient(RestateContext ctx, Duration delay) {
{{serviceName}}RestateDelayedClient(UnkeyedContext ctx, Duration delay) {
this.ctx = ctx;
this.delay = delay;
}
Expand All @@ -100,7 +94,7 @@ public class {{className}} {
@java.lang.Deprecated
{{/deprecated}}
{{{apidoc}}}
public {{#isOutputEmpty}}void{{/isOutputEmpty}}{{^isOutputEmpty}}{{outputType}}{{/isOutputEmpty}} {{methodName}}(RestateContext context{{^isInputEmpty}}, {{inputType}} request{{/isInputEmpty}}) throws dev.restate.sdk.common.TerminalException {
public {{#isOutputEmpty}}void{{/isOutputEmpty}}{{^isOutputEmpty}}{{outputType}}{{/isOutputEmpty}} {{methodName}}({{contextType}} context{{^isInputEmpty}}, {{inputType}} request{{/isInputEmpty}}) throws dev.restate.sdk.common.TerminalException {
throw new dev.restate.sdk.common.TerminalException(dev.restate.sdk.common.TerminalException.Code.UNIMPLEMENTED);
}

Expand All @@ -120,34 +114,34 @@ public class {{className}} {
private static final class HandlerAdapter<Req, Resp> implements
io.grpc.stub.ServerCalls.UnaryMethod<Req, Resp> {
private final java.util.function.BiFunction<RestateContext, Req, Resp> handler;
private final java.util.function.BiFunction<KeyedContext, Req, Resp> handler;
private HandlerAdapter(java.util.function.BiFunction<RestateContext, Req, Resp> handler) {
private HandlerAdapter(java.util.function.BiFunction<KeyedContext, Req, Resp> handler) {
this.handler = handler;
}

@Override
public void invoke(Req request, io.grpc.stub.StreamObserver<Resp> responseObserver) {
responseObserver.onNext(handler.apply(RestateContext.fromSyscalls(Syscalls.current()), request));
responseObserver.onNext(handler.apply(KeyedContext.fromSyscalls(Syscalls.current()), request));
responseObserver.onCompleted();
}

private static <Req, Resp> HandlerAdapter<Req, Resp> of(java.util.function.BiFunction<RestateContext, Req, Resp> handler) {
private static <Req, Resp> HandlerAdapter<Req, Resp> of(java.util.function.BiFunction<KeyedContext, Req, Resp> handler) {
return new HandlerAdapter<>(handler);
}

private static <Resp> HandlerAdapter<com.google.protobuf.Empty, Resp> of(java.util.function.Function<RestateContext, Resp> handler) {
private static <Resp> HandlerAdapter<com.google.protobuf.Empty, Resp> of(java.util.function.Function<KeyedContext, Resp> handler) {
return new HandlerAdapter<>((ctx, e) -> handler.apply(ctx));
}

private static <Req> HandlerAdapter<Req, com.google.protobuf.Empty> of(java.util.function.BiConsumer<RestateContext, Req> handler) {
private static <Req> HandlerAdapter<Req, com.google.protobuf.Empty> of(java.util.function.BiConsumer<KeyedContext, Req> handler) {
return new HandlerAdapter<>((ctx, req) -> {
handler.accept(ctx, req);
return com.google.protobuf.Empty.getDefaultInstance();
});
}

private static HandlerAdapter<com.google.protobuf.Empty, com.google.protobuf.Empty> of(java.util.function.Consumer<RestateContext> handler) {
private static HandlerAdapter<com.google.protobuf.Empty, com.google.protobuf.Empty> of(java.util.function.Consumer<KeyedContext> handler) {
return new HandlerAdapter<>((ctx, req) -> {
handler.accept(ctx);
return com.google.protobuf.Empty.getDefaultInstance();
Expand Down
24 changes: 12 additions & 12 deletions protoc-gen-restate/src/main/resources/ktStub.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
package {{packageName}};
{{/packageName}}

import dev.restate.sdk.kotlin.RestateContext;
import dev.restate.sdk.kotlin.UnkeyedContext;
import dev.restate.sdk.kotlin.KeyedContext;
import dev.restate.sdk.kotlin.Awaitable;
import dev.restate.sdk.kotlin.RestateKtService;
import dev.restate.sdk.common.syscalls.Syscalls;
import io.grpc.kotlin.ClientCalls.unaryRpc
import io.grpc.kotlin.ServerCalls.unaryServerMethodDefinition
import {{packageName}}.{{serviceName}}Grpc.getServiceDescriptor;
import dev.restate.sdk.kotlin.restateContextFromSyscalls;
import kotlin.time.Duration

{{#deprecated}}
Expand All @@ -18,14 +18,14 @@ import kotlin.time.Duration
public object {{className}} {
/**
* Create a new client from the given [RestateContext].
* Create a new client from the given [UnkeyedContext].
*/
fun newClient(ctx: RestateContext): {{serviceName}}RestateKtClient {
fun newClient(ctx: UnkeyedContext): {{serviceName}}RestateKtClient {
return {{serviceName}}RestateKtClient(ctx);
}

{{{javadoc}}}
public class {{serviceName}}RestateKtClient(private val ctx: RestateContext) {
public class {{serviceName}}RestateKtClient(private val ctx: UnkeyedContext) {
// Create a variant of this client to execute oneWay calls.
public fun oneWay(): {{serviceName}}RestateKtOneWayClient {
return {{serviceName}}RestateKtOneWayClient(ctx);
Expand All @@ -46,7 +46,7 @@ public object {{className}} {
{{/methods}}
}

public class {{serviceName}}RestateKtOneWayClient(private val ctx: RestateContext) {
public class {{serviceName}}RestateKtOneWayClient(private val ctx: UnkeyedContext) {
{{#methods}}
{{#deprecated}}@Deprecated{{/deprecated}}
{{{javadoc}}}
Expand All @@ -57,7 +57,7 @@ public object {{className}} {
{{/methods}}
}

public class {{serviceName}}RestateKtDelayedClient(private val ctx: RestateContext, private val delay: Duration) {
public class {{serviceName}}RestateKtDelayedClient(private val ctx: UnkeyedContext, private val delay: Duration) {
{{#methods}}
{{#deprecated}}@Deprecated{{/deprecated}}
{{{javadoc}}}
Expand All @@ -78,7 +78,7 @@ public object {{className}} {
@Deprecated
{{/deprecated}}
{{{javadoc}}}
public open suspend fun {{methodName}}(context: RestateContext{{^isInputEmpty}}, request: {{inputType}} {{/isInputEmpty}}){{^isOutputEmpty}}: {{outputType}}{{/isOutputEmpty}} {
public open suspend fun {{methodName}}(context: {{contextType}}{{^isInputEmpty}}, request: {{inputType}} {{/isInputEmpty}}){{^isOutputEmpty}}: {{outputType}}{{/isOutputEmpty}} {
throw dev.restate.sdk.common.TerminalException(dev.restate.sdk.common.TerminalException.Code.UNIMPLEMENTED);
}

Expand All @@ -91,18 +91,18 @@ public object {{className}} {
descriptor = {{packageName}}.{{serviceName}}Grpc.{{methodDescriptorGetter}}(),
implementation = {
{{#isInputEmpty}}{{#isOutputEmpty}}
{{methodName}}(restateContextFromSyscalls(Syscalls.current()))
{{methodName}}(KeyedContext.fromSyscalls(Syscalls.current()))
return@unaryServerMethodDefinition com.google.protobuf.Empty.getDefaultInstance()
{{/isOutputEmpty}}{{/isInputEmpty}}
{{#isInputEmpty}}{{^isOutputEmpty}}
return@unaryServerMethodDefinition {{methodName}}(restateContextFromSyscalls(Syscalls.current()))
return@unaryServerMethodDefinition {{methodName}}(KeyedContext.fromSyscalls(Syscalls.current()))
{{/isOutputEmpty}}{{/isInputEmpty}}
{{^isInputEmpty}}{{#isOutputEmpty}}
{{methodName}}(restateContextFromSyscalls(Syscalls.current()), it)
{{methodName}}(KeyedContext.fromSyscalls(Syscalls.current()), it)
return@unaryServerMethodDefinition com.google.protobuf.Empty.getDefaultInstance()
{{/isOutputEmpty}}{{/isInputEmpty}}
{{^isInputEmpty}}{{^isOutputEmpty}}
return@unaryServerMethodDefinition {{methodName}}(restateContextFromSyscalls(Syscalls.current()), it)
return@unaryServerMethodDefinition {{methodName}}(KeyedContext.fromSyscalls(Syscalls.current()), it)
{{/isOutputEmpty}}{{/isInputEmpty}}
}
))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ import kotlin.time.Duration
import kotlin.time.toJavaDuration
import kotlinx.coroutines.*

internal class RestateContextImpl internal constructor(private val syscalls: Syscalls) :
RestateContext {
internal class ContextImpl internal constructor(private val syscalls: Syscalls) : KeyedContext {
override suspend fun <T : Any> get(key: StateKey<T>): T? {
val deferred: Deferred<ByteString> =
suspendCancellableCoroutine { cont: CancellableContinuation<Deferred<ByteString>> ->
Expand Down
Loading

0 comments on commit 64b4d7a

Please sign in to comment.