Skip to content

Commit

Permalink
CholeskySolve inherits from BaseSolve
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Aug 25, 2023
1 parent a62f9aa commit a3869c6
Showing 3 changed files with 51 additions and 101 deletions.
144 changes: 47 additions & 97 deletions pytensor/tensor/slinalg.py
Original file line number Diff line number Diff line change
@@ -47,7 +47,7 @@ class Cholesky(Op):

__props__ = ("lower", "destructive", "on_error")

def __init__(self, lower=True, on_error="raise"):
def __init__(self, *, lower=True, on_error="raise"):
self.lower = lower
self.destructive = False
if on_error not in ("raise", "nan"):
@@ -125,77 +125,8 @@ def conjugate_solve_triangular(outer, inner):
return [grad]


cholesky = Cholesky()


class CholeskySolve(Op):
__props__ = ("lower", "check_finite")

def __init__(
self,
lower=True,
check_finite=True,
):
self.lower = lower
self.check_finite = check_finite

def __repr__(self):
return "CholeskySolve{%s}" % str(self._props())

def make_node(self, C, b):
C = as_tensor_variable(C)
b = as_tensor_variable(b)
assert C.ndim == 2
assert b.ndim in (1, 2)

# infer dtype by solving the most simple
# case with (1, 1) matrices
o_dtype = scipy.linalg.solve(
np.eye(1).astype(C.dtype), np.eye(1).astype(b.dtype)
).dtype
x = tensor(dtype=o_dtype, shape=b.type.shape)
return Apply(self, [C, b], [x])

def perform(self, node, inputs, output_storage):
C, b = inputs
rval = scipy.linalg.cho_solve(
(C, self.lower),
b,
check_finite=self.check_finite,
)

output_storage[0][0] = rval

def infer_shape(self, fgraph, node, shapes):
Cshape, Bshape = shapes
rows = Cshape[1]
if len(Bshape) == 1: # b is a Vector
return [(rows,)]
else:
cols = Bshape[1] # b is a Matrix
return [(rows, cols)]


cho_solve = CholeskySolve()


def cho_solve(c_and_lower, b, check_finite=True):
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
Parameters
----------
(c, lower) : tuple, (array, bool)
Cholesky factorization of a, as given by cho_factor
b : array
Right-hand side
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
"""

A, lower = c_and_lower
return CholeskySolve(lower=lower, check_finite=check_finite)(A, b)
def cholesky(x, lower=True, on_error="raise"):
return Cholesky(lower=lower, on_error=on_error)(x)


class SolveBase(Op):
@@ -208,6 +139,7 @@ class SolveBase(Op):

def __init__(
self,
*,
lower=False,
check_finite=True,
):
@@ -274,28 +206,56 @@ def L_op(self, inputs, outputs, output_gradients):

return [A_bar, b_bar]

def __repr__(self):
return f"{type(self).__name__}{self._props()}"

class CholeskySolve(SolveBase):
def __init__(self, **kwargs):
kwargs.setdefault("lower", True)
super().__init__(**kwargs)

def perform(self, node, inputs, output_storage):
C, b = inputs
rval = scipy.linalg.cho_solve(
(C, self.lower),
b,
check_finite=self.check_finite,
)

output_storage[0][0] = rval

def L_op(self, *args, **kwargs):
raise NotImplementedError()


def cho_solve(c_and_lower, b, *, check_finite=True):
"""Solve the linear equations A x = b, given the Cholesky factorization of A.
Parameters
----------
(c, lower) : tuple, (array, bool)
Cholesky factorization of a, as given by cho_factor
b : array
Right-hand side
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.
"""
A, lower = c_and_lower
return CholeskySolve(lower=lower, check_finite=check_finite)(A, b)


class SolveTriangular(SolveBase):
"""Solve a system of linear equations."""

__props__ = (
"lower",
"trans",
"unit_diagonal",
"lower",
"check_finite",
)

def __init__(
self,
trans=0,
lower=False,
unit_diagonal=False,
check_finite=True,
):
super().__init__(lower=lower, check_finite=check_finite)
def __init__(self, *, trans=0, unit_diagonal=False, **kwargs):
super().__init__(**kwargs)
self.trans = trans
self.unit_diagonal = unit_diagonal

@@ -321,12 +281,10 @@ def L_op(self, inputs, outputs, output_gradients):
return res


solvetriangular = SolveTriangular()


def solve_triangular(
a: TensorVariable,
b: TensorVariable,
*,
trans: Union[int, str] = 0,
lower: bool = False,
unit_diagonal: bool = False,
@@ -374,16 +332,11 @@ class Solve(SolveBase):
"check_finite",
)

def __init__(
self,
assume_a="gen",
lower=False,
check_finite=True,
):
def __init__(self, *, assume_a="gen", **kwargs):
if assume_a not in ("gen", "sym", "her", "pos"):
raise ValueError(f"{assume_a} is not a recognized matrix structure")

super().__init__(lower=lower, check_finite=check_finite)
super().__init__(**kwargs)
self.assume_a = assume_a

def perform(self, node, inputs, outputs):
@@ -397,10 +350,7 @@ def perform(self, node, inputs, outputs):
)


solve = Solve()


def solve(a, b, assume_a="gen", lower=False, check_finite=True):
def solve(a, b, *, assume_a="gen", lower=False, check_finite=True):
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
If the data matrix is known to be a particular type then supplying the
6 changes: 3 additions & 3 deletions tests/link/numba/test_nlinalg.py
Original file line number Diff line number Diff line change
@@ -46,7 +46,7 @@
],
)
def test_Cholesky(x, lower, exc):
g = slinalg.Cholesky(lower)(x)
g = slinalg.Cholesky(lower=lower)(x)

if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
@@ -91,7 +91,7 @@ def test_Cholesky(x, lower, exc):
],
)
def test_Solve(A, x, lower, exc):
g = slinalg.Solve(lower)(A, x)
g = slinalg.Solve(lower=lower)(A, x)

if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
@@ -125,7 +125,7 @@ def test_Solve(A, x, lower, exc):
],
)
def test_SolveTriangular(A, x, lower, exc):
g = slinalg.SolveTriangular(lower)(A, x)
g = slinalg.SolveTriangular(lower=lower)(A, x)

if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
2 changes: 1 addition & 1 deletion tests/tensor/test_slinalg.py
Original file line number Diff line number Diff line change
@@ -360,7 +360,7 @@ def setup_method(self):
super().setup_method()

def test_repr(self):
assert repr(CholeskySolve()) == "CholeskySolve{(True, True)}"
assert repr(CholeskySolve()) == "CholeskySolve(lower=True,check_finite=True)"

def test_infer_shape(self):
rng = np.random.default_rng(utt.fetch_seed())

0 comments on commit a3869c6

Please sign in to comment.