Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#pygrain Fix flow control for async shared memory deletion. #680

Merged
merged 1 commit into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 57 additions & 23 deletions grain/_src/python/shared_memory_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,10 @@ def close_and_unlink_shm(self) -> None:
shm.unlink()


def close_with_semaphore(
shm: shared_memory.SharedMemory, semaphore: threading.Semaphore
) -> None:
with semaphore:
shm.close()
def _del_shm(shm: shared_memory.SharedMemory, unlink: bool) -> None:
shm.close()
if unlink:
shm.unlink()


class SharedMemoryArray(np.ndarray):
Expand All @@ -59,8 +58,8 @@ class SharedMemoryArray(np.ndarray):
"""

_lock: threading.Lock = threading.Lock()
_unlink_thread_pool: pool.ThreadPool | None = None
_unlink_semaphore: threading.Semaphore | None = None
_del_thread_pool: pool.ThreadPool | None = None
_outstanding_del_requests: threading.Semaphore | None = None

def __new__(
cls,
Expand Down Expand Up @@ -121,11 +120,46 @@ def __reduce_ex__(self, protocol):
return self.from_shared_memory, (self.shm, self.shape, self.dtype)

@classmethod
def enable_async_del(cls, num_threads: int = 1) -> None:
with cls._lock:
if not SharedMemoryArray._unlink_thread_pool:
SharedMemoryArray._unlink_thread_pool = pool.ThreadPool(num_threads)
SharedMemoryArray._unlink_semaphore = threading.Semaphore(num_threads)
def enable_async_del(
cls, num_threads: int = 1, max_outstanding_requests: int = 50
) -> None:
"""Enables asynchronous deletion of shared memory arrays.

