Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ray.exceptions.ActorDiedError: The actor died unexpectedly before finishing this task. #8

Closed
SefaZeng opened this issue Feb 12, 2025 · 8 comments

Comments

@SefaZeng
Copy link

SefaZeng commented Feb 12, 2025

I tried to run the run_deepscaler_1.5b_8k.sh script following the steps in the README. However, I always encounter this error.
I thought it was due to CUDA OOM, so I tried reducing the length from 8192 to 1024, but the error still persists.
I'm wondering if you've ever encountered this issue and how you resolved it. I am runnig this code on 8xH20(96G).
Here is the logs:

Model config after override: Qwen2Config {
(WorkerDict pid=440680)   "_name_or_path": "/models/DeepSeek-R1-Distill-Qwen-1.5B",
(WorkerDict pid=440680)   "architectures": [
(WorkerDict pid=440680)     "Qwen2ForCausalLM"
(WorkerDict pid=440680)   ],
(WorkerDict pid=440680)   "attention_dropout": 0.0,
(WorkerDict pid=440680)   "bos_token_id": 151646,
(WorkerDict pid=440680)   "eos_token_id": 151643,
(WorkerDict pid=440680)   "hidden_act": "silu",
(WorkerDict pid=440680)   "hidden_size": 1536,
(WorkerDict pid=440680)   "initializer_range": 0.02,
(WorkerDict pid=440680)   "intermediate_size": 8960,
(WorkerDict pid=440680)   "max_position_embeddings": 131072,
(WorkerDict pid=440680)   "max_window_layers": 21,
(WorkerDict pid=440680)   "model_type": "qwen2",
(WorkerDict pid=440680)   "num_attention_heads": 12,
(WorkerDict pid=440680)   "num_hidden_layers": 28,
(WorkerDict pid=440680)   "num_key_value_heads": 2,
(WorkerDict pid=440680)   "pad_token_id": 151643,
(WorkerDict pid=440680)   "rms_norm_eps": 1e-06,
(WorkerDict pid=440680)   "rope_scaling": null,
(WorkerDict pid=440680)   "rope_theta": 10000,
(WorkerDict pid=440680)   "sliding_window": null,
(WorkerDict pid=440680)   "tie_word_embeddings": false,
(WorkerDict pid=440680)   "torch_dtype": "bfloat16",
(WorkerDict pid=440680)   "transformers_version": "4.47.1",
(WorkerDict pid=440680)   "use_cache": true,
(WorkerDict pid=440680)   "use_mrope": false,
(WorkerDict pid=440680)   "use_sliding_window": false,
(WorkerDict pid=440680)   "vocab_size": 151936
(WorkerDict pid=440680) }
(WorkerDict pid=440680) 
(WorkerDict pid=440680) Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2ForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
(pid=440907) /usr/local/lib/python3.10/dist-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash: [repeated 6x across cluster]
(pid=440907) No module named 'vllm._version' [repeated 6x across cluster]
(pid=440907)   from vllm.version import __version__ as VLLM_VERSION [repeated 6x across cluster]
(WorkerDict pid=440912) You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`. [repeated 7x across cluster]
(WorkerDict pid=440680) Qwen2ForCausalLM contains 1.78B parameters
(WorkerDict pid=440907) Total steps: 9420, num_warmup_steps: 0
(WorkerDict pid=440911) wrap_policy: functools.partial(<function transformer_auto_wrap_policy at 0x7f59b37825f0>, transformer_layer_cls={<class 'transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer'>}) [repeated 15x across cluster]
(WorkerDict pid=440680) Before building vllm rollout, memory allocated (GB): 0.8275494575500488, memory reserved (GB): 6.662109375
(WorkerDict pid=440913) INFO 02-12 03:20:39 config.py:1005] Chunked prefill is enabled with max_num_batched_tokens=8192.
(WorkerDict pid=440913) WARNING 02-12 03:20:39 config.py:380] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
(WorkerDict pid=440680) Actor use_remove_padding=True [repeated 15x across cluster]
(WorkerDict pid=440680) Total steps: 9420, num_warmup_steps: 0 [repeated 7x across cluster]
(WorkerDict pid=440913) local rank 0
(WorkerDict pid=440909) NCCL version 2.20.5+cuda12.4
(WorkerDict pid=440680) before init cache memory allocated: 4.515368448GB, reserved: 4.664066048GB
(WorkerDict pid=440680) after init cache memory allocated: 91.069025792GB, reserved: 91.217723392GB
(WorkerDict pid=440912) INFO 02-12 03:20:39 config.py:1005] Chunked prefill is enabled with max_num_batched_tokens=8192. [repeated 7x across cluster]
(WorkerDict pid=440912) WARNING 02-12 03:20:39 config.py:380] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used [repeated 7x across cluster]
(WorkerDict pid=440912) local rank 0 [repeated 7x across cluster]
(WorkerDict pid=440912) NCCL version 2.20.5+cuda12.4 [repeated 6x across cluster]
(WorkerDict pid=440680) kwargs: {'n': 8, 'logprobs': 1, 'max_tokens': 1024, 'detokenize': False, 'temperature': 0.6, 'top_k': -1, 'top_p': 1, 'ignore_eos': False}
(WorkerDict pid=440680) After building vllm rollout, memory allocated (GB): 81.49942827224731, memory reserved (GB): 84.953125
(WorkerDict pid=440680) After building sharding manager, memory allocated (GB): 81.49942827224731, memory reserved (GB): 84.953125
(WorkerDict pid=440680) /usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
(WorkerDict pid=440680)   warnings.warn(
(WorkerDict pid=440911) Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2ForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)` [repeated 7x across cluster]
(main_task pid=403722) Using LocalLogger is deprecated. The constructor API will change 
(main_task pid=403722) Validation: Generation end.
(WorkerDict pid=440913) kwargs: {'n': 8, 'logprobs': 1, 'max_tokens': 1024, 'detokenize': False, 'temperature': 0.6, 'top_k': -1, 'top_p': 1, 'ignore_eos': False} [repeated 7x across cluster]
(main_task pid=403722) "Initial validation metrics: {'val/test_score/': 0.0}"
(main_task pid=403722) step:0 - val/test_score/:0.000
(WorkerDict pid=440907) *** SIGFPE received at time=1739330472 on cpu 160 ***
(WorkerDict pid=440907) PC: @     0x7eff790ff921  (unknown)  (unknown)
(WorkerDict pid=440907)     @     0x7f2f84ec3520  (unknown)  (unknown)
(WorkerDict pid=440907) [2025-02-12 03:21:12,544 E 440907 440907] logging.cc:440: *** SIGFPE received at time=1739330472 on cpu 160 ***
(WorkerDict pid=440907) [2025-02-12 03:21:12,544 E 440907 440907] logging.cc:440: PC: @     0x7eff790ff921  (unknown)  (unknown)
(WorkerDict pid=440907) [2025-02-12 03:21:12,544 E 440907 440907] logging.cc:440:     @     0x7f2f84ec3520  (unknown)  (unknown)
(WorkerDict pid=440907) Fatal Python error: Floating point exception
(WorkerDict pid=440907) 
(WorkerDict pid=440907) Stack (most recent call first):
(WorkerDict pid=440907)   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/vocab_parallel_embedding.py", line 40 in apply
(WorkerDict pid=440907)   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/logits_processor.py", line 83 in _get_logits
(WorkerDict pid=440907)   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/logits_processor.py", line 61 in forward
(WorkerDict pid=440907)   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562 in _call_impl
(WorkerDict pid=440907)   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553 in _wrapped_call_impl
(WorkerDict pid=440907)   File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/qwen2.py", line 424 in compute_logits
(WorkerDict pid=440907)   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 1698 in execute_model
(WorkerDict pid=440907)   File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner_base.py", line 116 in _wrapper
(WorkerDict pid=440907)   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
(WorkerDict pid=440907)   File "/rl/deepscaler/verl/verl/third_party/vllm/vllm_v_0_6_3/worker.py", line 267 in execute_model
(WorkerDict pid=440907)   File "/rl/deepscaler/verl/verl/third_party/vllm/vllm_v_0_6_3/spmd_gpu_executor.py", line 163 in execute_model
(WorkerDict pid=440907)   File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 1386 in step
(WorkerDict pid=440907)   File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py", line 879 in _run_engine
(WorkerDict pid=440907)   File "/rl/deepscaler/verl/verl/third_party/vllm/vllm_v_0_6_3/llm.py", line 166 in _run_engine
(WorkerDict pid=440907)   File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py", line 353 in generate
(WorkerDict pid=440907)   File "/usr/local/lib/python3.10/dist-packages/vllm/utils.py", line 1063 in inner
(WorkerDict pid=440907)   File "/rl/deepscaler/verl/verl/workers/rollout/vllm_rollout/vllm_rollout.py", line 205 in generate_sequences
(WorkerDict pid=440907)   File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116 in decorate_context
(WorkerDict pid=440907)   File "/rl/deepscaler/verl/verl/workers/fsdp_workers.py", line 443 in generate_sequences
(WorkerDict pid=440907)   File "/rl/deepscaler/verl/verl/single_controller/base/decorator.py", line 404 in inner
(WorkerDict pid=440907)   File "/rl/deepscaler/verl/verl/single_controller/ray/base.py", line 399 in func
(WorkerDict pid=440907)   File "/usr/local/lib/python3.10/dist-packages/ray/util/tracing/tracing_helper.py", line 467 in _resume_span
(WorkerDict pid=440907)   File "/usr/local/lib/python3.10/dist-packages/ray/_private/function_manager.py", line 696 in actor_method_executor
(WorkerDict pid=440907)   File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 917 in main_loop
(WorkerDict pid=440907)   File "/usr/local/lib/python3.10/dist-packages/ray/_private/workers/default_worker.py", line 289 in <module>
(WorkerDict pid=440907) 
(WorkerDict pid=440907) Extension modules: msgpack._cmsgpack, google._upb._message, psutil._psutil_linux, psutil._psutil_posix, setproctitle, yaml._yaml, zstandard.backend_c, charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, uvloop.loop, ray._raylet, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, markupsafe._speedups, PIL._imaging, sentencepiece._sentencepiece, sklearn.__check_build._check_build, scipy._lib._ccallback_c, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg.cython_blas, scipy.linalg._matfuncs_expm, scipy.linalg._decomp_update, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.linalg._propack._spropack, scipy.sparse.linalg._propack._dpropack, scipy.sparse.linalg._propack._cpropack, scipy.sparse.linalg._propack._zpropack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.special._ufuncs_cxx, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.special._ellip_harm_2, scipy.spatial._ckdtree, scipy._lib.messagestream, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._distance_wrap, scipy.spatial._hausdorff, scipy.spatial.transform._rotation, scipy.optimize._group_columns, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._cobyla, scipy.optimize._slsqp, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy.optimize._highs.cython.src._highs_wrapper, scipy.optimize._highs._highs_wrapper, scipy.optimize._highs.cython.src._highs_constants, scipy.optimize._highs._highs_constants, scipy.linalg._interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.optimize._direct, scipy.integrate._odepack, scipy.integrate._quadpack, scipy.integrate._vode, scipy.integrate._dop, scipy.integrate._lsoda, scipy.interpolate._fitpack, scipy.interpolate._dfitpack, scipy.interpolate._bspl, scipy.interpolate._ppoly, scipy.interpolate.interpnd, scipy.interpolate._rbfinterp_pythran, scipy.interpolate._rgi_cython, scipy.special.cython_special, scipy.stats._stats, scipy.stats._biasedurn, scipy.stats._levy_stable.levyst, scipy.stats._stats_pythran, scipy._lib._uarray._uarray, scipy.stats._ansari_swilk_statistics, scipy.stats._sobol, scipy.stats._qmc_cy, scipy.stats._mvn, scipy.stats._rcont.rcont, scipy.stats._unuran.unuran_wrapper, scipy.ndimage._nd_image, _ni_label, scipy.ndimage._ni_label, pyarrow.lib, pandas._libs.tslibs.ccalendar, pandas._libs.tslibs.np_datetime, pandas._libs.tslibs.dtypes, pandas._libs.tslibs.base, pandas._libs.tslibs.nattype, pandas._libs.tslibs.timezones, pandas._libs.tslibs.fields, pandas._libs.tslibs.timedeltas, pandas._libs.tslibs.tzconversion, pandas._libs.tslibs.timestamps, pandas._libs.properties, pandas._libs.tslibs.offsets, pandas._libs.tslibs.strptime, pandas._libs.tslibs.parsing, pandas._libs.tslibs.conversion, pandas._libs.tslibs.period, pandas._libs.tslibs.vectorized, pandas._libs.ops_dispatch, pandas._libs.missing, pandas._libs.hashtable, pandas._libs.algos, pandas._libs.interval, pandas._libs.lib, pyarrow._compute, pandas._libs.ops, pandas._libs.hashing, pandas._libs.arrays, pandas._libs.tslib, pandas._libs.sparse, pandas._libs.internals, pandas._libs.indexing, pandas._libs.index, pandas._libs.writers, pandas._libs.join, pandas._libs.window.aggregations, pandas._libs.window.indexers, pandas._libs.reshape, pandas._libs.groupby, pandas._libs.json, pandas._libs.parsers, pandas._libs.testing, sklearn.utils._isfinite, sklearn.utils.sparsefuncs_fast, sklearn.utils.murmurhash, sklearn.utils._openmp_helpers, sklearn.metrics.cluster._expected_mutual_info_fast, sklearn.preprocessing._csr_polynomial_expansion, sklearn.preprocessing._target_encoder_fast, sklearn.metrics._dist_metrics, sklearn.metrics._pairwise_distances_reduction._datasets_pair, sklearn.utils._cython_blas, sklearn.metrics._pairwise_distances_reduction._base, sklearn.metrics._pairwise_distances_reduction._middle_term_computer, sklearn.utils._heap, sklearn.utils._sorting, sklearn.metrics._pairwise_distances_reduction._argkmin, sklearn.metrics._pairwise_distances_reduction._argkmin_classmode, sklearn.utils._vector_sentinel, sklearn.metrics._pairwise_distances_reduction._radius_neighbors, sklearn.metrics._pairwise_distances_reduction._radius_neighbors_classmode, sklearn.metrics._pairwise_fast, PIL._imagingft, msgspec._core, regex._regex, multidict._multidict, yarl._quoting_c, aiohttp._helpers, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket, frozenlist._frozenlist, pyarrow._json, zmq.backend.cython._zmq (total: 194)
(WorkerDict pid=440913) /usr/local/lib/python3.10/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html . [repeated 7x across cluster]
(WorkerDict pid=440913)   warnings.warn( [repeated 7x across cluster]
(WorkerDict pid=440910) 
(WorkerDict pid=440910) 
(WorkerDict pid=440912) 
(WorkerDict pid=440912) 
(WorkerDict pid=440913) 
(WorkerDict pid=440913) 
(WorkerDict pid=440680) 
(WorkerDict pid=440680) 
(WorkerDict pid=440908) 
(WorkerDict pid=440908) 
(WorkerDict pid=440909) 
(WorkerDict pid=440909) 
(WorkerDict pid=440911) 
(WorkerDict pid=440911) 
(raylet) A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffffb8e3ef059749d0078e996a7501000000 Worker ID: cc53fbe02f29b55d8c2331c76864b2d1485cc4d45e93b48833d1c336 Node ID: 0a41023dfda59bd7d3f79884ab515bd66dc2bbcb947c6685aa354158 Worker IP address: 29.77.196.151 Worker port: 45215 Worker PID: 440907 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
Error executing job with overrides: ['algorithm.adv_estimator=grpo', 'data.train_files=/rl/deepscaler/deepscaler/data/train.parquet', 'data.val_files=/rl/deepscaler/deepscaler/data/aime.parquet', 'data.train_batch_size=128', 'data.val_batch_size=512', 'data.max_prompt_length=1024', 'data.max_response_length=1024', 'actor_rollout_ref.model.path=/models/DeepSeek-R1-Distill-Qwen-1.5B', 'actor_rollout_ref.actor.optim.lr=1e-6', 'actor_rollout_ref.model.use_remove_padding=True', 'actor_rollout_ref.actor.ppo_mini_batch_size=64', 'actor_rollout_ref.actor.use_dynamic_bsz=True', 'actor_rollout_ref.actor.ppo_max_token_len_per_gpu=32768', 'actor_rollout_ref.actor.use_kl_loss=True', 'actor_rollout_ref.actor.kl_loss_coef=0.001', 'actor_rollout_ref.actor.kl_loss_type=low_var_kl', 'actor_rollout_ref.actor.ulysses_sequence_parallel_size=1', 'actor_rollout_ref.model.enable_gradient_checkpointing=True', 'actor_rollout_ref.actor.fsdp_config.param_offload=False', 'actor_rollout_ref.actor.fsdp_config.grad_offload=False', 'actor_rollout_ref.actor.fsdp_config.optimizer_offload=False', 'actor_rollout_ref.rollout.tensor_model_parallel_size=1', 'actor_rollout_ref.rollout.name=vllm', 'actor_rollout_ref.rollout.temperature=0.6', 'actor_rollout_ref.rollout.val_temperature=0.6', 'actor_rollout_ref.rollout.gpu_memory_utilization=0.95', 'actor_rollout_ref.rollout.n=8', 'actor_rollout_ref.rollout.n_val=8', 'actor_rollout_ref.ref.fsdp_config.param_offload=True', 'algorithm.kl_ctrl.kl_coef=0.001', 'trainer.critic_warmup=0', 'trainer.logger=[console]', 'trainer.project_name=deepscaler', 'trainer.experiment_name=deepscaler-1.5b-8k', '+trainer.val_before_train=True', 'trainer.n_gpus_per_node=8', 'trainer.nnodes=1', 'trainer.save_freq=20', 'trainer.test_freq=20', 'trainer.default_hdfs_dir=null', 'trainer.total_epochs=30']
Traceback (most recent call last):
  File "/rl/deepscaler/verl/verl/trainer/main_ppo.py", line 114, in main
    ray.get(main_task.remote(config))
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2745, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 901, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(ActorDiedError): ray::main_task() (pid=403722, ip=29.77.196.151)
  File "/rl/deepscaler/verl/verl/trainer/main_ppo.py", line 200, in main_task
    trainer.fit()
  File "/rl/deepscaler/verl/verl/trainer/ppo/ray_trainer.py", line 604, in fit
    gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
  File "/rl/deepscaler/verl/verl/single_controller/ray/base.py", line 42, in func
    output = ray.get(output)
ray.exceptions.ActorDiedError: The actor died unexpectedly before finishing this task.
        class_name: create_colocated_worker_cls.<locals>.WorkerDict
        actor_id: b8e3ef059749d0078e996a7501000000
        pid: 440907
        name: 0ds8OmWorkerDict_0:1
        namespace: 4064ec1d-2d21-4b8a-a01f-7e1e188025c3
        ip: 29.77.196.151
The actor is dead because its worker process has died. Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
@michaelzhiluo
Copy link
Contributor

Try Ray version 2.41 or 2.42?

@SefaZeng
Copy link
Author

Try Ray version 2.41 or 2.42?

2.38

@michaelzhiluo
Copy link
Contributor

Upgrade!

@SefaZeng
Copy link
Author

Upgrade!

Thanks! I tried both ray 2.41 and 2.42, but I'm still getting the same error.

@michaelzhiluo
Copy link
Contributor

What version of vLLM are you using?

@SefaZeng
Copy link
Author

What version of vLLM are you using?

0.6.3

@michaelzhiluo
Copy link
Contributor

You have the same exact version I'm using atm...

Apologies, not sure how to resolve your bug, but if you do, lmk and can put bug fix in the PR?

@SefaZeng
Copy link
Author

You have the same exact version I'm using atm...

Apologies, not sure how to resolve your bug, but if you do, lmk and can put bug fix in the PR?

I've found out that it's related to H20. One can fix it by updating nvidia-cublas-cu12 following the method in this issue vllm-project/vllm#4392.

pip install nvidia-cublas-cu12==12.3.4.1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants