Skip to content

Commit

Permalink
Clean the NeighborSearch and allow uniformed output (#2907)
Browse files Browse the repository at this point in the history
* test edit

* fix AttributeError

* Update test and docs

* fix test

* Update package/MDAnalysis/lib/NeighborSearch.py

Co-authored-by: Oliver Beckstein <orbeckst@gmail.com>

* make it compile PEP8

Co-authored-by: zhiyiwu <zhiyi.wu@gtc.ox.ac.uk>
Co-authored-by: Oliver Beckstein <orbeckst@gmail.com>
  • Loading branch information
3 people authored Aug 14, 2020
1 parent e091633 commit a3d10b0
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 27 deletions.
3 changes: 2 additions & 1 deletion package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ The rules for this file:
------------------------------------------------------------------------------
??/??/?? tylerjereddy, richardjgowers, IAlibay, hmacdope, orbeckst, cbouy,
lilyminium, daveminh, jbarnoud, yuxuanzhuang, VOD555, ianmkenney,
calcraven
calcraven, xiki-tempula

* 2.0.0

Expand Down Expand Up @@ -95,6 +95,7 @@ Changes
* Changes the minimal NumPy version to 1.16.0 (Issue #2827, PR #2831)
* Sets the minimal RDKit version for CI to 2020.03.1 (Issue #2827, PR #2831)
* Removes deprecated waterdynamics.HydrogenBondLifetimes (PR #2842)
* Make NeighborSearch return empty atomgroup, residue, segments instead of list (Issue #2892, PR #2907)

Deprecations

Expand Down
57 changes: 31 additions & 26 deletions package/MDAnalysis/lib/NeighborSearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,16 @@
from MDAnalysis.lib.distances import capped_distance
from MDAnalysis.lib.util import unique_int_1d

from MDAnalysis.core.groups import AtomGroup, Atom


class AtomNeighborSearch(object):
"""This class can be used to find all atoms/residues/segments within the
radius of a given query position.
For the neighbor search, this class uses the BioPython KDTree and its
wrapper PeriodicKDTree for non-periodic and periodic systems, respectively.
For the neighbor search, this class is a wrapper around
:class:`~MDAnalysis.lib.distances.capped_distance`.
"""

def __init__(self, atom_group, box=None, bucket_size=10):
def __init__(self, atom_group, box=None):
"""
Parameters
Expand All @@ -55,16 +53,10 @@ def __init__(self, atom_group, box=None, bucket_size=10):
:attr:`MDAnalysis.trajectory.base.Timestep.dimensions` when
periodic boundary conditions should be taken into account for
the calculation of contacts.
bucket_size : int
Number of entries in leafs of the KDTree. If you suffer poor
performance you can play around with this number. Increasing the
`bucket_size` will speed up the construction of the KDTree but
slow down the search.
"""
self.atom_group = atom_group
self._u = atom_group.universe
self._box = box
#self.kdtree = PeriodicKDTree(box=box, leafsize=bucket_size)

def search(self, atoms, radius, level='A'):
"""
Expand All @@ -73,21 +65,37 @@ def search(self, atoms, radius, level='A'):
Parameters
----------
atoms : AtomGroup, MDAnalysis.core.groups.Atom
list of atoms
atoms : AtomGroup, MDAnalysis.core.groups.AtomGroup
AtomGroup object
radius : float
Radius for search in Angstrom.
level : str
char (A, R, S). Return atoms(A), residues(R) or segments(S) within
*radius* of *atoms*.
Returns
-------
AtomGroup : :class:`~MDAnalysis.core.groups.AtomGroup`
When ``level='A'``, AtomGroup is being returned.
ResidueGroup : :class:`~MDAnalysis.core.groups.ResidueGroup`
When ``level='R'``, ResidueGroup is being returned.
SegmentGroup : :class:`~MDAnalysis.core.groups.SegmentGroup`
When ``level='S'``, SegmentGroup is being returned.
.. versionchanged:: 2.0.0
Now returns :class:`AtomGroup` (when empty this is now an empty
:class:`AtomGroup` instead of an empty list), :class:`ResidueGroup`,
or a :class:`SegmentGroup`
"""
unique_idx = []
if isinstance(atoms, Atom):
positions = atoms.position.reshape(1, 3)
else:
positions = atoms.positions

pairs = capped_distance(positions, self.atom_group.positions,
try:
# For atom groups, take the positions attribute
position = atoms.positions
except AttributeError:
# For atom, take the position attribute
position = atoms.position
pairs = capped_distance(position, self.atom_group.positions,
radius, box=self._box, return_distances=False)

if pairs.size > 0:
Expand All @@ -106,15 +114,12 @@ def _index2level(self, indices, level):
char (A, R, S). Return atoms(A), residues(R) or segments(S) within
*radius* of *atoms*.
"""
n_atom_list = self.atom_group[indices]
atomgroup = self.atom_group[indices]
if level == 'A':
if not n_atom_list:
return []
else:
return n_atom_list
return atomgroup
elif level == 'R':
return list({a.residue for a in n_atom_list})
return atomgroup.residues
elif level == 'S':
return list(set([a.segment for a in n_atom_list]))
return atomgroup.segments
else:
raise NotImplementedError('{0}: level not implemented'.format(level))
11 changes: 11 additions & 0 deletions testsuite/MDAnalysisTests/lib/test_neighborsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,14 @@ def test_search(universe):
ns_res = ns.search(universe.atoms[20], 20)
pns_res = pns.search(universe.atoms[20], 20)
assert_equal(ns_res, pns_res)


def test_zero(universe):
"""Check if empty atomgroup, residue, segments are returned"""
ns = NeighborSearch.AtomNeighborSearch(universe.atoms[:10])
ns_res = ns.search(universe.atoms[20], 0.1, level='A')
assert ns_res == universe.atoms[[]]
ns_res = ns.search(universe.atoms[20], 0.1, level='R')
assert ns_res == universe.atoms[[]].residues
ns_res = ns.search(universe.atoms[20], 0.1, level='S')
assert ns_res == universe.atoms[[]].segments

0 comments on commit a3d10b0

Please sign in to comment.