Skip to content

Commit

Permalink
removed collect_results for efficiency
Browse files Browse the repository at this point in the history
  • Loading branch information
rsexton2 committed Feb 9, 2025
1 parent 23a732b commit b9585c2
Showing 1 changed file with 21 additions and 40 deletions.
61 changes: 21 additions & 40 deletions basicrta/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,36 +28,38 @@ class ProcessProtein(object):
:type cutoff: float
"""

def __init__(self, niter, prot, cutoff, gskip):
def __init__(self, niter, prot, cutoff, gskip, taus=None, bars=None):
self.residues = {}
self.niter = niter
self.prot = prot
self.cutoff = cutoff
self.gskip = gskip
self.taus = taus
self.bars = bars

def __getitem__(self, item):
return getattr(self, item)

def _single_residue(self, adir, process=False):
if os.path.exists(f'{adir}/gibbs_{self.niter}.pkl'):
result = f'{adir}/gibbs_{self.niter}.pkl'
try:
result = f'{adir}/gibbs_{self.niter}.pkl'
g = Gibbs().load(result)
if process:
g.gskip = self.gskip
g.process_gibbs()
tau = g.estimate_tau()
except ValueError:
result = None
tau = [0, 0, 0]
else:
print(f'results for {adir} do not exist')
result = None
return result
tau = [0, 0, 0]

def _single_result(self, adir):
result = self._single_residue(adir)
residue = adir.split('/')[-1]
self.residues[residue] = result

setattr(self.residues, 'f{residue}', result)
setattr(self.residues.residue, 'tau', tau)

def reprocess(self, nproc=1):
"""Rerun processing and clustering on :class:`Gibbs` data.
Expand All @@ -83,30 +85,6 @@ def reprocess(self, nproc=1):
except KeyboardInterrupt:
pass

def collect_results(self, nproc=1):
"""Collect names of results for each residue in the `basicrta-{cutoff}`
directory in a dictionary stored in :attr:`ProcessProtein.results`.
"""
from glob import glob

dirs = np.array(glob(f'basicrta-{self.cutoff}/?[0-9]*'))
sorted_inds = (np.array([int(adir.split('/')[-1][1:]) for adir in dirs])
.argsort())
dirs = dirs[sorted_inds]
with (Pool(nproc, initializer=tqdm.set_lock,
initargs=(Lock(),)) as p):
try:
for _ in tqdm(p.imap(self._single_result, dirs),
total=len(dirs), position=0,
desc='overall progress'):
pass
#for adir in tqdm(dirs, desc='collecting results'):
# result = self._single_residue(adir)
# residue = adir.split('/')[-1]
# self.residues[residue] = result
except KeyboardInterrupt:
pass

def get_taus(self):
r"""Get :math:`\tau` and 95\% confidence interval bounds for the slowest
process for each residue.
Expand All @@ -120,17 +98,20 @@ def get_taus(self):

taus = []
for res in tqdm(self.residues, total=len(self.residues)):
if self.residues[res] is None:
result = [0, 0, 0]
else:
try:
gib = Gibbs().load(self.residues[res])
result = gib.estimate_tau()
except AttributeError:
result = [0, 0, 0]
taus.append(result)
taus.append(res.tau)
#if self.residues[res] is None:
# result = [0, 0, 0]
#else:
# try:
# gib = Gibbs().load(self.residues[res])
# result = gib.estimate_tau()
# except AttributeError:
# result = [0, 0, 0]
#taus.append(result)
taus = np.array(taus)
bars = get_bars(taus)
setattr(self, 'taus', taus[:, 1])
setattr(self, 'bars', bars)
return taus[:, 1], bars

def write_data(self, fname='tausout'):
Expand Down

0 comments on commit b9585c2

Please sign in to comment.