Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: Row Split Mode - Segmentation fault after model load on ROCm multi-gpu #9761

Closed
thamwangjun opened this issue Oct 6, 2024 · 15 comments
Closed
Labels
bug-unconfirmed critical severity Used to report critical severity bugs in llama.cpp (e.g. Crashing, Corrupted, Dataloss) stale

Comments

@thamwangjun
Copy link

What happened?

I am running on Rocm with 4 x Instinct MI100.
Only when using --split-mode row mode I get a Address boundary error.
llama.cpp was working when I had a XGMI GPU Bridge working with the 4 cards, but now the bridge is broken and am trying to run this only via PCIe.
My setup currrent passes Rocm validation suite.

llama-server --host 0.0.0.0 --p…' terminated by signal SIGSEGV (Address boundary error)

Name and Version

llama-cli --version
version: 3889 (b6d6c52)
built with cc (GCC) 14.2.1 20240910 for x86_64-pc-linux-gnu

llama-server --version
version: 3889 (b6d6c52)
built with cc (GCC) 14.2.1 20240910 for x86_64-pc-linux-gnu

What operating system are you seeing the problem on?

Linux

Relevant log output

gml_cuda_init: GGML_CUDA_FORCE_MMQ:    no
ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no
ggml_cuda_init: found 4 ROCm devices:
  Device 0: AMD Instinct MI100, compute capability 9.0, VMM: no
  Device 1: AMD Instinct MI100, compute capability 9.0, VMM: no
  Device 2: AMD Instinct MI100, compute capability 9.0, VMM: no
  Device 3: AMD Instinct MI100, compute capability 9.0, VMM: no
build: 3889 (b6d6c528) with cc (GCC) 14.2.1 20240910 for x86_64-pc-linux-gnu
system info: n_threads = 16, n_threads_batch = 16, total_threads = 32

system_info: n_threads = 16 (n_threads_batch = 16) / 32 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | SVE = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | RISCV_VECT = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 

