Skip to content

Commit

Permalink
Merge pull request #422 from psfoley/create_pytorch_objects_inference…
Browse files Browse the repository at this point in the history
…_fix

Create pytorch objects inference fix
  • Loading branch information
sarthakpati authored May 4, 2022
2 parents 6def27f + 7f3f76e commit f13fb9c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
2 changes: 1 addition & 1 deletion GANDLF/compute/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def validate_network(
label = None
if params["problem_type"] != "segmentation":
label = label_ground_truth
else:
elif "label" in patches_batch:
label = patches_batch["label"][torchio.DATA]

if label is not None:
Expand Down
29 changes: 21 additions & 8 deletions GANDLF/compute/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,16 @@ def create_pytorch_objects(parameters, train_csv=None, val_csv=None, device="cpu
"""
# initialize train and val loaders
train_loader, val_loader = None, None
headers_to_populate_train, headers_to_populate_val = None, None

if train_csv is not None:
# populate the data frames
parameters["training_data"], headers_train = parseTrainingCSV(
parameters["training_data"], headers_to_populate_train = parseTrainingCSV(
train_csv, train=True
)
parameters = populate_header_in_parameters(parameters, headers_train)
parameters = populate_header_in_parameters(
parameters, headers_to_populate_train
)
# get the train loader
train_loader = get_train_loader(parameters)
parameters["training_samples_size"] = len(train_loader)
Expand All @@ -52,7 +55,13 @@ def create_pytorch_objects(parameters, train_csv=None, val_csv=None, device="cpu
) = get_class_imbalance_weights(parameters["training_data"], parameters)

if val_csv is not None:
parameters["validation_data"], _ = parseTrainingCSV(val_csv, train=False)
parameters["validation_data"], headers_to_populate_val = parseTrainingCSV(
val_csv, train=False
)
if headers_to_populate_train is None:
parameters = populate_header_in_parameters(
parameters, headers_to_populate_val
)
# get the validation loader
val_loader = get_validation_loader(parameters)

Expand All @@ -69,12 +78,16 @@ def create_pytorch_objects(parameters, train_csv=None, val_csv=None, device="cpu
model, amp=parameters["model"]["amp"], device=device, optimizer=optimizer
)

if not ("step_size" in parameters["scheduler"]):
parameters["scheduler"]["step_size"] = (
parameters["training_samples_size"] / parameters["learning_rate"]
)
# only need to create scheduler if training
if train_csv is not None:
if not ("step_size" in parameters["scheduler"]):
parameters["scheduler"]["step_size"] = (
parameters["training_samples_size"] / parameters["learning_rate"]
)

scheduler = get_scheduler(parameters)
scheduler = get_scheduler(parameters)
else:
scheduler = None

# these keys contain generators, and are not needed beyond this point in params
generator_keys_to_remove = ["optimizer_object", "model_parameters"]
Expand Down

0 comments on commit f13fb9c

Please sign in to comment.