Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tune] PyTorch Lightning 1.7 with Ray Tune hangs #28197

Closed
marcmk6 opened this issue Aug 31, 2022 · 5 comments · Fixed by #28335
Closed

[tune] PyTorch Lightning 1.7 with Ray Tune hangs #28197

marcmk6 opened this issue Aug 31, 2022 · 5 comments · Fixed by #28335
Assignees
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks tune Tune-related issues

Comments

@marcmk6
Copy link

marcmk6 commented Aug 31, 2022

What happened + What you expected to happen

tune stops to run new trials while all computational resources are free. The program seems dead.

The RUNNING trial doesn't seem running as there are no metrics get updated for a very long time. (Completing a full trial won't even take so much time.)

== Status ==
Current time: 2022-08-31 16:10:46 (running for 00:16:33.95)
Memory usage on this node: 5.1/92.2 GiB
Using AsyncHyperBand: num_stopped=9
Bracket: Iter 8.000: -0.1386803835630417 | Iter 4.000: -0.12189790233969688 | Iter 2.000: -0.13475658744573593 | Iter 1.000: -0.1955309510231018
Resources requested: 1.0/48 CPUs, 0/0 GPUs, 0.0/54.8 GiB heap, 0.0/27.4 GiB objects
Current best trial: 9942d_00009 with loss=0.13544251024723053 and parameters={'layer_1_size': 64, 'layer_2_size': 64, 'lr': 0.002441834755046206, 'batch_size': 32}
Result logdir: /root/ray_results/tune_mnist_asha
Number of trials: 10/10 (1 RUNNING, 9 TERMINATED)
+------------------------------+------------+----------------------+----------------+----------------+-------------+--------------+----------+-----------------+----------------------+
| Trial name                   | status     | loc                  |   layer_1_size |   layer_2_size |          lr |   batch_size |     loss |   mean_accuracy |   training_iteration |
|------------------------------+------------+----------------------+----------------+----------------+-------------+--------------+----------+-----------------+----------------------|
| train_mnist_tune_9942d_00008 | RUNNING    | 192.168.129.251:6984 |             64 |             64 | 0.00179496  |           32 |          |                 |                      |
| train_mnist_tune_9942d_00000 | TERMINATED | 192.168.129.251:6936 |            128 |            128 | 0.000983265 |           64 | 0.143191 |        0.973497 |                   10 |
| train_mnist_tune_9942d_00001 | TERMINATED | 192.168.129.251:6966 |             32 |             64 | 0.00461906  |          128 | 0.17555  |        0.949609 |                    1 |
| train_mnist_tune_9942d_00002 | TERMINATED | 192.168.129.251:6968 |            128 |             64 | 0.000460251 |           32 | 0.140191 |        0.967158 |                    8 |
| train_mnist_tune_9942d_00003 | TERMINATED | 192.168.129.251:6970 |             32 |            128 | 0.000137986 |           64 | 0.329587 |        0.903877 |                    1 |
| train_mnist_tune_9942d_00004 | TERMINATED | 192.168.129.251:6972 |            128 |             64 | 0.000414967 |          128 | 0.224804 |        0.931641 |                    1 |
| train_mnist_tune_9942d_00005 | TERMINATED | 192.168.129.251:6978 |            128 |            256 | 0.0140905   |           64 | 0.40053  |        0.868078 |                    1 |
| train_mnist_tune_9942d_00006 | TERMINATED | 192.168.129.251:6980 |             64 |            128 | 0.00934191  |           32 | 0.276521 |        0.924164 |                    1 |
| train_mnist_tune_9942d_00007 | TERMINATED | 192.168.129.251:6982 |             32 |             64 | 0.00101897  |           32 | 0.150974 |        0.953225 |                    2 |
| train_mnist_tune_9942d_00009 | TERMINATED | 192.168.129.251:6986 |             64 |             64 | 0.00244183  |           32 | 0.135443 |        0.960589 |                    2 |
+------------------------------+------------+----------------------+----------------+----------------+-------------+--------------+----------+-----------------+----------------------+

Versions / Dependencies

pytorch-lightning       1.7.3
torch                   1.10.0
ray                     2.0.0

Python 3.7.11

Ubuntu 18.04

Reproduction script

Just run the official example https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/mnist_pytorch_lightning.py
python mnist_pytorch_lightning.py

Issue Severity

High: It blocks me from completing my task.

@marcmk6 marcmk6 added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Aug 31, 2022
@marcmk6 marcmk6 changed the title [tune] tune stops to start new trials [tune] tune stops to run trials Aug 31, 2022
@marcmk6 marcmk6 changed the title [tune] tune stops to run trials [tune] tune stops to run trials in Pytorch-lightning example Aug 31, 2022
@marcmk6 marcmk6 changed the title [tune] tune stops to run trials in Pytorch-lightning example [tune] tune stops to start new trials Sep 2, 2022
@krfricke krfricke self-assigned this Sep 6, 2022
@krfricke krfricke added P1 Issue that should be fixed within a few weeks tune Tune-related issues and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Sep 6, 2022
@krfricke
Copy link
Contributor

krfricke commented Sep 6, 2022

Does this always happen for you? Can you use py-spy to see where exactly the program is hanging?

@marcmk6
Copy link
Author

