Skip to content

Commit

Permalink
add equals/hashCode to debezium offsetHolder to avoid coder warning (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
scwhittle authored Feb 25, 2025
1 parent f0cc9b6 commit 87f0ed3
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,16 @@
import java.io.Serializable;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.function.Predicate;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.transforms.DoFn;
Expand All @@ -46,6 +49,9 @@
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Streams;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.hash.Hasher;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.hash.Hashing;
import org.apache.kafka.connect.source.SourceConnector;
import org.apache.kafka.connect.source.SourceRecord;
import org.apache.kafka.connect.source.SourceTask;
Expand All @@ -72,7 +78,7 @@
*
* <p>It might be initialized either as:
*
* <pre>KafkaSourceConsumerFn(connectorClass, SourceRecordMapper, maxRecords, milisecondsToRun)
* <pre>KafkaSourceConsumerFn(connectorClass, SourceRecordMapper, maxRecords, millisecondsToRun)
* </pre>
*
* Or with a time limiter:
Expand All @@ -87,7 +93,7 @@ public class KafkaSourceConsumerFn<T> extends DoFn<Map<String, String>, T> {
private final Class<? extends SourceConnector> connectorClass;
private final SourceRecordMapper<T> fn;

private final Long milisecondsToRun;
private final Long millisecondsToRun;
private final Integer maxRecords;

private static DateTime startTime;
Expand All @@ -100,17 +106,18 @@ public class KafkaSourceConsumerFn<T> extends DoFn<Map<String, String>, T> {
* @param connectorClass Supported Debezium connector class
* @param fn a SourceRecordMapper
* @param maxRecords Maximum number of records to fetch before finishing.
* @param milisecondsToRun Maximum time to run (in milliseconds)
* @param millisecondsToRun Maximum time to run (in milliseconds)
*/
@SuppressWarnings("unchecked")
KafkaSourceConsumerFn(
Class<?> connectorClass,
SourceRecordMapper<T> fn,
Integer maxRecords,
Long milisecondsToRun) {
Long millisecondsToRun) {
this.connectorClass = (Class<? extends SourceConnector>) connectorClass;
this.fn = fn;
this.maxRecords = maxRecords;
this.milisecondsToRun = milisecondsToRun;
this.millisecondsToRun = millisecondsToRun;
}

/**
Expand All @@ -128,7 +135,7 @@ public class KafkaSourceConsumerFn<T> extends DoFn<Map<String, String>, T> {
public OffsetHolder getInitialRestriction(@Element Map<String, String> unused)
throws IOException {
KafkaSourceConsumerFn.startTime = new DateTime();
return new OffsetHolder(null, null, null, this.maxRecords, this.milisecondsToRun);
return new OffsetHolder(null, null, null, this.maxRecords, this.millisecondsToRun);
}

@NewTracker
Expand Down Expand Up @@ -275,7 +282,7 @@ public ProcessContinuation process(
}

long elapsedTime = System.currentTimeMillis() - KafkaSourceConsumerFn.startTime.getMillis();
if (milisecondsToRun != null && milisecondsToRun > 0 && elapsedTime >= milisecondsToRun) {
if (millisecondsToRun != null && millisecondsToRun > 0 && elapsedTime >= millisecondsToRun) {
return ProcessContinuation.stop();
} else {
return ProcessContinuation.resume().withResumeDelay(Duration.standardSeconds(1));
Expand Down Expand Up @@ -336,30 +343,66 @@ public <T> Map<Map<String, T>, Map<String, Object>> offsets(

static class OffsetHolder implements Serializable {
public @Nullable Map<String, ?> offset;
public @Nullable List<?> history;
public @Nullable List<byte[]> history;
public @Nullable Integer fetchedRecords;
public @Nullable Integer maxRecords;
public final @Nullable Long milisToRun;
public final @Nullable Long millisToRun;

OffsetHolder(
@Nullable Map<String, ?> offset,
@Nullable List<?> history,
@Nullable List<byte[]> history,
@Nullable Integer fetchedRecords,
@Nullable Integer maxRecords,
@Nullable Long milisToRun) {
@Nullable Long millisToRun) {
this.offset = offset;
this.history = history == null ? new ArrayList<>() : history;
this.fetchedRecords = fetchedRecords;
this.maxRecords = maxRecords;
this.milisToRun = milisToRun;
this.millisToRun = millisToRun;
}

OffsetHolder(
@Nullable Map<String, ?> offset,
@Nullable List<?> history,
@Nullable List<byte[]> history,
@Nullable Integer fetchedRecords) {
this(offset, history, fetchedRecords, null, -1L);
}

@Override
public boolean equals(Object other) {
if (!(other instanceof OffsetHolder)) {
return false;
}
OffsetHolder otherOffset = (OffsetHolder) other;
if (history == null) {
return otherOffset.history == null;
} else {
if (otherOffset.history == null) {
return false;
}
if (history.size() != otherOffset.history.size()) {
return false;
}
if (!Streams.zip(history.stream(), otherOffset.history.stream(), Arrays::equals)
.allMatch(Predicate.isEqual(true))) {
return false;
}
}
return Objects.equals(offset, otherOffset.offset)
&& Objects.equals(fetchedRecords, otherOffset.fetchedRecords)
&& Objects.equals(maxRecords, otherOffset.maxRecords)
&& Objects.equals(millisToRun, otherOffset.millisToRun);
}

@Override
public int hashCode() {
Hasher hasher = Hashing.goodFastHash(32).newHasher();
for (byte[] h : history) {
hasher.putInt(h.length);
hasher.putBytes(h);
}
return Objects.hash(offset, hasher.hash(), fetchedRecords, maxRecords, millisToRun);
}
}

/** {@link RestrictionTracker} for Debezium connectors. */
Expand Down Expand Up @@ -395,7 +438,8 @@ public boolean tryClaim(Map<String, Object> position) {
int fetchedRecords =
this.restriction.fetchedRecords == null ? 0 : this.restriction.fetchedRecords + 1;
LOG.debug("------------Fetched records {} / {}", fetchedRecords, this.restriction.maxRecords);
LOG.debug("-------------- Time running: {} / {}", elapsedTime, (this.restriction.milisToRun));
LOG.debug(
"-------------- Time running: {} / {}", elapsedTime, (this.restriction.millisToRun));
this.restriction.offset = position;
this.restriction.fetchedRecords = fetchedRecords;
LOG.debug("-------------- History: {}", this.restriction.history);
Expand All @@ -404,9 +448,9 @@ public boolean tryClaim(Map<String, Object> position) {
// the attempt to claim.
// If we've reached neither, then we continue approve the claim.
return (this.restriction.maxRecords == null || fetchedRecords < this.restriction.maxRecords)
&& (this.restriction.milisToRun == null
|| this.restriction.milisToRun == -1
|| elapsedTime < this.restriction.milisToRun);
&& (this.restriction.millisToRun == null
|| this.restriction.millisToRun == -1
|| elapsedTime < this.restriction.millisToRun);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@
*/
package org.apache.beam.io.debezium;

import com.google.common.testing.EqualsTester;
import java.io.Serializable;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.beam.io.debezium.KafkaSourceConsumerFn.OffsetHolder;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.MapCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
Expand All @@ -30,6 +34,7 @@
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.apache.kafka.common.config.AbstractConfig;
Expand All @@ -50,6 +55,7 @@

@RunWith(JUnit4.class)
public class KafkaSourceConsumerFnTest implements Serializable {

@Test
public void testKafkaSourceConsumerFn() {
Map<String, String> config =
Expand Down Expand Up @@ -105,7 +111,55 @@ public void testStoppableKafkaSourceConsumerFn() {
pipeline.run().waitUntilFinish();
Assert.assertEquals(1, CounterTask.getCountTasks());
}
}

@Test
public void testKafkaOffsetHolderEquality() {
EqualsTester tester = new EqualsTester();

HashMap<String, Integer> map = new HashMap<>();
map.put("a", 1);
map.put("b", 2);
ArrayList<byte[]> list = new ArrayList<>();
list.add("abc".getBytes(StandardCharsets.US_ASCII));
list.add(new byte[0]);
tester.addEqualityGroup(
new OffsetHolder(
ImmutableMap.of("a", 1, "b", 2),
ImmutableList.of("abc".getBytes(StandardCharsets.US_ASCII), new byte[0]),
1,
null,
-1L),
new OffsetHolder(map, list, 1, null, -1L),
new OffsetHolder(map, list, 1, null, -1L),
new OffsetHolder(map, list, 1));
tester.addEqualityGroup(new OffsetHolder(null, null, null, null, null));
tester.addEqualityGroup(
new OffsetHolder(
ImmutableMap.of("a", 1),
ImmutableList.of("abc".getBytes(StandardCharsets.US_ASCII)),
1));
tester.addEqualityGroup(
new OffsetHolder(
ImmutableMap.of("a", 1),
ImmutableList.of("abc".getBytes(StandardCharsets.US_ASCII)),
2));
tester.addEqualityGroup(
new OffsetHolder(
ImmutableMap.of("a", 1),
ImmutableList.of("abc".getBytes(StandardCharsets.US_ASCII)),
1,
2,
null));
tester.addEqualityGroup(
new OffsetHolder(
ImmutableMap.of("a", 1),
ImmutableList.of("abc".getBytes(StandardCharsets.US_ASCII)),
1,
3,
null));
tester.testEquals();
}
};

class CounterSourceConnector extends SourceConnector {
public static class CounterSourceConnectorConfig extends AbstractConfig {
Expand Down

0 comments on commit 87f0ed3

Please sign in to comment.