Skip to content

Commit

Permalink
added tests for residues
Browse files Browse the repository at this point in the history
  • Loading branch information
lilyminium committed Mar 15, 2021
1 parent 3244844 commit a17359e
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 3 deletions.
6 changes: 5 additions & 1 deletion package/MDAnalysis/core/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,11 @@ def get_connections(self, typename, outside=False):
if not len(ugroup):
return ugroup
func = np.any if outside else np.all
seen = [np.in1d(col, self.ix_array) for col in ugroup._bix.T]
try:
indices = self.atoms.ix_array
except AttributeError: # if self is an Atom
indices = self.ix_array
seen = [np.in1d(col, indices) for col in ugroup._bix.T]
mask = func(seen, axis=0)
return ugroup[mask]

Expand Down
65 changes: 63 additions & 2 deletions testsuite/MDAnalysisTests/core/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,8 +1473,8 @@ def tpr():
return mda.Universe(TPR)


class TestGetConnections(object):
"""Test _MutableBase.get_connections"""
class TestGetConnectionsAtoms(object):
"""Test Atom and AtomGroup.get_connections"""

@pytest.mark.parametrize("typename",
["bonds", "angles", "dihedrals", "impropers"])
Expand Down Expand Up @@ -1530,6 +1530,67 @@ def test_get_empty_group(self, tpr, outside):
assert len(cxns) == 0


class TestGetConnectionsResidues(object):
"""Test Residue and ResidueGroup.get_connections"""

@pytest.mark.parametrize("typename, n_atoms", [
("bonds", 9),
("angles", 14),
("dihedrals", 9),
("impropers", 0),
])
def test_connection_from_res_not_outside(self, tpr, typename, n_atoms):
cxns = tpr.residues[10].get_connections(typename, outside=False)
assert len(cxns) == n_atoms

@pytest.mark.parametrize("typename, n_atoms", [
("bonds", 11),
("angles", 22),
("dihedrals", 27),
("impropers", 0),
])
def test_connection_from_res_outside(self, tpr, typename, n_atoms):
cxns = tpr.residues[10].get_connections(typename, outside=True)
assert len(cxns) == n_atoms

@pytest.mark.parametrize("typename, n_atoms", [
("bonds", 157),
("angles", 290),
("dihedrals", 351),
])
def test_connection_from_residues_not_outside(self, tpr, typename,
n_atoms):
ag = tpr.residues[:10]
cxns = ag.get_connections(typename, outside=False)
assert len(cxns) == n_atoms
indices = np.ravel(cxns.to_indices())
assert np.all(np.in1d(indices, ag.atoms.indices))

@pytest.mark.parametrize("typename, n_atoms", [
("bonds", 158),
("angles", 294),
("dihedrals", 360),
])
def test_connection_from_residues_outside(self, tpr, typename, n_atoms):
ag = tpr.residues[:10]
cxns = ag.get_connections(typename, outside=True)
assert len(cxns) == n_atoms
indices = np.ravel(cxns.to_indices())
assert not np.all(np.in1d(indices, ag.indices))

def test_invalid_connection_error(self, tpr):
with pytest.raises(AttributeError, match="does not contain"):
ag = tpr.residues[:10]
ag.get_connections("ureybradleys")

@pytest.mark.parametrize("outside", [True, False])
def test_get_empty_group(self, tpr, outside):
imp = tpr.impropers
ag = tpr.residues[:10]
cxns = ag.get_connections("impropers", outside=outside)
assert len(imp) == 0
assert len(cxns) == 0

@pytest.mark.parametrize("typename, n_atoms", [
("bonds", 13),
("angles", 27),
Expand Down

0 comments on commit a17359e

Please sign in to comment.