Skip to content

Commit

Permalink
update/extend the requires wrapper (#146)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] authored Jul 19, 2023
1 parent b8a9e90 commit 84cafbc
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

-
- Updated/Extended the `requires` wrapper ([#146](https://github.com/Lightning-AI/utilities/pull/146))

### Fixed

Expand Down
19 changes: 15 additions & 4 deletions src/lightning_utilities/core/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import functools
import importlib
import os
import warnings
from functools import lru_cache
from importlib.util import find_spec
Expand Down Expand Up @@ -266,9 +267,13 @@ def lazy_import(module_name: str, callback: Optional[Callable] = None) -> LazyMo
return LazyModule(module_name, callback=callback)


def requires(*module_path: str, raise_exception: bool = True) -> Callable[[Callable[P, T]], Callable[P, T]]:
def requires(*module_path_version: str, raise_exception: bool = True) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Wrap early import failure with some nice exception message.
Args:
module_path_version: pythin package path (e.g. `torch.cuda`) or pip like requiremsnt (e.g. `torch>=2.0.0`)
raise_exception: how strict the check shall be if exit the code or just warn user
Example:
>>> @requires("libpath", raise_exception=bool(int(os.getenv("LIGHTING_TESTING", "0"))))
... def my_cwd():
Expand All @@ -284,11 +289,17 @@ def requires(*module_path: str, raise_exception: bool = True) -> Callable[[Calla
"""

def decorator(func: Callable[P, T]) -> Callable[P, T]:
reqs = [
ModuleAvailableCache(mod_ver) if "." in mod_ver else RequirementCache(mod_ver)
for mod_ver in module_path_version
]
available = all(map(bool, reqs))

@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
unavailable_modules = [module for module in module_path if not module_available(module)]
if any(unavailable_modules):
msg = f"Required dependencies not available. Please run `pip install {' '.join(unavailable_modules)}`"
if not available:
missing = os.linesep.join([repr(r) for r in reqs if not bool(r)])
msg = f"Required dependencies not available: \n{missing}"
if raise_exception:
raise ModuleNotFoundError(msg)
warnings.warn(msg, stacklevel=2)
Expand Down
9 changes: 5 additions & 4 deletions tests/unittests/core/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def callback_fcn():
assert os.getcwd()


@requires("torch")
@requires("torch.unknown.subpackage")
def my_torch_func(i: int) -> int:
import torch # noqa

Expand All @@ -105,7 +105,8 @@ def my_torch_func(i: int) -> int:

def test_torch_func_raised():
with pytest.raises(
ModuleNotFoundError, match="Required dependencies not available. Please run `pip install torch`"
ModuleNotFoundError,
match="Required dependencies not available: \nModule not found: 'torch.unknown.subpackage'. ",
):
my_torch_func(42)

Expand All @@ -122,7 +123,7 @@ def test_rand_func_passed():


class MyTorchClass:
@requires("torch", "random")
@requires("torch>99.0", "random")
def __init__(self):
from random import randint

Expand All @@ -133,7 +134,7 @@ def __init__(self):

def test_torch_class_raised():
with pytest.raises(
ModuleNotFoundError, match="Required dependencies not available. Please run `pip install torch`"
ModuleNotFoundError, match="Required dependencies not available: \nModule not found: 'torch>99.0'."
):
MyTorchClass()

Expand Down

0 comments on commit 84cafbc

Please sign in to comment.