Skip to content

Commit

Permalink
(Redis Cluster) - Fixes for using redis cluster + pipeline (BerriAI#8442
Browse files Browse the repository at this point in the history
)

* update RedisCluster creation

* update RedisClusterCache

* add redis ClusterCache

* update async_set_cache_pipeline

* cleanup redis cluster usage

* fix redis pipeline

* test_init_async_client_returns_same_instance

* fix redis cluster

* update mypy_path

* fix init_redis_cluster

* remove stub

* test redis commit

* ClusterPipeline

* fix import

* RedisCluster import

* fix redis cluster

* Potential fix for code scanning alert no. 2129: Clear-text logging of sensitive information

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>

* fix naming of redis cluster integration

* test_redis_caching_ttl_pipeline

* fix async_set_cache_pipeline

---------

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and abhijitherekar committed Feb 20, 2025
1 parent de354ab commit 8ac6b97
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 27 deletions.
6 changes: 4 additions & 2 deletions litellm/_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
)

verbose_logger.debug(
"init_redis_cluster: startup nodes: ", redis_kwargs["startup_nodes"]
"init_redis_cluster: startup nodes are being initialized."
)
from redis.cluster import ClusterNode

Expand Down Expand Up @@ -266,7 +266,9 @@ def get_redis_client(**env_overrides):
return redis.Redis(**redis_kwargs)


def get_redis_async_client(**env_overrides) -> async_redis.Redis:
def get_redis_async_client(
**env_overrides,
) -> async_redis.Redis:
redis_kwargs = _get_redis_client_logic(**env_overrides)
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
args = _get_redis_url_kwargs(client=async_redis.Redis.from_url)
Expand Down
1 change: 1 addition & 0 deletions litellm/caching/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
from .in_memory_cache import InMemoryCache
from .qdrant_semantic_cache import QdrantSemanticCache
from .redis_cache import RedisCache
from .redis_cluster_cache import RedisClusterCache
from .redis_semantic_cache import RedisSemanticCache
from .s3_cache import S3Cache
26 changes: 18 additions & 8 deletions litellm/caching/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .in_memory_cache import InMemoryCache
from .qdrant_semantic_cache import QdrantSemanticCache
from .redis_cache import RedisCache
from .redis_cluster_cache import RedisClusterCache
from .redis_semantic_cache import RedisSemanticCache
from .s3_cache import S3Cache

