diff --git a/pytraj/all_actions.py b/pytraj/all_actions.py index 9e45ab133..32d046e7e 100644 --- a/pytraj/all_actions.py +++ b/pytraj/all_actions.py @@ -85,6 +85,28 @@ class DatasetType(StrEnum): XYMESH = 'xymesh' MATRIX3x3 = 'matrix3x3' +class AnalysisRunner: + def __init__(self, analysis_class): + self.datasets = CpptrajDatasetList() + self.analysis = analysis_class() + + def add_dataset(self, dataset_type, dataset_name, data): + if dataset_type == DatasetType.COORDS: + crdname = '_DEFAULTCRD_' + self.datasets.add(dataset_type.value, name=crdname) + self.datasets[0].top = data.top + for frame in data: + self.datasets[0].append(frame) + else: + self.datasets.add(dataset_type, dataset_name) + if dataset_type == DatasetType.XYMESH: + self.datasets[-1]._append_from_array(data.T) + else: + self.datasets[-1].data = np.asarray(data).astype('f8') + + def run_analysis(self, command): + self.analysis(command, dslist=self.datasets) + return self.datasets def _assert_mutable(trajiter): @@ -1935,18 +1957,14 @@ def timecorr(vec0, vec1, order=2, tstep=1., tcorr=10000., norm=False, dtype='nda norm : bool, default False dtype : str, default 'ndarray' """ - time_correlation_action = c_analysis.Analysis_Timecorr() - action_datasets = CpptrajDatasetList() - - action_datasets.add(DatasetType.VECTOR, "_vec0") - action_datasets.add(DatasetType.VECTOR, "_vec1") - action_datasets[0].data = np.asarray(vec0).astype('f8') - action_datasets[1].data = np.asarray(vec1).astype('f8') + runner = AnalysisRunner(c_analysis.Analysis_Timecorr) + runner.add_dataset(DatasetType.VECTOR, "_vec0", vec0) + runner.add_dataset(DatasetType.VECTOR, "_vec1", vec1) command = f"vec1 _vec0 vec2 _vec1 order {order} tstep {tstep} tcorr {tcorr} {'norm' if norm else ''}" - time_correlation_action(command, dslist=action_datasets) + runner.run_analysis(command) - return get_data_from_dtype(action_datasets[2:], dtype=dtype) + return get_data_from_dtype(runner.datasets[2:], dtype=dtype) @super_dispatch() @@ -2079,17 +2097,13 @@ def crank(data0, data1, mode='distance', dtype='ndarray'): ----- Same as `crank` in cpptraj """ - action_datasets = CpptrajDatasetList() - action_datasets.add(DatasetType.DOUBLE, "d0") - action_datasets.add(DatasetType.DOUBLE, "d1") - - action_datasets[0].data = np.asarray(data0) - action_datasets[1].data = np.asarray(data1) + runner = AnalysisRunner(c_analysis.Analysis_CrankShaft) + runner.add_dataset(DatasetType.DOUBLE, "d0", data0) + runner.add_dataset(DatasetType.DOUBLE, "d1", data1) - act = c_analysis.Analysis_CrankShaft() command = ' '.join((mode, 'd0', 'd1')) with capture_stdout() as (out, err): - act(command, dslist=action_datasets) + runner.run_analysis(command) return out.read() @@ -2768,14 +2782,12 @@ def lowestcurve(data, points=10, step=0.2): data = np.asarray(data).T - action_datasets = CpptrajDatasetList() - action_datasets.add(DatasetType.XYMESH, 'mydata') - action_datasets[0]._append_from_array(data) + runner = AnalysisRunner(c_analysis.Analysis_LowestCurve) + runner.add_dataset(DatasetType.XYMESH, 'mydata', data) - analysis_lowest_curve = c_analysis.Analysis_LowestCurve() - analysis_lowest_curve(command, dslist=action_datasets) + runner.run_analysis(command) - return np.array([action_datasets[-1]._xcrd(), np.array(action_datasets[-1].values)]) + return np.array([runner.datasets[-1]._xcrd(), np.array(runner.datasets[-1].values)]) def acorr(data, dtype='ndarray', option=''): @@ -2793,15 +2805,13 @@ def acorr(data, dtype='ndarray', option=''): ----- Same as `autocorr` in cpptraj """ - action_datasets = CpptrajDatasetList() - action_datasets.add(DatasetType.DOUBLE, "d0") - - action_datasets[0].data = np.asarray(data) + runner = AnalysisRunner(c_analysis.Analysis_AutoCorr) + runner.add_dataset(DatasetType.DOUBLE, "d0", np.asarray(data)) - act = c_analysis.Analysis_AutoCorr() command = "d0 out _tmp.out" - act(command, dslist=action_datasets) - return get_data_from_dtype(action_datasets[1:], dtype=dtype) + runner.run_analysis(command) + + return get_data_from_dtype(runner.datasets[1:], dtype=dtype) auto_correlation_function = acorr @@ -2820,17 +2830,14 @@ def xcorr(data0, data1, dtype='ndarray'): ----- Same as `corr` in cpptraj """ + runner = AnalysisRunner(c_analysis.Analysis_Corr) + runner.add_dataset(DatasetType.DOUBLE, "d0", np.asarray(data0)) + runner.add_dataset(DatasetType.DOUBLE, "d1", np.asarray(data1)) - action_datasets = CpptrajDatasetList() - action_datasets.add(DatasetType.DOUBLE, "d0") - action_datasets.add(DatasetType.DOUBLE, "d1") - - action_datasets[0].data = np.asarray(data0) - action_datasets[1].data = np.asarray(data1) + command = "d0 d1 out _tmp.out" + runner.run_analysis(command) - act = c_analysis.Analysis_Corr() - act("d0 d1 out _tmp.out", dslist=action_datasets) - return get_data_from_dtype(action_datasets[2:3], dtype=dtype) + return get_data_from_dtype(runner.datasets[2:3], dtype=dtype) cross_correlation_function = xcorr @@ -2900,7 +2907,6 @@ def strip(obj, mask): def rotdif(matrices, command): """ - Parameters ---------- matrices : 3D array, shape=(n_frames, 3, 3) @@ -2917,18 +2923,15 @@ def rotdif(matrices, command): ----- This method interface will be changed. """ - # TODO: update this method if cpptraj dumps data to CpptrajDatasetList matrices = np.asarray(matrices) - action_datasets = CpptrajDatasetList() - action_datasets.add(DatasetType.MATRIX3x3, name='myR0') - action_datasets[-1].aspect = "RM" - action_datasets[-1]._append_from_array(matrices) + runner = AnalysisRunner(c_analysis.Analysis_Rotdif) + runner.add_dataset(DatasetType.MATRIX3x3, "myR0", matrices) command = 'rmatrix myR0[RM] ' + command - act = c_analysis.Analysis_Rotdif() with capture_stdout() as (out, _): - act(command, dslist=action_datasets) + runner.run_analysis(command) + return out.read() @@ -2964,19 +2967,11 @@ def wavelet(traj, command): >>> command = ' '.join((c0, c1)) >>> wavelet_dict = pt.wavelet(traj, command) """ - - action_datasets = CpptrajDatasetList() - crdname = '_DEFAULTCRD_' - action_datasets.add(DatasetType.COORDS.value, name=crdname) - action_datasets[0].top = traj.top - - for frame in traj: - action_datasets[0].append(frame) - - act = c_analysis.Analysis_Wavelet() - act(command, dslist=action_datasets) - action_datasets.remove_set(action_datasets[crdname]) - return get_data_from_dtype(action_datasets, dtype='dict') + runner = AnalysisRunner(c_analysis.Analysis_Wavelet) + runner.add_dataset(DatasetType.COORDS, "_DEFAULTCRD_", traj) + runner.run_analysis(command) + runner.datasets.remove_set(runner.datasets["_DEFAULTCRD_"]) + return get_data_from_dtype(runner.datasets, dtype='dict') def atom_map(traj, ref, rmsfit=False):