-
Notifications
You must be signed in to change notification settings - Fork 198
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
27 changed files
with
1,105 additions
and
193 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Any, Hashable | ||
|
||
|
||
class MarqoAbstractCache(ABC): | ||
"""Abstract class for Marqo cache implementations, MUST be thread-safe. | ||
The acceptable key must be Hashable. | ||
The acceptable value is Any. | ||
When a cache is full, self.__setitem__() calls self.popitem() repeatedly | ||
until there is enough room for the item to be added. | ||
The cache MUST be a thread-safe implementation. | ||
""" | ||
|
||
@abstractmethod | ||
def get(self, key: Hashable, default=None) -> Any: | ||
"""Return the value for key if key is in the cache, else default. | ||
Args: | ||
key: __description__ | ||
default: __description__ | ||
Returns: | ||
__description__ | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def set(self, key: Hashable, value: Any) -> None: | ||
"""Set the value for key in the cache. | ||
Args: | ||
key: __description__ | ||
value: __description__ | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def popitem(self) -> None: | ||
"""Remove an item from the cache according to the defined eviction policy. The item is not returned. | ||
Raises: | ||
Exception: If there is an issue in removing an item (e.g., cache is already empty). | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def __contains__(self, key: Hashable) -> bool: | ||
"""Return True if the key is in the cache, else False. | ||
Args: | ||
key: __description__ | ||
Returns: | ||
If the key is in the cache, return True. Otherwise, return False. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def __setitem__(self, key: Hashable, value: Any) -> None: | ||
"""Set the value for key in the cache if the cache is not full, else popitem() until there is enough room. | ||
Args: | ||
key: __description__ | ||
value: __description__ | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def __getitem__(self, key: Hashable) -> Any: | ||
"""Return the value for key if key is in the cache, else raise KeyError. | ||
Args: | ||
key: __description__ | ||
Raises: | ||
KeyError: If the key is not in the cache. | ||
Returns: | ||
__description__ | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def __len__(self) -> int: | ||
"""Return the number of items in the cache.""" | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def maxsize(self) -> int: | ||
"""Return the maximum size of the cache.""" | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def currsize(self) -> int: | ||
"""Return the current size of the cache.""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from enum import Enum | ||
|
||
|
||
class MarqoCacheType(str, Enum): | ||
LRU = "LRU" | ||
LFU = "LFU" |
98 changes: 98 additions & 0 deletions
98
src/marqo/inference/inference_cache/marqo_inference_cache.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from typing import List, Optional, Union, Tuple | ||
|
||
from marqo.api.exceptions import EnvVarError | ||
from marqo.inference.inference_cache.abstract_cache import MarqoAbstractCache | ||
from marqo.inference.inference_cache.enums import MarqoCacheType | ||
from marqo.inference.inference_cache.marqo_lfu_cache import MarqoLFUCache | ||
from marqo.inference.inference_cache.marqo_lru_cache import MarqoLRUCache | ||
|
||
|
||
class MarqoInferenceCache: | ||
"""MarqoInferenceCache is a thread-safe cache implementation for storing embeddings. | ||
The key is a string consisting of model_cache_key and content to identify the cache. | ||
The value is a list of floats representing the embeddings. | ||
""" | ||
|
||
_CACHE_TYPES_MAPPING = { | ||
MarqoCacheType.LRU: MarqoLRUCache, | ||
MarqoCacheType.LFU: MarqoLFUCache, | ||
} | ||
|
||
def __init__(self, cache_size: int = 0, cache_type: Union[None, str, MarqoCacheType] = MarqoCacheType.LRU): | ||
self._cache = self._build_cache(cache_size, cache_type) | ||
|
||
def _build_cache(self, cache_size: int, cache_type: MarqoCacheType) -> Optional[MarqoAbstractCache]: | ||
"""Return a cache instance based on the cache type and size. | ||
Args: | ||
cache_size: The maximum size of the cache. | ||
cache_type: The type of the cache. | ||
Returns: | ||
A cache instance based on the cache type and size. None if the cache_size is 0. | ||
Raises: | ||
EnvVarError: If the cache size or type is invalid. | ||
""" | ||
if not isinstance(cache_size, int) or cache_size < 0: | ||
raise EnvVarError(f"Invalid cache size: {cache_size}. " | ||
f"Must be a non-negative integer. " | ||
f"Please set the 'MARQO_INFERENCE_CACHE_SIZE' " | ||
f"environment variable to a non-negative integer.") | ||
elif cache_size == 0: | ||
return None | ||
elif cache_size > 0: | ||
if cache_type not in self._CACHE_TYPES_MAPPING: | ||
raise EnvVarError(f"Invalid cache type: {cache_type}. " | ||
f"Must be one of {self._CACHE_TYPES_MAPPING.keys()}." | ||
f"Please set the 'MARQO_INFERENCE_CACHE_TYPE' " | ||
f"environment variable to one of the valid cache types.") | ||
return self._CACHE_TYPES_MAPPING[cache_type](maxsize=cache_size) | ||
else: | ||
ValueError(f"Invalid cache size: {cache_size}.") | ||
|
||
def get(self, model_cache_key: str, content: str, default=None) -> Optional[List[float]]: | ||
key = self._generate_key(model_cache_key, content) | ||
return self._cache.get(key, default) | ||
|
||
def set(self, model_cache_key: str, content: str, value: List[float]) -> None: | ||
self.__setitem__(model_cache_key, content, value) | ||
|
||
def __getitem__(self, model_cache_key: str, content: str, key: str) -> List[float]: | ||
key = self._generate_key(model_cache_key, content) | ||
return self._cache[key] | ||
|
||
def __setitem__(self, model_cache_key: str, content: str, value: List[float]) -> None: | ||
key = self._generate_key(model_cache_key, content) | ||
self._cache[key] = value | ||
|
||
def __contains__(self, item: Tuple) -> bool: | ||
if len(item) != 2: | ||
raise ValueError("MarqoInferenceCache received an unsupported input for 'in' operation. " | ||
"Expected input is a tuple with 'model-cache-key' and 'content'. " | ||
"E.g., ('my-model-cache-key', 'content'). ") | ||
model_cache_key, content = item | ||
key = self._generate_key(model_cache_key, content) | ||
return key in self._cache | ||
|
||
def _generate_key(self, model_cache_key: str, content: str) -> str: | ||
if not isinstance(model_cache_key, str): | ||
raise TypeError(f"model_cache_key must be a string, not {type(model_cache_key)}") | ||
if not isinstance(content, str): | ||
raise TypeError(f"content must be a string, not {type(content)}") | ||
return f"{model_cache_key}||{content}" | ||
|
||
def is_enabled(self) -> bool: | ||
"""Return True if the cache is enabled, else False.""" | ||
return self._cache is not None | ||
|
||
@property | ||
def maxsize(self) -> int: | ||
"""Return the maximum size of the cache.""" | ||
return self._cache.maxsize | ||
|
||
@property | ||
def currsize(self) -> int: | ||
"""Return the current size of the cache.""" | ||
return self._cache.currsize |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from typing import Hashable, Any | ||
|
||
from cachetools import LFUCache | ||
from readerwriterlock import rwlock | ||
|
||
from marqo.inference.inference_cache.abstract_cache import MarqoAbstractCache | ||
|
||
|
||
class MarqoLFUCache(MarqoAbstractCache): | ||
"""A thread-safe Least Frequently Used (LFU) cache implementation with a read-write lock. | ||
This class is currently implemented using cachetools.LFUCache, but it can be replaced with any other LFU cache. | ||
""" | ||
|
||
def __init__(self, maxsize: int): | ||
self._cache = LFUCache(maxsize=maxsize) | ||
self.lock = rwlock.RWLockFair() | ||
|
||
def get(self, key: Hashable, default=None) -> Any: | ||
with self.lock.gen_rlock(): | ||
return self._cache.get(key, default) | ||
|
||
def set(self, key: Hashable, value: Any): | ||
"""The lock is implemented in the __setitem__ method to avoid double locking when setting a value.""" | ||
self.__setitem__(key, value) | ||
|
||
def __contains__(self, key: Hashable) -> bool: | ||
with self.lock.gen_rlock(): | ||
return key in self._cache | ||
|
||
def __setitem__(self, key: Hashable, value: Any): | ||
with self.lock.gen_wlock(): | ||
self._cache[key] = value | ||
|
||
def __getitem__(self, key: Hashable) -> Any: | ||
with self.lock.gen_rlock(): | ||
return self._cache[key] | ||
|
||
def __len__(self) -> int: | ||
return len(self._cache) | ||
|
||
def popitem(self) -> None: | ||
with self.lock.gen_wlock(): | ||
self._cache.popitem() | ||
|
||
@property | ||
def maxsize(self) -> int: | ||
"""Return the maximum size of the cache.""" | ||
return int(self._cache.maxsize) | ||
|
||
@property | ||
def currsize(self) -> int: | ||
"""Return the current size of the cache.""" | ||
return int(self._cache.currsize) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from typing import Hashable, Any | ||
|
||
from cachetools import LRUCache | ||
from readerwriterlock import rwlock | ||
|
||
from marqo.inference.inference_cache.abstract_cache import MarqoAbstractCache | ||
|
||
|
||
class MarqoLRUCache(MarqoAbstractCache): | ||
"""A thread-safe Least Recently Used (LRU) cache implementation with a read-write lock. | ||
This class is currently implemented using cachetools.LRUCache, but it can be replaced with any other LRU cache. | ||
""" | ||
|
||
def __init__(self, maxsize: int): | ||
self._cache = LRUCache(maxsize=maxsize) | ||
self.lock = rwlock.RWLockFair() | ||
|
||
def get(self, key: Hashable, default=None) -> Any: | ||
with self.lock.gen_rlock(): | ||
return self._cache.get(key, default) | ||
|
||
def set(self, key: Hashable, value: Any): | ||
"""The lock is implemented in the __setitem__ method to avoid double locking when setting a value.""" | ||
self.__setitem__(key, value) | ||
|
||
def __contains__(self, key: Hashable) -> bool: | ||
with self.lock.gen_rlock(): | ||
return key in self._cache | ||
|
||
def __setitem__(self, key: Hashable, value: Any): | ||
with self.lock.gen_wlock(): | ||
self._cache[key] = value | ||
|
||
def __getitem__(self, key: Hashable) -> Any: | ||
with self.lock.gen_rlock(): | ||
return self._cache[key] | ||
|
||
def __len__(self) -> int: | ||
return len(self._cache) | ||
|
||
def popitem(self) -> None: | ||
with self.lock.gen_wlock(): | ||
self._cache.popitem() | ||
|
||
@property | ||
def maxsize(self) -> int: | ||
"""Return the maximum size of the cache.""" | ||
return int(self._cache.maxsize) | ||
|
||
@property | ||
def currsize(self) -> int: | ||
"""Return the current size of the cache.""" | ||
return int(self._cache.currsize) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.