Skip to content

Commit

Permalink
refactor(cache): Break out cache handler retrieval method.
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBavenstrand committed Jan 26, 2024
1 parent 454a861 commit aba9e41
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions mleko/cache/cache_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,20 @@ def extract_number(file_path: Path) -> int:
return tuple(output_data) if len(output_data) > 1 else output_data[0]
return None

def _get_handler(self, cache_handlers: CacheHandler | list[CacheHandler], index: int = 0) -> CacheHandler:
"""Gets the cache handler at the given index.
Args:
cache_handlers: A CacheHandler instance or a list of CacheHandler instances.
index: The index of the cache handler to get.
Returns:
Handler at the given index. If a single CacheHandler instance is provided, it will be returned.
"""
if isinstance(cache_handlers, list):
return cache_handlers[index]
return cache_handlers

def _save_to_cache(
self,
cache_key: str,
Expand All @@ -281,16 +295,14 @@ def _save_to_cache(
provided, each CacheHandler instance will be used for each cache file.
"""
if isinstance(output, Sequence):
for i in range(len(output)):
writer = cache_handlers[i].writer if isinstance(cache_handlers, list) else cache_handlers.writer
suffix = cache_handlers[i].suffix if isinstance(cache_handlers, list) else cache_handlers.suffix
cache_file_path = self._cache_directory / f"{cache_key}_{i}.{suffix}"
writer(cache_file_path, output[i])
for i, output_item in enumerate(output):
handler = self._get_handler(cache_handlers, i)
cache_file_path = self._cache_directory / f"{cache_key}_{i}.{handler.suffix}"
handler.writer(cache_file_path, output_item)
else:
writer = cache_handlers[0].writer if isinstance(cache_handlers, list) else cache_handlers.writer
suffix = cache_handlers[0].suffix if isinstance(cache_handlers, list) else cache_handlers.suffix
cache_file_path = self._cache_directory / f"{cache_key}.{suffix}"
writer(cache_file_path, output)
handler = self._get_handler(cache_handlers)
cache_file_path = self._cache_directory / f"{cache_key}.{handler.suffix}"
handler.writer(cache_file_path, output)

def _find_cache_type_name(self, cls: type) -> str | None:
"""Recursively searches the class hierarchy for the name of the class that inherits from `CacheMixin`.
Expand Down

0 comments on commit aba9e41

Please sign in to comment.