Skip to content

Commit

Permalink
Update cgcnn_inference.py
Browse files Browse the repository at this point in the history
  • Loading branch information
williamyxl authored Apr 10, 2024
1 parent 50694e3 commit 5833b4b
Showing 1 changed file with 22 additions and 20 deletions.
42 changes: 22 additions & 20 deletions mofa/scoring/cgcnn_inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import ase
from ase.io import read
from pymatgen.io.ase import AseAtomsAdaptor
from train.dist_utils import get_local_rank, init_distributed
import torch.nn as nn
Expand All @@ -12,7 +11,6 @@
from typing import Union
from pathlib import Path
from models import cgcnn, CrystalGraphConvNet
import os
import abc
import json
import torch
Expand All @@ -23,6 +21,14 @@

_atom_init_dir = Path(__file__).parent / "files"
_cgcnn_models_dir = (Path(__file__).parent / ".." / ".." / "models" / "cgcnn-hmof-0.1bar-300k").resolve()
BACKBONES = {
"cgcnn": cgcnn.CrystalGraphConvNet
}
BACKBONE_KWARGS = {
"cgcnn": dict(orig_atom_fea_len=92, nbr_fea_len=41,
atom_fea_len=64, n_conv=3, h_fea_len=128, n_h=1,
classification=False, learnable=False, explain=False)
}


class Opt:
Expand Down Expand Up @@ -310,12 +316,14 @@ def find_class(self, module, name):
0) > 0,
'batch_size': opt.batch_size}
if self.opt.dataset not in ["cifdata"]:
self.ds_train, self.ds_val, self.ds_test = torch.utils.data.random_split(full_dataset, _get_split_sizes(self.opt.train_frac, full_dataset),
self.ds_train, self.ds_val, self.ds_test = torch.utils.data.random_split(full_dataset,
_get_split_sizes(self.opt.train_frac, full_dataset),
generator=torch.Generator().manual_seed(0))
else:

self.ds_train, self.ds_val, self.ds_test = torch.utils.data.random_split(full_dataset, _get_split_sizes(self.opt.train_frac, full_dataset),
generator=torch.Generator().manual_seed(0))
self.ds_train, self.ds_val, self.ds_test = torch.utils.data.random_split(full_dataset,
_get_split_sizes(self.opt.train_frac, full_dataset),
generator=torch.Generator().manual_seed(0))

self._mean = None
self._std = None
Expand Down Expand Up @@ -380,12 +388,16 @@ def infer_for_crystal(opt, dataloader, model):
y = data_batch.y

if opt.ensemble_names is not None:
df_list = df_list + [pd.DataFrame(data=np.concatenate([np.array(data_names).reshape(-1, 1), energies.detach().cpu().numpy().reshape(-1, 1),
stds.detach().cpu().numpy().reshape(-1, 1), y.detach().cpu().numpy().reshape(-1, 1)], axis=1),
df_list = df_list + [pd.DataFrame(data=np.concatenate([np.array(data_names).reshape(-1, 1),
energies.detach().cpu().numpy().reshape(-1, 1),
stds.detach().cpu().numpy().reshape(-1, 1),
y.detach().cpu().numpy().reshape(-1, 1)], axis=1),
columns=["name", "pred", "std", "real"])]
else:
df_list = df_list + [pd.DataFrame(data=np.concatenate([np.array(data_names).reshape(-1, 1), energies.detach().cpu().numpy().reshape(-1, 1),
y.detach().cpu().numpy().reshape(-1, 1)], axis=1), columns=["name", "pred", "real"])]
df_list = df_list + [pd.DataFrame(data=np.concatenate([np.array(data_names).reshape(-1, 1),
energies.detach().cpu().numpy().reshape(-1, 1),
y.detach().cpu().numpy().reshape(-1, 1)], axis=1),
columns=["name", "pred", "real"])]

df = pd.concat(df_list, axis=0, ignore_index=True)

Expand Down Expand Up @@ -466,6 +478,7 @@ def infer(ase_mofs, mof_names, opt=None):
opt.name = name
model = call_model(opt, mean, std)
models.append(model)

def model(*inp):
return (torch.cat([models[0](*inp), models[1](*inp), models[2](*inp)], dim=-1).mean(dim=-1),
torch.cat([models[0](*inp), models[1](*inp), models[2](*inp)], dim=-1).std(dim=-1))
Expand All @@ -474,17 +487,6 @@ def model(*inp):
return df


BACKBONES = {
"cgcnn": cgcnn.CrystalGraphConvNet
}

BACKBONE_KWARGS = {
"cgcnn": dict(orig_atom_fea_len=92, nbr_fea_len=41,
atom_fea_len=64, n_conv=3, h_fea_len=128, n_h=1,
classification=False, learnable=False, explain=False)
}


def run_cgcnn_pred(mofs: list[ase.Atoms], names: list[str], opt: Opt):
df = infer(mofs, names, opt)
return df
Expand Down

0 comments on commit 5833b4b

Please sign in to comment.