Skip to content

Commit

Permalink
Simplify the sync SocketBuffer, add type hints (#2543)
Browse files Browse the repository at this point in the history
  • Loading branch information
kristjanvalur authored Jan 22, 2023
1 parent 5e258a1 commit e39c7ba
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 39 deletions.
1 change: 1 addition & 0 deletions CHANGES
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
* Simplify synchronous SocketBuffer state management
* Fix string cleanse in Redis Graph
* Make PythonParser resumable in case of error (#2510)
* Add `timeout=None` in `SentinelConnectionManager.read_response`
Expand Down
83 changes: 44 additions & 39 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import socket
import threading
import weakref
from io import SEEK_END
from itertools import chain
from queue import Empty, Full, LifoQueue
from time import time
from typing import Optional
from typing import Optional, Union
from urllib.parse import parse_qs, unquote, urlparse

from redis.backoff import NoBackoff
Expand Down Expand Up @@ -163,39 +164,47 @@ def parse_error(self, response):


class SocketBuffer:
def __init__(self, socket, socket_read_size, socket_timeout):
def __init__(
self, socket: socket.socket, socket_read_size: int, socket_timeout: float
):
self._sock = socket
self.socket_read_size = socket_read_size
self.socket_timeout = socket_timeout
self._buffer = io.BytesIO()
# number of bytes written to the buffer from the socket
self.bytes_written = 0
# number of bytes read from the buffer
self.bytes_read = 0

@property
def length(self):
return self.bytes_written - self.bytes_read
def unread_bytes(self) -> int:
"""
Remaining unread length of buffer
"""
pos = self._buffer.tell()
end = self._buffer.seek(0, SEEK_END)
self._buffer.seek(pos)
return end - pos

def _read_from_socket(self, length=None, timeout=SENTINEL, raise_on_timeout=True):
def _read_from_socket(
self,
length: Optional[int] = None,
timeout: Union[float, object] = SENTINEL,
raise_on_timeout: Optional[bool] = True,
) -> bool:
sock = self._sock
socket_read_size = self.socket_read_size
buf = self._buffer
buf.seek(self.bytes_written)
marker = 0
custom_timeout = timeout is not SENTINEL

buf = self._buffer
current_pos = buf.tell()
buf.seek(0, SEEK_END)
if custom_timeout:
sock.settimeout(timeout)
try:
if custom_timeout:
sock.settimeout(timeout)
while True:
data = self._sock.recv(socket_read_size)
# an empty string indicates the server shutdown the socket
if isinstance(data, bytes) and len(data) == 0:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
buf.write(data)
data_length = len(data)
self.bytes_written += data_length
marker += data_length

if length is not None and length > marker:
Expand All @@ -215,55 +224,53 @@ def _read_from_socket(self, length=None, timeout=SENTINEL, raise_on_timeout=True
return False
raise ConnectionError(f"Error while reading from socket: {ex.args}")
finally:
buf.seek(current_pos)
if custom_timeout:
sock.settimeout(self.socket_timeout)

def can_read(self, timeout):
return bool(self.length) or self._read_from_socket(
def can_read(self, timeout: float) -> bool:
return bool(self.unread_bytes()) or self._read_from_socket(
timeout=timeout, raise_on_timeout=False
)

def read(self, length):
def read(self, length: int) -> bytes:
length = length + 2 # make sure to read the \r\n terminator
# make sure we've read enough data from the socket
if length > self.length:
self._read_from_socket(length - self.length)

self._buffer.seek(self.bytes_read)
# BufferIO will return less than requested if buffer is short
data = self._buffer.read(length)
self.bytes_read += len(data)
missing = length - len(data)
if missing:
# fill up the buffer and read the remainder
self._read_from_socket(missing)
data += self._buffer.read(missing)
return data[:-2]

def readline(self):
def readline(self) -> bytes:
buf = self._buffer
buf.seek(self.bytes_read)
data = buf.readline()
while not data.endswith(SYM_CRLF):
# there's more data in the socket that we need
self._read_from_socket()
buf.seek(self.bytes_read)
data = buf.readline()
data += buf.readline()

self.bytes_read += len(data)
return data[:-2]

def get_pos(self):
def get_pos(self) -> int:
"""
Get current read position
"""
return self.bytes_read
return self._buffer.tell()

def rewind(self, pos):
def rewind(self, pos: int) -> None:
"""
Rewind the buffer to a specific position, to re-start reading
"""
self.bytes_read = pos
self._buffer.seek(pos)

def purge(self):
def purge(self) -> None:
"""
After a successful read, purge the read part of buffer
"""
unread = self.bytes_written - self.bytes_read
unread = self.unread_bytes()

# Only if we have read all of the buffer do we truncate, to
# reduce the amount of memory thrashing. This heuristic
Expand All @@ -276,13 +283,10 @@ def purge(self):
view = self._buffer.getbuffer()
view[:unread] = view[-unread:]
self._buffer.truncate(unread)
self.bytes_written = unread
self.bytes_read = 0
self._buffer.seek(0)

def close(self):
def close(self) -> None:
try:
self.bytes_written = self.bytes_read = 0
self._buffer.close()
except Exception:
# issue #633 suggests the purge/close somehow raised a
Expand Down Expand Up @@ -498,6 +502,7 @@ def read_response(self, disable_decoding=False):
return response


DefaultParser: BaseParser
if HIREDIS_AVAILABLE:
DefaultParser = HiredisParser
else:
Expand Down

0 comments on commit e39c7ba

Please sign in to comment.