Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Java SDK with new service protocol negotiation mechanism #318

Merged
merged 7 commits into from
May 19, 2024
Merged
3 changes: 2 additions & 1 deletion sdk-core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies {
// We need this for the manifest
implementation(platform(jacksonLibs.jackson.bom))
implementation(jacksonLibs.jackson.annotations)
implementation(jacksonLibs.jackson.databind)

// We don't want a hard-dependency on it
compileOnly(coreLibs.log4j.core)
Expand All @@ -43,7 +44,7 @@ sourceSets {

// Configure jsonSchema2Pojo
jsonSchema2Pojo {
setSource(files("$projectDir/src/main/service-protocol/deployment_manifest_schema.json"))
setSource(files("$projectDir/src/main/service-protocol/endpoint_manifest_schema.json"))
targetPackage = "dev.restate.sdk.core.manifest"
targetDirectory = generatedJ2SPDir.get().asFile

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.sdk.core;

import static dev.restate.sdk.core.ServiceProtocol.MAX_SERVICE_PROTOCOL_VERSION;
import static dev.restate.sdk.core.ServiceProtocol.MIN_SERVICE_PROTOCOL_VERSION;

import dev.restate.sdk.common.HandlerType;
import dev.restate.sdk.common.ServiceType;
import dev.restate.sdk.common.syscalls.HandlerDefinition;
Expand All @@ -17,18 +20,19 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

final class DeploymentManifest {
final class EndpointManifest {

private static final Input EMPTY_INPUT = new Input();
private static final Output EMPTY_OUTPUT = new Output().withSetContentTypeIfEmpty(false);

private final DeploymentManifestSchema manifest;
private final EndpointManifestSchema manifest;

public DeploymentManifest(
DeploymentManifestSchema.ProtocolMode protocolMode, Stream<ServiceDefinition<?>> components) {
public EndpointManifest(
EndpointManifestSchema.ProtocolMode protocolMode, Stream<ServiceDefinition<?>> components) {
this.manifest =
new DeploymentManifestSchema()
.withMinProtocolVersion(1)
.withMaxProtocolVersion(1)
new EndpointManifestSchema()
.withMinProtocolVersion(MIN_SERVICE_PROTOCOL_VERSION.getNumber())
.withMaxProtocolVersion(MAX_SERVICE_PROTOCOL_VERSION.getNumber())
.withProtocolMode(protocolMode)
.withServices(
components
Expand All @@ -39,12 +43,12 @@ public DeploymentManifest(
.withTy(convertServiceType(svc.getServiceType()))
.withHandlers(
svc.getHandlers().stream()
.map(DeploymentManifest::convertHandler)
.map(EndpointManifest::convertHandler)
.collect(Collectors.toList())))
.collect(Collectors.toList()));
}

public DeploymentManifestSchema manifest() {
public EndpointManifestSchema manifest() {
return this.manifest;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ public void onNext(InvocationFlow.InvocationInput invocationInput) {
MessageLite msg = invocationInput.message();
LOG.trace("Received input message {} {}", msg.getClass(), msg);
if (this.invocationState == InvocationState.WAITING_START) {
MessageHeader.checkProtocolVersion(invocationInput.header());
this.onStartMessage(msg);
} else if (msg instanceof Protocol.CompletionMessage) {
// We check the instance rather than the state, because the user code might still be
Expand Down
19 changes: 0 additions & 19 deletions sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@

public class MessageHeader {

static final short SUPPORTED_PROTOCOL_VERSION = 2;

static final short VERSION_MASK = 0x03FF;
static final short DONE_FLAG = 0x0001;
static final int REQUIRES_ACK_FLAG = 0x8000;

Expand Down Expand Up @@ -101,20 +98,4 @@ public static MessageHeader fromMessage(MessageLite msg) {
// Messages with no flags
return new MessageHeader(MessageType.fromMessage(msg), 0, msg.getSerializedSize());
}

public static void checkProtocolVersion(MessageHeader header) {
if (header.type != MessageType.StartMessage) {
throw new IllegalStateException("Expected StartMessage, got " + header.type);
}

short version = (short) (header.flags & VERSION_MASK);
if (version != SUPPORTED_PROTOCOL_VERSION) {
throw new IllegalStateException(
"Unsupported protocol version "
+ version
+ ", only version "
+ SUPPORTED_PROTOCOL_VERSION
+ " is supported");
}
}
}
18 changes: 9 additions & 9 deletions sdk-core/src/main/java/dev/restate/sdk/core/RestateEndpoint.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import dev.restate.sdk.common.BindableServiceFactory;
import dev.restate.sdk.common.syscalls.HandlerDefinition;
import dev.restate.sdk.common.syscalls.ServiceDefinition;
import dev.restate.sdk.core.manifest.DeploymentManifestSchema;
import dev.restate.sdk.core.manifest.EndpointManifestSchema;
import dev.restate.sdk.core.manifest.Service;
import io.opentelemetry.api.OpenTelemetry;
import io.opentelemetry.api.trace.Span;
Expand All @@ -34,18 +34,18 @@ public class RestateEndpoint {
private final Map<String, ServiceAndOptions<?>> services;
private final Tracer tracer;
private final RequestIdentityVerifier requestIdentityVerifier;
private final DeploymentManifest deploymentManifest;
private final EndpointManifest deploymentManifest;

private RestateEndpoint(
DeploymentManifestSchema.ProtocolMode protocolMode,
EndpointManifestSchema.ProtocolMode protocolMode,
Map<String, ServiceAndOptions<?>> services,
Tracer tracer,
RequestIdentityVerifier requestIdentityVerifier) {
this.services = services;
this.tracer = tracer;
this.requestIdentityVerifier = requestIdentityVerifier;
this.deploymentManifest =
new DeploymentManifest(protocolMode, services.values().stream().map(c -> c.service));
new EndpointManifest(protocolMode, services.values().stream().map(c -> c.service));

this.logCreation();
}
Expand Down Expand Up @@ -99,8 +99,8 @@ public ResolvedEndpointHandler resolve(
return new ResolvedEndpointHandlerImpl(stateMachine, handler, svc.options, syscallExecutor);
}

public DeploymentManifestSchema handleDiscoveryRequest() {
DeploymentManifestSchema response = this.deploymentManifest.manifest();
public EndpointManifestSchema handleDiscoveryRequest() {
EndpointManifestSchema response = this.deploymentManifest.manifest();
LOG.info(
"Replying to discovery request with services [{}]",
response.getServices().stream().map(Service::getName).collect(Collectors.joining(",")));
Expand All @@ -113,18 +113,18 @@ private void logCreation() {

// -- Builder

public static Builder newBuilder(DeploymentManifestSchema.ProtocolMode protocolMode) {
public static Builder newBuilder(EndpointManifestSchema.ProtocolMode protocolMode) {
return new Builder(protocolMode);
}

public static class Builder {

private final List<ServiceAndOptions<?>> services = new ArrayList<>();
private final DeploymentManifestSchema.ProtocolMode protocolMode;
private final EndpointManifestSchema.ProtocolMode protocolMode;
private RequestIdentityVerifier requestIdentityVerifier;
private Tracer tracer = OpenTelemetry.noop().getTracer("NOOP");

public Builder(DeploymentManifestSchema.ProtocolMode protocolMode) {
public Builder(EndpointManifestSchema.ProtocolMode protocolMode) {
this.protocolMode = protocolMode;
}

Expand Down
139 changes: 139 additions & 0 deletions sdk-core/src/main/java/dev/restate/sdk/core/ServiceProtocol.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH
//
// This file is part of the Restate Java SDK,
// which is released under the MIT license.
//
// You can find a copy of the license in file LICENSE in the root
// directory of this repository or package, or at
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.sdk.core;

import com.fasterxml.jackson.databind.ObjectMapper;
import dev.restate.generated.service.discovery.Discovery;
import dev.restate.generated.service.protocol.Protocol;
import dev.restate.sdk.core.manifest.EndpointManifestSchema;
import java.util.Objects;
import java.util.Optional;

public class ServiceProtocol {
public static final Protocol.ServiceProtocolVersion MIN_SERVICE_PROTOCOL_VERSION =
Protocol.ServiceProtocolVersion.V1;
public static final Protocol.ServiceProtocolVersion MAX_SERVICE_PROTOCOL_VERSION =
Protocol.ServiceProtocolVersion.V1;

public static final Discovery.ServiceDiscoveryProtocolVersion
MIN_SERVICE_DISCOVERY_PROTOCOL_VERSION = Discovery.ServiceDiscoveryProtocolVersion.V1;
public static final Discovery.ServiceDiscoveryProtocolVersion
MAX_SERVICE_DISCOVERY_PROTOCOL_VERSION = Discovery.ServiceDiscoveryProtocolVersion.V1;

public static Protocol.ServiceProtocolVersion parseServiceProtocolVersion(String version) {
tillrohrmann marked this conversation as resolved.
Show resolved Hide resolved
version = version.trim();

if (version.equals("application/vnd.restate.invocation.v1")) {
return Protocol.ServiceProtocolVersion.V1;
}
return Protocol.ServiceProtocolVersion.SERVICE_PROTOCOL_VERSION_UNSPECIFIED;
}

public static String serviceProtocolVersionToHeaderValue(
Protocol.ServiceProtocolVersion version) {
if (Objects.requireNonNull(version) == Protocol.ServiceProtocolVersion.V1) {
return "application/vnd.restate.invocation.v1";
}
throw new IllegalArgumentException(
String.format("Service protocol version '%s' has no header value", version.getNumber()));
}

public static boolean is_supported(Protocol.ServiceProtocolVersion serviceProtocolVersion) {
return MIN_SERVICE_PROTOCOL_VERSION.getNumber() <= serviceProtocolVersion.getNumber()
&& serviceProtocolVersion.getNumber() <= MAX_SERVICE_PROTOCOL_VERSION.getNumber();
}

public static boolean is_supported(
Discovery.ServiceDiscoveryProtocolVersion serviceDiscoveryProtocolVersion) {
return MIN_SERVICE_DISCOVERY_PROTOCOL_VERSION.getNumber()
<= serviceDiscoveryProtocolVersion.getNumber()
&& serviceDiscoveryProtocolVersion.getNumber()
<= MAX_SERVICE_DISCOVERY_PROTOCOL_VERSION.getNumber();
}

/**
* Selects the highest supported service protocol version from a list of supported versions.
*
* @param acceptedVersionsString A comma-separated list of accepted service protocol versions.
* @return The highest supported service protocol version, otherwise
* Protocol.ServiceProtocolVersion.SERVICE_PROTOCOL_VERSION_UNSPECIFIED
*/
public static Discovery.ServiceDiscoveryProtocolVersion
selectSupportedServiceDiscoveryProtocolVersion(String acceptedVersionsString) {
// assume V1 in case nothing was set
if (acceptedVersionsString == null || acceptedVersionsString.isEmpty()) {
return Discovery.ServiceDiscoveryProtocolVersion.V1;
}

final String[] supportedVersions = acceptedVersionsString.split(",");

Discovery.ServiceDiscoveryProtocolVersion maxVersion =
Discovery.ServiceDiscoveryProtocolVersion.SERVICE_DISCOVERY_PROTOCOL_VERSION_UNSPECIFIED;

for (String versionString : supportedVersions) {
final Optional<Discovery.ServiceDiscoveryProtocolVersion> optionalVersion =
parseServiceDiscoveryProtocolVersion(versionString);

if (optionalVersion.isPresent()) {
final Discovery.ServiceDiscoveryProtocolVersion version = optionalVersion.get();
if (is_supported(version) && version.getNumber() > maxVersion.getNumber()) {
maxVersion = version;
}
}
}

return maxVersion;
}

public static Optional<Discovery.ServiceDiscoveryProtocolVersion>
parseServiceDiscoveryProtocolVersion(String versionString) {
versionString = versionString.trim();

if (versionString.equals("application/vnd.restate.endpointmanifest.v1+json")) {
return Optional.of(Discovery.ServiceDiscoveryProtocolVersion.V1);
}
return Optional.empty();
}

public static String serviceDiscoveryProtocolVersionToHeaderValue(
Discovery.ServiceDiscoveryProtocolVersion version) {
if (Objects.requireNonNull(version) == Discovery.ServiceDiscoveryProtocolVersion.V1) {
return "application/vnd.restate.endpointmanifest.v1+json";
}
throw new IllegalArgumentException(
String.format(
"Service discovery protocol version '%s' has no header value", version.getNumber()));
}

public static class DiscoveryResponseSerializer {
private static final ObjectMapper MANIFEST_OBJECT_MAPPER = new ObjectMapper();

private final Discovery.ServiceDiscoveryProtocolVersion serviceDiscoveryProtocolVersion;

public DiscoveryResponseSerializer(
Discovery.ServiceDiscoveryProtocolVersion serviceDiscoveryProtocolVersion) {
if (!is_supported(serviceDiscoveryProtocolVersion)) {
throw new IllegalArgumentException("Unsupported service discovery protocol version");
}

this.serviceDiscoveryProtocolVersion = serviceDiscoveryProtocolVersion;
}

public byte[] serialize(EndpointManifestSchema response) throws Exception {
if (this.serviceDiscoveryProtocolVersion == Discovery.ServiceDiscoveryProtocolVersion.V1) {
return MANIFEST_OBJECT_MAPPER.writeValueAsBytes(response);
}

throw new IllegalStateException(
String.format(
"DiscoveryResponseSerializer does not support service discovery protocol '%s'",
this.serviceDiscoveryProtocolVersion.getNumber()));
}
}
}
2 changes: 1 addition & 1 deletion sdk-core/src/main/service-protocol/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ This repo contains specification documents and Protobuf schemas of the Restate S

To format the spec document:

```
```shell
npx prettier -w service-invocation-protocol.md
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) 2024 - Restate Software, Inc., Restate GmbH
//
// This file is part of the Restate service protocol, which is
// released under the MIT license.
//
// You can find a copy of the license in file LICENSE in the root
// directory of this repository or package, or at
// https://github.com/restatedev/service-protocol/blob/main/LICENSE

syntax = "proto3";

package dev.restate.service.discovery;

option java_package = "dev.restate.generated.service.discovery";
option go_package = "restate.dev/sdk-go/pb/service/discovery";

// Service discovery protocol version.
enum ServiceDiscoveryProtocolVersion {
SERVICE_DISCOVERY_PROTOCOL_VERSION_UNSPECIFIED = 0;
// initial service discovery protocol version using endpoint_manifest_schema.json
V1 = 1;
}
Loading
Loading