Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable tp on CPU #36299

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft

enable tp on CPU #36299

wants to merge 7 commits into from

Conversation

jiqing-feng
Copy link
Contributor

@jiqing-feng jiqing-feng commented Feb 20, 2025

CPU device cannot use index.

If we pass index for cpu device, the check will never be passed

@Rocketknight1
Copy link
Member

Is there a reason we want to support TP on CPU? I assumed it would mainly be useful for multi-GPU nodes.

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Feb 21, 2025

Is there a reason we want to support TP on CPU? I assumed it would mainly be useful for multi-GPU nodes.

Intel Xeon CPU has multi numa nodes which means we can implement TP model and each part on a NUMA node. Currently we can enable the function and select gloo backend to run TP model on CPU.

Besides, we should always make sure that CPU device cannot assign index.

@Rocketknight1
Copy link
Member

In that case this change makes sense to me, but maybe we should just raise an error saying that TP on CPU is not supported yet, rather than setting index to None? cc @ArthurZucker @Cyrilvallez

@jiqing-feng
Copy link
Contributor Author

In that case this change makes sense to me, but maybe we should just raise an error saying that TP on CPU is not supported yet, rather than setting index to None? cc @ArthurZucker @Cyrilvallez

Actually the TP functionality is ready on CPU, just run with the following codes:

CMD: OMP_NUM_THREADS=56 numactl -C 0-55 -m 0 torchrun --nnodes=2 --node_rank=0 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 tp_hf.py & OMP_NUM_THREADS=56 numactl -C 56-111 -m 1 torchrun --nnodes=2 --node_rank=1 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 tp_hf.py & wait

import os
import torch.distributed as dist
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel

import time
import torch

import torch
import os

model_id = "meta-llama/Llama-3.1-8B-Instruct"

def main(is_tp, rank, world_size) -> None:
    backend = "ccl"
    print(is_tp)
    if is_tp:
        dist.init_process_group(backend)

    model_kwargs = dict(torch_dtype=torch.bfloat16)
    if is_tp:
        model_kwargs["tp_plan"] = "auto"
    else:
        model_kwargs["device_map"] = "cpu"

    # Retrieve tensor parallel model
    model = AutoModel.from_pretrained(model_id, **model_kwargs)
    print(model.dtype)

    # Prepare input tokens
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    prompt = "Can I help" * 200
    inputs = tokenizer(prompt, return_tensors="pt", max_length=512).input_ids.to(model.device)
    print(f"inpu shape is {inputs.shape}")

    # model = torch.compile(model)
    # warm-up
    dist.barrier()
    for i in range(5):
        outputs = model(inputs)

    dist.barrier()
    for i in range(5):
        with torch.no_grad():
            start = time.time()
            outputs = model(inputs)
            end = time.time()
            print(f"time cost {(end-start)*1000} ms")

    print(outputs)


if __name__ == "__main__":
    rank = int(os.environ["RANK"]) if "RANK" in os.environ else 0
    world_size = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    is_tp = "RANK" in os.environ
    main(is_tp, rank, world_size)

@Cyrilvallez
Copy link
Member

Hey @jiqing-feng! The TP code is going to change quite a bit in the near future as we work to improve loading efficiency, so it would be best to put this issue on hold for now and revisit afterwards 🤗
As a side note, did you experiment with it on your setup? Is it truly worth it/faster compared to having the model on cpu as usual? 🤔

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Feb 25, 2025

Hey @jiqing-feng! The TP code is going to change quite a bit in the near future as we work to improve loading efficiency, so it would be best to put this issue on hold for now and revisit afterwards 🤗 As a side note, did you experiment with it on your setup? Is it truly worth it/faster compared to having the model on cpu as usual? 🤔

OK, but I suppose this change is really tiny to not impact the refactor. It's okay to wait for your refactor.
For now, the performance is not as good as non-TP, but the functionality is ready, we'd like to enable the functionality first and then resolve the performance issue. Thanks.

@jiqing-feng
Copy link
Contributor Author

jiqing-feng commented Feb 26, 2025

Hi @SunMarc @Rocketknight1 @Cyrilvallez . As this change is really tiny, and the logic that cannot assign index in cpu device is reasonable, could we merge this PR? We will optimize the TP performance on CPU in our next step.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think can we can merge this without impacting the refactor cc @Cyrilvallez

@SunMarc SunMarc requested a review from Cyrilvallez February 26, 2025 13:53
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nice! it's missing:

Otherwise much welcome 🤗

@jiqing-feng
Copy link
Contributor Author

Super nice! it's missing:

Otherwise much welcome 🤗

Hi @ArthurZucker ,

  1. I will enable the CPU TP doc after we fix the performance issue
  2. I'd like to enable the cpu tests here, but the test hanged when I ran in cuda. My command is: PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py . Do you have any more detailed guides for running the test?
  3. Done.

@jiqing-feng jiqing-feng marked this pull request as draft March 6, 2025 06:58
@jiqing-feng
Copy link
Contributor Author

Convert to draft because of the new regression:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/jiqingfe/tp_hf.py", line 100, in <module>
[rank0]:     main(is_tp, rank, world_size)
[rank0]:   File "/home/jiqingfe/tp_hf.py", line 56, in main
[rank0]:     outputs = model(inputs)
[rank0]:   File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/jiqingfe/transformers/src/transformers/models/llama/modeling_llama.py", line 571, in forward
[rank0]:     position_embeddings = self.rotary_emb(hidden_states, position_ids)
[rank0]:   File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)                                                                                      [rank0]:   File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/jiqingfe/transformers/src/transformers/models/llama/modeling_llama.py", line 131, in forward
[rank0]:     with torch.autocast(device_type=device_type, enabled=False):
[rank0]:   File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 230, in __init__
[rank0]:     dtype = torch.get_autocast_dtype(device_type)
[rank0]: RuntimeError: unsupported scalarType

@ArthurZucker
Copy link
Collaborator

Sounds great! 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants