Skip to content

Commit

Permalink
Inference Cache (#802)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanliAlex authored Apr 23, 2024
1 parent 97370d4 commit 460ac27
Show file tree
Hide file tree
Showing 27 changed files with 1,105 additions and 193 deletions.
2 changes: 2 additions & 0 deletions requirements.dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ opencv-python-headless==4.6.0.66
psutil==5.9.4
multilingual-clip==1.0.10
redis==4.4.2
readerwriterlock==1.0.9
cachetools==5.3.1

# pin specific packages (last working 0.0.19 image)
# to fix ARM64 build scikit-learn error
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ pydantic==1.10.11
httpx==0.25.0
semver==3.0.2
memory-profiler==0.61.0
cachetools==5.3.1
pynvml==11.5.0 # For cuda utilization
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
"opencv-python-headless",
"psutil",
"multilingual_clip",
"readerwriterlock==1.0.9",
"cachetools==5.3.1"
"pynvml==11.5.0"
],
name="marqo-engine",
Expand Down
Empty file.
98 changes: 98 additions & 0 deletions src/marqo/inference/inference_cache/abstract_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from abc import ABC, abstractmethod
from typing import Any, Hashable


class MarqoAbstractCache(ABC):
"""Abstract class for Marqo cache implementations, MUST be thread-safe.
The acceptable key must be Hashable.
The acceptable value is Any.
When a cache is full, self.__setitem__() calls self.popitem() repeatedly
until there is enough room for the item to be added.
The cache MUST be a thread-safe implementation.
"""

@abstractmethod
def get(self, key: Hashable, default=None) -> Any:
"""Return the value for key if key is in the cache, else default.
Args:
key: __description__
default: __description__
Returns:
__description__
"""
pass

@abstractmethod
def set(self, key: Hashable, value: Any) -> None:
"""Set the value for key in the cache.
Args:
key: __description__
value: __description__
"""
pass

@abstractmethod
def popitem(self) -> None:
"""Remove an item from the cache according to the defined eviction policy. The item is not returned.
Raises:
Exception: If there is an issue in removing an item (e.g., cache is already empty).
"""
pass

@abstractmethod
def __contains__(self, key: Hashable) -> bool:
"""Return True if the key is in the cache, else False.
Args:
key: __description__
Returns:
If the key is in the cache, return True. Otherwise, return False.
"""
pass

@abstractmethod
def __setitem__(self, key: Hashable, value: Any) -> None:
"""Set the value for key in the cache if the cache is not full, else popitem() until there is enough room.
Args:
key: __description__
value: __description__
"""
pass

@abstractmethod
def __getitem__(self, key: Hashable) -> Any:
"""Return the value for key if key is in the cache, else raise KeyError.
Args:
key: __description__
Raises:
KeyError: If the key is not in the cache.
Returns:
__description__
"""
pass

@abstractmethod
def __len__(self) -> int:
"""Return the number of items in the cache."""
pass

@property
@abstractmethod
def maxsize(self) -> int:
"""Return the maximum size of the cache."""
pass

@property
@abstractmethod
def currsize(self) -> int:
"""Return the current size of the cache."""
pass
6 changes: 6 additions & 0 deletions src/marqo/inference/inference_cache/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from enum import Enum


class MarqoCacheType(str, Enum):
LRU = "LRU"
LFU = "LFU"
98 changes: 98 additions & 0 deletions src/marqo/inference/inference_cache/marqo_inference_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from typing import List, Optional, Union, Tuple

from marqo.api.exceptions import EnvVarError
from marqo.inference.inference_cache.abstract_cache import MarqoAbstractCache
from marqo.inference.inference_cache.enums import MarqoCacheType
from marqo.inference.inference_cache.marqo_lfu_cache import MarqoLFUCache
from marqo.inference.inference_cache.marqo_lru_cache import MarqoLRUCache


class MarqoInferenceCache:
"""MarqoInferenceCache is a thread-safe cache implementation for storing embeddings.
The key is a string consisting of model_cache_key and content to identify the cache.
The value is a list of floats representing the embeddings.
"""

_CACHE_TYPES_MAPPING = {
MarqoCacheType.LRU: MarqoLRUCache,
MarqoCacheType.LFU: MarqoLFUCache,
}

def __init__(self, cache_size: int = 0, cache_type: Union[None, str, MarqoCacheType] = MarqoCacheType.LRU):
self._cache = self._build_cache(cache_size, cache_type)

def _build_cache(self, cache_size: int, cache_type: MarqoCacheType) -> Optional[MarqoAbstractCache]:
"""Return a cache instance based on the cache type and size.
Args:
cache_size: The maximum size of the cache.
cache_type: The type of the cache.
Returns:
A cache instance based on the cache type and size. None if the cache_size is 0.
Raises:
EnvVarError: If the cache size or type is invalid.
"""
if not isinstance(cache_size, int) or cache_size < 0:
raise EnvVarError(f"Invalid cache size: {cache_size}. "
f"Must be a non-negative integer. "
f"Please set the 'MARQO_INFERENCE_CACHE_SIZE' "
f"environment variable to a non-negative integer.")
elif cache_size == 0:
return None
elif cache_size > 0:
if cache_type not in self._CACHE_TYPES_MAPPING:
raise EnvVarError(f"Invalid cache type: {cache_type}. "
f"Must be one of {self._CACHE_TYPES_MAPPING.keys()}."
f"Please set the 'MARQO_INFERENCE_CACHE_TYPE' "
f"environment variable to one of the valid cache types.")
return self._CACHE_TYPES_MAPPING[cache_type](maxsize=cache_size)
else:
ValueError(f"Invalid cache size: {cache_size}.")

def get(self, model_cache_key: str, content: str, default=None) -> Optional[List[float]]:
key = self._generate_key(model_cache_key, content)
return self._cache.get(key, default)

def set(self, model_cache_key: str, content: str, value: List[float]) -> None:
self.__setitem__(model_cache_key, content, value)

def __getitem__(self, model_cache_key: str, content: str, key: str) -> List[float]:
key = self._generate_key(model_cache_key, content)
return self._cache[key]

def __setitem__(self, model_cache_key: str, content: str, value: List[float]) -> None:
key = self._generate_key(model_cache_key, content)
self._cache[key] = value

def __contains__(self, item: Tuple) -> bool:
if len(item) != 2:
raise ValueError("MarqoInferenceCache received an unsupported input for 'in' operation. "
"Expected input is a tuple with 'model-cache-key' and 'content'. "
"E.g., ('my-model-cache-key', 'content'). ")
model_cache_key, content = item
key = self._generate_key(model_cache_key, content)
return key in self._cache

def _generate_key(self, model_cache_key: str, content: str) -> str:
if not isinstance(model_cache_key, str):
raise TypeError(f"model_cache_key must be a string, not {type(model_cache_key)}")
if not isinstance(content, str):
raise TypeError(f"content must be a string, not {type(content)}")
return f"{model_cache_key}||{content}"

def is_enabled(self) -> bool:
"""Return True if the cache is enabled, else False."""
return self._cache is not None

@property
def maxsize(self) -> int:
"""Return the maximum size of the cache."""
return self._cache.maxsize

@property
def currsize(self) -> int:
"""Return the current size of the cache."""
return self._cache.currsize
54 changes: 54 additions & 0 deletions src/marqo/inference/inference_cache/marqo_lfu_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Hashable, Any

from cachetools import LFUCache
from readerwriterlock import rwlock

from marqo.inference.inference_cache.abstract_cache import MarqoAbstractCache


class MarqoLFUCache(MarqoAbstractCache):
"""A thread-safe Least Frequently Used (LFU) cache implementation with a read-write lock.
This class is currently implemented using cachetools.LFUCache, but it can be replaced with any other LFU cache.
"""

def __init__(self, maxsize: int):
self._cache = LFUCache(maxsize=maxsize)
self.lock = rwlock.RWLockFair()

def get(self, key: Hashable, default=None) -> Any:
with self.lock.gen_rlock():
return self._cache.get(key, default)

def set(self, key: Hashable, value: Any):
"""The lock is implemented in the __setitem__ method to avoid double locking when setting a value."""
self.__setitem__(key, value)

def __contains__(self, key: Hashable) -> bool:
with self.lock.gen_rlock():
return key in self._cache

def __setitem__(self, key: Hashable, value: Any):
with self.lock.gen_wlock():
self._cache[key] = value

def __getitem__(self, key: Hashable) -> Any:
with self.lock.gen_rlock():
return self._cache[key]

def __len__(self) -> int:
return len(self._cache)

def popitem(self) -> None:
with self.lock.gen_wlock():
self._cache.popitem()

@property
def maxsize(self) -> int:
"""Return the maximum size of the cache."""
return int(self._cache.maxsize)

@property
def currsize(self) -> int:
"""Return the current size of the cache."""
return int(self._cache.currsize)
54 changes: 54 additions & 0 deletions src/marqo/inference/inference_cache/marqo_lru_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import Hashable, Any

from cachetools import LRUCache
from readerwriterlock import rwlock

from marqo.inference.inference_cache.abstract_cache import MarqoAbstractCache


class MarqoLRUCache(MarqoAbstractCache):
"""A thread-safe Least Recently Used (LRU) cache implementation with a read-write lock.
This class is currently implemented using cachetools.LRUCache, but it can be replaced with any other LRU cache.
"""

def __init__(self, maxsize: int):
self._cache = LRUCache(maxsize=maxsize)
self.lock = rwlock.RWLockFair()

def get(self, key: Hashable, default=None) -> Any:
with self.lock.gen_rlock():
return self._cache.get(key, default)

def set(self, key: Hashable, value: Any):
"""The lock is implemented in the __setitem__ method to avoid double locking when setting a value."""
self.__setitem__(key, value)

def __contains__(self, key: Hashable) -> bool:
with self.lock.gen_rlock():
return key in self._cache

def __setitem__(self, key: Hashable, value: Any):
with self.lock.gen_wlock():
self._cache[key] = value

def __getitem__(self, key: Hashable) -> Any:
with self.lock.gen_rlock():
return self._cache[key]

def __len__(self) -> int:
return len(self._cache)

def popitem(self) -> None:
with self.lock.gen_wlock():
self._cache.popitem()

@property
def maxsize(self) -> int:
"""Return the maximum size of the cache."""
return int(self._cache.maxsize)

@property
def currsize(self) -> int:
"""Return the current size of the cache."""
return int(self._cache.currsize)
9 changes: 4 additions & 5 deletions src/marqo/s2_inference/processing/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import numpy as np
import torch
import torchvision
from marqo.s2_inference.s2_inference import available_models,_create_model_cache_key
from marqo.s2_inference.s2_inference import get_logger
from marqo.s2_inference.s2_inference import get_available_models, _create_model_cache_key, get_logger
from marqo.s2_inference.types import Dict, List, Union, ImageType, Tuple, ndarray, Literal
from marqo.s2_inference.clip_utils import format_and_load_CLIP_image
from marqo.s2_inference.errors import ChunkerError
Expand Down Expand Up @@ -207,7 +206,7 @@ def _load_and_cache_model(self):
model_type = (self.model_name, self.device)
model_cache_key = _create_model_cache_key(self.model_name, self.device)

if model_cache_key not in available_models:
if model_cache_key not in get_available_models():
logger.info(f"loading model {model_type}")
if model_type[0] in self.allowed_model_types:
func = self.model_load_function
Expand All @@ -216,11 +215,11 @@ def _load_and_cache_model(self):

self.model, self.preprocess = func(self.model_name, self.device)

available_models[model_cache_key] = {AvailableModelsKey.model: (self.model, self.preprocess),
get_available_models()[model_cache_key] = {AvailableModelsKey.model: (self.model, self.preprocess),
AvailableModelsKey.most_recently_used_time : datetime.datetime.now()}

else:
self.model, self.preprocess = available_models[model_cache_key][AvailableModelsKey.model]
self.model, self.preprocess = get_available_models()[model_cache_key][AvailableModelsKey.model]

def _load_image(self, image):
self.image, self.image_pt, self.original_size = load_rcnn_image(image, size=self.size)
Expand Down
Loading

0 comments on commit 460ac27

Please sign in to comment.