Skip to content

Commit

Permalink
Changing how optional dependencies are import state are checked (#551)
Browse files Browse the repository at this point in the history
* adding initial check

* adding check_module_enabled func

* silly typo "module" instead of module

* switch all optionals to check_module_enabled

* changing to c/longdouble for tests

* cleaning up old code and adding noqa

* Revert "changing to c/longdouble for tests"

This reverts commit 06534b1.

* numpydoc style docstring

* adding cupy/cusignal import tests and how enables are set

* adding message to be optional

* fixing logic in cupy/cusignal_enabled

* cleaning up old code snippets and typos

* small format changes

* more formatting changes

* changing import to import_module for pylint

* minor: restyling deps.py

* minor: added alex-rakowski to contributors

---------

Co-authored-by: mrava87 <matteoravasi@gmail.com>
  • Loading branch information
alex-rakowski and mrava87 authored Dec 17, 2023
1 parent d6e484c commit 77eca0f
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 32 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,4 @@ A list of video tutorials to learn more about PyLops:
* Rohan Babbar, rohanbabbar04
* Wei Zhang, ZhangWeiGeo
* Fedor Goncharov, fedor-goncharov
* Alex Rakowski, alex-rakowski
3 changes: 2 additions & 1 deletion docs/source/credits.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ Contributors
* `Aniket Singh Rawat <https://github.com/dikwickley>`_, dikwickley
* `Rohan Babbar <https://github.com/rohanbabbar04>`_, rohanbabbar04
* `Wei Zhang <https://github.com/ZhangWeiGeo>`_, ZhangWeiGeo
* `Fedor Goncharov <https://github.com/fedor-goncharov>`_, fedor-goncharov
* `Fedor Goncharov <https://github.com/fedor-goncharov>`_, fedor-goncharov
* `Alex Rakowski <https://github.com/alex-rakowski>`_, alex-rakowski
135 changes: 104 additions & 31 deletions pylops/utils/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,78 @@
]

import os
from importlib import util

# check package availability
cupy_enabled = (
util.find_spec("cupy") is not None and int(os.getenv("CUPY_PYLOPS", 1)) == 1
)
cusignal_enabled = (
util.find_spec("cusignal") is not None and int(os.getenv("CUSIGNAL_PYLOPS", 1)) == 1
)
devito_enabled = util.find_spec("devito") is not None
numba_enabled = util.find_spec("numba") is not None
pyfftw_enabled = util.find_spec("pyfftw") is not None
pywt_enabled = util.find_spec("pywt") is not None
skfmm_enabled = util.find_spec("skfmm") is not None
spgl1_enabled = util.find_spec("spgl1") is not None
sympy_enabled = util.find_spec("sympy") is not None
torch_enabled = util.find_spec("torch") is not None
from importlib import import_module, util
from typing import Optional


# error message at import of available package
def devito_import(message):
def cupy_import(message: Optional[str] = None) -> str:
# detect if cupy is available and the user is expecting to be used
cupy_test = (
util.find_spec("cupy") is not None and int(os.getenv("CUPY_PYLOPS", 1)) == 1
)
# if cupy should be importable
if cupy_test:
# try importing it
try:
import_module("cupy") # noqa: F401

# if successful set the message to None.
cupy_message = None
# if unable to import but the package is installed
except (ImportError, ModuleNotFoundError) as e:
cupy_message = (
f"Failed to import cupy, Falling back to CPU (error: {e}). "
"Please ensure your CUDA environment is set up correctly "
"for more details visit 'https://docs.cupy.dev/en/stable/install.html'"
)
print(UserWarning(cupy_message))
# if cupy_test is False, it means not installed or environment variable set to 0
else:
cupy_message = (
"Cupy package not installed or os.getenv('CUPY_PYLOPS') == 0. "
f"In order to be able to use {message} "
"ensure 'os.getenv('CUPY_PYLOPS') == 1' and run "
"'pip install cupy'; "
"for more details visit 'https://docs.cupy.dev/en/stable/install.html'"
)

return cupy_message


def cusignal_import(message: Optional[str] = None) -> str:
cusignal_test = (
util.find_spec("cusignal") is not None
and int(os.getenv("CUSIGNAL_PYLOPS", 1)) == 1
)
if cusignal_test:
try:
import_module("cusignal") # noqa: F401

cusignal_message = None
except (ImportError, ModuleNotFoundError) as e:
cusignal_message = (
f"Failed to import cusignal. Falling back to CPU (error: {e}) . "
"Please ensure your CUDA environment is set up correctly; "
"for more details visit 'https://github.com/rapidsai/cusignal#installation'"
)
print(UserWarning(cusignal_message))
else:
cusignal_message = (
"Cusignal not installed or os.getenv('CUSIGNAL_PYLOPS') == 0. "
f"In order to be able to use {message} "
"ensure 'os.getenv('CUSIGNAL_PYLOPS') == 1' and run "
"'conda install cusignal'; "
"for more details visit ''https://github.com/rapidsai/cusignal#installation''"
)

return cusignal_message


def devito_import(message: Optional[str] = None) -> str:
if devito_enabled:
try:
import devito # noqa: F401
import_module("devito") # noqa: F401

devito_message = None
except Exception as e:
Expand All @@ -49,10 +97,10 @@ def devito_import(message):
return devito_message


def numba_import(message):
def numba_import(message: Optional[str] = None) -> str:
if numba_enabled:
try:
import numba # noqa: F401
import_module("numba") # noqa: F401

numba_message = None
except Exception as e:
Expand All @@ -68,10 +116,10 @@ def numba_import(message):
return numba_message


def pyfftw_import(message):
def pyfftw_import(message: Optional[str] = None) -> str:
if pyfftw_enabled:
try:
import pyfftw # noqa: F401
import_module("pyfftw") # noqa: F401

pyfftw_message = None
except Exception as e:
Expand All @@ -87,10 +135,10 @@ def pyfftw_import(message):
return pyfftw_message


def pywt_import(message):
def pywt_import(message: Optional[str] = None) -> str:
if pywt_enabled:
try:
import pywt # noqa: F401
import_module("pywt") # noqa: F401

pywt_message = None
except Exception as e:
Expand All @@ -106,10 +154,10 @@ def pywt_import(message):
return pywt_message


def skfmm_import(message):
def skfmm_import(message: Optional[str] = None) -> str:
if skfmm_enabled:
try:
import skfmm # noqa: F401
import_module("skfmm") # noqa: F401

skfmm_message = None
except Exception as e:
Expand All @@ -124,10 +172,10 @@ def skfmm_import(message):
return skfmm_message


def spgl1_import(message):
def spgl1_import(message: Optional[str] = None) -> str:
if spgl1_enabled:
try:
import spgl1 # noqa: F401
import_module("spgl1") # noqa: F401

spgl1_message = None
except Exception as e:
Expand All @@ -141,10 +189,10 @@ def spgl1_import(message):
return spgl1_message


def sympy_import(message):
def sympy_import(message: Optional[str] = None) -> str:
if sympy_enabled:
try:
import sympy # noqa: F401
import_module("sympy") # noqa: F401

sympy_message = None
except Exception as e:
Expand All @@ -156,3 +204,28 @@ def sympy_import(message):
f'"pip install sympy".'
)
return sympy_message


# Set package availability booleans
# cupy and cusignal: the package is imported to check everything is working correctly,
# if not the package is disabled. We do this here as both libraries are used as drop-in
# replacement for many numpy and scipy routines when cupy arrays are provided.
# all other libraries: we simply check if the package is available and postpone its import
# to check everything is working correctly when a user tries to create an operator that requires
# such a package
cupy_enabled: bool = (
True if (cupy_import() is None and int(os.getenv("CUPY_PYLOPS", 1)) == 1) else False
)
cusignal_enabled: bool = (
True
if (cusignal_import() is None and int(os.getenv("CUSIGNAL_PYLOPS", 1)) == 1)
else False
)
devito_enabled = util.find_spec("devito") is not None
numba_enabled = util.find_spec("numba") is not None
pyfftw_enabled = util.find_spec("pyfftw") is not None
pywt_enabled = util.find_spec("pywt") is not None
skfmm_enabled = util.find_spec("skfmm") is not None
spgl1_enabled = util.find_spec("spgl1") is not None
sympy_enabled = util.find_spec("sympy") is not None
torch_enabled = util.find_spec("torch") is not None

0 comments on commit 77eca0f

Please sign in to comment.