Skip to content

Commit

Permalink
Fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
itsgt authored Feb 13, 2024
1 parent 149e7bf commit 36b5e95
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions matdeeplearn/trainers/property_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,13 @@ def validate(self, split="val"):
def predict(self, loader, split, results_dir="train_results", write_output=True, labels=True, vmap_pred = False):
for mod in self.model:
mod.eval()
if vmap_pred:
params, buffers = stack_module_state(self.model)
if vmap_pred:
params, buffers = stack_module_state(self.model)
base_model = copy.deepcopy(self.model[0])
base_model = base_model.to('meta')
# TODO: Allow to work with pos_grad and cell_grad
def fmodel(params, buffers, x):
return functional_call(base_model, (params, buffers), (x,))['output']
# TODO: Allow to work with pos_grad and cell_grad
def fmodel(params, buffers, x):
return functional_call(base_model, (params, buffers), (x,))['output']

# assert isinstance(loader, torch.utils.data.dataloader.DataLoader)

Expand All @@ -264,29 +264,29 @@ def fmodel(params, buffers, x):
loader_iter = iter(loader)
for i in range(0, len(loader_iter)):
batch = next(loader_iter).to(self.rank)
out = {}
out_stack={}
if not vmap_pred:
out_list = self._forward([batch])
for key in out_list[0].keys():
temp = [o[key] for o in out_list]
if temp[0] is not None:
out_stack[key] = torch.stack(temp)
out[key] = torch.mean(out_stack[key], dim=0)
out[key+"_std"] = torch.std(out_stack[key], dim=0)
else:
out[key] = None
out[key+"_std"] = None
batch_p = [o["output"].data.cpu().numpy() for o in out_list]
out = {}
out_stack={}
if not vmap_pred:
out_list = self._forward([batch])
for key in out_list[0].keys():
temp = [o[key] for o in out_list]
if temp[0] is not None:
out_stack[key] = torch.stack(temp)
out[key] = torch.mean(out_stack[key], dim=0)
out[key+"_std"] = torch.std(out_stack[key], dim=0)
else:
out[key] = None
out[key+"_std"] = None
batch_p = [o["output"].data.cpu().numpy() for o in out_list]

else:
out_list = vmap(fmodel, in_dims = (0, 0, None))(self.params, self.buffers, batch)
out["output"] = torch.mean(out_list, dim = 0)
out["output_std"] = torch.std(out_list, dim = 0)
batch_p = [out_list[o].cpu().numpy() for o in range(out_list.size()[0])]
else:
out_list = vmap(fmodel, in_dims = (0, 0, None))(self.params, self.buffers, batch)
out["output"] = torch.mean(out_list, dim = 0)
out["output_std"] = torch.std(out_list, dim = 0)
batch_p = [out_list[o].cpu().numpy() for o in range(out_list.size()[0])]
batch_p_mean = out["output"].cpu().numpy()
batch_stds = out["output_std"].cpu().numpy()
batch_stds = out["output_std"].cpu().numpy()
batch_ids = batch.structure_id

if labels == True:
Expand Down

0 comments on commit 36b5e95

Please sign in to comment.