Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create pytorch objects inference fix #422

Merged
2 changes: 1 addition & 1 deletion GANDLF/compute/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def validate_network(
label = None
if params["problem_type"] != "segmentation":
label = label_ground_truth
else:
elif "label" in patches_batch:
label = patches_batch["label"][torchio.DATA]

if label is not None:
Expand Down
29 changes: 21 additions & 8 deletions GANDLF/compute/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@ def create_pytorch_objects(parameters, train_csv=None, val_csv=None, device="cpu
"""
# initialize train and val loaders
train_loader, val_loader = None, None
headers_to_populate_train, headers_to_populate_val = None, None

if train_csv is not None:
# populate the data frames
parameters["training_data"], headers_train = parseTrainingCSV(
parameters["training_data"], headers_to_populate_train = parseTrainingCSV(
train_csv, train=True
)
parameters = populate_header_in_parameters(parameters, headers_train)
parameters = populate_header_in_parameters(
parameters, headers_to_populate_train
)
# get the train loader
train_loader = get_train_loader(parameters)
parameters["training_samples_size"] = len(train_loader)
Expand All @@ -52,7 +55,13 @@ def create_pytorch_objects(parameters, train_csv=None, val_csv=None, device="cpu
) = get_class_imbalance_weights(parameters["training_data"], parameters)

if val_csv is not None:
parameters["validation_data"], _ = parseTrainingCSV(val_csv, train=False)
parameters["validation_data"], headers_to_populate_val = parseTrainingCSV(
val_csv, train=False
)
if headers_to_populate_train is None:
parameters = populate_header_in_parameters(
parameters, headers_to_populate_val
)
# get the validation loader
val_loader = get_validation_loader(parameters)

Expand All @@ -69,12 +78,16 @@ def create_pytorch_objects(parameters, train_csv=None, val_csv=None, device="cpu
model, amp=parameters["model"]["amp"], device=device, optimizer=optimizer
)

if not ("step_size" in parameters["scheduler"]):
parameters["scheduler"]["step_size"] = (
parameters["training_samples_size"] / parameters["learning_rate"]
)
# only need to create scheduler if training
if train_csv is not None:
if not ("step_size" in parameters["scheduler"]):
parameters["scheduler"]["step_size"] = (
parameters["training_samples_size"] / parameters["learning_rate"]
)

scheduler = get_scheduler(parameters)
scheduler = get_scheduler(parameters)
else:
scheduler = None

# these keys contain generators, and are not needed beyond this point in params
generator_keys_to_remove = ["optimizer_object", "model_parameters"]
Expand Down