Skip to content

Commit

Permalink
Small fixes and changes to PyTorch trainer (#348)
Browse files Browse the repository at this point in the history
- Measure more timings
- Delete pretrained model in case of crash
- Load weights to CPU first, then move to GPU (if weights were stored on
gpu 0 and should be loaded on gpu 1, we otherwise load them directly on
gpu 0, which is problematic)
  • Loading branch information
MaxiBoether authored Jan 8, 2024
1 parent 3dd39f0 commit 65c2348
Showing 1 changed file with 40 additions and 14 deletions.
54 changes: 40 additions & 14 deletions modyn/trainer_server/internal/trainer/pytorch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,9 @@ def save_state(self, destination: Union[pathlib.Path, io.BytesIO], iteration: Op
def load_state_if_given(self, path: pathlib.Path, load_optimizer_state: bool = False) -> None:
assert path.exists(), "Cannot load state from non-existing file"
self._info(f"Loading model state from {path}")
# We load the weights on the CPU, and `load_state_dict` moves them to GPU
with open(path, "rb") as state_file:
checkpoint = torch.load(io.BytesIO(state_file.read()))
checkpoint = torch.load(io.BytesIO(state_file.read()), map_location=torch.device("cpu"))

assert "model" in checkpoint
self._model.model.load_state_dict(checkpoint["model"])
Expand Down Expand Up @@ -383,8 +384,8 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches
self.update_queue(AvailableQueues.TRAINING, batch_number, self._num_samples, training_active=True)

stopw.start("PreprocessBatch", resume=True)
sample_ids, target, data = self.preprocess_batch(batch)
stopw.stop()
sample_ids, target, data = self.preprocess_batch(batch, stopw)
stopw.stop("PreprocessBatch")

if retrieve_weights_from_dataloader:
# model output is a torch.FloatTensor but weights is a torch.DoubleTensor.
Expand Down Expand Up @@ -458,17 +459,24 @@ def train(self) -> None: # pylint: disable=too-many-locals, too-many-branches
self._log["epochs"][epoch]["BatchTimings"] = batch_timings

# mypy cannot handle np.min and np.max
batch_timings = np.array(batch_timings)
self._log["epochs"][epoch]["MinFetchBatch"] = np.min(batch_timings).item() # type: ignore
self._log["epochs"][epoch]["MaxFetchBatch"] = np.max(batch_timings).item() # type: ignore
self._log["epochs"][epoch]["AvgFetchBatch"] = np.mean(batch_timings).item()
self._log["epochs"][epoch]["MedianFetchBatch"] = np.median(batch_timings).item()
self._log["epochs"][epoch]["StdFetchBatch"] = np.std(batch_timings).item()
del batch_timings
if len(batch_timings) > 0:
batch_timings = np.array(batch_timings)
self._log["epochs"][epoch]["MinFetchBatch"] = np.min(batch_timings).item() # type: ignore
self._log["epochs"][epoch]["MaxFetchBatch"] = np.max(batch_timings).item() # type: ignore
self._log["epochs"][epoch]["AvgFetchBatch"] = np.mean(batch_timings).item()
self._log["epochs"][epoch]["MedianFetchBatch"] = np.median(batch_timings).item()
self._log["epochs"][epoch]["StdFetchBatch"] = np.std(batch_timings).item()
del batch_timings
else:
self._error("Got zero batch timings, cannot get minimum.")

self._log["epochs"][epoch]["TotalFetchBatch"] = stopw.measurements.get("FetchBatch", 0)
self._log["epochs"][epoch]["OnBatchBeginCallbacks"] = stopw.measurements.get("OnBatchBeginCallbacks", 0)
self._log["epochs"][epoch]["PreprocessBatch"] = stopw.measurements.get("PreprocessBatch", 0)
self._log["epochs"][epoch]["PreprocSampleIDs"] = stopw.measurements.get("PreprocSampleIDs", 0)
self._log["epochs"][epoch]["LabelTransform"] = stopw.measurements.get("LabelTransform", 0)
self._log["epochs"][epoch]["MoveLabelToGPU"] = stopw.measurements.get("MoveLabelToGPU", 0)
self._log["epochs"][epoch]["MoveDataToGPU"] = stopw.measurements.get("MoveDataToGPU", 0)
self._log["epochs"][epoch]["DownsampleBTS"] = stopw.measurements.get("DownsampleBTS", 0)
self._log["epochs"][epoch]["DownsampleSTB"] = stopw.measurements.get("DownsampleSTB", 0)
self._log["epochs"][epoch]["Forward"] = stopw.measurements.get("Forward", 0)
Expand Down Expand Up @@ -566,20 +574,34 @@ def update_queue(
except queue.Empty:
pass

def preprocess_batch(self, batch: tuple) -> tuple[list, torch.Tensor, Union[torch.Tensor, dict]]:
def preprocess_batch(
self, batch: tuple, stopw: Optional[Stopwatch] = None
) -> tuple[list, torch.Tensor, Union[torch.Tensor, dict]]:
if stopw is None:
stopw = Stopwatch()

stopw.start("PreprocSampleIDs", resume=True)
sample_ids = batch[0]
if isinstance(sample_ids, torch.Tensor):
sample_ids = sample_ids.tolist()
elif isinstance(sample_ids, tuple):
sample_ids = list(sample_ids)

assert isinstance(sample_ids, list), "Cannot parse result from DataLoader"
stopw.stop("PreprocSampleIDs")

if self._label_tranformer_function is None:
target = batch[2].to(self._device)
stopw.start("LabelTransform", resume=True)
if self._label_tranformer_function is not None:
target = self._label_tranformer_function(batch[2])
else:
target = self._label_tranformer_function(batch[2]).to(self._device)
target = batch[2]
stopw.stop("LabelTransform")

stopw.start("MoveLabelToGPU", resume=True)
target = target.to(self._device)
stopw.stop("MoveLabelToGPU")

stopw.start("MoveDataToGPU", resume=True)
data: Union[torch.Tensor, dict]
if isinstance(batch[1], torch.Tensor):
data = batch[1].to(self._device)
Expand All @@ -592,6 +614,7 @@ def preprocess_batch(self, batch: tuple) -> tuple[list, torch.Tensor, Union[torc
"The format of the data provided is not supported in modyn. "
"Please use either torch tensors or dict[str, torch.Tensor]"
)
stopw.stop("MoveDataToGPU")

return sample_ids, target, data

Expand Down Expand Up @@ -798,3 +821,6 @@ def train(
exception_msg = traceback.format_exc()
logger.error(exception_msg)
exception_queue.put(exception_msg)
pretrained_path = training_info.pretrained_model_path
if pretrained_path is not None and pretrained_path.exists():
pretrained_path.unlink()

0 comments on commit 65c2348

Please sign in to comment.