Skip to content

Commit

Permalink
refactor: simplify forceflat logic
Browse files Browse the repository at this point in the history
  • Loading branch information
cako committed May 8, 2023
1 parent 5ee83c4 commit 7b97451
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions pylops/linearoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,10 @@ class LinearOperator(_LinearOperator):
forceflat : :obj:`bool`, optional
.. versionadded:: 2.2.0
Force an array to be flattened after matvec/rmatvec if the input is ambiguous (i.e., is a 1D array both when
operating with ND arrays and with 1D arrays. Defaults to ``None`` for operators that have no ambiguity (and
to ``False`` for those with ambiguity)
Force an array to be flattened after matvec/rmatvec if the input is ambiguous
(i.e., is a 1D array both when operating with ND arrays and with 1D arrays).
Defaults to ``None`` for operators that have no ambiguity or to ``False``
for those with ambiguity.
name : :obj:`str`, optional
.. versionadded:: 2.0.0
Expand Down Expand Up @@ -361,25 +362,26 @@ def __add__(self, x: LinearOperator) -> LinearOperator:
)
Op.clinear = Op.clinear and Opx.clinear
Op.explicit = False
# Define forceflat (if differing, raise error)
if isinstance(self.forceflat, bool) and isinstance(Opx.forceflat, bool):
if self.forceflat is None and Opx.forceflat is None:
Op.forceflat = None
elif self.forceflat is not None and Opx.forceflat is not None:
# Define forceflat only if differing, otherwise raise error
if self.forceflat != Opx.forceflat:
raise ValueError(
f"the two operators have contrasting forceflat {Op.forceflat}-{Opx.forceflat}"
f"Operators have conflicting forceflat {Op.forceflat} != {Opx.forceflat}"
)
else:
Op.forceflat = self.forceflat
if isinstance(self.forceflat, bool) or isinstance(Opx.forceflat, bool):
Op.forceflat = self.forceflat
else: # Only one of them is None
Op.forceflat = (
self.forceflat if self.forceflat is not None else Opx.forceflat
)
else:
Op.forceflat = None

# Replace if shape-like
if len(self.dims) == 1:
Op.dims = Opx.dims
if len(self.dimsd) == 1:
Op.dimsd = Opx.dimsd

return Op
else:
return NotImplemented
Expand Down Expand Up @@ -627,20 +629,19 @@ def dot(self, x: NDArray) -> NDArray:
self._copy_attributes(Op, exclude=["dims", "explicit", "forceflat", "name"])
Op.clinear = Op.clinear and Opx.clinear
Op.explicit = False
# Define forceflat (if differing, raise error)
if isinstance(self.forceflat, bool) and isinstance(Opx.forceflat, bool):
if self.forceflat is None and Opx.forceflat is None:
Op.forceflat = None
elif self.forceflat is not None and Opx.forceflat is not None:
# Define forceflat only if differing, otherwise raise error
if self.forceflat != Opx.forceflat:
raise ValueError(
f"the two operators have contrasting forceflat {Op.forceflat}-{Opx.forceflat}"
f"Operators have conflicting forceflat {Op.forceflat} != {Opx.forceflat}"
)
else:
Op.forceflat = self.forceflat
if isinstance(self.forceflat, bool) or isinstance(Opx.forceflat, bool):
Op.forceflat = self.forceflat
else: # Only one of them is None
Op.forceflat = (
self.forceflat if self.forceflat is not None else Opx.forceflat
)
else:
Op.forceflat = None
Op.dims = Opx.dims
return Op
elif np.isscalar(x):
Expand Down Expand Up @@ -840,7 +841,6 @@ def tosparse(self) -> NDArray:

# loop through columns of self
for i in range(n):

# make a unit vector for the ith column
unit_i = np.zeros(n)
unit_i[i] = 1
Expand Down

0 comments on commit 7b97451

Please sign in to comment.