Skip to content

Commit

Permalink
Update lightning code to 2.0 compliant
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed May 17, 2024
1 parent 9c8c366 commit 87aa438
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 7 deletions.
3 changes: 2 additions & 1 deletion mofa/difflinker_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_parallel_devices(devices: Any) -> Any:
@staticmethod
def auto_device_count() -> int:
# Return a value for auto-device selection when `Trainer(devices="auto")`
raise NotImplementedError()
return torch.xpu.device_count()

@staticmethod
def is_available() -> bool:
Expand Down Expand Up @@ -259,6 +259,7 @@ def main(
last_checkpoint = find_last_checkpoint(checkpoints_dir)
ddpm = DDPM.load_from_checkpoint(
last_checkpoint,
map_location=args.device,
strict=False,
data_path=args.data,
train_data_prefix=args.train_data_prefix,
Expand Down
2 changes: 1 addition & 1 deletion mofa/utils/difflinker_sample_and_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def generate_animation(ddpm, chain_batch, node_mask, n_mol):
@lru_cache(maxsize=1) # Keep only one model in memory
def load_model(path, device) -> DDPM:
"""Load the DDPM model from disk"""
return DDPM.load_from_checkpoint(path, torch_device=device).eval().to(device)
return DDPM.load_from_checkpoint(path, map_location=device).eval().to(device)


def main_run(templates: list[LigandTemplate],
Expand Down
5 changes: 4 additions & 1 deletion mofa/utils/src/egnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,10 @@ def get_edges(self, n_nodes, batch_size):
for j in range(n_nodes):
rows.append(i + batch_idx * n_nodes)
cols.append(j + batch_idx * n_nodes)
edges = [torch.LongTensor(rows).to(self.device), torch.LongTensor(cols).to(self.device)]
edges = [
torch.LongTensor(rows),
torch.LongTensor(cols)
]
edges_dic_b[batch_size] = edges
return edges
else:
Expand Down
13 changes: 9 additions & 4 deletions mofa/utils/src/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def __init__(
)
self.linker_size_sampler = DistributionNodes(LINKER_SIZE_DIST)

self.validation_step_outputs = []

def setup(self, stage: Optional[str] = None):
dataset_type = MOADDataset if '.' in self.train_data_prefix else ZincDataset
if self.dataset_override == "MOFA":
Expand Down Expand Up @@ -236,7 +238,7 @@ def validation_step(self, data, *args):
loss = vlb_loss
else:
raise NotImplementedError(self.loss_type)
return {
validation_step = {
'loss': loss,
'delta_log_px': delta_log_px,
'kl_prior': kl_prior,
Expand All @@ -247,6 +249,7 @@ def validation_step(self, data, *args):
'noise_t': noise_t,
'noise_0': noise_0
}
self.validation_step_outputs.append(validation_step)

def test_step(self, data, *args):
delta_log_px, kl_prior, loss_term_t, loss_term_0, l2_loss, noise_t, noise_0 = self.forward(data, training=False)
Expand Down Expand Up @@ -275,9 +278,9 @@ def train_epoch_end(self, training_step_outputs):
self.metrics.setdefault(f'{metric}/train', []).append(avg_metric)
self.log(f'{metric}/train', avg_metric, prog_bar=True)

def validation_epoch_end(self, validation_step_outputs):
for metric in validation_step_outputs[0].keys():
avg_metric = self.aggregate_metric(validation_step_outputs, metric)
def on_validation_epoch_end(self):
for metric in self.validation_step_outputs[0].keys():
avg_metric = self.aggregate_metric(self.validation_step_outputs, metric)
self.metrics.setdefault(f'{metric}/val', []).append(avg_metric)
self.log(f'{metric}/val', avg_metric, prog_bar=True)

Expand All @@ -293,6 +296,8 @@ def validation_epoch_end(self, validation_step_outputs):
for metric, value in best_metrics.items():
self.log(f'best_{metric}', value, prog_bar=True, batch_size=self.batch_size)

self.validation_step_outputs.clear()

def test_epoch_end(self, test_step_outputs):
for metric in test_step_outputs[0].keys():
avg_metric = self.aggregate_metric(test_step_outputs, metric)
Expand Down

0 comments on commit 87aa438

Please sign in to comment.