From e39c7babdd80ecdd930bed9d201e8fb9187309ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kristj=C3=A1n=20Valur=20J=C3=B3nsson?= Date: Sun, 22 Jan 2023 18:49:12 +0000 Subject: [PATCH] Simplify the sync SocketBuffer, add type hints (#2543) --- CHANGES | 1 + redis/connection.py | 83 ++++++++++++++++++++++++--------------------- 2 files changed, 45 insertions(+), 39 deletions(-) diff --git a/CHANGES b/CHANGES index 02daf5ee4c..d89079ba6f 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Simplify synchronous SocketBuffer state management * Fix string cleanse in Redis Graph * Make PythonParser resumable in case of error (#2510) * Add `timeout=None` in `SentinelConnectionManager.read_response` diff --git a/redis/connection.py b/redis/connection.py index 126ea5db32..57f0a3a81e 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -5,10 +5,11 @@ import socket import threading import weakref +from io import SEEK_END from itertools import chain from queue import Empty, Full, LifoQueue from time import time -from typing import Optional +from typing import Optional, Union from urllib.parse import parse_qs, unquote, urlparse from redis.backoff import NoBackoff @@ -163,31 +164,40 @@ def parse_error(self, response): class SocketBuffer: - def __init__(self, socket, socket_read_size, socket_timeout): + def __init__( + self, socket: socket.socket, socket_read_size: int, socket_timeout: float + ): self._sock = socket self.socket_read_size = socket_read_size self.socket_timeout = socket_timeout self._buffer = io.BytesIO() - # number of bytes written to the buffer from the socket - self.bytes_written = 0 - # number of bytes read from the buffer - self.bytes_read = 0 - @property - def length(self): - return self.bytes_written - self.bytes_read + def unread_bytes(self) -> int: + """ + Remaining unread length of buffer + """ + pos = self._buffer.tell() + end = self._buffer.seek(0, SEEK_END) + self._buffer.seek(pos) + return end - pos - def _read_from_socket(self, length=None, timeout=SENTINEL, raise_on_timeout=True): + def _read_from_socket( + self, + length: Optional[int] = None, + timeout: Union[float, object] = SENTINEL, + raise_on_timeout: Optional[bool] = True, + ) -> bool: sock = self._sock socket_read_size = self.socket_read_size - buf = self._buffer - buf.seek(self.bytes_written) marker = 0 custom_timeout = timeout is not SENTINEL + buf = self._buffer + current_pos = buf.tell() + buf.seek(0, SEEK_END) + if custom_timeout: + sock.settimeout(timeout) try: - if custom_timeout: - sock.settimeout(timeout) while True: data = self._sock.recv(socket_read_size) # an empty string indicates the server shutdown the socket @@ -195,7 +205,6 @@ def _read_from_socket(self, length=None, timeout=SENTINEL, raise_on_timeout=True raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) buf.write(data) data_length = len(data) - self.bytes_written += data_length marker += data_length if length is not None and length > marker: @@ -215,55 +224,53 @@ def _read_from_socket(self, length=None, timeout=SENTINEL, raise_on_timeout=True return False raise ConnectionError(f"Error while reading from socket: {ex.args}") finally: + buf.seek(current_pos) if custom_timeout: sock.settimeout(self.socket_timeout) - def can_read(self, timeout): - return bool(self.length) or self._read_from_socket( + def can_read(self, timeout: float) -> bool: + return bool(self.unread_bytes()) or self._read_from_socket( timeout=timeout, raise_on_timeout=False ) - def read(self, length): + def read(self, length: int) -> bytes: length = length + 2 # make sure to read the \r\n terminator - # make sure we've read enough data from the socket - if length > self.length: - self._read_from_socket(length - self.length) - - self._buffer.seek(self.bytes_read) + # BufferIO will return less than requested if buffer is short data = self._buffer.read(length) - self.bytes_read += len(data) + missing = length - len(data) + if missing: + # fill up the buffer and read the remainder + self._read_from_socket(missing) + data += self._buffer.read(missing) return data[:-2] - def readline(self): + def readline(self) -> bytes: buf = self._buffer - buf.seek(self.bytes_read) data = buf.readline() while not data.endswith(SYM_CRLF): # there's more data in the socket that we need self._read_from_socket() - buf.seek(self.bytes_read) - data = buf.readline() + data += buf.readline() - self.bytes_read += len(data) return data[:-2] - def get_pos(self): + def get_pos(self) -> int: """ Get current read position """ - return self.bytes_read + return self._buffer.tell() - def rewind(self, pos): + def rewind(self, pos: int) -> None: """ Rewind the buffer to a specific position, to re-start reading """ - self.bytes_read = pos + self._buffer.seek(pos) - def purge(self): + def purge(self) -> None: """ After a successful read, purge the read part of buffer """ - unread = self.bytes_written - self.bytes_read + unread = self.unread_bytes() # Only if we have read all of the buffer do we truncate, to # reduce the amount of memory thrashing. This heuristic @@ -276,13 +283,10 @@ def purge(self): view = self._buffer.getbuffer() view[:unread] = view[-unread:] self._buffer.truncate(unread) - self.bytes_written = unread - self.bytes_read = 0 self._buffer.seek(0) - def close(self): + def close(self) -> None: try: - self.bytes_written = self.bytes_read = 0 self._buffer.close() except Exception: # issue #633 suggests the purge/close somehow raised a @@ -498,6 +502,7 @@ def read_response(self, disable_decoding=False): return response +DefaultParser: BaseParser if HIREDIS_AVAILABLE: DefaultParser = HiredisParser else: