Skip to content

Commit

Permalink
Wrap consumer.poll() for KafkaConsumer iteration (#1902)
Browse files Browse the repository at this point in the history
  • Loading branch information
dpkp authored Sep 29, 2019
1 parent a9f513c commit 5d1d424
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 11 deletions.
10 changes: 6 additions & 4 deletions kafka/consumer/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def _retrieve_offsets(self, timestamps, timeout_ms=float("inf")):
raise Errors.KafkaTimeoutError(
"Failed to get offsets by timestamps in %s ms" % (timeout_ms,))

def fetched_records(self, max_records=None):
def fetched_records(self, max_records=None, update_offsets=True):
"""Returns previously fetched records and updates consumed offsets.
Arguments:
Expand Down Expand Up @@ -330,10 +330,11 @@ def fetched_records(self, max_records=None):
else:
records_remaining -= self._append(drained,
self._next_partition_records,
records_remaining)
records_remaining,
update_offsets)
return dict(drained), bool(self._completed_fetches)

def _append(self, drained, part, max_records):
def _append(self, drained, part, max_records, update_offsets):
if not part:
return 0

Expand Down Expand Up @@ -366,7 +367,8 @@ def _append(self, drained, part, max_records):
for record in part_records:
drained[tp].append(record)

self._subscriptions.assignment[tp].position = next_offset
if update_offsets:
self._subscriptions.assignment[tp].position = next_offset
return len(part_records)

else:
Expand Down
69 changes: 63 additions & 6 deletions kafka/consumer/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,8 @@ class KafkaConsumer(six.Iterator):
'sasl_plain_password': None,
'sasl_kerberos_service_name': 'kafka',
'sasl_kerberos_domain_name': None,
'sasl_oauth_token_provider': None
'sasl_oauth_token_provider': None,
'legacy_iterator': False, # enable to revert to < 1.4.7 iterator
}
DEFAULT_SESSION_TIMEOUT_MS_0_9 = 30000

Expand Down Expand Up @@ -597,7 +598,7 @@ def partitions_for_topic(self, topic):
partitions = cluster.partitions_for_topic(topic)
return partitions

def poll(self, timeout_ms=0, max_records=None):
def poll(self, timeout_ms=0, max_records=None, update_offsets=True):
"""Fetch data from assigned topics / partitions.
Records are fetched and returned in batches by topic-partition.
Expand All @@ -621,6 +622,12 @@ def poll(self, timeout_ms=0, max_records=None):
dict: Topic to list of records since the last fetch for the
subscribed list of topics and partitions.
"""
# Note: update_offsets is an internal-use only argument. It is used to
# support the python iterator interface, and which wraps consumer.poll()
# and requires that the partition offsets tracked by the fetcher are not
# updated until the iterator returns each record to the user. As such,
# the argument is not documented and should not be relied on by library
# users to not break in the future.
assert timeout_ms >= 0, 'Timeout must not be negative'
if max_records is None:
max_records = self.config['max_poll_records']
Expand All @@ -631,7 +638,7 @@ def poll(self, timeout_ms=0, max_records=None):
start = time.time()
remaining = timeout_ms
while True:
records = self._poll_once(remaining, max_records)
records = self._poll_once(remaining, max_records, update_offsets=update_offsets)
if records:
return records

Expand All @@ -641,7 +648,7 @@ def poll(self, timeout_ms=0, max_records=None):
if remaining <= 0:
return {}

def _poll_once(self, timeout_ms, max_records):
def _poll_once(self, timeout_ms, max_records, update_offsets=True):
"""Do one round of polling. In addition to checking for new data, this does
any needed heart-beating, auto-commits, and offset updates.
Expand All @@ -660,7 +667,7 @@ def _poll_once(self, timeout_ms, max_records):

# If data is available already, e.g. from a previous network client
# poll() call to commit, then just return it immediately
records, partial = self._fetcher.fetched_records(max_records)
records, partial = self._fetcher.fetched_records(max_records, update_offsets=update_offsets)
if records:
# Before returning the fetched records, we can send off the
# next round of fetches and avoid block waiting for their
Expand All @@ -680,7 +687,7 @@ def _poll_once(self, timeout_ms, max_records):
if self._coordinator.need_rejoin():
return {}

records, _ = self._fetcher.fetched_records(max_records)
records, _ = self._fetcher.fetched_records(max_records, update_offsets=update_offsets)
return records

def position(self, partition):
Expand Down Expand Up @@ -743,6 +750,9 @@ def pause(self, *partitions):
for partition in partitions:
log.debug("Pausing partition %s", partition)
self._subscription.pause(partition)
# Because the iterator checks is_fetchable() on each iteration
# we expect pauses to get handled automatically and therefore
# we do not need to reset the full iterator (forcing a full refetch)

def paused(self):
"""Get the partitions that were previously paused using
Expand Down Expand Up @@ -790,6 +800,8 @@ def seek(self, partition, offset):
assert partition in self._subscription.assigned_partitions(), 'Unassigned partition'
log.debug("Seeking to offset %s for partition %s", offset, partition)
self._subscription.assignment[partition].seek(offset)
if not self.config['legacy_iterator']:
self._iterator = None

def seek_to_beginning(self, *partitions):
"""Seek to the oldest available offset for partitions.
Expand All @@ -814,6 +826,8 @@ def seek_to_beginning(self, *partitions):
for tp in partitions:
log.debug("Seeking to beginning of partition %s", tp)
self._subscription.need_offset_reset(tp, OffsetResetStrategy.EARLIEST)
if not self.config['legacy_iterator']:
self._iterator = None

def seek_to_end(self, *partitions):
"""Seek to the most recent available offset for partitions.
Expand All @@ -838,6 +852,8 @@ def seek_to_end(self, *partitions):
for tp in partitions:
log.debug("Seeking to end of partition %s", tp)
self._subscription.need_offset_reset(tp, OffsetResetStrategy.LATEST)
if not self.config['legacy_iterator']:
self._iterator = None

def subscribe(self, topics=(), pattern=None, listener=None):
"""Subscribe to a list of topics, or a topic regex pattern.
Expand Down Expand Up @@ -913,6 +929,8 @@ def unsubscribe(self):
self._client.cluster.need_all_topic_metadata = False
self._client.set_topics([])
log.debug("Unsubscribed all topics or patterns and assigned partitions")
if not self.config['legacy_iterator']:
self._iterator = None

def metrics(self, raw=False):
"""Get metrics on consumer performance.
Expand Down Expand Up @@ -1075,6 +1093,25 @@ def _update_fetch_positions(self, partitions):
# Then, do any offset lookups in case some positions are not known
self._fetcher.update_fetch_positions(partitions)

def _message_generator_v2(self):
timeout_ms = 1000 * (self._consumer_timeout - time.time())
record_map = self.poll(timeout_ms=timeout_ms, update_offsets=False)
for tp, records in six.iteritems(record_map):
# Generators are stateful, and it is possible that the tp / records
# here may become stale during iteration -- i.e., we seek to a
# different offset, pause consumption, or lose assignment.
for record in records:
# is_fetchable(tp) should handle assignment changes and offset
# resets; for all other changes (e.g., seeks) we'll rely on the
# outer function destroying the existing iterator/generator
# via self._iterator = None
if not self._subscription.is_fetchable(tp):
log.debug("Not returning fetched records for partition %s"
" since it is no longer fetchable", tp)
break
self._subscription.assignment[tp].position = record.offset + 1
yield record

def _message_generator(self):
assert self.assignment() or self.subscription() is not None, 'No topic subscription or manual partition assignment'
while time.time() < self._consumer_timeout:
Expand Down Expand Up @@ -1127,6 +1164,26 @@ def __iter__(self): # pylint: disable=non-iterator-returned
return self

def __next__(self):
# Now that the heartbeat thread runs in the background
# there should be no reason to maintain a separate iterator
# but we'll keep it available for a few releases just in case
if self.config['legacy_iterator']:
return self.next_v1()
else:
return self.next_v2()

def next_v2(self):
self._set_consumer_timeout()
while time.time() < self._consumer_timeout:
if not self._iterator:
self._iterator = self._message_generator_v2()
try:
return next(self._iterator)
except StopIteration:
self._iterator = None
raise StopIteration()

def next_v1(self):
if not self._iterator:
self._iterator = self._message_generator()

Expand Down
6 changes: 5 additions & 1 deletion kafka/coordinator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,14 @@ def poll_heartbeat(self):
self.heartbeat.poll()

def time_to_next_heartbeat(self):
"""Returns seconds (float) remaining before next heartbeat should be sent
Note: Returns infinite if group is not joined
"""
with self._lock:
# if we have not joined the group, we don't need to send heartbeats
if self.state is MemberState.UNJOINED:
return sys.maxsize
return float('inf')
return self.heartbeat.time_to_next_heartbeat()

def _handle_join_success(self, member_assignment_bytes):
Expand Down

0 comments on commit 5d1d424

Please sign in to comment.