diff --git a/mofa/scoring/cgcnn_inference.py b/mofa/scoring/cgcnn_inference.py index f3ff8a9c..1a0cede8 100644 --- a/mofa/scoring/cgcnn_inference.py +++ b/mofa/scoring/cgcnn_inference.py @@ -19,6 +19,8 @@ import numpy as np +_atom_init_dir = Path(__file__).parent / "files" + class Opt: def __init__(self, **entries): self.__dict__.update(entries) @@ -521,7 +523,7 @@ def run_cgcnn_pred_wrapper_serial(mofs: list[ase.Atoms], run_name="some_random_s "data_norm": False, "dataset": 'cifdata', "train_frac": 1, - "data_dir_crystal": 'cme575-sp24-cgcnn/train_data', + "data_dir_crystal": _atom_init_dir, "pin_memory": False, "save_to_pickle": None, "num_oversample": 0,