Expand Down Expand Up @@ -158,14 +159,23 @@ def __init__(
None. Cache is set as a litellm param
"""
if type == LiteLLMCacheType.REDIS:
self.cache: BaseCache = RedisCache(
host=host,
port=port,
password=password,
redis_flush_size=redis_flush_size,
startup_nodes=redis_startup_nodes,
**kwargs,
)
if redis_startup_nodes:
self.cache: BaseCache = RedisClusterCache(
host=host,
port=port,
password=password,
redis_flush_size=redis_flush_size,
startup_nodes=redis_startup_nodes,
**kwargs,
)
else:
self.cache = RedisCache(
host=host,
port=port,
password=password,
redis_flush_size=redis_flush_size,
**kwargs,
)
elif type == LiteLLMCacheType.REDIS_SEMANTIC:
self.cache = RedisSemanticCache(
host=host,
Expand Down
49 changes: 35 additions & 14 deletions litellm/caching/redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import json
import time
from datetime import timedelta
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union

import litellm
from litellm._logging import print_verbose, verbose_logger
Expand All @@ -26,15 +26,20 @@

if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from redis.asyncio import Redis
from redis.asyncio import Redis, RedisCluster
from redis.asyncio.client import Pipeline
from redis.asyncio.cluster import ClusterPipeline

pipeline = Pipeline
cluster_pipeline = ClusterPipeline
async_redis_client = Redis
async_redis_cluster_client = RedisCluster
Span = _Span
else:
pipeline = Any
cluster_pipeline = Any
async_redis_client = Any
async_redis_cluster_client = Any
Span = Any


Expand Down Expand Up @@ -122,7 +127,9 @@ def __init__(
else:
super().__init__() # defaults to 60s

def init_async_client(self):
def init_async_client(
self,
) -> Union[async_redis_client, async_redis_cluster_client]:
from .._redis import get_redis_async_client

return get_redis_async_client(
Expand Down Expand Up @@ -345,8 +352,14 @@ async def async_set_cache(self, key, value, **kwargs):
)

async def _pipeline_helper(
self, pipe: pipeline, cache_list: List[Tuple[Any, Any]], ttl: Optional[float]
self,
pipe: Union[pipeline, cluster_pipeline],
cache_list: List[Tuple[Any, Any]],
ttl: Optional[float],
) -> List:
"""
Helper function for executing a pipeline of set operations on Redis
"""
ttl = self.get_ttl(ttl=ttl)
# Iterate through each key-value pair in the cache_list and set them in the pipeline.
for cache_key, cache_value in cache_list:
Expand All @@ -359,7 +372,11 @@ async def _pipeline_helper(
_td: Optional[timedelta] = None
if ttl is not None:
_td = timedelta(seconds=ttl)
pipe.set(cache_key, json_cache_value, ex=_td)
pipe.set( # type: ignore
name=cache_key,
value=json_cache_value,
ex=_td,
)
# Execute the pipeline and return the results.
results = await pipe.execute()
return results
Expand All @@ -373,9 +390,8 @@ async def async_set_cache_pipeline(
# don't waste a network request if there's nothing to set
if len(cache_list) == 0:
return
from redis.asyncio import Redis

_redis_client: Redis = self.init_async_client() # type: ignore
_redis_client = self.init_async_client()
start_time = time.time()

print_verbose(
Expand All @@ -384,7 +400,7 @@ async def async_set_cache_pipeline(
cache_value: Any = None
try:
async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe:
async with redis_client.pipeline(transaction=False) as pipe:
results = await self._pipeline_helper(pipe, cache_list, ttl)

print_verbose(f"pipeline results: {results}")
Expand Down Expand Up @@ -730,7 +746,8 @@ async def async_batch_get_cache(
"""
Use Redis for bulk read operations
"""
_redis_client = await self.init_async_client()
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `mget`
_redis_client: Any = self.init_async_client()
key_value_dict = {}
start_time = time.time()
try:
Expand Down Expand Up @@ -822,7 +839,8 @@ def sync_ping(self) -> bool:
raise e

async def ping(self) -> bool:
_redis_client = self.init_async_client()
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ping`
_redis_client: Any = self.init_async_client()
start_time = time.time()
async with _redis_client as redis_client:
print_verbose("Pinging Async Redis Cache")
Expand Down Expand Up @@ -858,7 +876,8 @@ async def ping(self) -> bool:
raise e

async def delete_cache_keys(self, keys):
_redis_client = self.init_async_client()
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
_redis_client: Any = self.init_async_client()
# keys is a list, unpack it so it gets passed as individual elements to delete
async with _redis_client as redis_client:
await redis_client.delete(*keys)
Expand All @@ -881,7 +900,8 @@ async def disconnect(self):
await self.async_redis_conn_pool.disconnect(inuse_connections=True)

async def async_delete_cache(self, key: str):
_redis_client = self.init_async_client()
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `delete`
_redis_client: Any = self.init_async_client()
# keys is str
async with _redis_client as redis_client:
await redis_client.delete(key)
Expand Down Expand Up @@ -936,7 +956,7 @@ async def async_increment_pipeline(

try:
async with _redis_client as redis_client:
async with redis_client.pipeline(transaction=True) as pipe:
async with redis_client.pipeline(transaction=False) as pipe:
results = await self._pipeline_increment_helper(
pipe, increment_list
)
Expand Down Expand Up @@ -991,7 +1011,8 @@ async def async_get_ttl(self, key: str) -> Optional[int]:
Redis ref: https://redis.io/docs/latest/commands/ttl/
"""
try:
_redis_client = await self.init_async_client()
# typed as Any, redis python lib has incomplete type stubs for RedisCluster and does not include `ttl`
_redis_client: Any = self.init_async_client()
async with _redis_client as redis_client:
ttl = await redis_client.ttl(key)
if ttl <= -1: # -1 means the key does not exist, -2 key does not exist
Expand Down
44 changes: 44 additions & 0 deletions litellm/caching/redis_cluster_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
Redis Cluster Cache implementation
Key differences:
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
"""

from typing import TYPE_CHECKING, Any, Optional

from litellm.caching.redis_cache import RedisCache

if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from redis.asyncio import Redis, RedisCluster
from redis.asyncio.client import Pipeline

pipeline = Pipeline
async_redis_client = Redis
Span = _Span
else:
pipeline = Any
async_redis_client = Any
Span = Any


class RedisClusterCache(RedisCache):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.redis_cluster_client: Optional[RedisCluster] = None

def init_async_client(self):
from redis.asyncio import RedisCluster

from .._redis import get_redis_async_client

if self.redis_cluster_client:
return self.redis_cluster_client

_redis_client = get_redis_async_client(
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
)
if isinstance(_redis_client, RedisCluster):
self.redis_cluster_client = _redis_client
return _redis_client
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[mypy]
warn_return_any = False
ignore_missing_imports = True
mypy_path = litellm/stubs

[mypy-google.*]
ignore_missing_imports = True
12 changes: 9 additions & 3 deletions tests/local_testing/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import litellm
from litellm import aembedding, completion, embedding
from litellm.caching.caching import Cache

from redis.asyncio import RedisCluster
from litellm.caching.redis_cluster_cache import RedisClusterCache
from unittest.mock import AsyncMock, patch, MagicMock, call
import datetime
from datetime import timedelta
Expand Down Expand Up @@ -2328,8 +2329,12 @@ async def test_redis_caching_ttl_pipeline():
# Verify that the set method was called on the mock Redis instance
mock_set.assert_has_calls(
[
call.set("test_key1", '"test_value1"', ex=expected_timedelta),
call.set("test_key2", '"test_value2"', ex=expected_timedelta),
call.set(
name="test_key1", value='"test_value1"', ex=expected_timedelta
),
call.set(
name="test_key2", value='"test_value2"', ex=expected_timedelta
),
]
)

Expand Down Expand Up @@ -2388,6 +2393,7 @@ async def test_redis_increment_pipeline():
from litellm.caching.redis_cache import RedisCache

litellm.set_verbose = True
litellm._turn_on_debug()
redis_cache = RedisCache(
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
Expand Down

0 comments on commit 8ac6b97

Please sign in to comment.