main: HTTP server is listening, hostname: 0.0.0.0, port: 40480, http threads: 31
main: loading model
llama_model_loader: loaded meta data with 33 key-value pairs and 724 tensors from models/Meta-Llama-3.1-70B-Instruct-Q4_K_L.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Meta Llama 3.1 70B Instruct
llama_model_loader: - kv   3:                           general.finetune str              = Instruct
llama_model_loader: - kv   4:                           general.basename str              = Meta-Llama-3.1
llama_model_loader: - kv   5:                         general.size_label str              = 70B
llama_model_loader: - kv   6:                            general.license str              = llama3.1
llama_model_loader: - kv   7:                               general.tags arr[str,6]       = ["facebook", "meta", "pytorch", "llam...
llama_model_loader: - kv   8:                          general.languages arr[str,8]       = ["en", "de", "fr", "it", "pt", "hi", ...
llama_model_loader: - kv   9:                          llama.block_count u32              = 80
llama_model_loader: - kv  10:                       llama.context_length u32              = 131072
llama_model_loader: - kv  11:                     llama.embedding_length u32              = 8192
llama_model_loader: - kv  12:                  llama.feed_forward_length u32              = 28672
llama_model_loader: - kv  13:                 llama.attention.head_count u32              = 64
llama_model_loader: - kv  14:              llama.attention.head_count_kv u32              = 8
llama_model_loader: - kv  15:                       llama.rope.freq_base f32              = 500000.000000
llama_model_loader: - kv  16:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  17:                          general.file_type u32              = 15
llama_model_loader: - kv  18:                           llama.vocab_size u32              = 128256
llama_model_loader: - kv  19:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv  20:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  21:                         tokenizer.ggml.pre str              = llama-bpe
llama_model_loader: - kv  22:                      tokenizer.ggml.tokens arr[str,128256]  = ["!", "\"", "#", "$", "%", "&", "'", ...
llama_model_loader: - kv  23:                  tokenizer.ggml.token_type arr[i32,128256]  = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  24:                      tokenizer.ggml.merges arr[str,280147]  = ["Ġ Ġ", "Ġ ĠĠĠ", "ĠĠ ĠĠ", "...
llama_model_loader: - kv  25:                tokenizer.ggml.bos_token_id u32              = 128000
llama_model_loader: - kv  26:                tokenizer.ggml.eos_token_id u32              = 128009
llama_model_loader: - kv  27:                    tokenizer.chat_template str              = {{- bos_token }}\n{%- if custom_tools ...
llama_model_loader: - kv  28:               general.quantization_version u32              = 2
llama_model_loader: - kv  29:                      quantize.imatrix.file str              = /models_out/Meta-Llama-3.1-70B-Instru...
llama_model_loader: - kv  30:                   quantize.imatrix.dataset str              = /training_dir/calibration_datav3.txt
llama_model_loader: - kv  31:             quantize.imatrix.entries_count i32              = 560
llama_model_loader: - kv  32:              quantize.imatrix.chunks_count i32              = 125
llama_model_loader: - type  f32:  162 tensors
llama_model_loader: - type q8_0:    2 tensors
llama_model_loader: - type q4_K:  440 tensors
llama_model_loader: - type q5_K:   40 tensors
llama_model_loader: - type q6_K:   80 tensors
llm_load_vocab: special tokens cache size = 256
llm_load_vocab: token to piece cache size = 0.7999 MB
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 128256
llm_load_print_meta: n_merges         = 280147
llm_load_print_meta: vocab_only       = 0
llm_load_print_meta: n_ctx_train      = 131072
llm_load_print_meta: n_embd           = 8192
llm_load_print_meta: n_layer          = 80
llm_load_print_meta: n_head           = 64
llm_load_print_meta: n_head_kv        = 8
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_swa            = 0
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 8
llm_load_print_meta: n_embd_k_gqa     = 1024
llm_load_print_meta: n_embd_v_gqa     = 1024
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 28672
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 500000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_ctx_orig_yarn  = 131072
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 0
llm_load_print_meta: ssm_d_inner      = 0
llm_load_print_meta: ssm_d_state      = 0
llm_load_print_meta: ssm_dt_rank      = 0
llm_load_print_meta: ssm_dt_b_c_rms   = 0
llm_load_print_meta: model type       = 70B
llm_load_print_meta: model ftype      = Q4_K - Medium
llm_load_print_meta: model params     = 70.55 B
llm_load_print_meta: model size       = 40.32 GiB (4.91 BPW) 
llm_load_print_meta: general.name     = Meta Llama 3.1 70B Instruct
llm_load_print_meta: BOS token        = 128000 '<|begin_of_text|>'
llm_load_print_meta: EOS token        = 128009 '<|eot_id|>'
llm_load_print_meta: LF token         = 128 'Ä'
llm_load_print_meta: EOT token        = 128009 '<|eot_id|>'
llm_load_print_meta: EOM token        = 128008 '<|eom_id|>'
llm_load_print_meta: EOG token        = 128008 '<|eom_id|>'
llm_load_print_meta: EOG token        = 128009 '<|eot_id|>'
llm_load_print_meta: max token length = 256
llm_load_tensors: ggml ctx size =    1.02 MiB
llm_load_tensors: offloading 80 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 81/81 layers to GPU
llm_load_tensors: ROCm_Split buffer size = 40217.12 MiB
llm_load_tensors:      ROCm0 buffer size =     5.05 MiB
llm_load_tensors:        CPU buffer size =  1064.62 MiB
.................................................................................................
llama_new_context_with_model: n_ctx      = 131072
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 1
llama_new_context_with_model: freq_base  = 500000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      ROCm0 KV buffer size = 11520.00 MiB
llama_new_context_with_model: KV self size  = 11520.00 MiB, K (q4_0): 5760.00 MiB, V (q4_0): 5760.00 MiB
llama_new_context_with_model:  ROCm_Host  output buffer size =     0.98 MiB
llama_new_context_with_model:      ROCm0 compute buffer size =   448.00 MiB
llama_new_context_with_model:  ROCm_Host compute buffer size =   272.01 MiB
llama_new_context_with_model: graph nodes  = 2247
llama_new_context_with_model: graph splits = 2
llama_init_from_gpt_params: warming up the model with an empty run - please wait ... (--no-warmup to disable)
fish: Job 1, 'llama-server --host 0.0.0.0 --p…' terminated by signal SIGSEGV (Address boundary error)
@thamwangjun thamwangjun added bug-unconfirmed critical severity Used to report critical severity bugs in llama.cpp (e.g. Crashing, Corrupted, Dataloss) labels Oct 6, 2024
@JohannesGaessler
Copy link
Collaborator

Did this work with a previous llama.cpp version? If so, with which commit did it stop working?

@thamwangjun
Copy link
Author

@JohannesGaessler This is a recent machine, so I can't say. But it was working fine only if AMD's XGMI GPU interconnect link was working (that allows peer 2 peer GPU communication), at b3870. Now without XGMI (I have removed the bridge) at b3870 it does not work as well on that version.
Note that this Seg Fault only happens when split mode is row. On layer mode it works fine (but with much lower performance).
I will try to see if I can gather more data from this, since it is very reproducible on my end.

@thamwangjun
Copy link
Author

@JohannesGaessler I have a core dump, would it help?

@JohannesGaessler
Copy link
Collaborator

A core dump would probably not be of much use. If it worked with the physical link the problem likely has to do with peer access getting automatically enabled/disabled based on the HIP implementation of cudaCanAccessPeer. And depending on the state of that there likely is a segmentation fault during one of the memcpys between devices.

For debugging I would like you to try the following two edits to ggml_cuda_set_peer_access in ggml/src/ggml-cuda.cu:

  1. Remove everything in the #ifdef NDEBUG block.
  2. Add a small delay when toggling peer access like in this PR.

In principle there could also be issues if multiple threads were to enter that function at the same time but to my knowledge that shouldn't be happening (@slaren correct me if I'm wrong).

thamwangjun added a commit to thamwangjun/ggerganov-llama.cpp that referenced this issue Oct 8, 2024
@thamwangjun
Copy link
Author

@JohannesGaessler The proposed fix did not work, forked and tried it on a fresh repo. I tried --threads 1 as well and it had the same Address Boundary Error. Is there anything else I can try?

@JohannesGaessler
Copy link
Collaborator

I don't know anything else to try. You can upload a core dump but realistically I don't think this will be fixed in the foreseeable future. AMD hardware is truth be told quite poorly supported.

@thamwangjun
Copy link
Author

@JohannesGaessler I got a backtrace, not sure if you can interpret it:

This is on a build at dca1d4b.

Thread 1 "llama-server" received signal SIGSEGV, Segmentation fault.
0x00007fff9fb241c2 in ?? () from /opt/rocm/lib/libamdhip64.so.6
(gdb) backtrace
#0  0x00007fff9fb241c2 in ?? () from /opt/rocm/lib/libamdhip64.so.6
#1  0x00007fff9fb2d3ff in ?? () from /opt/rocm/lib/libamdhip64.so.6
#2  0x00007fff9fb2ea07 in ?? () from /opt/rocm/lib/libamdhip64.so.6
#3  0x00007fff9fb0b484 in ?? () from /opt/rocm/lib/libamdhip64.so.6
#4  0x00007fff9f9c0618 in ?? () from /opt/rocm/lib/libamdhip64.so.6
#5  0x00007fff9f9c092d in ?? () from /opt/rocm/lib/libamdhip64.so.6
#6  0x00007fff9f9ca60d in hipMemcpy2DAsync () from /opt/rocm/lib/libamdhip64.so.6
#7  0x00007ffff77b1100 in ggml_cuda_Memcpy2DPeerAsync (dst=0x0, dpitch=<optimized out>, src=<optimized out>, srcDevice=<optimized out>, spitch=<optimized out>, width=<optimized out>, height=<optimized out>, stream=<optimized out>,
    dstDevice=<optimized out>) at /home/thamw/development/github/llama.cpp/ggml/src/ggml-cuda.cu:1354
#8  ggml_cuda_op_mul_mat (ctx=..., src0=<optimized out>, src1=<optimized out>, dst=<optimized out>, op=<optimized out>, quantize_src1=<optimized out>) at /home/thamw/development/github/llama.cpp/ggml/src/ggml-cuda.cu:1625
#9  ggml_cuda_mul_mat (ctx=..., src0=<optimized out>, src0@entry=0x55557f5befc0, src1=<optimized out>, dst=<optimized out>, dst@entry=0x55557be5ae10) at /home/thamw/development/github/llama.cpp/ggml/src/ggml-cuda.cu:1941
#10 0x00007ffff77acbd2 in ggml_cuda_compute_forward (ctx=..., dst=0x55557be5ae10) at /home/thamw/development/github/llama.cpp/ggml/src/ggml-cuda.cu:2259
#11 ggml_backend_cuda_graph_compute (backend=<optimized out>, cgraph=0x5555809e3538) at /home/thamw/development/github/llama.cpp/ggml/src/ggml-cuda.cu:2667
#12 0x00007ffff771d6d9 in ggml_backend_graph_compute_async (backend=0x55557b146840, cgraph=0x5555809e3538) at /home/thamw/development/github/llama.cpp/ggml/src/ggml-backend.cpp:315
#13 0x00007ffff77229d0 in ggml_backend_sched_compute_splits (sched=0x555580315b30) at /home/thamw/development/github/llama.cpp/ggml/src/ggml-backend.cpp:2139
#14 0x00007ffff772364d in ggml_backend_sched_graph_compute_async (sched=0x555580315b30, graph=0x55557be19ab0) at /home/thamw/development/github/llama.cpp/ggml/src/ggml-backend.cpp:2327
#15 0x00007ffff7c7e96b in llama_graph_compute (lctx=..., gf=0x55557be19ab0, n_threads=16, threadpool=0x0) at /home/thamw/development/github/llama.cpp/src/llama.cpp:17056
#16 0x00007ffff7c7f4ee in llama_decode_internal (lctx=..., batch_all=...) at /home/thamw/development/github/llama.cpp/src/llama.cpp:17243
#17 0x00007ffff7c8da0d in llama_decode (ctx=0x5555805dd520, batch=...) at /home/thamw/development/github/llama.cpp/src/llama.cpp:21141
#18 0x00005555557109f9 in llama_init_from_gpt_params (params=...) at /home/thamw/development/github/llama.cpp/common/common.cpp:951
#19 0x0000555555609cd3 in server_context::load_model (this=0x7fffffffcf00, params_=...) at /home/thamw/development/github/llama.cpp/examples/server/server.cpp:671
#20 0x00005555555e093d in main (argc=24, argv=0x7fffffffe548) at /home/thamw/development/github/llama.cpp/examples/server/server.cpp:3339

@thamwangjun
Copy link
Author

Which brings us to this line here:

return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);

🤔

@hjc4869
Copy link

hjc4869 commented Oct 9, 2024

I've encountered this issue before with multiple Radeon GPUs on Debian. That was caused by the lack of CONFIG_PCI_P2PDMA and CONFIG_HSA_AMD_P2P in kernel config.

It's worth checking with something like rocm-bandwidth-test to see if PCIe P2P DMA is working properly.

@Andryusz
Copy link

@hjc4869 Thanks for the hint - it resolved the issue at least for me.

By the way, after update from rocm 6.0 to 6.2.1 I no longer had crash in llamacpp, but the models would just produce garbage with row split mode. Anyway, custom kernel build with CONFIG_HSA_AMD_P2P enabled resolved the issue completely.

@hjc4869
Copy link

hjc4869 commented Oct 11, 2024

@hjc4869 Thanks for the hint - it resolved the issue at least for me.

By the way, after update from rocm 6.0 to 6.2.1 I no longer had crash in llamacpp, but the models would just produce garbage with row split mode. Anyway, custom kernel build with CONFIG_HSA_AMD_P2P enabled resolved the issue completely.

I can repro that locally with my 2*W7900DS setup. Sounds like newer ROCm having problems here, I guess it's not checking P2P DMA availability or not dealing with the lack of P2P correctly, and caused data corruption.

Maybe we can document these hiccups for row splitting on ROCm somewhere to save some headaches for others.

@thamwangjun
Copy link
Author

@hjc4869 thanks as well, I am now building a new kernel with the configs you have described. I will get back, I hope this will work! 🤞

@thamwangjun
Copy link
Author

Thanks @hjc4869, I have verified that it is working now.
I need to have both CONFIG_HSA_AMD_P2P and CONFIG_DMABUF_MOVE_NOTIFY enabled in the new built linux kernel for it to work. CONFIG_DMABUF_MOVE_NOTIFY is an experimental flag and is a dependency of CONFIG_HSA_AMD_P2P.

@hjc4869 Do you know how the PR should be done to document this? Is a new HIP.md in llama.cpp/docs/backend/ appropriate?

@hjc4869
Copy link

hjc4869 commented Oct 13, 2024

Glad to hear it's working. Haven't done contributions to this project before so I'm not sure about that, maybe we can find some examples in the already merged PR.

@github-actions github-actions bot added the stale label Nov 13, 2024
Copy link
Contributor

This issue was closed because it has been inactive for 14 days since being marked as stale.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug-unconfirmed critical severity Used to report critical severity bugs in llama.cpp (e.g. Crashing, Corrupted, Dataloss) stale
Projects
None yet
Development

No branches or pull requests

4 participants