-
Notifications
You must be signed in to change notification settings - Fork 641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama 3.1 8b f16 sharded TP8 compiles but fails to run #19428
Comments
So a quick look at this. It looks like data is being copied between devices by simply device.queue_copy (or similar) passing a reference to a buffer created on one device to another device to perform the copy. (This is based on behavior, not looking at the output IR) . I think this is what is happening because if I enable peering across all devices that are created, the problem goes away (since device_ptrs are now valid across devices). However, from my understanding of the API, this is not really how we expect it to be used. Instead of 8 logical devices (each backed by a single physical device) we should be using a single logical device backed by 8 physical devices (where each physical device would be represented by a queue on the logical device). This allows us to set up peering explicitly between requested devices. @benvanik Is my understanding of this correct? |
https://gist.github.com/AWoloszyn/630716214cf64e936c4c6cca077c2752 |
We want to be using multiple physical devices via one logical HAL device in most cases but all existing frameworks do things the other way around today so we'll likely have to support both well enough :( I believe the API we want to use is cuMemcpyPeerAsync - that supports device to device transfers of pointers that are local to each device. #19160 (which I need to land!) adds tracking of which device and queue affinity a buffer was allocated for and could be used to detect peer transfers and route to cuMemcpyPeerAsync when appropriate. (this support would be a new feature instead of a bug fix, so something that can happen after the bigger rewrite PR that's outstanding) |
Dug through hip to figure out where cuMemcpyPeerAsync bottomed out, looks like here: https://github.com/ROCm/clr/blob/7c9c7a6332f69f740f59aaaeb83ad6eeff4598d6/rocclr/device/rocm/rocvirtual.cpp#L2240 If the devices are peered it just does a normal copy. Explains why your patch works! I'm honestly not sure of the risk of enabling peering between devices we may not use as peers beyond memory allocation overhead as the virtual address spaces are unified (but that may be the case anyway). So if we find that peering isn't a cost when not used we could just peer all the devices. |
What happened?
Compiling and running the llama 3.1 8b f16 TP8 IR with
iree-run-module
, we get this error:Steps to reproduce your issue
../iree-build-no-trace/tools/iree-compile 8b_f16_bs4_tp8_nondecomposed_prefill.mlir --iree-hip-target=gfx942 -o=prefill_8b_tp8_12_5.vmfb --iree-hal-target-device=hip[0] --iree-hal-target-device=hip[1] --iree-hal-target-device=hip[2] --iree-hal-target-device=hip[3] --iree-hal-target-device=hip[4] --iree-hal-target-device=hip[5] --iree-hal-target-device=hip[6] --iree-hal-target-device=hip[7] --iree-dispatch-creation-enable-aggressive-fusion=true --iree-global-opt-propagate-transposes=true --iree-opt-aggressively-propagate-transposes=true --iree-opt-data-tiling=false --iree-preprocessing-pass-pipeline='builtin.module(util.func(iree-preprocessing-generalize-linalg-matmul-experimental))' --iree-hal-indirect-command-buffers=true --iree-stream-resource-memory-model=discrete --iree-hip-legacy-sync=false --iree-hal-memoization=true --iree-opt-strip-assertions
ROCR_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ../iree-build-no-trace/tools/iree-run-module --hip_use_streams=true --device_allocator=caching --module=prefill_8b_tp8_12_5.vmfb --parameters=model=llama3.1_8b_fp16_tp8_parameters.irpa --parameters=model=llama3.1_8b_fp16_tp8_parameters.rank0.irpa --parameters=model=llama3.1_8b_fp16_tp8_parameters.rank1.irpa --parameters=model=llama3.1_8b_fp16_tp8_parameters.rank2.irpa --parameters=model=llama3.1_8b_fp16_tp8_parameters.rank3.irpa --parameters=model=llama3.1_8b_fp16_tp8_parameters.rank4.irpa --parameters=model=llama3.1_8b_fp16_tp8_parameters.rank5.irpa --parameters=model=llama3.1_8b_fp16_tp8_parameters.rank6.irpa --parameters=model=llama3.1_8b_fp16_tp8_parameters.rank7.irpa --device=hip://0 --device=hip://1 --device=hip://2 --device=hip://3 --device=hip://4 --device=hip://5 --device=hip://6 --device=hip://7 --function=prefill_bs4 --input=@/data/llama3.1/weights/8b/prefill_args_bs4_128/random_tokens.npy --input=@/data/llama3.1/weights/8b/prefill_args_bs4_128/seq_lens.npy --input=@/data/llama3.1/weights/8b/prefill_args_bs4_128/seq_block_ids.npy --input=@/data/llama3.1/weights/8b/prefill_args_bs4_128/cs_f16_shard_0.npy --input=@/data/llama3.1/weights/8b/prefill_args_bs4_128/cs_f16_shard_1.npy --input=@/data/llama3.1/weights/8b/prefill_args_bs4_128/cs_f16_shard_2.npy --input=@/data/llama3.1/weights/8b/prefill_args_bs4_128/cs_f16_shard_3.npy --input=@/data/llama3.1/weights/8b/prefill_args_bs4_128/cs_f16_shard_4.npy --input=@/data/llama3.1/weights/8b/prefill_args_bs4_128/cs_f16_shard_5.npy --input=@/data/llama3.1/weights/8b/prefill_args_bs4_128/cs_f16_shard_6.npy --input=@/data/llama3.1/weights/8b/prefill_args_bs4_128/cs_f16_shard_7.npy
What component(s) does this issue relate to?
Runtime
Version information
7dd6fa6 + commits from #18790
Additional context
No response
The text was updated successfully, but these errors were encountered: