Skip to content

Commit

Permalink
Add the EagerStateTest
Browse files Browse the repository at this point in the history
  • Loading branch information
slinkydeveloper committed May 25, 2023
1 parent 7e7929b commit 7ca190b
Show file tree
Hide file tree
Showing 3 changed files with 268 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import io.grpc.MethodDescriptor;
import io.grpc.ServerServiceDefinition;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -58,7 +59,7 @@ void executeTest(
String testName,
ServerServiceDefinition svc,
String method,
List<MessageLite> input,
List<InvocationInput> input,
ThreadingModel threadingModel,
BiConsumer<FutureSubscriber<MessageLite>, Duration> outputAssert) {
Executor syscallsExecutor =
Expand Down Expand Up @@ -97,11 +98,8 @@ void executeTest(
});

// Pipe entries
for (MessageLite inputEntry : input) {
syscallsExecutor.execute(
() ->
inputPublisher.push(
InvocationInput.of(MessageHeader.fromMessage(inputEntry), inputEntry)));
for (InvocationInput inputEntry : input) {
syscallsExecutor.execute(() -> inputPublisher.push(inputEntry));
}
// Complete the input publisher
syscallsExecutor.execute(inputPublisher::close);
Expand All @@ -111,13 +109,7 @@ void executeTest(
assertThat(inputPublisher.isSubscriptionCancelled()).isTrue();
} else {
// Create publisher
BufferedMockPublisher<InvocationInput> inputPublisher =
new BufferedMockPublisher<>(
input.stream()
.map(
inputEntry ->
InvocationInput.of(MessageHeader.fromMessage(inputEntry), inputEntry))
.collect(Collectors.toList()));
BufferedMockPublisher<InvocationInput> inputPublisher = new BufferedMockPublisher<>(input);

// Wire invocation
handler.output().subscribe(outputSubscriber);
Expand All @@ -142,7 +134,7 @@ interface TestDefinition {

String getMethod();

List<MessageLite> getInput();
List<InvocationInput> getInput();

HashSet<ThreadingModel> getThreadingModels();

Expand All @@ -164,33 +156,66 @@ static TestInvocationBuilder testInvocation(
}

static class TestInvocationBuilder {
private final BindableService svc;
private final String method;
protected final BindableService svc;
protected final String method;

TestInvocationBuilder(BindableService svc, String method) {
this.svc = svc;
this.method = method;
}

WithInputBuilder withInput(short flags, MessageLiteOrBuilder msgOrBuilder) {
MessageLite msg = ProtoUtils.build(msgOrBuilder);
return new WithInputBuilder(
svc,
method,
List.of(InvocationInput.of(MessageHeader.fromMessage(msg).copyWithFlags(flags), msg)));
}

WithInputBuilder withInput(MessageLiteOrBuilder... messages) {
return new WithInputBuilder(svc, method, Arrays.asList(messages));
return new WithInputBuilder(
svc,
method,
Arrays.stream(messages)
.map(
msgOrBuilder -> {
MessageLite msg = ProtoUtils.build(msgOrBuilder);
return InvocationInput.of(MessageHeader.fromMessage(msg), msg);
})
.collect(Collectors.toList()));
}
}

static class WithInputBuilder {
private final BindableService svc;
private final String method;
private final List<MessageLiteOrBuilder> input;
static class WithInputBuilder extends TestInvocationBuilder {
private final List<InvocationInput> input;

WithInputBuilder(BindableService svc, String method, List<MessageLiteOrBuilder> input) {
this.svc = svc;
this.method = method;
this.input = input;
WithInputBuilder(BindableService svc, String method, List<InvocationInput> input) {
super(svc, method);
this.input = new ArrayList<>(input);
}

WithInputBuilder withInput(short flags, MessageLiteOrBuilder msgOrBuilder) {
MessageLite msg = ProtoUtils.build(msgOrBuilder);
this.input.add(
InvocationInput.of(MessageHeader.fromMessage(msg).copyWithFlags(flags), msg));
return this;
}

WithInputBuilder withInput(MessageLiteOrBuilder... messages) {
this.input.addAll(
Arrays.stream(messages)
.map(
msgOrBuilder -> {
MessageLite msg = ProtoUtils.build(msgOrBuilder);
return InvocationInput.of(MessageHeader.fromMessage(msg), msg);
})
.collect(Collectors.toList()));
return this;
}

UsingThreadingModelsBuilder usingThreadingModels(ThreadingModel... threadingModels) {
return new UsingThreadingModelsBuilder(
svc, method, input, new HashSet<>(Arrays.asList(threadingModels)));
this.svc, this.method, input, new HashSet<>(Arrays.asList(threadingModels)));
}

UsingThreadingModelsBuilder usingAllThreadingModels() {
Expand All @@ -201,13 +226,13 @@ UsingThreadingModelsBuilder usingAllThreadingModels() {
static class UsingThreadingModelsBuilder {
private final BindableService svc;
private final String method;
private final List<MessageLiteOrBuilder> input;
private final List<InvocationInput> input;
private final HashSet<ThreadingModel> threadingModels;

UsingThreadingModelsBuilder(
BindableService svc,
String method,
List<MessageLiteOrBuilder> input,
List<InvocationInput> input,
HashSet<ThreadingModel> threadingModels) {
this.svc = svc;
this.method = method;
Expand All @@ -221,10 +246,6 @@ ExpectingOutputMessages expectingOutput(MessageLiteOrBuilder... messages) {
return assertingOutput(actual -> assertThat(actual).asList().isEqualTo(builtMessages));
}

ExpectingOutputMessages expectingNoOutput() {
return assertingOutput(messages -> assertThat(messages).asList().isEmpty());
}

ExpectingOutputMessages assertingOutput(Consumer<List<MessageLite>> messages) {
return new ExpectingOutputMessages(svc, method, input, threadingModels, messages);
}
Expand All @@ -241,22 +262,22 @@ ExpectingFailure assertingFailure(Consumer<Throwable> assertFailure) {
public abstract static class BaseTestDefinition implements TestDefinition {
protected final BindableService svc;
protected final String method;
protected final List<MessageLiteOrBuilder> input;
protected final List<InvocationInput> input;
protected final HashSet<ThreadingModel> threadingModels;
protected final String named;

public BaseTestDefinition(
BindableService svc,
String method,
List<MessageLiteOrBuilder> input,
List<InvocationInput> input,
HashSet<ThreadingModel> threadingModels) {
this(svc, method, input, threadingModels, svc.getClass().getSimpleName());
}

public BaseTestDefinition(
BindableService svc,
String method,
List<MessageLiteOrBuilder> input,
List<InvocationInput> input,
HashSet<ThreadingModel> threadingModels,
String named) {
this.svc = svc;
Expand All @@ -277,8 +298,8 @@ public String getMethod() {
}

@Override
public List<MessageLite> getInput() {
return input.stream().map(ProtoUtils::build).collect(Collectors.toList());
public List<InvocationInput> getInput() {
return input;
}

@Override
Expand All @@ -298,7 +319,7 @@ static class ExpectingOutputMessages extends BaseTestDefinition {
ExpectingOutputMessages(
BindableService svc,
String method,
List<MessageLiteOrBuilder> input,
List<InvocationInput> input,
HashSet<ThreadingModel> threadingModels,
Consumer<List<MessageLite>> messagesAssert) {
super(svc, method, input, threadingModels);
Expand All @@ -308,7 +329,7 @@ static class ExpectingOutputMessages extends BaseTestDefinition {
ExpectingOutputMessages(
BindableService svc,
String method,
List<MessageLiteOrBuilder> input,
List<InvocationInput> input,
HashSet<ThreadingModel> threadingModels,
Consumer<List<MessageLite>> messagesAssert,
String named) {
Expand Down Expand Up @@ -351,7 +372,7 @@ static class ExpectingFailure extends BaseTestDefinition {
ExpectingFailure(
BindableService svc,
String method,
List<MessageLiteOrBuilder> input,
List<InvocationInput> input,
HashSet<ThreadingModel> threadingModels,
Consumer<Throwable> throwableAssert) {
super(svc, method, input, threadingModels);
Expand All @@ -361,7 +382,7 @@ static class ExpectingFailure extends BaseTestDefinition {
ExpectingFailure(
BindableService svc,
String method,
List<MessageLiteOrBuilder> input,
List<InvocationInput> input,
HashSet<ThreadingModel> threadingModels,
Consumer<Throwable> throwableAssert,
String named) {
Expand Down
Loading

0 comments on commit 7ca190b

Please sign in to comment.