Skip to content

Commit

Permalink
Merge pull request #65 from Fung-Lab/vxfung-patch-4
Browse files Browse the repository at this point in the history
Vxfung patch 4
  • Loading branch information
vxfung authored Jan 26, 2024
2 parents 3cb3695 + a650461 commit d07afb0
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 5 deletions.
6 changes: 3 additions & 3 deletions matdeeplearn/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,18 @@ def from_config(cls, config):
else:
rank = torch.device("cuda" if torch.cuda.is_available() else "cpu")
local_world_size = 1
dataset = cls._load_dataset(config["dataset"], config["task"]["run_mode"]) if hasattr(config["dataset"], "src") else None
dataset = cls._load_dataset(config["dataset"], config["task"]["run_mode"]) if "src" in config["dataset"] else None
model = cls._load_model(config["model"], config["dataset"]["preprocess_params"], dataset, local_world_size, rank)
optimizer = cls._load_optimizer(config["optim"], model, local_world_size)
sampler = cls._load_sampler(config["optim"], dataset, local_world_size, rank) if hasattr(config["dataset"], "src") else None
sampler = cls._load_sampler(config["optim"], dataset, local_world_size, rank) if "src" in config["dataset"] else None
data_loader = cls._load_dataloader(
config["optim"],
config["dataset"],
dataset,
sampler,
config["task"]["run_mode"],
config["model"]
) if hasattr(config["dataset"], "src") else None
) if "src" in config["dataset"] else None

scheduler = cls._load_scheduler(config["optim"]["scheduler"], optimizer)
loss = cls._load_loss(config["optim"]["loss"])
Expand Down
1 change: 1 addition & 0 deletions test/configs/cpu/test_predict.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ model:
otf_node_attr: False
# compute gradients w.r.t to positions and cell, requires otf_edge=True
gradient: False
model_ensemble: 1

optim:
max_epochs: 5
Expand Down
1 change: 1 addition & 0 deletions test/configs/cpu/test_training.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ model:
otf_node_attr: False
# compute gradients w.r.t to positions and cell, requires otf_edge=True
gradient: False
model_ensemble: 1

optim:
max_epochs: 5
Expand Down
4 changes: 2 additions & 2 deletions test/scripts/cpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ def trainer_property(config, train: bool):

def assert_valid_predictions(trainer, load: str):
try:
out = trainer.predict(loader=trainer.data_loader[load], split="predict", write_output=False)
out = trainer.predict(loader=trainer.data_loader[0][load], split="predict", write_output=False)
assert isinstance(out["predict"][0][0], (floating, float, integer, int))
assert isinstance(out["ids"][0][0], str)
if load != "predict_loader":
assert isinstance(out["target"][0][0], (floating, float, integer, int))
except:
assert False


0 comments on commit d07afb0

Please sign in to comment.