marcmk6 commented Sep 7, 2022

Hi Kai,

Please download this py-spy flame graph to get an interactive viewing.

I tried 10-trials tuning for several times but wasn't able to reproduce it. Then the problem occurred for every running of 50 trials that I ran.

    if args.smoke_test:
        tune_mnist_asha(num_samples=1, num_epochs=6, gpus_per_trial=0, data_dir=data_dir)
        tune_mnist_pbt(num_samples=1, num_epochs=6, gpus_per_trial=0, data_dir=data_dir)
    else:
        # ASHA scheduler
        tune_mnist_asha(num_samples=50, num_epochs=10, gpus_per_trial=0, data_dir=data_dir)
        # Population based training
        # tune_mnist_pbt(num_samples=10, num_epochs=10, gpus_per_trial=0, data_dir=data_dir)

@krfricke
Copy link
Contributor

krfricke commented Sep 7, 2022

I can reproduce the error (sometimes...).

The main training thread of a hanging process looks like this:

Thread 46278 (idle): "Thread-19"
    poll (multiprocessing/popen_fork.py:27)
    wait (multiprocessing/popen_fork.py:43)
    join (multiprocessing/process.py:149)
    _terminate_pool (multiprocessing/pool.py:729)
    __call__ (multiprocessing/util.py:224)
    terminate (multiprocessing/pool.py:654)
    __exit__ (multiprocessing/pool.py:736)
    num_cuda_devices (pytorch_lightning/utilities/device_parser.py:348)
    is_available (pytorch_lightning/accelerators/cuda.py:91)
    _log_device_info (pytorch_lightning/trainer/trainer.py:1740)
    _setup_on_init (pytorch_lightning/trainer/trainer.py:619)
    __init__ (pytorch_lightning/trainer/trainer.py:534)
    insert_env_defaults (pytorch_lightning/utilities/argparse.py:345)
    train_mnist_tune_checkpoint (3580892420.py:30)
    inner (ray/tune/trainable/util.py:359)
    _trainable_func (ray/tune/trainable/function_trainable.py:684)
    _resume_span (ray/util/tracing/tracing_helper.py:466)
    entrypoint (ray/tune/trainable/function_trainable.py:362)
    run (ray/tune/trainable/function_trainable.py:289)
    _bootstrap_inner (threading.py:954)
    _bootstrap (threading.py:912)

This hangs in pytorch lightnings device lookup here: https://github.com/Lightning-AI/lightning/blob/master/src/pytorch_lightning/utilities/device_parser.py#L334

The respective process looks like this:

Thread 46190 (idle): "Thread-11"
    _blocking (grpc/_channel.py:933)
    __call__ (grpc/_channel.py:944)
    close (ray/_private/gcs_pubsub.py:281)
    disconnect (ray/_private/worker.py:2118)
    shutdown (ray/_private/worker.py:1581)
    wrapper (ray/_private/client_mode_hook.py:105)
    sigterm_handler (ray/_private/worker.py:752)
    __getattr__ (ray/_private/utils.py:433)
    _flush_std_streams (multiprocessing/util.py:435)
    _bootstrap (multiprocessing/process.py:335)
    _launch (multiprocessing/popen_fork.py:71)
    __init__ (multiprocessing/popen_fork.py:19)
    _Popen (multiprocessing/context.py:277)
    start (multiprocessing/process.py:121)
    _repopulate_pool_static (multiprocessing/pool.py:326)
    _repopulate_pool (multiprocessing/pool.py:303)
    __init__ (multiprocessing/pool.py:212)
    Pool (multiprocessing/context.py:119)
    num_cuda_devices (pytorch_lightning/utilities/device_parser.py:347)
    is_available (pytorch_lightning/accelerators/cuda.py:91)
    _log_device_info (pytorch_lightning/trainer/trainer.py:1740)
    _setup_on_init (pytorch_lightning/trainer/trainer.py:619)
    __init__ (pytorch_lightning/trainer/trainer.py:534)
    insert_env_defaults (pytorch_lightning/utilities/argparse.py:345)
    train_mnist_tune_checkpoint (3580892420.py:30)
    inner (ray/tune/trainable/util.py:359)
    _trainable_func (ray/tune/trainable/function_trainable.py:684)
    _resume_span (ray/util/tracing/tracing_helper.py:466)
    entrypoint (ray/tune/trainable/function_trainable.py:362)
    run (ray/tune/trainable/function_trainable.py:289)
    _bootstrap_inner (threading.py:954)
    _bootstrap (threading.py:912)

Generally it seems the multiprocessing pool interacts with Ray (i.e. triggers a worker exit), which blocks a GCS update forever as the main process does not exit. I'll try to create a minimal repro for this.

@krfricke
Copy link
Contributor

krfricke commented Sep 7, 2022

In the meantime, you should be able to use this as a workaround:

import ray
ray.init(runtime_env={"env_vars": {"PL_DISABLE_FORK": "1"}})

(add somewhere at the top)

@marcmk6
Copy link
Author

marcmk6 commented Sep 14, 2022

In the meantime, you should be able to use this as a workaround:

import ray
ray.init(runtime_env={"env_vars": {"PL_DISABLE_FORK": "1"}})

(add somewhere at the top)

Yes it works!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks tune Tune-related issues
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants