diff --git a/sdks/java/io/debezium/src/main/java/org/apache/beam/io/debezium/KafkaSourceConsumerFn.java b/sdks/java/io/debezium/src/main/java/org/apache/beam/io/debezium/KafkaSourceConsumerFn.java index 0c9632a69926..54330d620477 100644 --- a/sdks/java/io/debezium/src/main/java/org/apache/beam/io/debezium/KafkaSourceConsumerFn.java +++ b/sdks/java/io/debezium/src/main/java/org/apache/beam/io/debezium/KafkaSourceConsumerFn.java @@ -29,13 +29,16 @@ import java.io.Serializable; import java.lang.reflect.InvocationTargetException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; +import java.util.function.Predicate; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.transforms.DoFn; @@ -46,6 +49,9 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.hash.Hasher; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.hash.Hashing; import org.apache.kafka.connect.source.SourceConnector; import org.apache.kafka.connect.source.SourceRecord; import org.apache.kafka.connect.source.SourceTask; @@ -72,7 +78,7 @@ * *

It might be initialized either as: * - *

KafkaSourceConsumerFn(connectorClass, SourceRecordMapper, maxRecords, milisecondsToRun)
+ * 
KafkaSourceConsumerFn(connectorClass, SourceRecordMapper, maxRecords, millisecondsToRun)
  * 
* * Or with a time limiter: @@ -87,7 +93,7 @@ public class KafkaSourceConsumerFn extends DoFn, T> { private final Class connectorClass; private final SourceRecordMapper fn; - private final Long milisecondsToRun; + private final Long millisecondsToRun; private final Integer maxRecords; private static DateTime startTime; @@ -100,17 +106,18 @@ public class KafkaSourceConsumerFn extends DoFn, T> { * @param connectorClass Supported Debezium connector class * @param fn a SourceRecordMapper * @param maxRecords Maximum number of records to fetch before finishing. - * @param milisecondsToRun Maximum time to run (in milliseconds) + * @param millisecondsToRun Maximum time to run (in milliseconds) */ + @SuppressWarnings("unchecked") KafkaSourceConsumerFn( Class connectorClass, SourceRecordMapper fn, Integer maxRecords, - Long milisecondsToRun) { + Long millisecondsToRun) { this.connectorClass = (Class) connectorClass; this.fn = fn; this.maxRecords = maxRecords; - this.milisecondsToRun = milisecondsToRun; + this.millisecondsToRun = millisecondsToRun; } /** @@ -128,7 +135,7 @@ public class KafkaSourceConsumerFn extends DoFn, T> { public OffsetHolder getInitialRestriction(@Element Map unused) throws IOException { KafkaSourceConsumerFn.startTime = new DateTime(); - return new OffsetHolder(null, null, null, this.maxRecords, this.milisecondsToRun); + return new OffsetHolder(null, null, null, this.maxRecords, this.millisecondsToRun); } @NewTracker @@ -275,7 +282,7 @@ public ProcessContinuation process( } long elapsedTime = System.currentTimeMillis() - KafkaSourceConsumerFn.startTime.getMillis(); - if (milisecondsToRun != null && milisecondsToRun > 0 && elapsedTime >= milisecondsToRun) { + if (millisecondsToRun != null && millisecondsToRun > 0 && elapsedTime >= millisecondsToRun) { return ProcessContinuation.stop(); } else { return ProcessContinuation.resume().withResumeDelay(Duration.standardSeconds(1)); @@ -336,30 +343,66 @@ public Map, Map> offsets( static class OffsetHolder implements Serializable { public @Nullable Map offset; - public @Nullable List history; + public @Nullable List history; public @Nullable Integer fetchedRecords; public @Nullable Integer maxRecords; - public final @Nullable Long milisToRun; + public final @Nullable Long millisToRun; OffsetHolder( @Nullable Map offset, - @Nullable List history, + @Nullable List history, @Nullable Integer fetchedRecords, @Nullable Integer maxRecords, - @Nullable Long milisToRun) { + @Nullable Long millisToRun) { this.offset = offset; this.history = history == null ? new ArrayList<>() : history; this.fetchedRecords = fetchedRecords; this.maxRecords = maxRecords; - this.milisToRun = milisToRun; + this.millisToRun = millisToRun; } OffsetHolder( @Nullable Map offset, - @Nullable List history, + @Nullable List history, @Nullable Integer fetchedRecords) { this(offset, history, fetchedRecords, null, -1L); } + + @Override + public boolean equals(Object other) { + if (!(other instanceof OffsetHolder)) { + return false; + } + OffsetHolder otherOffset = (OffsetHolder) other; + if (history == null) { + return otherOffset.history == null; + } else { + if (otherOffset.history == null) { + return false; + } + if (history.size() != otherOffset.history.size()) { + return false; + } + if (!Streams.zip(history.stream(), otherOffset.history.stream(), Arrays::equals) + .allMatch(Predicate.isEqual(true))) { + return false; + } + } + return Objects.equals(offset, otherOffset.offset) + && Objects.equals(fetchedRecords, otherOffset.fetchedRecords) + && Objects.equals(maxRecords, otherOffset.maxRecords) + && Objects.equals(millisToRun, otherOffset.millisToRun); + } + + @Override + public int hashCode() { + Hasher hasher = Hashing.goodFastHash(32).newHasher(); + for (byte[] h : history) { + hasher.putInt(h.length); + hasher.putBytes(h); + } + return Objects.hash(offset, hasher.hash(), fetchedRecords, maxRecords, millisToRun); + } } /** {@link RestrictionTracker} for Debezium connectors. */ @@ -395,7 +438,8 @@ public boolean tryClaim(Map position) { int fetchedRecords = this.restriction.fetchedRecords == null ? 0 : this.restriction.fetchedRecords + 1; LOG.debug("------------Fetched records {} / {}", fetchedRecords, this.restriction.maxRecords); - LOG.debug("-------------- Time running: {} / {}", elapsedTime, (this.restriction.milisToRun)); + LOG.debug( + "-------------- Time running: {} / {}", elapsedTime, (this.restriction.millisToRun)); this.restriction.offset = position; this.restriction.fetchedRecords = fetchedRecords; LOG.debug("-------------- History: {}", this.restriction.history); @@ -404,9 +448,9 @@ public boolean tryClaim(Map position) { // the attempt to claim. // If we've reached neither, then we continue approve the claim. return (this.restriction.maxRecords == null || fetchedRecords < this.restriction.maxRecords) - && (this.restriction.milisToRun == null - || this.restriction.milisToRun == -1 - || elapsedTime < this.restriction.milisToRun); + && (this.restriction.millisToRun == null + || this.restriction.millisToRun == -1 + || elapsedTime < this.restriction.millisToRun); } @Override diff --git a/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/KafkaSourceConsumerFnTest.java b/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/KafkaSourceConsumerFnTest.java index 2764b1c87d01..f5ada3033561 100644 --- a/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/KafkaSourceConsumerFnTest.java +++ b/sdks/java/io/debezium/src/test/java/org/apache/beam/io/debezium/KafkaSourceConsumerFnTest.java @@ -17,11 +17,15 @@ */ package org.apache.beam.io.debezium; +import com.google.common.testing.EqualsTester; import java.io.Serializable; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; +import org.apache.beam.io.debezium.KafkaSourceConsumerFn.OffsetHolder; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.MapCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; @@ -30,6 +34,7 @@ import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.kafka.common.config.AbstractConfig; @@ -50,6 +55,7 @@ @RunWith(JUnit4.class) public class KafkaSourceConsumerFnTest implements Serializable { + @Test public void testKafkaSourceConsumerFn() { Map config = @@ -105,7 +111,55 @@ public void testStoppableKafkaSourceConsumerFn() { pipeline.run().waitUntilFinish(); Assert.assertEquals(1, CounterTask.getCountTasks()); } -} + + @Test + public void testKafkaOffsetHolderEquality() { + EqualsTester tester = new EqualsTester(); + + HashMap map = new HashMap<>(); + map.put("a", 1); + map.put("b", 2); + ArrayList list = new ArrayList<>(); + list.add("abc".getBytes(StandardCharsets.US_ASCII)); + list.add(new byte[0]); + tester.addEqualityGroup( + new OffsetHolder( + ImmutableMap.of("a", 1, "b", 2), + ImmutableList.of("abc".getBytes(StandardCharsets.US_ASCII), new byte[0]), + 1, + null, + -1L), + new OffsetHolder(map, list, 1, null, -1L), + new OffsetHolder(map, list, 1, null, -1L), + new OffsetHolder(map, list, 1)); + tester.addEqualityGroup(new OffsetHolder(null, null, null, null, null)); + tester.addEqualityGroup( + new OffsetHolder( + ImmutableMap.of("a", 1), + ImmutableList.of("abc".getBytes(StandardCharsets.US_ASCII)), + 1)); + tester.addEqualityGroup( + new OffsetHolder( + ImmutableMap.of("a", 1), + ImmutableList.of("abc".getBytes(StandardCharsets.US_ASCII)), + 2)); + tester.addEqualityGroup( + new OffsetHolder( + ImmutableMap.of("a", 1), + ImmutableList.of("abc".getBytes(StandardCharsets.US_ASCII)), + 1, + 2, + null)); + tester.addEqualityGroup( + new OffsetHolder( + ImmutableMap.of("a", 1), + ImmutableList.of("abc".getBytes(StandardCharsets.US_ASCII)), + 1, + 3, + null)); + tester.testEquals(); + } +}; class CounterSourceConnector extends SourceConnector { public static class CounterSourceConnectorConfig extends AbstractConfig {