From 36b5e954e7867e4afb1f26b4e38027b2ad523ec1 Mon Sep 17 00:00:00 2001 From: Ian Slagle Date: Mon, 12 Feb 2024 22:43:18 -0500 Subject: [PATCH] Fix formatting --- matdeeplearn/trainers/property_trainer.py | 52 +++++++++++------------ 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/matdeeplearn/trainers/property_trainer.py b/matdeeplearn/trainers/property_trainer.py index 509b27d3..374dabf9 100644 --- a/matdeeplearn/trainers/property_trainer.py +++ b/matdeeplearn/trainers/property_trainer.py @@ -236,13 +236,13 @@ def validate(self, split="val"): def predict(self, loader, split, results_dir="train_results", write_output=True, labels=True, vmap_pred = False): for mod in self.model: mod.eval() - if vmap_pred: - params, buffers = stack_module_state(self.model) + if vmap_pred: + params, buffers = stack_module_state(self.model) base_model = copy.deepcopy(self.model[0]) base_model = base_model.to('meta') - # TODO: Allow to work with pos_grad and cell_grad - def fmodel(params, buffers, x): - return functional_call(base_model, (params, buffers), (x,))['output'] + # TODO: Allow to work with pos_grad and cell_grad + def fmodel(params, buffers, x): + return functional_call(base_model, (params, buffers), (x,))['output'] # assert isinstance(loader, torch.utils.data.dataloader.DataLoader) @@ -264,29 +264,29 @@ def fmodel(params, buffers, x): loader_iter = iter(loader) for i in range(0, len(loader_iter)): batch = next(loader_iter).to(self.rank) - out = {} - out_stack={} - if not vmap_pred: - out_list = self._forward([batch]) - for key in out_list[0].keys(): - temp = [o[key] for o in out_list] - if temp[0] is not None: - out_stack[key] = torch.stack(temp) - out[key] = torch.mean(out_stack[key], dim=0) - out[key+"_std"] = torch.std(out_stack[key], dim=0) - else: - out[key] = None - out[key+"_std"] = None - batch_p = [o["output"].data.cpu().numpy() for o in out_list] + out = {} + out_stack={} + if not vmap_pred: + out_list = self._forward([batch]) + for key in out_list[0].keys(): + temp = [o[key] for o in out_list] + if temp[0] is not None: + out_stack[key] = torch.stack(temp) + out[key] = torch.mean(out_stack[key], dim=0) + out[key+"_std"] = torch.std(out_stack[key], dim=0) + else: + out[key] = None + out[key+"_std"] = None + batch_p = [o["output"].data.cpu().numpy() for o in out_list] - else: - out_list = vmap(fmodel, in_dims = (0, 0, None))(self.params, self.buffers, batch) - out["output"] = torch.mean(out_list, dim = 0) - out["output_std"] = torch.std(out_list, dim = 0) - batch_p = [out_list[o].cpu().numpy() for o in range(out_list.size()[0])] - + else: + out_list = vmap(fmodel, in_dims = (0, 0, None))(self.params, self.buffers, batch) + out["output"] = torch.mean(out_list, dim = 0) + out["output_std"] = torch.std(out_list, dim = 0) + batch_p = [out_list[o].cpu().numpy() for o in range(out_list.size()[0])] + batch_p_mean = out["output"].cpu().numpy() - batch_stds = out["output_std"].cpu().numpy() + batch_stds = out["output_std"].cpu().numpy() batch_ids = batch.structure_id if labels == True: