diff --git a/package/MDAnalysis/core/groups.py b/package/MDAnalysis/core/groups.py index 69a9e6293e2..c0013a424a7 100644 --- a/package/MDAnalysis/core/groups.py +++ b/package/MDAnalysis/core/groups.py @@ -415,7 +415,11 @@ def get_connections(self, typename, outside=False): if not len(ugroup): return ugroup func = np.any if outside else np.all - seen = [np.in1d(col, self.ix_array) for col in ugroup._bix.T] + try: + indices = self.atoms.ix_array + except AttributeError: # if self is an Atom + indices = self.ix_array + seen = [np.in1d(col, indices) for col in ugroup._bix.T] mask = func(seen, axis=0) return ugroup[mask] diff --git a/testsuite/MDAnalysisTests/core/test_groups.py b/testsuite/MDAnalysisTests/core/test_groups.py index 641d1d1fa20..1897604a024 100644 --- a/testsuite/MDAnalysisTests/core/test_groups.py +++ b/testsuite/MDAnalysisTests/core/test_groups.py @@ -1473,8 +1473,8 @@ def tpr(): return mda.Universe(TPR) -class TestGetConnections(object): - """Test _MutableBase.get_connections""" +class TestGetConnectionsAtoms(object): + """Test Atom and AtomGroup.get_connections""" @pytest.mark.parametrize("typename", ["bonds", "angles", "dihedrals", "impropers"]) @@ -1530,6 +1530,67 @@ def test_get_empty_group(self, tpr, outside): assert len(cxns) == 0 +class TestGetConnectionsResidues(object): + """Test Residue and ResidueGroup.get_connections""" + + @pytest.mark.parametrize("typename, n_atoms", [ + ("bonds", 9), + ("angles", 14), + ("dihedrals", 9), + ("impropers", 0), + ]) + def test_connection_from_res_not_outside(self, tpr, typename, n_atoms): + cxns = tpr.residues[10].get_connections(typename, outside=False) + assert len(cxns) == n_atoms + + @pytest.mark.parametrize("typename, n_atoms", [ + ("bonds", 11), + ("angles", 22), + ("dihedrals", 27), + ("impropers", 0), + ]) + def test_connection_from_res_outside(self, tpr, typename, n_atoms): + cxns = tpr.residues[10].get_connections(typename, outside=True) + assert len(cxns) == n_atoms + + @pytest.mark.parametrize("typename, n_atoms", [ + ("bonds", 157), + ("angles", 290), + ("dihedrals", 351), + ]) + def test_connection_from_residues_not_outside(self, tpr, typename, + n_atoms): + ag = tpr.residues[:10] + cxns = ag.get_connections(typename, outside=False) + assert len(cxns) == n_atoms + indices = np.ravel(cxns.to_indices()) + assert np.all(np.in1d(indices, ag.atoms.indices)) + + @pytest.mark.parametrize("typename, n_atoms", [ + ("bonds", 158), + ("angles", 294), + ("dihedrals", 360), + ]) + def test_connection_from_residues_outside(self, tpr, typename, n_atoms): + ag = tpr.residues[:10] + cxns = ag.get_connections(typename, outside=True) + assert len(cxns) == n_atoms + indices = np.ravel(cxns.to_indices()) + assert not np.all(np.in1d(indices, ag.indices)) + + def test_invalid_connection_error(self, tpr): + with pytest.raises(AttributeError, match="does not contain"): + ag = tpr.residues[:10] + ag.get_connections("ureybradleys") + + @pytest.mark.parametrize("outside", [True, False]) + def test_get_empty_group(self, tpr, outside): + imp = tpr.impropers + ag = tpr.residues[:10] + cxns = ag.get_connections("impropers", outside=outside) + assert len(imp) == 0 + assert len(cxns) == 0 + @pytest.mark.parametrize("typename, n_atoms", [ ("bonds", 13), ("angles", 27),