Skip to content

Commit

Permalink
fix: Use a cursor client so we can close the client (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
palmere-google authored Jan 6, 2022
1 parent 1bff893 commit 7de246d
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public UserCodeClassLoader getUserCodeClassLoader() {
});
return new PubsubLiteSourceReader<>(
new PubsubLiteRecordEmitter<>(),
settings.getCursorCommitter(),
settings.getCursorClient(),
settings.getSplitReaderSupplier(),
new Configuration(),
readerContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
import static com.google.cloud.pubsublite.internal.wire.ServiceClients.addDefaultMetadata;
import static com.google.cloud.pubsublite.internal.wire.ServiceClients.addDefaultSettings;

import com.google.api.core.ApiFutureCallback;
import com.google.api.core.ApiFutures;
import com.google.api.gax.rpc.ApiException;
import com.google.auto.value.AutoValue;
import com.google.cloud.pubsublite.AdminClient;
Expand Down Expand Up @@ -50,9 +48,7 @@
import com.google.cloud.pubsublite.proto.SeekRequest;
import com.google.cloud.pubsublite.v1.SubscriberServiceClient;
import com.google.cloud.pubsublite.v1.SubscriberServiceSettings;
import com.google.common.util.concurrent.MoreExecutors;
import java.io.Serializable;
import java.util.function.Consumer;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.apache.flink.api.connector.source.Boundedness;
Expand Down Expand Up @@ -182,24 +178,6 @@ Supplier<SplitReader<Record<OutputT>, SubscriptionPartitionSplit>> getSplitReade
timestampSelector());
}

Consumer<SubscriptionPartitionSplit> getCursorCommitter() {
CursorClient client = getCursorClient();
return (SubscriptionPartitionSplit split) -> {
ApiFutures.addCallback(
client.commitCursor(split.subscriptionPath(), split.partition(), split.start()),
new ApiFutureCallback<Void>() {
@Override
public void onFailure(Throwable throwable) {
LOG.error("Failed to commit cursor to Pub/Sub Lite ", throwable);
}

@Override
public void onSuccess(Void unused) {}
},
MoreExecutors.directExecutor());
};
}