Args:
num_threads: The number of threads to use for deletion.
max_outstanding_requests: The maximum number of outstanding requests to
close/unlink shared memory. A larger value may make the __del__ method
faster, but it may also lead to OOM errors or hitting file descriptor
limits, since `max_outstanding_requests` shared memory objects and their
associated file descriptors may be buffered before deletion.
"""
with SharedMemoryArray._lock:
if not SharedMemoryArray._del_thread_pool:
if max_outstanding_requests < num_threads:
raise ValueError(
"max_outstanding_requests must be at least num_threads."
)
SharedMemoryArray._del_thread_pool = pool.ThreadPool(num_threads)
SharedMemoryArray._outstanding_del_requests = threading.Semaphore(
max_outstanding_requests
)

# For use in tests.
@classmethod
def _disable_async_del(cls) -> None:
cls._del_thread_pool = None
cls._outstanding_del_requests = None

# Mocked in tests, so be careful refactoring.
@classmethod
def close_shm_async(
cls,
shm: shared_memory.SharedMemory,
unlink: bool,
) -> None:
_del_shm(shm, unlink)
assert cls._outstanding_del_requests is not None
cls._outstanding_del_requests.release()

def unlink_on_del(self) -> None:
"""Mark this object responsible for unlinking the shared memory."""
Expand All @@ -135,19 +169,19 @@ def __del__(self) -> None:
# Ensure that this array is not a view before closing shared memory
if not isinstance(self.base, mmap.mmap):
return
thread_pool = SharedMemoryArray._unlink_thread_pool
semaphore = SharedMemoryArray._unlink_semaphore
thread_pool = SharedMemoryArray._del_thread_pool
outstanding_del_requests = SharedMemoryArray._outstanding_del_requests
shm = self.shm
assert isinstance(shm, shared_memory.SharedMemory)
if thread_pool:
assert semaphore is not None
assert outstanding_del_requests is not None
# We use a semaphore to make sure that we don't accumulate too many
# requests to close/unlink shared memory, which could lead to OOM errors
thread_pool.apply_async(close_with_semaphore, args=(shm, semaphore))
else:
shm.close()
if self._unlink_on_del:
if thread_pool:
thread_pool.apply_async(shm.unlink)
# requests to close/unlink shared memory, which could lead to OOM errors.
if outstanding_del_requests.acquire(blocking=False):
thread_pool.apply_async(
SharedMemoryArray.close_shm_async, args=(shm, self._unlink_on_del)
)
else:
shm.unlink()
_del_shm(shm, unlink=self._unlink_on_del)
else:
_del_shm(shm, unlink=self._unlink_on_del)
103 changes: 103 additions & 0 deletions grain/_src/python/shared_memory_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.
"""Tests for shared memory array."""
from multiprocessing import shared_memory
import threading
import time
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
import multiprocessing
Expand All @@ -25,6 +28,23 @@
import tensorflow as tf


def _create_and_delete_shm() -> SharedMemoryArrayMetadata:
data = np.array([[1, 2], [3, 4]], dtype=np.int32)
shm_array = SharedMemoryArray(data.shape, data.dtype)
shm_array.unlink_on_del()
metadata = shm_array.metadata
return metadata


def _wait_for_deletion(metadata: SharedMemoryArrayMetadata) -> None:
while True:
try:
_ = shared_memory.SharedMemory(name=metadata.name, create=False)
time.sleep(0.1)
except FileNotFoundError:
break


class SharedMemoryArrayTest(parameterized.TestCase):

@parameterized.parameters(["numpy", "tensorflow", "jax"])
Expand Down Expand Up @@ -74,6 +94,89 @@ def test_batch_dict_of_data_with_shared_memory(self, mode):
with self.assertRaises(FileNotFoundError):
_ = shared_memory.SharedMemory(name=shm_metadata.name, create=False)

def test_async_unlink_limit(self):
SharedMemoryArray._disable_async_del()
SharedMemoryArray.enable_async_del(max_outstanding_requests=1)
event = threading.Event()
original_close_shm_async = SharedMemoryArray.close_shm_async

def _wait_for_event(shm, unlink_on_del):
event.wait(timeout=60)
original_close_shm_async(shm, unlink_on_del)

with mock.patch.object(
SharedMemoryArray, "close_shm_async", side_effect=_wait_for_event
):
metadata = _create_and_delete_shm()
time.sleep(1)
# This should succeed, since the unlink request is async and we haven't
# yet allowed it to progress past the event.
_ = shared_memory.SharedMemory(name=metadata.name, create=False)

# All outstanding requests in use, so this should delete the shared memory
# right away.
metadata_2 = _create_and_delete_shm()
with self.assertRaises(FileNotFoundError):
_ = shared_memory.SharedMemory(name=metadata_2.name, create=False)

event.set()
_wait_for_deletion(metadata)

def test_del_no_pool(self):
SharedMemoryArray._disable_async_del()
# Tests deletion of SharedMemory resource when enable_async_del is not
# called.
data = np.array([[1, 2], [3, 4]], dtype=np.int32)
shm_array = SharedMemoryArray(data.shape, data.dtype)
shm_array.unlink_on_del()
metadata = shm_array.metadata
del shm_array
with self.assertRaises(FileNotFoundError):
_ = shared_memory.SharedMemory(name=metadata.name, create=False)

def test_del_many_async(self):
SharedMemoryArray._disable_async_del()
SharedMemoryArray.enable_async_del(
num_threads=4, max_outstanding_requests=20
)
shm_metadatas = [_create_and_delete_shm() for _ in range(50)]
for metadata in shm_metadatas:
_wait_for_deletion(metadata)

def test_del_many_async_reuse_pool(self):
max_outstanding_requests = 20
SharedMemoryArray._disable_async_del()
SharedMemoryArray.enable_async_del(
num_threads=4, max_outstanding_requests=max_outstanding_requests
)
original_close_shm_async = SharedMemoryArray.close_shm_async

def my_close_shm_async(shm, unlink_on_del):
original_close_shm_async(shm, unlink_on_del)

with mock.patch.object(
SharedMemoryArray, "close_shm_async", side_effect=my_close_shm_async
) as mock_close_shm_async:
with self.subTest("first_round_of_requests"):
shm_metadatas = [
_create_and_delete_shm() for _ in range(max_outstanding_requests)
]
for metadata in shm_metadatas:
_wait_for_deletion(metadata)
self.assertEqual(
max_outstanding_requests, mock_close_shm_async.call_count
)
with self.subTest("second_round_of_requests"):
# Do it again to make sure the pool is reused.
shm_metadatas = [
_create_and_delete_shm() for _ in range(max_outstanding_requests)
]
for metadata in shm_metadatas:
_wait_for_deletion(metadata)
self.assertEqual(
2 * max_outstanding_requests, mock_close_shm_async.call_count
)


if __name__ == "__main__":
absltest.main()
Loading