Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create pytorch objects inference fix #422

Merged
2 changes: 1 addition & 1 deletion GANDLF/compute/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def validate_network(
label = None
if params["problem_type"] != "segmentation":
label = label_ground_truth
else:
elif "label" in patches_batch:
label = patches_batch["label"][torchio.DATA]

if label is not None:
Expand Down
30 changes: 22 additions & 8 deletions GANDLF/compute/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ def create_pytorch_objects(parameters, train_csv=None, val_csv=None, device="cpu
"""
# initialize train and val loaders
train_loader, val_loader = None, None
headers_to_populate_train, headers_to_populate_val = None, None

if train_csv is not None:
# populate the data frames
parameters["training_data"], headers_train = parseTrainingCSV(
parameters["training_data"], headers_to_populate_train = parseTrainingCSV(
train_csv, train=True
)
parameters = populate_header_in_parameters(parameters, headers_train)
# get the train loader
train_loader = get_train_loader(parameters)
parameters["training_samples_size"] = len(train_loader)
Expand All @@ -52,10 +52,20 @@ def create_pytorch_objects(parameters, train_csv=None, val_csv=None, device="cpu
) = get_class_imbalance_weights(parameters["training_data"], parameters)

if val_csv is not None:
parameters["validation_data"], _ = parseTrainingCSV(val_csv, train=False)
parameters["validation_data"], headers_to_populate_val = parseTrainingCSV(
val_csv, train=False
)
# get the validation loader
val_loader = get_validation_loader(parameters)

# populate required headers
headers_to_populate = headers_to_populate_train
if headers_to_populate is None:
if headers_to_populate_val is not None:
headers_to_populate = headers_to_populate_val
if headers_to_populate is not None:
parameters = populate_header_in_parameters(parameters, headers_to_populate)

psfoley marked this conversation as resolved.
Show resolved Hide resolved
# get the model
model = get_model(parameters)
parameters["model_parameters"] = model.parameters()
Expand All @@ -69,12 +79,16 @@ def create_pytorch_objects(parameters, train_csv=None, val_csv=None, device="cpu
model, amp=parameters["model"]["amp"], device=device, optimizer=optimizer
)

if not ("step_size" in parameters["scheduler"]):
parameters["scheduler"]["step_size"] = (
parameters["training_samples_size"] / parameters["learning_rate"]
)
# only need to create scheduler if training
if train_csv is not None:
if not ("step_size" in parameters["scheduler"]):
parameters["scheduler"]["step_size"] = (
parameters["training_samples_size"] / parameters["learning_rate"]
)

scheduler = get_scheduler(parameters)
scheduler = get_scheduler(parameters)
else:
scheduler = None

# these keys contain generators, and are not needed beyond this point in params
generator_keys_to_remove = ["optimizer_object", "model_parameters"]
Expand Down