Skip to content

Commit

Permalink
#pygrain Fix flow control for async shared memory deletion.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712025850
  • Loading branch information
aaudiber authored and copybara-github committed Jan 6, 2025
1 parent bc177d6 commit ceabb3d
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 23 deletions.
80 changes: 57 additions & 23 deletions grain/_src/python/shared_memory_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,10 @@ def close_and_unlink_shm(self) -> None:
shm.unlink()


def close_with_semaphore(
shm: shared_memory.SharedMemory, semaphore: threading.Semaphore
) -> None:
with semaphore:
shm.close()
def _del_shm(shm: shared_memory.SharedMemory, unlink: bool) -> None:
shm.close()
if unlink:
shm.unlink()


class SharedMemoryArray(np.ndarray):
Expand All @@ -59,8 +58,8 @@ class SharedMemoryArray(np.ndarray):
"""

_lock: threading.Lock = threading.Lock()
_unlink_thread_pool: pool.ThreadPool | None = None
_unlink_semaphore: threading.Semaphore | None = None
_del_thread_pool: pool.ThreadPool | None = None
_outstanding_del_requests: threading.Semaphore | None = None

def __new__(
cls,
Expand Down Expand Up @@ -121,11 +120,46 @@ def __reduce_ex__(self, protocol):
return self.from_shared_memory, (self.shm, self.shape, self.dtype)

@classmethod
def enable_async_del(cls, num_threads: int = 1) -> None:
with cls._lock:
if not SharedMemoryArray._unlink_thread_pool:
SharedMemoryArray._unlink_thread_pool = pool.ThreadPool(num_threads)
SharedMemoryArray._unlink_semaphore = threading.Semaphore(num_threads)
def enable_async_del(
cls, num_threads: int = 1, max_outstanding_requests: int = 50
) -> None:
"""Enables asynchronous deletion of shared memory arrays.
Args:
num_threads: The number of threads to use for deletion.
max_outstanding_requests: The maximum number of outstanding requests to
close/unlink shared memory. A larger value may make the __del__ method
faster, but it may also lead to OOM errors or hitting file descriptor
limits, since `max_outstanding_requests` shared memory objects and their
associated file descriptors may be buffered before deletion.
"""
with SharedMemoryArray._lock:
if not SharedMemoryArray._del_thread_pool:
if max_outstanding_requests < num_threads:
raise ValueError(
"max_outstanding_requests must be at least num_threads."
)
SharedMemoryArray._del_thread_pool = pool.ThreadPool(num_threads)
SharedMemoryArray._outstanding_del_requests = threading.Semaphore(
max_outstanding_requests
)

# For use in tests.
@classmethod
def _disable_async_del(cls) -> None:
cls._del_thread_pool = None
cls._outstanding_del_requests = None

# Mocked in tests, so be careful refactoring.
@classmethod
def close_shm_async(
cls,
shm: shared_memory.SharedMemory,
unlink: bool,
) -> None:
_del_shm(shm, unlink)
assert cls._outstanding_del_requests is not None
cls._outstanding_del_requests.release()

def unlink_on_del(self) -> None:
"""Mark this object responsible for unlinking the shared memory."""
Expand All @@ -135,19 +169,19 @@ def __del__(self) -> None:
# Ensure that this array is not a view before closing shared memory
if not isinstance(self.base, mmap.mmap):
return
thread_pool = SharedMemoryArray._unlink_thread_pool
semaphore = SharedMemoryArray._unlink_semaphore
thread_pool = SharedMemoryArray._del_thread_pool
outstanding_del_requests = SharedMemoryArray._outstanding_del_requests
shm = self.shm
assert isinstance(shm, shared_memory.SharedMemory)
if thread_pool:
assert semaphore is not None
assert outstanding_del_requests is not None
# We use a semaphore to make sure that we don't accumulate too many
# requests to close/unlink shared memory, which could lead to OOM errors
thread_pool.apply_async(close_with_semaphore, args=(shm, semaphore))
else:
shm.close()
if self._unlink_on_del:
if thread_pool:
thread_pool.apply_async(shm.unlink)
# requests to close/unlink shared memory, which could lead to OOM errors.
if outstanding_del_requests.acquire(blocking=False):
thread_pool.apply_async(
SharedMemoryArray.close_shm_async, args=(shm, self._unlink_on_del)
)
else:
shm.unlink()
_del_shm(shm, unlink=self._unlink_on_del)
else:
_del_shm(shm, unlink=self._unlink_on_del)
103 changes: 103 additions & 0 deletions grain/_src/python/shared_memory_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.
"""Tests for shared memory array."""
from multiprocessing import shared_memory
import threading
import time
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
import multiprocessing
Expand All @@ -25,6 +28,23 @@
import tensorflow as tf


def _create_and_delete_shm() -> SharedMemoryArrayMetadata:
data = np.array([[1, 2], [3, 4]], dtype=np.int32)
shm_array = SharedMemoryArray(data.shape, data.dtype)
shm_array.unlink_on_del()
metadata = shm_array.metadata
return metadata


def _wait_for_deletion(metadata: SharedMemoryArrayMetadata) -> None:
while True:
try:
_ = shared_memory.SharedMemory(name=metadata.name, create=False)
time.sleep(0.1)
except FileNotFoundError:
break


class SharedMemoryArrayTest(parameterized.TestCase):

@parameterized.parameters(["numpy", "tensorflow", "jax"])
Expand Down Expand Up @@ -74,6 +94,89 @@ def test_batch_dict_of_data_with_shared_memory(self, mode):
with self.assertRaises(FileNotFoundError):
_ = shared_memory.SharedMemory(name=shm_metadata.name, create=False)

def test_async_unlink_limit(self):
SharedMemoryArray._disable_async_del()
SharedMemoryArray.enable_async_del(max_outstanding_requests=1)
event = threading.Event()
original_close_shm_async = SharedMemoryArray.close_shm_async

def _wait_for_event(shm, unlink_on_del):
event.wait(timeout=60)
original_close_shm_async(shm, unlink_on_del)

with mock.patch.object(
SharedMemoryArray, "close_shm_async", side_effect=_wait_for_event
):
metadata = _create_and_delete_shm()
time.sleep(1)
# This should succeed, since the unlink request is async and we haven't
# yet allowed it to progress past the event.
_ = shared_memory.SharedMemory(name=metadata.name, create=False)

# All outstanding requests in use, so this should delete the shared memory
# right away.
metadata_2 = _create_and_delete_shm()
with self.assertRaises(FileNotFoundError):
_ = shared_memory.SharedMemory(name=metadata_2.name, create=False)

event.set()
_wait_for_deletion(metadata)

def test_del_no_pool(self):
SharedMemoryArray._disable_async_del()
# Tests deletion of SharedMemory resource when enable_async_del is not
# called.
data = np.array([[1, 2], [3, 4]], dtype=np.int32)
shm_array = SharedMemoryArray(data.shape, data.dtype)
shm_array.unlink_on_del()
metadata = shm_array.metadata
del shm_array
with self.assertRaises(FileNotFoundError):
_ = shared_memory.SharedMemory(name=metadata.name, create=False)

def test_del_many_async(self):
SharedMemoryArray._disable_async_del()
SharedMemoryArray.enable_async_del(
num_threads=4, max_outstanding_requests=20
)
shm_metadatas = [_create_and_delete_shm() for _ in range(50)]
for metadata in shm_metadatas:
_wait_for_deletion(metadata)

def test_del_many_async_reuse_pool(self):
max_outstanding_requests = 20
SharedMemoryArray._disable_async_del()
SharedMemoryArray.enable_async_del(
num_threads=4, max_outstanding_requests=max_outstanding_requests
)
original_close_shm_async = SharedMemoryArray.close_shm_async

def my_close_shm_async(shm, unlink_on_del):
original_close_shm_async(shm, unlink_on_del)

with mock.patch.object(
SharedMemoryArray, "close_shm_async", side_effect=my_close_shm_async
) as mock_close_shm_async:
with self.subTest("first_round_of_requests"):
shm_metadatas = [
_create_and_delete_shm() for _ in range(max_outstanding_requests)
]
for metadata in shm_metadatas:
_wait_for_deletion(metadata)
self.assertEqual(
max_outstanding_requests, mock_close_shm_async.call_count
)
with self.subTest("second_round_of_requests"):
# Do it again to make sure the pool is reused.
shm_metadatas = [
_create_and_delete_shm() for _ in range(max_outstanding_requests)
]
for metadata in shm_metadatas:
_wait_for_deletion(metadata)
self.assertEqual(
2 * max_outstanding_requests, mock_close_shm_async.call_count
)


if __name__ == "__main__":
absltest.main()

0 comments on commit ceabb3d

Please sign in to comment.