Skip to content

Commit

Permalink
Merge pull request #392 from rikardn/immutablematrix
Browse files Browse the repository at this point in the history
Better compatibility with sympy ImmutableMatrix (fixes #363)
  • Loading branch information
isuruf authored Mar 20, 2023
2 parents 0ca50eb + b586d41 commit 6c9a264
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
10 changes: 8 additions & 2 deletions symengine/lib/symengine_wrapper.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3895,7 +3895,7 @@ cdef class DenseMatrixBase(MatrixBase):
l.append(c2py(A.get(i, j))._sympy_())
s.append(l)
import sympy
return sympy.Matrix(s)
return sympy.ImmutableMatrix(s)

def _sage_(self):
s = []
Expand All @@ -3906,7 +3906,7 @@ cdef class DenseMatrixBase(MatrixBase):
l.append(c2py(A.get(i, j))._sage_())
s.append(l)
import sage.all as sage
return sage.Matrix(s)
return sage.Matrix(s, immutable=True)

def dump_real(self, double[::1] out):
cdef size_t ri, ci, nr, nc
Expand Down Expand Up @@ -4046,6 +4046,12 @@ cdef class ImmutableDenseMatrix(DenseMatrixBase):
def __setitem__(self, key, value):
raise TypeError("Cannot set values of {}".format(self.__class__))

def _applyfunc(self, f):
res = DenseMatrix(self)
res._applyfunc(f)
return ImmutableDenseMatrix(res)


ImmutableMatrix = ImmutableDenseMatrix


Expand Down
28 changes: 20 additions & 8 deletions symengine/tests/test_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,23 @@
Rational, function_symbol, I, NonSquareMatrixError, ShapeError, zeros,
ones, eye, ImmutableMatrix)
from symengine.test_utilities import raises
import unittest


try:
import numpy as np
HAVE_NUMPY = True
have_numpy = True
except ImportError:
HAVE_NUMPY = False
have_numpy = False

try:
import sympy
from sympy.core.cache import clear_cache
import atexit
atexit.register(clear_cache)
have_sympy = True
except ImportError:
have_sympy = False


def test_init():
Expand Down Expand Up @@ -520,21 +530,18 @@ def test_reshape():
assert C != A


# @pytest.mark.skipif(not HAVE_NUMPY, reason='requires numpy')
@unittest.skipIf(not have_numpy, 'requires numpy')
def test_dump_real():
if not HAVE_NUMPY: # nosetests work-around
return
ref = [1, 2, 3, 4]
A = DenseMatrix(2, 2, ref)
out = np.empty(4)
A.dump_real(out)
assert np.allclose(out, ref)


# @pytest.mark.skipif(not HAVE_NUMPY, reason='requires numpy')

@unittest.skipIf(not have_numpy, 'requires numpy')
def test_dump_complex():
if not HAVE_NUMPY: # nosetests work-around
return
ref = [1j, 2j, 3j, 4j]
A = DenseMatrix(2, 2, ref)
out = np.empty(4, dtype=np.complex128)
Expand Down Expand Up @@ -741,3 +748,8 @@ def test_repr_latex():
latex_string = testmat._repr_latex_()
assert isinstance(latex_string, str)
init_printing(False)

@unittest.skipIf(not have_sympy, "SymPy not installed")
def test_simplify():
A = ImmutableMatrix([1])
assert type(A.simplify()) == type(A)

0 comments on commit 6c9a264

Please sign in to comment.