Skip to content

Commit

Permalink
AnalysisRunner
Browse files Browse the repository at this point in the history
  • Loading branch information
hainm committed Jun 18, 2024
1 parent cfdccc9 commit ab78b66
Showing 1 changed file with 55 additions and 60 deletions.
115 changes: 55 additions & 60 deletions pytraj/all_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,28 @@ class DatasetType(StrEnum):
XYMESH = 'xymesh'
MATRIX3x3 = 'matrix3x3'

class AnalysisRunner:
def __init__(self, analysis_class):
self.datasets = CpptrajDatasetList()
self.analysis = analysis_class()

def add_dataset(self, dataset_type, dataset_name, data):
if dataset_type == DatasetType.COORDS:
crdname = '_DEFAULTCRD_'
self.datasets.add(dataset_type.value, name=crdname)
self.datasets[0].top = data.top
for frame in data:
self.datasets[0].append(frame)
else:
self.datasets.add(dataset_type, dataset_name)
if dataset_type == DatasetType.XYMESH:
self.datasets[-1]._append_from_array(data.T)
else:
self.datasets[-1].data = np.asarray(data).astype('f8')

def run_analysis(self, command):
self.analysis(command, dslist=self.datasets)
return self.datasets


def _assert_mutable(trajiter):
Expand Down Expand Up @@ -1935,18 +1957,14 @@ def timecorr(vec0, vec1, order=2, tstep=1., tcorr=10000., norm=False, dtype='nda
norm : bool, default False
dtype : str, default 'ndarray'
"""
time_correlation_action = c_analysis.Analysis_Timecorr()
action_datasets = CpptrajDatasetList()

action_datasets.add(DatasetType.VECTOR, "_vec0")
action_datasets.add(DatasetType.VECTOR, "_vec1")
action_datasets[0].data = np.asarray(vec0).astype('f8')
action_datasets[1].data = np.asarray(vec1).astype('f8')
runner = AnalysisRunner(c_analysis.Analysis_Timecorr)
runner.add_dataset(DatasetType.VECTOR, "_vec0", vec0)
runner.add_dataset(DatasetType.VECTOR, "_vec1", vec1)

command = f"vec1 _vec0 vec2 _vec1 order {order} tstep {tstep} tcorr {tcorr} {'norm' if norm else ''}"
time_correlation_action(command, dslist=action_datasets)
runner.run_analysis(command)

return get_data_from_dtype(action_datasets[2:], dtype=dtype)
return get_data_from_dtype(runner.datasets[2:], dtype=dtype)


@super_dispatch()
Expand Down Expand Up @@ -2079,17 +2097,13 @@ def crank(data0, data1, mode='distance', dtype='ndarray'):
-----
Same as `crank` in cpptraj
"""
action_datasets = CpptrajDatasetList()
action_datasets.add(DatasetType.DOUBLE, "d0")
action_datasets.add(DatasetType.DOUBLE, "d1")

action_datasets[0].data = np.asarray(data0)
action_datasets[1].data = np.asarray(data1)
runner = AnalysisRunner(c_analysis.Analysis_CrankShaft)
runner.add_dataset(DatasetType.DOUBLE, "d0", data0)
runner.add_dataset(DatasetType.DOUBLE, "d1", data1)

act = c_analysis.Analysis_CrankShaft()
command = ' '.join((mode, 'd0', 'd1'))
with capture_stdout() as (out, err):
act(command, dslist=action_datasets)
runner.run_analysis(command)
return out.read()


Expand Down Expand Up @@ -2768,14 +2782,12 @@ def lowestcurve(data, points=10, step=0.2):

data = np.asarray(data).T

action_datasets = CpptrajDatasetList()
action_datasets.add(DatasetType.XYMESH, 'mydata')
action_datasets[0]._append_from_array(data)
runner = AnalysisRunner(c_analysis.Analysis_LowestCurve)
runner.add_dataset(DatasetType.XYMESH, 'mydata', data)

analysis_lowest_curve = c_analysis.Analysis_LowestCurve()
analysis_lowest_curve(command, dslist=action_datasets)
runner.run_analysis(command)

return np.array([action_datasets[-1]._xcrd(), np.array(action_datasets[-1].values)])
return np.array([runner.datasets[-1]._xcrd(), np.array(runner.datasets[-1].values)])


def acorr(data, dtype='ndarray', option=''):
Expand All @@ -2793,15 +2805,13 @@ def acorr(data, dtype='ndarray', option=''):
-----
Same as `autocorr` in cpptraj
"""
action_datasets = CpptrajDatasetList()
action_datasets.add(DatasetType.DOUBLE, "d0")

action_datasets[0].data = np.asarray(data)
runner = AnalysisRunner(c_analysis.Analysis_AutoCorr)
runner.add_dataset(DatasetType.DOUBLE, "d0", np.asarray(data))

act = c_analysis.Analysis_AutoCorr()
command = "d0 out _tmp.out"
act(command, dslist=action_datasets)
return get_data_from_dtype(action_datasets[1:], dtype=dtype)
runner.run_analysis(command)

return get_data_from_dtype(runner.datasets[1:], dtype=dtype)


auto_correlation_function = acorr
Expand All @@ -2820,17 +2830,14 @@ def xcorr(data0, data1, dtype='ndarray'):
-----
Same as `corr` in cpptraj
"""
runner = AnalysisRunner(c_analysis.Analysis_Corr)
runner.add_dataset(DatasetType.DOUBLE, "d0", np.asarray(data0))
runner.add_dataset(DatasetType.DOUBLE, "d1", np.asarray(data1))

action_datasets = CpptrajDatasetList()
action_datasets.add(DatasetType.DOUBLE, "d0")
action_datasets.add(DatasetType.DOUBLE, "d1")

action_datasets[0].data = np.asarray(data0)
action_datasets[1].data = np.asarray(data1)
command = "d0 d1 out _tmp.out"
runner.run_analysis(command)

act = c_analysis.Analysis_Corr()
act("d0 d1 out _tmp.out", dslist=action_datasets)
return get_data_from_dtype(action_datasets[2:3], dtype=dtype)
return get_data_from_dtype(runner.datasets[2:3], dtype=dtype)


cross_correlation_function = xcorr
Expand Down Expand Up @@ -2900,7 +2907,6 @@ def strip(obj, mask):

def rotdif(matrices, command):
"""
Parameters
----------
matrices : 3D array, shape=(n_frames, 3, 3)
Expand All @@ -2917,18 +2923,15 @@ def rotdif(matrices, command):
-----
This method interface will be changed.
"""
# TODO: update this method if cpptraj dumps data to CpptrajDatasetList
matrices = np.asarray(matrices)

action_datasets = CpptrajDatasetList()
action_datasets.add(DatasetType.MATRIX3x3, name='myR0')
action_datasets[-1].aspect = "RM"
action_datasets[-1]._append_from_array(matrices)
runner = AnalysisRunner(c_analysis.Analysis_Rotdif)
runner.add_dataset(DatasetType.MATRIX3x3, "myR0", matrices)

command = 'rmatrix myR0[RM] ' + command
act = c_analysis.Analysis_Rotdif()
with capture_stdout() as (out, _):
act(command, dslist=action_datasets)
runner.run_analysis(command)

return out.read()


Expand Down Expand Up @@ -2964,19 +2967,11 @@ def wavelet(traj, command):
>>> command = ' '.join((c0, c1))
>>> wavelet_dict = pt.wavelet(traj, command)
"""

action_datasets = CpptrajDatasetList()
crdname = '_DEFAULTCRD_'
action_datasets.add(DatasetType.COORDS.value, name=crdname)
action_datasets[0].top = traj.top

for frame in traj:
action_datasets[0].append(frame)

act = c_analysis.Analysis_Wavelet()
act(command, dslist=action_datasets)
action_datasets.remove_set(action_datasets[crdname])
return get_data_from_dtype(action_datasets, dtype='dict')
runner = AnalysisRunner(c_analysis.Analysis_Wavelet)
runner.add_dataset(DatasetType.COORDS, "_DEFAULTCRD_", traj)
runner.run_analysis(command)
runner.datasets.remove_set(runner.datasets["_DEFAULTCRD_"])
return get_data_from_dtype(runner.datasets, dtype='dict')


def atom_map(traj, ref, rmsfit=False):
Expand Down

0 comments on commit ab78b66

Please sign in to comment.