Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch from psycopg v2 to v3 #164

Merged
merged 17 commits into from
Mar 17, 2024
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -508,15 +508,15 @@ bucket = SQLiteBucket(rates, conn, table)

#### PostgresBucket

Postgres is supported, but you have to install `psycopg2` or `asyncpg` either as an extra or as a separate package.
Postgres is supported, but you have to install `psycopg[pool]` either as an extra or as a separate package.

You can use Postgres's built-in **CURRENT_TIMESTAMP** as the time source with `PostgresClock`, or use an external custom time source.

```python
from pyrate_limiter import PostgresBucket, Rate, PostgresClock
from psycopg2.pool import ThreadedConnectionPool
from psycopg_pool import ConnectionPool

connection_pool = ThreadedConnectionPool(5, 10, 'postgresql://postgres:postgres@localhost:5432')
connection_pool = ConnectionPool('postgresql://postgres:postgres@localhost:5432')

clock = PostgresClock(connection_pool)
rates = [Rate(3, 1000), Rate(4, 1500)]
Expand Down
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Reuse virtualenv created by poetry instead of creating new ones
nox.options.reuse_existing_virtualenvs = True

PYTEST_ARGS = ["--verbose", "--maxfail=1", "--numprocesses=8"]
PYTEST_ARGS = ["--verbose", "--maxfail=1", "--numprocesses=auto"]
COVERAGE_ARGS = ["--cov", "--cov-report=term", "--cov-report=xml", "--cov-report=html"]


Expand Down
97 changes: 78 additions & 19 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pyrate-limiter"
version = "3.5.1"
version = "3.6.0"
description = "Python Rate-Limiter using Leaky-Bucket Algorithm"
authors = ["vutr <me@vutr.io>"]
license = "MIT"
Expand Down Expand Up @@ -29,7 +29,7 @@ python = "^3.8"
# Optional backend dependencies
filelock = {optional=true, version=">=3.0"}
redis = {optional=true, version="^5.0.0"}
psycopg2 = {version = "^2.9.9", optional = true}
psycopg = {extras = ["pool"], version = "^3.1.18", optional = true}

# Documentation dependencies needed for Readthedocs builds
furo = {optional=true, version="^2022.3.4"}
Expand All @@ -40,7 +40,7 @@ sphinx-copybutton = {optional=true, version=">=0.5"}
sphinxcontrib-apidoc = {optional=true, version="^0.3"}

[tool.poetry.extras]
all = ["filelock", "redis", "psycopg2"]
all = ["filelock", "redis", "psycopg"]
docs = ["furo", "myst-parser", "sphinx", "sphinx-autodoc-typehints",
"sphinx-copybutton", "sphinxcontrib-apidoc"]

Expand All @@ -58,7 +58,7 @@ coverage = "6"
[tool.poetry.group.dev.dependencies]
pytest = "^8.1.1"
pytest-asyncio = "^0.23.5.post1"
psycopg2 = "^2.9.9"
psycopg = {extras = ["pool"], version = "^3.1.18"}

[tool.black]
line-length = 120
Expand Down
7 changes: 6 additions & 1 deletion pyrate_limiter/abstracts/bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(self, leak_interval: int):
self.async_buckets = defaultdict()
self.clocks = defaultdict()
self.leak_interval = leak_interval
self._task = None
super().__init__()

def register(self, bucket: AbstractBucket, clock: AbstractClock):
Expand Down Expand Up @@ -171,7 +172,7 @@ async def _leak(self, sync=True) -> None:
def leak_async(self):
if self.async_buckets and not self.is_async_leak_started:
self.is_async_leak_started = True
asyncio.create_task(self._leak(sync=False))
self._task = asyncio.create_task(self._leak(sync=False))

def run(self) -> None:
assert self.sync_buckets
Expand All @@ -181,6 +182,10 @@ def start(self) -> None:
if self.sync_buckets and not self.is_alive():
super().start()

def cancel(self) -> None:
if self._task:
self._task.cancel()


class BucketFactory(ABC):
"""Asbtract BucketFactory class.
Expand Down
56 changes: 25 additions & 31 deletions pyrate_limiter/buckets/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..abstracts import RateItem

if TYPE_CHECKING:
from psycopg2.pool import AbstractConnectionPool
from psycopg_pool import ConnectionPool


class Queries:
Expand Down Expand Up @@ -54,9 +54,9 @@ class Queries:

class PostgresBucket(AbstractBucket):
table: str
pool: AbstractConnectionPool
pool: ConnectionPool

def __init__(self, pool: AbstractConnectionPool, table: str, rates: List[Rate]):
def __init__(self, pool: ConnectionPool, table: str, rates: List[Rate]):
self.table = table.lower()
self.pool = pool
assert rates
Expand All @@ -65,21 +65,15 @@ def __init__(self, pool: AbstractConnectionPool, table: str, rates: List[Rate]):
self._create_table()

@contextmanager
def _get_conn(self, autocommit=False):
with self.pool._getconn() as conn:
with conn.cursor() as cur:
yield cur

if autocommit:
conn.commit()

self.pool._putconn(conn)
def _get_conn(self):
with self.pool.connection() as conn:
yield conn

def _create_table(self):
with self._get_conn(autocommit=True) as cur:
cur.execute(Queries.CREATE_BUCKET_TABLE.format(table=self._full_tbl))
with self._get_conn() as conn:
conn.execute(Queries.CREATE_BUCKET_TABLE.format(table=self._full_tbl))
index_name = f'timestampIndex_{self.table}'
cur.execute(Queries.CREATE_INDEX_ON_TIMESTAMP.format(table=self._full_tbl, index=index_name))
conn.execute(Queries.CREATE_INDEX_ON_TIMESTAMP.format(table=self._full_tbl, index=index_name))

def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
"""Put an item (typically the current time) in the bucket
Expand All @@ -88,12 +82,12 @@ def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:
if item.weight == 0:
return True

with self._get_conn(autocommit=True) as cur:
with self._get_conn() as conn:
for rate in self.rates:
bound = f"SELECT TO_TIMESTAMP({item.timestamp / 1000}) - INTERVAL '{rate.interval} milliseconds'"
query = f'SELECT COUNT(*) FROM {self._full_tbl} WHERE item_timestamp >= ({bound})'
cur.execute(query)
count = int(cur.fetchone()[0])
conn = conn.execute(query)
count = int(conn.fetchone()[0])

if rate.limit - count < item.weight:
self.failing_rate = rate
Expand All @@ -103,7 +97,7 @@ def put(self, item: RateItem) -> Union[bool, Awaitable[bool]]:

query = Queries.PUT.format(table=self._full_tbl)
arguments = [(item.name, item.weight, item.timestamp / 1000)] * item.weight
cur.executemany(query, tuple(arguments))
conn.executemany(query, tuple(arguments))

return True

Expand All @@ -120,12 +114,12 @@ def leak(

count = 0

with self._get_conn(autocommit=True) as cur:
cur.execute(Queries.LEAK_COUNT.format(table=self._full_tbl, timestamp=lower_bound / 1000))
result = cur.fetchone()
with self._get_conn() as conn:
conn = conn.execute(Queries.LEAK_COUNT.format(table=self._full_tbl, timestamp=lower_bound / 1000))
result = conn.fetchone()

if result:
cur.execute(Queries.LEAK.format(table=self._full_tbl, timestamp=lower_bound / 1000))
conn.execute(Queries.LEAK.format(table=self._full_tbl, timestamp=lower_bound / 1000))
count = int(result[0])

return count
Expand All @@ -134,18 +128,18 @@ def flush(self) -> Union[None, Awaitable[None]]:
"""Flush the whole bucket
- Must remove `failing-rate` after flushing
"""
with self._get_conn(autocommit=True) as cur:
cur.execute(Queries.FLUSH.format(table=self._full_tbl))
with self._get_conn() as conn:
conn.execute(Queries.FLUSH.format(table=self._full_tbl))
self.failing_rate = None

return None

def count(self) -> Union[int, Awaitable[int]]:
"""Count number of items in the bucket"""
count = 0
with self._get_conn() as cur:
cur.execute(Queries.COUNT.format(table=self._full_tbl))
result = cur.fetchone()
with self._get_conn() as conn:
conn = conn.execute(Queries.COUNT.format(table=self._full_tbl))
result = conn.fetchone()
assert result
count = int(result[0])

Expand All @@ -158,9 +152,9 @@ def peek(self, index: int) -> Union[Optional[RateItem], Awaitable[Optional[RateI
"""
item = None

with self._get_conn() as cur:
cur.execute(Queries.PEEK.format(table=self._full_tbl, offset=index))
result = cur.fetchone()
with self._get_conn() as conn:
conn = conn.execute(Queries.PEEK.format(table=self._full_tbl, offset=index))
result = conn.fetchone()
if result:
name, weight, timestamp = result[0], int(result[1]), int(result[2])
item = RateItem(name=name, weight=weight, timestamp=timestamp)
Expand Down
8 changes: 3 additions & 5 deletions pyrate_limiter/clocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .utils import dedicated_sqlite_clock_connection

if TYPE_CHECKING:
from psycopg2.pool import AbstractConnectionPool
from psycopg_pool import ConnectionPool


class MonotonicClock(AbstractClock):
Expand Down Expand Up @@ -57,13 +57,13 @@ def now(self) -> int:
class PostgresClock(AbstractClock):
"""Get timestamp using Postgres as remote clock backend"""

def __init__(self, pool: 'AbstractConnectionPool'):
def __init__(self, pool: 'ConnectionPool'):
self.pool = pool

def now(self) -> int:
value = 0

with self.pool._getconn() as conn:
with self.pool.connection() as conn:
with conn.cursor() as cur:
cur.execute("SELECT EXTRACT(epoch FROM current_timestamp) * 1000")
result = cur.fetchone()
Expand All @@ -73,6 +73,4 @@ def now(self) -> int:

value = int(result[0])

self.pool._putconn(conn)

return value
Loading
Loading