-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New mp_ctx
defaults in 4.3.0+ (#6218) for M1 Macs could break parallel sampling for M1 Mac users using JAX
#6362
Comments
mp_ctx
defaults in 4.3.0+ (#6218) for M1 Macs could break parallel samplingmp_ctx
defaults in 4.3.0+ (#6218) for M1 Macs could break parallel sampling for M1 max users using Jax
mp_ctx
defaults in 4.3.0+ (#6218) for M1 Macs could break parallel sampling for M1 max users using Jaxmp_ctx
defaults in 4.3.0+ (#6218) for M1 Macs could break parallel sampling for M1 Mac users using JAX
Other users couldn't sample in parallel with the older default method. Can you sample parallel with any of the If not, you may be stuck with sequential sampling (cores=1) or relying on |
I am not asking to revert to the older defaults. I know my case is a fringe one, but I think some debugging messages or warnings would be helpful, because the current code is a bit opaque about two things:
So I think adding some warnings or debug logs would be helpful here. Something like: if mp_ctx is None or isinstance(mp_ctx, str):
# Closes issue https://github.com/pymc-devs/pymc/issues/3849
# Related issue https://github.com/pymc-devs/pymc/issues/5339
if isinstance(mp_ctx, str):
logger.warning("A str was passed to mp_ctx. We recommend passing a multiprocessing context using multiprocessing.get_context() method.")
if platform.system() == "Darwin":
if platform.processor() == "arm":
mp_ctx = "fork"
logger.debug("mp_ctx is set to 'fork' for MacOS with ARM architecture. This might cause unexpected behavior with JAX, which is inherently multithreaded.")
else:
mp_ctx = "forkserver"
mp_ctx = multiprocessing.get_context(mp_ctx) Happy to submit a PR if you think this is a good idea. BTW, setting |
Oh, we shouldn't override the user specified |
Something like: if mp_ctx is None or isinstance(mp_ctx, str):
if mp_ctx is None and platform.system() == "Darwin":
# Closes issue https://github.com/pymc-devs/pymc/issues/3849
# Related issue https://github.com/pymc-devs/pymc/issues/5339
if platform.processor() == "arm":
mp_ctx = "fork"
else:
mp_ctx = "forkserver"
mp_ctx = multiprocessing.get_context(mp_ctx) |
That sounds good to me too 👍. I still think adding some debugging log to the default of |
Yeah debugging logs sound fine, since this logic is all pretty hacky anyway |
Great! I opened a PR #6363 |
Describe the issue:
Problem:
In #6218, the
mp_ctx
is set tofork
by default if the user's system runs M1 Macs. However, I was running some jax functions wrapped in aesara ops following this tutorial Whencores
is set to2
, sampling would get stuck at 0.00%. Settingcores
to1
, settingmp_ctx
tomultiprocessing.get_context("forkserver")
or downgrading to pymc 4.2.2 would solve this issue.Cause:
According to this post, JAX is internally multithreaded and does not work with
fork
strategy inmultiprocessing
. So Jax functions wrapped in aesara ops would not work in parallel sampling after 4.3.0 on M1 Macs becausemp_ctx
is set tofork
and would not change unless acontext
object is passed. Even simply settingforkserver
tomp_ctx
does not work because a string argument does not changemp_ctx
at all if the system is M1 Mac. This makes the code very difficult to debug.Solution:
mp_ctx
so at least it does not internally force the context tofork
if astr
is passed on M1 Macs. Maybe the users can be warned if their system is M1 Macs.Reproduceable code example:
Error message:
PyMC version information:
aesara==2.8.7, aesara==2.8.6
pymc==4.3.0, pymc==4.4.0
pymc==4.2.2 does not have the same issue.
Context for the issue:
No response
The text was updated successfully, but these errors were encountered: