Skip to content

Commit

Permalink
support custom argument reducers (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
quintenroets authored Jan 19, 2025
1 parent e2f63eb commit 4e22021
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.2.1
current_version = 0.2.2
commit = False
allow_dirty = True

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "persistent-function-cache"
version = "0.2.1"
version = "0.2.2"
description = "Persistent cache for expensive functions"
authors = [{name = "Quinten Roets", email = "qdr2104@columbia.edu"}]
license = {text = "MIT"}
Expand Down
12 changes: 11 additions & 1 deletion src/persistent_cache/main/cache_slot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Callable, Iterable, Iterator
from dataclasses import dataclass
from functools import cached_property
from inspect import BoundArguments
from typing import Any

from persistent_cache.models import Path
Expand All @@ -18,6 +19,7 @@ class CacheSlot:
kwargs: dict[str, Any]
directory: Path
key_arguments: Iterable[str] | str | None
argument_reducers: dict[str, Callable[[Any], Any]] | None
extra_keys: Any
key_reducer: type[Reducer] | None
deep_learning: bool
Expand Down Expand Up @@ -60,17 +62,25 @@ def keys(self) -> Iterator[Any]:

@property
def argument_values(self) -> Iterator[Any]:
if self.key_arguments is None:
if self.key_arguments is None and self.argument_reducers is None:
yield from self.args
yield from self.kwargs.values()
else:
arguments = inspect.signature(self.function).bind(*self.args, **self.kwargs)
arguments.apply_defaults()
yield from self.extract_argument_values(arguments)

def extract_argument_values(self, arguments: BoundArguments) -> Iterator[Any]:
if self.key_arguments is not None:
if isinstance(self.key_arguments, str):
yield arguments.arguments.get(self.key_arguments)
else:
for name in self.key_arguments:
yield arguments.arguments.get(name)
if self.argument_reducers is not None:
for argument_name, reducer in self.argument_reducers.items():
argument = arguments.arguments.get(argument_name)
yield reducer(argument)

@property
def reducer(self) -> type[Reducer]:
Expand Down
4 changes: 4 additions & 0 deletions src/persistent_cache/main/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def cache(
*,
cache_directory: Path = Path.cache,
cache_key_arguments: Iterable[str] | str | None = None,
argument_reducers: dict[str, Callable[[Any], Any]] | None = None,
extra_cache_keys: Iterable[Any] | None = None,
key_reducer: type[Reducer] = Reducer,
deep_learning: bool = False,
Expand All @@ -29,6 +30,7 @@ def cache(
*,
cache_directory: Path = Path.cache,
cache_key_arguments: Iterable[str] | str | None = None,
argument_reducers: dict[str, Callable[[Any], Any]] | None = None,
extra_cache_keys: Any = None,
key_reducer: type[Reducer] = Reducer,
deep_learning: bool = False,
Expand All @@ -41,6 +43,7 @@ def cache( # noqa: PLR0913
*,
cache_directory: Path = Path.cache,
cache_key_arguments: Iterable[str] | str | None = None,
argument_reducers: dict[str, Callable[[Any], Any]] | None = None,
extra_cache_keys: Any = None,
key_reducer: type[Reducer] | None = None,
deep_learning: bool = False,
Expand Down Expand Up @@ -68,6 +71,7 @@ def wrapped_function(*args: Any, **kwargs: Any) -> Any:
kwargs,
cache_directory,
cache_key_arguments,
argument_reducers,
extra_cache_keys,
key_reducer,
deep_learning,
Expand Down
11 changes: 11 additions & 0 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,14 @@ def test_cache_with_reducer(
key_reducer=Reducer,
)
cached_function("test")


@pytest.mark.parametrize("cache_decorator", caches)
def test_cache_with_argument_reducer(
cache_decorator: Callable[..., Any],
) -> None:
cached_function = cache_decorator(
calculate_with_name,
argument_reducers={"value": lambda x: x.strip()},
)
cached_function("test")

0 comments on commit 4e22021

Please sign in to comment.