Skip to content

Commit

Permalink
switch all optionals to check_module_enabled
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-rakowski committed Nov 23, 2023
1 parent 2822602 commit f786229
Showing 1 changed file with 43 additions and 19 deletions.
62 changes: 43 additions & 19 deletions pylops/utils/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
]

import os
from importlib import import_module, util
from importlib import import_module # , util
from typing import Optional

# check package availability
# cupy_enabled = (
# util.find_spec("cupy") is not None and int(os.getenv("CUPY_PYLOPS", 1)) == 1
# )

# cusignal_enabled = (
# util.find_spec("cusignal") is not None and int(os.getenv("CUSIGNAL_PYLOPS", 1)) == 1
# )
# try:
# import_module("cupy")
# # if can succesfully import cupy, check envrionment
Expand All @@ -28,40 +30,62 @@
# cupy_enabled = False
# except Exception as e:
# raise UserWarning("Unexpceted Exception when importing cupy") from e
# devito_enabled = util.find_spec("devito") is not None
# numba_enabled = util.find_spec("numba") is not None
# pyfftw_enabled = util.find_spec("pyfftw") is not None
# pywt_enabled = util.find_spec("pywt") is not None
# skfmm_enabled = util.find_spec("skfmm") is not None
# spgl1_enabled = util.find_spec("spgl1") is not None
# sympy_enabled = util.find_spec("sympy") is not None
# torch_enabled = util.find_spec("torch") is not None


def check_module_enabled(
module: str,
envrionment_str: Optional[str] = None,
envrionment_val: Optional[str] = "1",
envrionment_val: Optional[int] = 1,
) -> bool:

"""
Checks whether a specific module can be imported in the current Python environment.
Args:
module (str): The name of the module to check import state for.
envrionment_str (Optional[str]): An optional environment variable name to check for. If provided,
the function will return True only if the environment variable is set to the specified value.
Defaults to None.
envrionment_val (Optional[str]): The value to compare the environment variable against. Defaults to "1".
Returns:
bool: True if the module is available, False otherwise.
"""
# try to import the module
try:
import_module(module)

# run envrionment check if needed
if envrionment_str is not None:
return os.environ[envrionment_str] == envrionment_val
# return True if the value matches expected value
return int(os.getenv(envrionment_str, envrionment_val)) == envrionment_val
# if no environment check return True as import_module worked
else:
return True
# if cannot import and provides expected Exceptions, return False
except (ImportError, ModuleNotFoundError):
return False
# raise warning if anyother exception raised in import
except Exception as e:
raise UserWarning("Unexpceted Exception when importing cupy") from e
raise UserWarning(f"Unexpceted Exception when importing {module}") from e


cupy_enabled = check_module_enabled("cupy", "CUPY_PYLOPS")

cusignal_enabled = (
util.find_spec("cusignal") is not None and int(os.getenv("CUSIGNAL_PYLOPS", 1)) == 1
)
devito_enabled = util.find_spec("devito") is not None
numba_enabled = util.find_spec("numba") is not None
pyfftw_enabled = util.find_spec("pyfftw") is not None
pywt_enabled = util.find_spec("pywt") is not None
skfmm_enabled = util.find_spec("skfmm") is not None
spgl1_enabled = util.find_spec("spgl1") is not None
sympy_enabled = util.find_spec("sympy") is not None
torch_enabled = util.find_spec("torch") is not None
cusignal_enabled = check_module_enabled("cusignal", "CUSIGNAL_PYLOPS")
devito_enabled = check_module_enabled("devito")
numba_enabled = check_module_enabled("numba")
pyfftw_enabled = check_module_enabled("pyfftw")
pywt_enabled = check_module_enabled("pywt")
skfmm_enabled = check_module_enabled("skfmm")
spgl1_enabled = check_module_enabled("spgl1")
sympy_enabled = check_module_enabled("sympy")
torch_enabled = check_module_enabled("torch")


# error message at import of available package
Expand Down

0 comments on commit f786229

Please sign in to comment.