diff --git a/package/CHANGELOG b/package/CHANGELOG index 740e41a51c9..78034d53680 100644 --- a/package/CHANGELOG +++ b/package/CHANGELOG @@ -104,6 +104,9 @@ Fixes * Fix syntax warning over comparison of literals using is (Issue #3066) Enhancements + * Added intra_bonds, intra_angles, intra_dihedrals etc. to return only + the connections involving atoms within the AtomGroup, instead of + including atoms outside the AtomGroup (Issue #1264, #2821, PR #3200) * Added del_TopologyAttr function (PR #3069) * Switch GNMAnalysis to AnalysisBase (Issue #3243) * Adds python 3.9 support (Issue #2974, PR #3027, #3245) diff --git a/package/MDAnalysis/coordinates/MOL2.py b/package/MDAnalysis/coordinates/MOL2.py index 3764e77758c..7b16a011e9d 100644 --- a/package/MDAnalysis/coordinates/MOL2.py +++ b/package/MDAnalysis/coordinates/MOL2.py @@ -324,7 +324,7 @@ def encode_block(self, obj): if hasattr(obj, "bonds"): # Grab only bonds between atoms in the obj # ie none that extend out of it - bondgroup = obj.bonds.atomgroup_intersection(obj, strict=True) + bondgroup = obj.intra_bonds bonds = sorted((b[0], b[1], b.order) for b in bondgroup) bond_lines = ["@BOND"] bls = ["{0:>5} {1:>5} {2:>5} {3:>2}".format(bid, diff --git a/package/MDAnalysis/coordinates/ParmEd.py b/package/MDAnalysis/coordinates/ParmEd.py index a714d5a0fcb..de8205ee386 100644 --- a/package/MDAnalysis/coordinates/ParmEd.py +++ b/package/MDAnalysis/coordinates/ParmEd.py @@ -269,8 +269,7 @@ def convert(self, obj): # bonds try: - params = ag_or_ts.bonds.atomgroup_intersection(ag_or_ts, - strict=True) + params = ag_or_ts.intra_bonds except AttributeError: pass else: diff --git a/package/MDAnalysis/core/groups.py b/package/MDAnalysis/core/groups.py index 83e7b32c690..1a634eebbef 100644 --- a/package/MDAnalysis/core/groups.py +++ b/package/MDAnalysis/core/groups.py @@ -382,6 +382,41 @@ def __getattr__(self, attr): err += 'Did you mean {match}?'.format(match=match) raise AttributeError(err) + def get_connections(self, typename, outside=True): + """ + Get bonded connections between atoms as a + :class:`~MDAnalysis.core.topologyobjects.TopologyGroup`. + + Parameters + ---------- + typename : str + group name. One of {"bonds", "angles", "dihedrals", + "impropers", "ureybradleys", "cmaps"} + outside : bool (optional) + Whether to include connections involving atoms outside + this group. + + Returns + ------- + TopologyGroup + containing the bonded group of choice, i.e. bonds, angles, + dihedrals, impropers, ureybradleys or cmaps. + + .. versionadded:: 1.1.0 + """ + # AtomGroup has handy error messages for missing attributes + ugroup = getattr(self.universe.atoms, typename) + if not ugroup: + return ugroup + func = np.any if outside else np.all + try: + indices = self.atoms.ix_array + except AttributeError: # if self is an Atom + indices = self.ix_array + seen = [np.in1d(col, indices) for col in ugroup._bix.T] + mask = func(seen, axis=0) + return ugroup[mask] + class _ImmutableBase(object): """Class used to shortcut :meth:`__new__` to :meth:`object.__new__`. diff --git a/package/MDAnalysis/core/topologyattrs.py b/package/MDAnalysis/core/topologyattrs.py index 88aaedd8f86..a3d0940544c 100644 --- a/package/MDAnalysis/core/topologyattrs.py +++ b/package/MDAnalysis/core/topologyattrs.py @@ -32,17 +32,19 @@ These are usually read by the TopologyParser. """ -import Bio.Seq -import Bio.SeqRecord from collections import defaultdict import copy import functools import itertools import numbers -import numpy as np +from inspect import signature as inspect_signature import warnings import textwrap -from inspect import signature as inspect_signature +from types import MethodType + +import Bio.Seq +import Bio.SeqRecord +import numpy as np from ..lib.util import (cached, convert_aa_code, iterable, warn_if_not_unique, unique_int_1d) @@ -2258,7 +2260,31 @@ def wrapper(self, values, *args, **kwargs): return wrapper -class _Connection(AtomAttr): +class _ConnectionTopologyAttrMeta(_TopologyAttrMeta): + """ + Specific metaclass for atom-connectivity topology attributes. + + This class adds an ``intra_{attrname}`` property to groups + to return only the connections within the atoms in the group. + """ + def __init__(cls, name, bases, classdict): + type.__init__(type, name, bases, classdict) + attrname = classdict.get('attrname') + + if attrname is not None: + def intra_connection(self, ag): + """Get connections only within this AtomGroup + """ + return ag.get_connections(attrname, outside=False) + + method = MethodType(intra_connection, cls) + prop = property(method, None, None, method.__doc__) + cls.transplants[AtomGroup].append((f"intra_{attrname}", prop)) + + super().__init__(name, bases, classdict) + + +class _Connection(AtomAttr, metaclass=_ConnectionTopologyAttrMeta): """Base class for connectivity between atoms .. versionchanged:: 1.0.0 @@ -2307,14 +2333,23 @@ def set_atoms(self, ag): return NotImplementedError("Cannot set bond information") def get_atoms(self, ag): + """ + Get connection values where the atom indices are in + the given atomgroup. + + Parameters + ---------- + ag : AtomGroup + + """ try: unique_bonds = set(itertools.chain( *[self._bondDict[a] for a in ag.ix])) except TypeError: # maybe we got passed an Atom unique_bonds = self._bondDict[ag.ix] - bond_idx, types, guessed, order = np.hsplit( - np.array(sorted(unique_bonds), dtype=object), 4) + unique_bonds = np.array(sorted(unique_bonds), dtype=object) + bond_idx, types, guessed, order = np.hsplit(unique_bonds, 4) bond_idx = np.array(bond_idx.ravel().tolist(), dtype=np.int32) types = types.ravel() guessed = guessed.ravel() diff --git a/testsuite/MDAnalysisTests/core/test_groups.py b/testsuite/MDAnalysisTests/core/test_groups.py index 83ff8724f05..aeee158074e 100644 --- a/testsuite/MDAnalysisTests/core/test_groups.py +++ b/testsuite/MDAnalysisTests/core/test_groups.py @@ -28,11 +28,12 @@ ) import pytest import operator +import warnings import MDAnalysis as mda from MDAnalysis.exceptions import NoDataError from MDAnalysisTests import make_Universe, no_deprecated_call -from MDAnalysisTests.datafiles import PSF, DCD +from MDAnalysisTests.datafiles import PSF, DCD, TPR from MDAnalysis.core import groups @@ -1461,3 +1462,153 @@ def test_decorator(self, compound, pbc, unwrap): self.dummy_funtion(compound=compound, pbc=pbc, unwrap=unwrap) else: assert_equal(self.dummy_funtion(compound=compound, pbc=pbc, unwrap=unwrap), 0) + + +@pytest.fixture() +def tpr(): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", + message="No coordinate reader found") + return mda.Universe(TPR) + +class TestGetConnectionsAtoms(object): + """Test Atom and AtomGroup.get_connections""" + + @pytest.mark.parametrize("typename", + ["bonds", "angles", "dihedrals", "impropers"]) + def test_connection_from_atom_not_outside(self, tpr, typename): + cxns = tpr.atoms[1].get_connections(typename, outside=False) + assert len(cxns) == 0 + + @pytest.mark.parametrize("typename, n_atoms", [ + ("bonds", 1), + ("angles", 3), + ("dihedrals", 4), + ]) + def test_connection_from_atom_outside(self, tpr, typename, n_atoms): + cxns = tpr.atoms[10].get_connections(typename, outside=True) + assert len(cxns) == n_atoms + + @pytest.mark.parametrize("typename, n_atoms", [ + ("bonds", 9), + ("angles", 15), + ("dihedrals", 12), + ]) + def test_connection_from_atoms_not_outside(self, tpr, typename, + n_atoms): + ag = tpr.atoms[:10] + cxns = ag.get_connections(typename, outside=False) + assert len(cxns) == n_atoms + indices = np.ravel(cxns.to_indices()) + assert np.all(np.in1d(indices, ag.indices)) + + @pytest.mark.parametrize("typename, n_atoms", [ + ("bonds", 13), + ("angles", 27), + ("dihedrals", 38), + ]) + def test_connection_from_atoms_outside(self, tpr, typename, n_atoms): + ag = tpr.atoms[:10] + cxns = ag.get_connections(typename, outside=True) + assert len(cxns) == n_atoms + indices = np.ravel(cxns.to_indices()) + assert not np.all(np.in1d(indices, ag.indices)) + + def test_invalid_connection_error(self, tpr): + with pytest.raises(AttributeError, match="does not contain"): + ag = tpr.atoms[:10] + ag.get_connections("ureybradleys") + + @pytest.mark.parametrize("outside", [True, False]) + def test_get_empty_group(self, tpr, outside): + imp = tpr.impropers + ag = tpr.atoms[:10] + cxns = ag.get_connections("impropers", outside=outside) + assert len(imp) == 0 + assert len(cxns) == 0 + + +class TestGetConnectionsResidues(object): + """Test Residue and ResidueGroup.get_connections""" + + @pytest.mark.parametrize("typename, n_atoms", [ + ("bonds", 9), + ("angles", 14), + ("dihedrals", 9), + ("impropers", 0), + ]) + def test_connection_from_res_not_outside(self, tpr, typename, n_atoms): + cxns = tpr.residues[10].get_connections(typename, outside=False) + assert len(cxns) == n_atoms + + @pytest.mark.parametrize("typename, n_atoms", [ + ("bonds", 11), + ("angles", 22), + ("dihedrals", 27), + ("impropers", 0), + ]) + def test_connection_from_res_outside(self, tpr, typename, n_atoms): + cxns = tpr.residues[10].get_connections(typename, outside=True) + assert len(cxns) == n_atoms + + @pytest.mark.parametrize("typename, n_atoms", [ + ("bonds", 157), + ("angles", 290), + ("dihedrals", 351), + ]) + def test_connection_from_residues_not_outside(self, tpr, typename, + n_atoms): + ag = tpr.residues[:10] + cxns = ag.get_connections(typename, outside=False) + assert len(cxns) == n_atoms + indices = np.ravel(cxns.to_indices()) + assert np.all(np.in1d(indices, ag.atoms.indices)) + + @pytest.mark.parametrize("typename, n_atoms", [ + ("bonds", 158), + ("angles", 294), + ("dihedrals", 360), + ]) + def test_connection_from_residues_outside(self, tpr, typename, n_atoms): + ag = tpr.residues[:10] + cxns = ag.get_connections(typename, outside=True) + assert len(cxns) == n_atoms + indices = np.ravel(cxns.to_indices()) + assert not np.all(np.in1d(indices, ag.atoms.indices)) + + def test_invalid_connection_error(self, tpr): + with pytest.raises(AttributeError, match="does not contain"): + ag = tpr.residues[:10] + ag.get_connections("ureybradleys") + + @pytest.mark.parametrize("outside", [True, False]) + def test_get_empty_group(self, tpr, outside): + imp = tpr.impropers + ag = tpr.residues[:10] + cxns = ag.get_connections("impropers", outside=outside) + assert len(imp) == 0 + assert len(cxns) == 0 + + +@pytest.mark.parametrize("typename, n_inside", [ + ("intra_bonds", 9), + ("intra_angles", 15), + ("intra_dihedrals", 12), +]) +def test_topologygroup_gets_connections_inside(tpr, typename, n_inside): + ag = tpr.atoms[:10] + cxns = getattr(ag, typename) + assert len(cxns) == n_inside + indices = np.ravel(cxns.to_indices()) + assert np.all(np.in1d(indices, ag.indices)) + + +@pytest.mark.parametrize("typename, n_outside", [ + ("bonds", 13), + ("angles", 27), + ("dihedrals", 38), +]) +def test_topologygroup_gets_connections_outside(tpr, typename, n_outside): + ag = tpr.atoms[:10] + cxns = getattr(ag, typename) + assert len(cxns) == n_outside