diff --git a/fipy/solvers/__init__.py b/fipy/solvers/__init__.py index c18fbf61f6..d21ed4ca0d 100644 --- a/fipy/solvers/__init__.py +++ b/fipy/solvers/__init__.py @@ -7,138 +7,108 @@ from future.utils import text_to_native_str __all__ = [text_to_native_str(n) for n in __all__] -solver = _parseSolver() +_desired_solver = _parseSolver() -def _envSolver(solver): - import os - if solver is None and 'FIPY_SOLVERS' in os.environ: - solver = os.environ['FIPY_SOLVERS'].lower() - return solver - -solver = _envSolver(solver) +import os +if _desired_solver is None and 'FIPY_SOLVERS' in os.environ: + _desired_solver = os.environ['FIPY_SOLVERS'].lower() + +exceptions = [] class SerialSolverError(Exception): def __init__(self, solver): super(SerialSolverError, self).__init__(solver + ' does not run in parallel') -if solver == "pysparse": - if _parallelComm.Nproc > 1: - raise SerialSolverError('pysparse') - from fipy.solvers.pysparse import * - __all__.extend(pysparse.__all__) - from fipy.matrices.pysparseMatrix import _PysparseMeshMatrix - _MeshMatrix = _PysparseMeshMatrix - -elif solver == "petsc": - from fipy.solvers.petsc import * - __all__.extend(petsc.__all__) - - from fipy.matrices.petscMatrix import _PETScMeshMatrix - _MeshMatrix = _PETScMeshMatrix - -elif solver == "trilinos": - from fipy.solvers.trilinos import * - __all__.extend(trilinos.__all__) - - try: - from fipy.matrices.pysparseMatrix import _PysparseMeshMatrix - _MeshMatrix = _PysparseMeshMatrix - except ImportError: - from fipy.matrices.trilinosMatrix import _TrilinosMeshMatrix - _MeshMatrix = _TrilinosMeshMatrix - -elif solver == "scipy": - if _parallelComm.Nproc > 1: - raise SerialSolverError('scipy') - from fipy.solvers.scipy import * - __all__.extend(scipy.__all__) - from fipy.matrices.scipyMatrix import _ScipyMeshMatrix - _MeshMatrix = _ScipyMeshMatrix - -elif solver == "pyamg": - if _parallelComm.Nproc > 1: - raise SerialSolverError('pyamg') - from fipy.solvers.pyAMG import * - __all__.extend(pyAMG.__all__) - from fipy.matrices.scipyMatrix import _ScipyMeshMatrix - _MeshMatrix = _ScipyMeshMatrix - -elif solver == "pyamgx": - if _parallelComm.Nproc > 1: - raise SerialSolverError('pyamgx') - from fipy.solvers.pyamgx import * - __all__.extend(pyamgx.__all__) - from fipy.matrices.scipyMatrix import _ScipyMeshMatrix - _MeshMatrix = _ScipyMeshMatrix - -elif solver == "no-pysparse": - from fipy.solvers.trilinos import * - __all__.extend(trilinos.__all__) - from fipy.matrices.trilinosMatrix import _TrilinosMeshMatrix - _MeshMatrix = _TrilinosMeshMatrix - -elif solver is None: - # If no argument or environment variable, try importing them and seeing - # what works - - exceptions = [] +solver = None +if solver is None and _desired_solver in ["pysparse", None]: try: if _parallelComm.Nproc > 1: raise SerialSolverError('pysparse') from fipy.solvers.pysparse import * __all__.extend(pysparse.__all__) - solver = "pysparse" from fipy.matrices.pysparseMatrix import _PysparseMeshMatrix _MeshMatrix = _PysparseMeshMatrix + solver = "pysparse" + except Exception as inst: + exceptions.append(inst) - except (ImportError, SerialSolverError) as inst: +if solver is None and _desired_solver in ["petsc", None]: + try: + from fipy.solvers.petsc import * + __all__.extend(petsc.__all__) + + from fipy.matrices.petscMatrix import _PETScMeshMatrix + _MeshMatrix = _PETScMeshMatrix + solver = "petsc" + except Exception as inst: exceptions.append(inst) - try: - from fipy.solvers.trilinos import * - __all__.extend(trilinos.__all__) - solver = "trilinos" +if solver is None and _desired_solver in ["trilinos", "no-pysparse", None]: + try: + from fipy.solvers.trilinos import * + __all__.extend(trilinos.__all__) + + if _desired_solver != "no-pysparse": try: from fipy.matrices.pysparseMatrix import _PysparseMeshMatrix _MeshMatrix = _PysparseMeshMatrix + solver = "trilinos" except ImportError: - solver = "no-pysparse" - from fipy.matrices.trilinosMatrix import _TrilinosMeshMatrix - _MeshMatrix = _TrilinosMeshMatrix - except ImportError as inst: - exceptions.append(inst) + pass + + if solver is None: + # no-pysparse requested or pysparseMatrix failed to import + from fipy.matrices.trilinosMatrix import _TrilinosMeshMatrix + _MeshMatrix = _TrilinosMeshMatrix + solver = "no-pysparse" + except Exception as inst: + exceptions.append(inst) - try: - if _parallelComm.Nproc > 1: - raise SerialSolverError('pyamg') - from fipy.solvers.pyAMG import * - __all__.extend(pyAMG.__all__) - solver = "pyamg" - from fipy.matrices.scipyMatrix import _ScipyMeshMatrix - _MeshMatrix = _ScipyMeshMatrix - except (ImportError, SerialSolverError) as inst: - exceptions.append(inst) - - try: - if _parallelComm.Nproc > 1: - raise SerialSolverError('scipy') - from fipy.solvers.scipy import * - __all__.extend(scipy.__all__) - solver = "scipy" - from fipy.matrices.scipyMatrix import _ScipyMeshMatrix - _MeshMatrix = _ScipyMeshMatrix - except (ImportError, SerialSolverError) as inst: - exceptions.append(inst) - import warnings - warnings.warn("Could not import any solver package. If you are using Trilinos, make sure you have all of the necessary Trilinos packages installed - Epetra, EpetraExt, AztecOO, Amesos, ML, and IFPACK.") - for inst in exceptions: - warnings.warn(inst.__class__.__name__ + ': ' + inst.message) - - -else: - raise ImportError('Unknown solver package %s' % solver) +if solver is None and _desired_solver in ["scipy", None]: + try: + if _parallelComm.Nproc > 1: + raise SerialSolverError('scipy') + from fipy.solvers.scipy import * + __all__.extend(scipy.__all__) + from fipy.matrices.scipyMatrix import _ScipyMeshMatrix + _MeshMatrix = _ScipyMeshMatrix + solver = "scipy" + except Exception as inst: + exceptions.append(inst) + +if solver is None and _desired_solver in ["pyamg", None]: + try: + if _parallelComm.Nproc > 1: + raise SerialSolverError('pyamg') + from fipy.solvers.pyAMG import * + __all__.extend(pyAMG.__all__) + from fipy.matrices.scipyMatrix import _ScipyMeshMatrix + _MeshMatrix = _ScipyMeshMatrix + solver = "pyamg" + except Exception as inst: + exceptions.append(inst) + +if solver is None and _desired_solver in ["pyamgx", None]: + try: + if _parallelComm.Nproc > 1: + raise SerialSolverError('pyamgx') + from fipy.solvers.pyamgx import * + __all__.extend(pyamgx.__all__) + from fipy.matrices.scipyMatrix import _ScipyMeshMatrix + _MeshMatrix = _ScipyMeshMatrix + solver = "pyamgx" + except Exception as inst: + exceptions.append(inst) +if solver is None: + if _desired_solver is None: + raise ImportError('Unable to load a solver: %s' % [str(e) for e in exceptions]) + else: + if len(exceptions) > 0: + raise ImportError('Unable to load solver %s: %s' % (_desired_solver, [str(e) for e in exceptions])) + else: + raise ImportError('Unknown solver package %s' % _desired_solver) from fipy.tests.doctestPlus import register_skipper