-
Notifications
You must be signed in to change notification settings - Fork 359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] get_backend is not thread-safe #279
Comments
Hmmm, good point, does wrapping |
I don't think so. Assume two threads T1 and T2 using the same backend cannot find the backend in the dict. A simple solution would be to introduce a lock, something like def get_backend(tensor) -> 'AbstractBackend':
"""
Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor.
If needed, imports package and creates backend
"""
for framework_name, backend in _backends.items():
if backend.is_appropriate_type(tensor):
return backend
with lock:
# Try to find backend again
for framework_name, backend in _backends.items():
if backend.is_appropriate_type(tensor):
return backend
# Find backend subclasses recursively
backend_subclasses = []
backends = AbstractBackend.__subclasses__()
while backends:
backend = backends.pop()
backends += backend.__subclasses__()
backend_subclasses.append(backend)
for BackendSubclass in backend_subclasses:
if _debug_importing:
print('Testing for subclass of ', BackendSubclass)
if BackendSubclass.framework_name not in _backends:
# check that module was already imported. Otherwise it can't be imported
if BackendSubclass.framework_name in sys.modules:
if _debug_importing:
print('Imported backend for ', BackendSubclass.framework_name)
backend = BackendSubclass()
_backends[backend.framework_name] = backend
if backend.is_appropriate_type(tensor):
return backend
raise RuntimeError('Tensor type unknown to einops {}'.format(type(tensor))) |
True, but not an issue: rest of function is idempontent, and no problem if backend is created twice. |
Describe the bug
Using einops with multiple threads can lead to a race condition, as the backend dictionary is updated while being iterated over in another thread.
Reproduction steps
Rerun multiple times
Expected behavior
No race condition
Your platform
einops 0.6.1, python 3.11.3, jax v0.4.14
The text was updated successfully, but these errors were encountered: