Skip to content

Commit

Permalink
Merge pull request #62 from Fung-Lab/no_data_option
Browse files Browse the repository at this point in the history
Allow for no datasets
  • Loading branch information
vxfung authored Jan 25, 2024
2 parents 470c825 + e992766 commit 3cb3695
Showing 1 changed file with 38 additions and 30 deletions.
68 changes: 38 additions & 30 deletions matdeeplearn/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,18 @@ def __init__(
logging.info(
f"GPU is available: {torch.cuda.is_available()}, Quantity: {os.environ.get('LOCAL_WORLD_SIZE', None)}"
)
logging.info("Dataset(s) used:")
for key in self.dataset:
logging.info(f"Dataset length: {key, len(self.dataset[key])}")
if self.dataset.get("train"):
logging.debug(self.dataset["train"][0])
logging.debug(self.dataset["train"][0].z[0])
logging.debug(self.dataset["train"][0].y[0])
else:
logging.debug(self.dataset[list(self.dataset.keys())[0]][0])
logging.debug(self.dataset[list(self.dataset.keys())[0]][0].x[0])
logging.debug(self.dataset[list(self.dataset.keys())[0]][0].y[0])
if not (self.dataset is None):
logging.info("Dataset(s) used:")
for key in self.dataset:
logging.info(f"Dataset length: {key, len(self.dataset[key])}")
if self.dataset.get("train"):
logging.debug(self.dataset["train"][0])
logging.debug(self.dataset["train"][0].z[0])
logging.debug(self.dataset["train"][0].y[0])
else:
logging.debug(self.dataset[list(self.dataset.keys())[0]][0])
logging.debug(self.dataset[list(self.dataset.keys())[0]][0].x[0])
logging.debug(self.dataset[list(self.dataset.keys())[0]][0].y[0])

if str(self.rank) not in ("cpu", "cuda"):
logging.debug(self.model[0].module)
Expand Down Expand Up @@ -144,18 +145,18 @@ def from_config(cls, config):
else:
rank = torch.device("cuda" if torch.cuda.is_available() else "cpu")
local_world_size = 1
dataset = cls._load_dataset(config["dataset"], config["task"]["run_mode"])
dataset = cls._load_dataset(config["dataset"], config["task"]["run_mode"]) if hasattr(config["dataset"], "src") else None
model = cls._load_model(config["model"], config["dataset"]["preprocess_params"], dataset, local_world_size, rank)
optimizer = cls._load_optimizer(config["optim"], model, local_world_size)
sampler = cls._load_sampler(config["optim"], dataset, local_world_size, rank)
sampler = cls._load_sampler(config["optim"], dataset, local_world_size, rank) if hasattr(config["dataset"], "src") else None
data_loader = cls._load_dataloader(
config["optim"],
config["dataset"],
dataset,
sampler,
config["task"]["run_mode"],
config["model"]
)
) if hasattr(config["dataset"], "src") else None

scheduler = cls._load_scheduler(config["optim"]["scheduler"], optimizer)
loss = cls._load_loss(config["optim"]["loss"])
Expand Down Expand Up @@ -270,10 +271,11 @@ def _load_dataset(dataset_config, task):
def _load_model(model_config, graph_config, dataset, world_size, rank):
"""Loads the model if from a config file."""

if dataset.get("train"):
dataset = dataset["train"]
else:
dataset = dataset[list(dataset.keys())[0]]
if not (dataset is None):
if dataset.get("train"):
dataset = dataset["train"]
else:
dataset = dataset[list(dataset.keys())[0]]

if isinstance(dataset, torch.utils.data.Subset):
dataset = dataset.dataset
Expand All @@ -293,22 +295,28 @@ def _load_model(model_config, graph_config, dataset, world_size, rank):
if graph_config["node_dim"]:
node_dim = graph_config["node_dim"]
else:
node_dim = dataset.num_features
edge_dim = graph_config["edge_dim"]
if dataset[0]["y"].ndim == 0:
output_dim = 1
node_dim = dataset.num_features
edge_dim = graph_config["edge_dim"]
if not (dataset is None):
if dataset[0]["y"].ndim == 0:
output_dim = 1
else:
output_dim = dataset[0]["y"].shape[1]
else:
output_dim = dataset[0]["y"].shape[1]
output_dim = graph_config["output_dim"]

# Determine if this is a node or graph level model
if dataset[0]["y"].shape[0] == dataset[0]["z"].shape[0]:
model_config["prediction_level"] = "node"
elif dataset[0]["y"].shape[0] == 1:
model_config["prediction_level"] = "graph"
if not (dataset is None):
if dataset[0]["y"].shape[0] == dataset[0]["z"].shape[0]:
model_config["prediction_level"] = "node"
elif dataset[0]["y"].shape[0] == 1:
model_config["prediction_level"] = "graph"
else:
raise ValueError(
"Target labels do not have the correct dimensions for node or graph-level prediction."
)
else:
raise ValueError(
"Target labels do not have the correct dimensions for node or graph-level prediction."
)
model_config["prediction_level"] = graph_config["prediction_level"]

model_cls = registry.get_model_class(model_config["name"])
model = model_cls(
Expand Down

0 comments on commit 3cb3695

Please sign in to comment.