Skip to content

Commit

Permalink
Clean up solver choosing
Browse files Browse the repository at this point in the history
There was a lot of redundant code between the specified solver cases and
the guessing solver cases. These treatements have been consolidated and
simplified. Easier to rearrange order of solvers to try and to add in new
solvers.

Addresses usnistgov#644
  • Loading branch information
guyer committed Jan 27, 2020
1 parent 6bd3686 commit 9f31d13
Showing 1 changed file with 79 additions and 109 deletions.
188 changes: 79 additions & 109 deletions fipy/solvers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,138 +7,108 @@
from future.utils import text_to_native_str
__all__ = [text_to_native_str(n) for n in __all__]

solver = _parseSolver()
_desired_solver = _parseSolver()

def _envSolver(solver):
import os
if solver is None and 'FIPY_SOLVERS' in os.environ:
solver = os.environ['FIPY_SOLVERS'].lower()
return solver

solver = _envSolver(solver)
import os
if _desired_solver is None and 'FIPY_SOLVERS' in os.environ:
_desired_solver = os.environ['FIPY_SOLVERS'].lower()

exceptions = []

class SerialSolverError(Exception):
def __init__(self, solver):
super(SerialSolverError, self).__init__(solver + ' does not run in parallel')

if solver == "pysparse":
if _parallelComm.Nproc > 1:
raise SerialSolverError('pysparse')
from fipy.solvers.pysparse import *
__all__.extend(pysparse.__all__)
from fipy.matrices.pysparseMatrix import _PysparseMeshMatrix
_MeshMatrix = _PysparseMeshMatrix

elif solver == "petsc":
from fipy.solvers.petsc import *
__all__.extend(petsc.__all__)

from fipy.matrices.petscMatrix import _PETScMeshMatrix
_MeshMatrix = _PETScMeshMatrix

elif solver == "trilinos":
from fipy.solvers.trilinos import *
__all__.extend(trilinos.__all__)

try:
from fipy.matrices.pysparseMatrix import _PysparseMeshMatrix
_MeshMatrix = _PysparseMeshMatrix
except ImportError:
from fipy.matrices.trilinosMatrix import _TrilinosMeshMatrix
_MeshMatrix = _TrilinosMeshMatrix

elif solver == "scipy":
if _parallelComm.Nproc > 1:
raise SerialSolverError('scipy')
from fipy.solvers.scipy import *
__all__.extend(scipy.__all__)
from fipy.matrices.scipyMatrix import _ScipyMeshMatrix
_MeshMatrix = _ScipyMeshMatrix

elif solver == "pyamg":
if _parallelComm.Nproc > 1:
raise SerialSolverError('pyamg')
from fipy.solvers.pyAMG import *
__all__.extend(pyAMG.__all__)
from fipy.matrices.scipyMatrix import _ScipyMeshMatrix
_MeshMatrix = _ScipyMeshMatrix

elif solver == "pyamgx":
if _parallelComm.Nproc > 1:
raise SerialSolverError('pyamgx')
from fipy.solvers.pyamgx import *
__all__.extend(pyamgx.__all__)
from fipy.matrices.scipyMatrix import _ScipyMeshMatrix
_MeshMatrix = _ScipyMeshMatrix

elif solver == "no-pysparse":
from fipy.solvers.trilinos import *
__all__.extend(trilinos.__all__)
from fipy.matrices.trilinosMatrix import _TrilinosMeshMatrix
_MeshMatrix = _TrilinosMeshMatrix

elif solver is None:
# If no argument or environment variable, try importing them and seeing
# what works

exceptions = []
solver = None

if solver is None and _desired_solver in ["pysparse", None]:
try:
if _parallelComm.Nproc > 1:
raise SerialSolverError('pysparse')
from fipy.solvers.pysparse import *
__all__.extend(pysparse.__all__)
solver = "pysparse"
from fipy.matrices.pysparseMatrix import _PysparseMeshMatrix
_MeshMatrix = _PysparseMeshMatrix
solver = "pysparse"
except Exception as inst:
exceptions.append(inst)

except (ImportError, SerialSolverError) as inst:
if solver is None and _desired_solver in ["petsc", None]:
try:
from fipy.solvers.petsc import *
__all__.extend(petsc.__all__)

from fipy.matrices.petscMatrix import _PETScMeshMatrix
_MeshMatrix = _PETScMeshMatrix
solver = "petsc"
except Exception as inst:
exceptions.append(inst)

try:
from fipy.solvers.trilinos import *
__all__.extend(trilinos.__all__)
solver = "trilinos"
if solver is None and _desired_solver in ["trilinos", "no-pysparse", None]:
try:
from fipy.solvers.trilinos import *
__all__.extend(trilinos.__all__)

if _desired_solver != "no-pysparse":
try:
from fipy.matrices.pysparseMatrix import _PysparseMeshMatrix
_MeshMatrix = _PysparseMeshMatrix
solver = "trilinos"
except ImportError:
solver = "no-pysparse"
from fipy.matrices.trilinosMatrix import _TrilinosMeshMatrix
_MeshMatrix = _TrilinosMeshMatrix
except ImportError as inst:
exceptions.append(inst)
pass

if solver is None:
# no-pysparse requested or pysparseMatrix failed to import
from fipy.matrices.trilinosMatrix import _TrilinosMeshMatrix
_MeshMatrix = _TrilinosMeshMatrix
solver = "no-pysparse"
except Exception as inst:
exceptions.append(inst)

try:
if _parallelComm.Nproc > 1:
raise SerialSolverError('pyamg')
from fipy.solvers.pyAMG import *
__all__.extend(pyAMG.__all__)
solver = "pyamg"
from fipy.matrices.scipyMatrix import _ScipyMeshMatrix
_MeshMatrix = _ScipyMeshMatrix
except (ImportError, SerialSolverError) as inst:
exceptions.append(inst)

try:
if _parallelComm.Nproc > 1:
raise SerialSolverError('scipy')
from fipy.solvers.scipy import *
__all__.extend(scipy.__all__)
solver = "scipy"
from fipy.matrices.scipyMatrix import _ScipyMeshMatrix
_MeshMatrix = _ScipyMeshMatrix
except (ImportError, SerialSolverError) as inst:
exceptions.append(inst)
import warnings
warnings.warn("Could not import any solver package. If you are using Trilinos, make sure you have all of the necessary Trilinos packages installed - Epetra, EpetraExt, AztecOO, Amesos, ML, and IFPACK.")
for inst in exceptions:
warnings.warn(inst.__class__.__name__ + ': ' + inst.message)


else:
raise ImportError('Unknown solver package %s' % solver)
if solver is None and _desired_solver in ["scipy", None]:
try:
if _parallelComm.Nproc > 1:
raise SerialSolverError('scipy')
from fipy.solvers.scipy import *
__all__.extend(scipy.__all__)
from fipy.matrices.scipyMatrix import _ScipyMeshMatrix
_MeshMatrix = _ScipyMeshMatrix
solver = "scipy"
except Exception as inst:
exceptions.append(inst)

if solver is None and _desired_solver in ["pyamg", None]:
try:
if _parallelComm.Nproc > 1:
raise SerialSolverError('pyamg')
from fipy.solvers.pyAMG import *
__all__.extend(pyAMG.__all__)
from fipy.matrices.scipyMatrix import _ScipyMeshMatrix
_MeshMatrix = _ScipyMeshMatrix
solver = "pyamg"
except Exception as inst:
exceptions.append(inst)

if solver is None and _desired_solver in ["pyamgx", None]:
try:
if _parallelComm.Nproc > 1:
raise SerialSolverError('pyamgx')
from fipy.solvers.pyamgx import *
__all__.extend(pyamgx.__all__)
from fipy.matrices.scipyMatrix import _ScipyMeshMatrix
_MeshMatrix = _ScipyMeshMatrix
solver = "pyamgx"
except Exception as inst:
exceptions.append(inst)

if solver is None:
if _desired_solver is None:
raise ImportError('Unable to load a solver: %s' % [str(e) for e in exceptions])
else:
if len(exceptions) > 0:
raise ImportError('Unable to load solver %s: %s' % (_desired_solver, [str(e) for e in exceptions]))
else:
raise ImportError('Unknown solver package %s' % _desired_solver)

from fipy.tests.doctestPlus import register_skipper

Expand Down

0 comments on commit 9f31d13

Please sign in to comment.