diff --git a/tests/scoring/test_cgcnn_inference.py b/tests/scoring/test_cgcnn_inference.py new file mode 100644 index 00000000..f3cb4bb6 --- /dev/null +++ b/tests/scoring/test_cgcnn_inference.py @@ -0,0 +1,12 @@ +from pytest import mark +from mofa.scoring.raspa import RASPARunner +from ase.io import read +from pathlib import Path + + +@mark.parametrize('extxyz_name', ['test-zn']) +def test_run_cgcnn_pred_wrapper_serial(extxyz_name, cif_dir, tmpdir): + my_ase_mofs = [read(Path(datadir) / x, format="cif") for x in os.listdir(datadir) if x.endswith(".cif")] + pred, std = run_cgcnn_pred_wrapper_serial(my_ase_mofs, manual_batch_size=7, ncpus_to_load_data=1) + assert len(pred) == len(my_ase_mofs) + assert len(std) == len(my_ase_mofs)