Skip to content

Commit

Permalink
update memory monitor (#1940)
Browse files Browse the repository at this point in the history
* update memory monitor

* add round

* fix memory

* add rounding

* add rounding

* update memory  monitor

* reformat

* round only if non 0

* Update composer/callbacks/memory_monitor.py

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>

---------

Co-authored-by: Daniel King <43149077+dakinggg@users.noreply.github.com>
  • Loading branch information
mvpatel2000 and dakinggg authored Feb 8, 2023
1 parent a918b2c commit 6a9d088
Showing 1 changed file with 43 additions and 32 deletions.
75 changes: 43 additions & 32 deletions composer/callbacks/memory_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

"""Log memory usage during training."""
import logging
import math
import warnings
from typing import Dict, Union
from typing import Dict, Optional, Union

import torch.cuda

Expand Down Expand Up @@ -50,31 +51,32 @@ class MemoryMonitor(Callback):
The following statistics are recorded:
+----------------+--------------------------------------------------------------------------------+
| Statistic | Description |
+================+================================================================================+
| alloc_requests | Number of memory allocation requests received by the memory allocator. |
+----------------+--------------------------------------------------------------------------------+
| free_requests | Number of memory free requests received by the memory allocator. |
+----------------+--------------------------------------------------------------------------------+
| allocated_mem | Amount of allocated memory in bytes. |
+----------------+--------------------------------------------------------------------------------+
| active_mem | Amount of active memory in bytes at the time of recording. |
+----------------+--------------------------------------------------------------------------------+
| inactive_mem | Amount of inactive, non-releaseable memory in bytes at the time of recording. |
+----------------+--------------------------------------------------------------------------------+
| reserved_mem | Amount of reserved memory in bytes at the time of recording. |
+----------------+--------------------------------------------------------------------------------+
| alloc_retries | Number of failed cudaMalloc calls that result in a cache flush and retry. |
+----------------+--------------------------------------------------------------------------------+
+----------------+-----------------------------------------------------------------------------------+
| Statistic | Description |
+================+===================================================================================+
| allocated_mem | Amount of allocated memory in gigabytes. |
+----------------+-----------------------------------------------------------------------------------+
| active_mem | Amount of active memory in gigabytes at the time of recording. |
+----------------+-----------------------------------------------------------------------------------+
| inactive_mem | Amount of inactive, non-releaseable memory in gigabytes at the time of recording. |
+----------------+-----------------------------------------------------------------------------------+
| reserved_mem | Amount of reserved memory in gigabytes at the time of recording. |
+----------------+-----------------------------------------------------------------------------------+
| alloc_retries | Number of failed cudaMalloc calls that result in a cache flush and retry. |
+----------------+-----------------------------------------------------------------------------------+
.. note::
Memory usage monitoring is only supported for GPU devices.
Args:
memory_keys (Dict[str, str], optional): A dict specifying memory statistics to log. Keys
are the names of memory statistics to log from `torch.cuda.memory_stats()`, and values
are the names they will be logged under. If not provided, the above statistics are
logged. Defaults to None.
"""

def __init__(self) -> None:
# Memory monitor takes no args
pass
def __init__(self, memory_keys: Optional[Dict[str, str]] = None) -> None:
self.memory_keys = memory_keys

def init(self, state: State, logger: Logger) -> None:
# Not relying on `torch.cuda.is_available()` since the model could be on CPU.
Expand All @@ -90,28 +92,37 @@ def after_train_batch(self, state: State, logger: Logger):
if model_device.type != 'cuda':
return

memory_report = _get_memory_report()
memory_report = _get_memory_report(self.memory_keys)

logger.log_metrics({f'memory/{mem_stat}': val for (mem_stat, val) in memory_report.items()})


_MEMORY_STATS = {
'allocation.all.allocated': 'alloc_requests',
'allocation.all.freed': 'free_requests',
'allocated_bytes.all.allocated': 'allocated_mem',
_MEMORY_KEYS = {
'allocated_bytes.all.current': 'allocated_mem',
'active_bytes.all.current': 'active_mem',
'inactive_split_bytes.all.current': 'inactive_mem',
'reserved_bytes.all.current': 'reserved_mem',
'num_alloc_retries': 'alloc_retries',
}


def _get_memory_report() -> Dict[str, Union[int, float]]:
def _get_memory_report(memory_keys: Optional[Dict[str, str]] = None) -> Dict[str, Union[int, float]]:
memory_stats = torch.cuda.memory_stats()

# simplify the memory_stats
memory_report = {
name: memory_stats[torch_name] for (torch_name, name) in _MEMORY_STATS.items() if torch_name in memory_stats
}
memory_keys = memory_keys or _MEMORY_KEYS

# simplify and reformat the memory_stats
memory_report = {}
for (torch_name, name) in memory_keys.items():
if torch_name in memory_stats:
# Convert to gigabytes
if 'bytes' in torch_name:
gigabytes = memory_stats[torch_name] / 1.0e9
# Round to preserve 5 significant digits
if gigabytes != 0:
order_of_magnitude = int(math.floor(math.log10(abs(gigabytes))))
gigabytes = round(gigabytes, -order_of_magnitude + 4)
memory_report[name.replace('bytes', 'gigabytes')] = gigabytes
else:
memory_report[name] = memory_stats[torch_name]

return memory_report

0 comments on commit 6a9d088

Please sign in to comment.