Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing how optional dependencies are import state are checked #551

Merged
merged 18 commits into from
Dec 17, 2023
Merged
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 92 additions & 17 deletions pylops/utils/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,8 @@
]

import os
from importlib import util

# check package availability
cupy_enabled = (
util.find_spec("cupy") is not None and int(os.getenv("CUPY_PYLOPS", 1)) == 1
)
cusignal_enabled = (
util.find_spec("cusignal") is not None and int(os.getenv("CUSIGNAL_PYLOPS", 1)) == 1
)
devito_enabled = util.find_spec("devito") is not None
numba_enabled = util.find_spec("numba") is not None
pyfftw_enabled = util.find_spec("pyfftw") is not None
pywt_enabled = util.find_spec("pywt") is not None
skfmm_enabled = util.find_spec("skfmm") is not None
spgl1_enabled = util.find_spec("spgl1") is not None
sympy_enabled = util.find_spec("sympy") is not None
torch_enabled = util.find_spec("torch") is not None
from importlib import import_module, util
from typing import Optional


# error message at import of available package
Expand Down Expand Up @@ -156,3 +141,93 @@ def sympy_import(message):
f'"pip install sympy".'
)
return sympy_message


def cupy_import(message: Optional[str] = None):
# detect if cupy should be importable
cupy_test = (
util.find_spec("cupy") is not None and int(os.getenv("CUPY_PYLOPS", 1)) == 1
)
# if cupy should be importable
if cupy_test:
# try importing it
try:
import_module("cupy") # noqa: F401

# if successful set the message to None.
cupy_message = None
# if unable to import but it is installed
except (ImportError, ModuleNotFoundError) as e:
cupy_message = (
f"Failed to import cupy, Falling back to CPU (error: {e}). "
f""
"Please ensure your CUDA envrionment is set up correctly "
"for more details visit 'https://docs.cupy.dev/en/stable/install.html'"
)
print(UserWarning(cupy_message))
# if cupy_test is False it means not installed or envrionment variable set to 0
else:
cupy_message = (
f"cupy package not installed or os.getenv('CUPY_PYLOPS') == 0. In order to be able to use "
f"{message} "
"ensure 'os.getenv('CUPY_PYLOPS') == 1' and run "
"'pip install cupy'. "
"for more details visit 'https://docs.cupy.dev/en/stable/install.html'"
)

return cupy_message


def cusignal_import(message: Optional[str] = None):
# detect if cupy should be importable
cusignal_test = (
util.find_spec("cusignal") is not None
and int(os.getenv("CUSIGNAL_PYLOPS", 1)) == 1
)
# if cupy should be importable
if cusignal_test:
# try importing it
try:
import_module("cusignal") # noqa: F401

# if successful set the message to None.
cusignal_message = None
# if unable to import but it is installed
except (ImportError, ModuleNotFoundError) as e:
cusignal_message = (
f"Failed to import cusignal. Falling back to CPU (error: {e}) . "
f""
"Please ensure your CUDA envrionment is set up correctly "
"for more details visit 'https://github.com/rapidsai/cusignal#installation'"
)
print(UserWarning(cusignal_message))
# if cupy_test is False it means not installed or envrionment variable set to 0
else:
cusignal_message = (
f"cusignal package not installed or os.getenv('CUSIGNAL_PYLOPS') == 0. In order to be able to use "
f"{message} "
"ensure 'os.getenv('CUSIGNAL_PYLOPS') == 1' and run "
"'conda install cusignal'. "
"for more details visit ''https://github.com/rapidsai/cusignal#installation''"
)

return cusignal_message


# Set package avlaiblity booleans
cupy_enabled: bool = (
True if (cupy_import() is None and int(os.getenv("CUPY_PYLOPS", 1)) == 1) else False
)
cusignal_enabled: bool = (
True
if (cusignal_import() is None and int(os.getenv("CUSIGNAL_PYLOPS", 1)) == 1)
else False
)
devito_enabled = util.find_spec("devito") is not None
numba_enabled = util.find_spec("numba") is not None
pyfftw_enabled = util.find_spec("pyfftw") is not None
pywt_enabled = util.find_spec("pywt") is not None
skfmm_enabled = util.find_spec("skfmm") is not None
spgl1_enabled = util.find_spec("spgl1") is not None
sympy_enabled = util.find_spec("sympy") is not None
torch_enabled = util.find_spec("torch") is not None