@AutoValue.Builder
public abstract static class Builder<OutputT> {
// Required
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
*/
package com.google.cloud.pubsublite.flink.internal.reader;

import com.google.api.core.ApiFutureCallback;
import com.google.api.core.ApiFutures;
import com.google.cloud.pubsublite.flink.internal.split.SubscriptionPartitionSplit;
import com.google.cloud.pubsublite.internal.CursorClient;
import com.google.cloud.pubsublite.internal.wire.SystemExecutors;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.concurrent.GuardedBy;
import java.util.Collection;
Expand All @@ -25,19 +29,22 @@
import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CheckpointCursorCommitter implements AutoCloseable {
private static final Logger LOG = LoggerFactory.getLogger(CheckpointCursorCommitter.class);

public class CheckpointCursorCommitter {
@GuardedBy("this")
private final Set<SubscriptionPartitionSplit> finished = new HashSet<>();

@GuardedBy("this")
private final LinkedHashMap<Long, List<SubscriptionPartitionSplit>> checkpoints =
new LinkedHashMap<>();

private final Consumer<SubscriptionPartitionSplit> cursorCommitter;
private final CursorClient cursorCommitter;

public CheckpointCursorCommitter(Consumer<SubscriptionPartitionSplit> cursorCommitter) {
public CheckpointCursorCommitter(CursorClient cursorCommitter) {
this.cursorCommitter = cursorCommitter;
}

Expand All @@ -51,13 +58,28 @@ public synchronized void addCheckpoint(
.build());
}

private void commitCursor(SubscriptionPartitionSplit split) {
ApiFutures.addCallback(
cursorCommitter.commitCursor(split.subscriptionPath(), split.partition(), split.start()),
new ApiFutureCallback<Void>() {
@Override
public void onFailure(Throwable throwable) {
LOG.error("Failed to commit cursor to Pub/Sub Lite ", throwable);
}

@Override
public void onSuccess(Void unused) {}
},
SystemExecutors.getAlarmExecutor());
}

public synchronized void notifyCheckpointComplete(long checkpointId) {
if (!checkpoints.containsKey(checkpointId)) {
return;
}
// Commit offsets corresponding to this checkpoint.
List<SubscriptionPartitionSplit> splits = checkpoints.get(checkpointId);
splits.forEach(cursorCommitter);
splits.forEach(this::commitCursor);
// Prune all checkpoints created before the one we just committed.
Iterator<Entry<Long, List<SubscriptionPartitionSplit>>> iter =
checkpoints.entrySet().iterator();
Expand All @@ -73,4 +95,9 @@ public synchronized void notifyCheckpointComplete(long checkpointId) {
public synchronized void notifySplitFinished(Collection<SubscriptionPartitionSplit> splits) {
finished.addAll(splits);
}

@Override
public void close() throws Exception {
cursorCommitter.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import com.google.cloud.pubsublite.flink.internal.split.SubscriptionPartitionSplit;
import com.google.cloud.pubsublite.flink.internal.split.SubscriptionPartitionSplitState;
import com.google.cloud.pubsublite.internal.CursorClient;
import java.util.*;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.flink.api.connector.source.SourceReaderContext;
Expand All @@ -34,7 +34,7 @@ public class PubsubLiteSourceReader<T>

public PubsubLiteSourceReader(
RecordEmitter<Record<T>, T, SubscriptionPartitionSplitState> recordEmitter,
Consumer<SubscriptionPartitionSplit> cursorCommitter,
CursorClient cursorCommitter,
Supplier<SplitReader<Record<T>, SubscriptionPartitionSplit>> splitReaderSupplier,
Configuration config,
SourceReaderContext context) {
Expand Down Expand Up @@ -78,4 +78,13 @@ protected void onSplitFinished(Map<String, SubscriptionPartitionSplitState> map)
.map(SubscriptionPartitionSplitState::toSplit)
.collect(Collectors.toList()));
}

@Override
public void close() throws Exception {
try {
checkpointCursorCommitter.close();
} finally {
super.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@

import static com.google.cloud.pubsublite.internal.testing.UnitTestExamples.examplePartition;
import static com.google.cloud.pubsublite.internal.testing.UnitTestExamples.exampleSubscriptionPath;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

import com.google.api.core.ApiFutures;
import com.google.cloud.pubsublite.Offset;
import com.google.cloud.pubsublite.Partition;
import com.google.cloud.pubsublite.flink.internal.split.SubscriptionPartitionSplit;
import com.google.cloud.pubsublite.internal.CursorClient;
import com.google.common.collect.ImmutableList;
import java.util.function.Consumer;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -34,12 +37,14 @@

@RunWith(MockitoJUnitRunner.class)
public class CheckpointCursorCommitterTest {
@Mock Consumer<SubscriptionPartitionSplit> mockConsumer;
@Mock CursorClient mockCursorClient;
CheckpointCursorCommitter cursorCommitter;

@Before
public void setUp() {
cursorCommitter = new CheckpointCursorCommitter(mockConsumer);
cursorCommitter = new CheckpointCursorCommitter(mockCursorClient);
when(mockCursorClient.commitCursor(any(), any(), any()))
.thenReturn(ApiFutures.immediateFuture(null));
}

public static SubscriptionPartitionSplit splitFromPartition(Partition partition) {
Expand All @@ -52,23 +57,25 @@ public void testFinishedSplits() {
cursorCommitter.notifySplitFinished(ImmutableList.of(split));
cursorCommitter.addCheckpoint(1, ImmutableList.of());
cursorCommitter.notifyCheckpointComplete(1);
verify(mockConsumer).accept(split);
verify(mockCursorClient)
.commitCursor(split.subscriptionPath(), split.partition(), split.start());
}

@Test
public void testCheckpointCommitted() {
SubscriptionPartitionSplit split = splitFromPartition(examplePartition());
cursorCommitter.addCheckpoint(1, ImmutableList.of(split));
cursorCommitter.notifyCheckpointComplete(1);
verify(mockConsumer).accept(split);
verify(mockCursorClient)
.commitCursor(split.subscriptionPath(), split.partition(), split.start());
}

@Test
public void testUnknownCheckpoint() {
SubscriptionPartitionSplit split = splitFromPartition(examplePartition());
cursorCommitter.addCheckpoint(1, ImmutableList.of(split));
cursorCommitter.notifyCheckpointComplete(4);
verifyNoInteractions(mockConsumer);
verifyNoInteractions(mockCursorClient);
}

@Test
Expand All @@ -81,8 +88,15 @@ public void testIntermediateCheckpointSkipped() {

// Checkpoint 1 is committed, removing checkpoint 2
cursorCommitter.notifyCheckpointComplete(1);
verify(mockConsumer).accept(split1);
verify(mockCursorClient)
.commitCursor(split1.subscriptionPath(), split1.partition(), split1.start());
cursorCommitter.notifyCheckpointComplete(2);
verifyNoMoreInteractions(mockConsumer);
verifyNoMoreInteractions(mockCursorClient);
}

@Test
public void testClose() throws Exception {
cursorCommitter.close();
verify(mockCursorClient).close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,23 @@

import static com.google.cloud.pubsublite.internal.testing.UnitTestExamples.exampleSubscriptionPath;
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.api.core.ApiFutures;
import com.google.cloud.pubsublite.Offset;
import com.google.cloud.pubsublite.Partition;
import com.google.cloud.pubsublite.SequencedMessage;
import com.google.cloud.pubsublite.flink.MessageTimestampExtractor;
import com.google.cloud.pubsublite.flink.PubsubLiteDeserializationSchema;
import com.google.cloud.pubsublite.flink.internal.split.SubscriptionPartitionSplit;
import com.google.cloud.pubsublite.internal.CursorClient;
import com.google.cloud.pubsublite.proto.Cursor;
import com.google.cloud.pubsublite.proto.PubSubMessage;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import java.util.Optional;
import java.util.function.Consumer;
import org.apache.flink.api.common.serialization.SimpleStringSchema;
import org.apache.flink.api.connector.source.*;
import org.apache.flink.configuration.Configuration;
Expand All @@ -46,7 +48,7 @@
@RunWith(MockitoJUnitRunner.class)
public class PubsubLiteSourceReaderTest {
@Mock CompletablePullSubscriber.Factory mockFactory;
@Mock Consumer<SubscriptionPartitionSplit> mockCursorCommitter;
@Mock CursorClient mockCursorClient;
@Mock SourceReaderContext mockContext;
TestingReaderOutput<String> output = new TestingReaderOutput<>();
SourceReader<String, SubscriptionPartitionSplit> reader;
Expand Down Expand Up @@ -75,14 +77,16 @@ public void setUp() {
reader =
new PubsubLiteSourceReader<>(
new PubsubLiteRecordEmitter<>(),
mockCursorCommitter,
mockCursorClient,
() ->
new DeserializingSplitReader<>(
new MessageSplitReader(mockFactory),
PubsubLiteDeserializationSchema.dataOnly(new SimpleStringSchema()),
MessageTimestampExtractor.publishTimeExtractor()),
new Configuration(),
mockContext);
when(mockCursorClient.commitCursor(any(), any(), any()))
.thenReturn(ApiFutures.immediateFuture(null));
}

@Test
Expand All @@ -100,8 +104,8 @@ public void testReader() throws Exception {

reader.snapshotState(1);
reader.notifyCheckpointComplete(1);
verify(mockCursorCommitter).accept(makeSplit(Partition.of(0), Offset.of(2)));
verify(mockCursorCommitter).accept(makeSplit(Partition.of(0), Offset.of(2)));
verify(mockCursorClient).commitCursor(exampleSubscriptionPath(), Partition.of(0), Offset.of(2));
verify(mockCursorClient).commitCursor(exampleSubscriptionPath(), Partition.of(1), Offset.of(3));

while (output.getEmittedRecords().size() < 6) {
reader.pollNext(output);
Expand Down

0 comments on commit 7de246d

Please sign in to comment.