Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update form pretrained to make TP a first class citizen #36335

Merged
merged 79 commits into from
Feb 26, 2025
Merged

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Feb 21, 2025

What does this PR do?

  • Match keys with loaded safe without opening the file
  • makes sure that even when the checkpoint is sharded, we are only load our rank's slice of the tensor
  • init nccl by default for "most" use cases.
  • make device_map="auto" equivalent to tp_plan="auto" when nccl is initialized
  • Updates after Load models much faster on accelerator devices!! #36380 because when there are more than 1 device, you need a bit more memory

This follows up #31771 which was only when the model was not sharded into many checkpoints.

Before

CUDA_VISIBLE_DEVICES="2,3" python ../test_safe_load.py 
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:18<00:00,  4.53s/it]
Loading took 57.167264461517334 seconds

after

CUDA_VISIBLE_DEVICES="2,3" python ../test_safe_load.py 
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:04<00:00,  1.22s/it]
Model loading time: 16.91 seconds
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                          model_load        37.18%        5.197s       100.00%       13.979s       13.979s       0.000us         0.00%        1.946s        1.946s         256 b     -14.96 Gb      14.96 Gb     -37.23 Gb             1  
                            aten::to         0.02%       2.267ms        59.76%        8.354s       5.154ms       0.000us         0.00%        1.946s       1.201ms         268 b           0 b      22.27 Gb           0 b          1621  
                      aten::_to_copy         0.05%       6.416ms        59.75%        8.352s      11.121ms       0.000us         0.00%        1.946s       2.592ms         268 b           0 b      22.27 Gb           0 b           751  
                 aten::empty_strided        29.42%        4.112s        43.26%        6.048s       8.053ms       0.000us         0.00%       0.000us       0.000us         268 b         268 b      22.27 Gb      22.27 Gb           751  
                         aten::copy_         0.08%      11.731ms        16.44%        2.298s       3.060ms        1.946s       100.00%        1.946s       2.592ms           0 b           0 b           0 b           0 b           751  
                     cudaMemcpyAsync        13.51%        1.889s        13.51%        1.889s       4.142ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           456  
                          cudaMalloc        13.25%        1.852s        13.25%        1.852s     142.438ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            13  
          cudaDeviceEnablePeerAccess         2.75%     384.208ms         2.75%     384.208ms     384.208ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
                            Resource         1.91%     267.407ms         1.91%     267.407ms      66.852ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             4  
                         aten::empty         0.07%       9.366ms         1.45%     202.361ms     173.106us       0.000us         0.00%       0.000us       0.000us      29.92 Gb      29.92 Gb      29.92 Gb      29.92 Gb          1169  
                            cudaFree         1.39%     194.026ms         1.39%     194.026ms      21.558ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             9  
                        aten::detach         0.04%       5.907ms         0.09%      11.928ms       3.138us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          3801  
               cudaStreamSynchronize         0.08%      11.102ms         0.08%      11.102ms      37.761us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           294  
    cudaDeviceGetStreamPriorityRange         0.06%       8.988ms         0.06%       8.988ms       8.988ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
                              detach         0.04%       6.021ms         0.04%       6.021ms       1.584us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          3801  
                          aten::set_         0.04%       5.689ms         0.04%       5.689ms       9.775us       0.000us         0.00%       0.000us       0.000us     -14.96 Gb     -14.96 Gb           0 b           0 b           582  
                        aten::select         0.02%       3.463ms         0.03%       4.076ms      14.008us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           291  
                         aten::slice         0.02%       2.819ms         0.02%       3.018ms      10.371us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           291  
                          aten::view         0.02%       2.897ms         0.02%       2.897ms       4.978us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           582  
                       aten::reshape         0.01%     719.616us         0.01%       1.868ms       6.420us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           291  
                 cudaStreamWaitEvent         0.01%       1.337ms         0.01%       1.337ms       4.126us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           324  
                          aten::item         0.01%     797.166us         0.01%       1.258ms       4.322us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           291  
                    aten::as_strided         0.01%     812.310us         0.01%     812.310us       1.396us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           582  
                          aten::ones         0.00%     306.842us         0.01%     790.086us      12.155us       0.000us         0.00%       0.000us       0.000us     520.00 Kb           0 b           0 b           0 b            65  
                      cudaMemGetInfo         0.00%     487.182us         0.00%     487.182us     243.591us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             2  
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 13.979s
Self CUDA time total: 1.946s

Loading took 16.91019082069397 seconds

runnning this:

# torchrun --master-addr 127.0.0.1   --nnodes 1  --nproc-per-node 4 /raid/arthur/test_safe_load.py 
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import os
import time
from torch.cuda.memory import caching_allocator_alloc
import torch
import time
from transformers import AutoModelForCausalLM
from torch.profiler import profile, record_function, ProfilerActivity

model_path = "meta-llama/Meta-Llama-3-8B-Instruct"

# On main you need to init nccl manually
# rank = int(os.environ["RANK"])
# world_size = int(os.environ["WORLD_SIZE"])
# torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size)
# torch.cuda.set_device(rank)

with torch.no_grad():
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    start = time.time()
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
                record_shapes=True, profile_memory=True) as prof:
        with record_function("model_load"):
            model = AutoModelForCausalLM.from_pretrained(
                        model_path, 
                        torch_dtype=torch.bfloat16, 
                        # tp_plan="auto", 
                        device_map="auto",
                        attn_implementation="sdpa"
            )
    end = time.time()

    print(f"Model loading time: {end - start:.2f} seconds")
    print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=25))


    print(f"Loading took {end-start} seconds")
    model.eval()
    
    input_ids =tokenizer(["Roses are red,"], return_tensors="pt",add_special_tokens=True).to("cuda") 
    out = model.generate(**input_ids, max_new_tokens=20)
    print(out)
    print(tokenizer.batch_decode(out))

With tensor parallel this is what we get:

Before

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:18<00:00,  4.55s/it]
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:18<00:00,  4.55s/it]
Model loading time: 25.17 seconds
Model loading time: 26.03 seconds
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             model_load        20.13%        4.308s       100.00%       21.397s       21.397s       0.000us         0.00%        2.383s        2.383s          24 b     -14.96 Gb       7.97 Gb     -33.44 Gb             1  
                                     record_param_comms        54.43%       11.646s        65.96%       14.114s      31.156ms     121.694ms         5.17%     121.694ms     268.641us           0 b           0 b           0 b           0 b           453  
                                       c10d::broadcast_         0.00%      92.163us        57.79%       12.366s       12.366s       0.000us         0.00%       9.120us       9.120us           0 b           0 b           0 b           0 b             1  
                                               aten::to         0.01%       2.050ms        11.15%        2.387s       1.360ms       0.000us         0.00%        2.230s       1.271ms         288 b           0 b      14.96 Gb           0 b          1755  
                                         aten::_to_copy         0.03%       5.748ms        11.14%        2.384s       4.041ms       0.000us         0.00%        2.230s       3.780ms         288 b           0 b      14.96 Gb           0 b           590  
                                            aten::copy_         0.04%       9.292ms        10.83%        2.316s       3.226ms        2.231s        94.82%        2.241s       3.121ms           0 b           0 b           0 b           0 b           718  
                                        cudaMemcpyAsync        10.53%        2.253s        10.54%        2.254s       6.572ms       0.000us         0.00%       9.797ms      28.561us           0 b           0 b           0 b           0 b           343  
                                         c10d::scatter_         0.05%      11.542ms         8.22%        1.758s       7.814ms       0.000us         0.00%     121.685ms     540.823us           0 b           0 b           0 b           0 b           225  
                                           nccl:scatter         0.00%       0.000us             0        1.738s       7.725ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           225  
                                           Unrecognized         5.18%        1.108s         5.18%        1.108s     277.068ms     252.096us         0.01%     252.096us      63.024us           0 b           0 b           0 b           0 b             4  
                                  cudaFuncGetAttributes         0.02%       3.946ms         5.13%        1.098s      35.404ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            31  
                                    cudaLaunchKernelExC         4.17%     892.977ms         4.17%     892.977ms       3.951ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           226  
                                             cudaMalloc         2.47%     527.767ms         2.47%     527.767ms     263.884ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             2  
                                            aten::empty         0.03%       6.956ms         2.18%     465.463ms     463.148us       0.000us         0.00%       0.000us       0.000us      29.92 Gb      29.92 Gb      19.46 Gb      19.46 Gb          1005  
                           cudaStreamCreateWithPriority         1.17%     249.387ms         1.17%     249.387ms       1.948ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           128  
                                          cudaHostAlloc         1.04%     222.252ms         1.04%     222.252ms      74.084ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             3  
                                    aten::empty_strided         0.05%      11.379ms         0.42%      89.782ms     107.139us       0.000us         0.00%       0.000us       0.000us         288 b         288 b      21.95 Gb      21.95 Gb           838  
                                           aten::detach         0.04%       9.459ms         0.33%      69.813ms      18.518us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          3770  
                                                 detach         0.25%      53.100ms         0.28%      60.271ms      18.336us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          3287  
                                  cudaStreamSynchronize         0.17%      35.822ms         0.17%      35.822ms     118.223us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           303  
                                       aten::contiguous         0.00%     308.805us         0.11%      23.698ms     185.144us       0.000us         0.00%      10.749ms      83.978us           0 b           0 b       4.50 Gb           0 b           128  
                                            aten::clone         0.00%     961.672us         0.11%      23.390ms     182.732us       0.000us         0.00%      10.749ms      83.978us           0 b           0 b       4.50 Gb           0 b           128  
                                       cudaLaunchKernel         0.01%       2.740ms         0.08%      17.436ms     136.217us       0.000us         0.00%     252.096us       1.970us           0 b           0 b           0 b           0 b           128  
                                         nccl:broadcast         0.00%       0.000us             0      16.122ms      16.122ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
                                            aten::chunk         0.01%       1.547ms         0.05%      10.381ms      46.137us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           225  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 21.397s
Self CUDA time total: 2.353s

Loading took 25.16935968399048 seconds
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             model_load        19.87%        4.302s       100.00%       21.648s       21.648s       0.000us         0.00%        2.603s        2.603s          24 b     -14.96 Gb       7.97 Gb     -33.44 Gb             1  
                                     record_param_comms        53.83%       11.654s        65.18%       14.111s      31.149ms     405.938ms        16.80%     415.691ms     917.641us           0 b           0 b           0 b           0 b           453  
                                       c10d::broadcast_         0.00%      83.247us        57.13%       12.368s       12.368s       0.000us         0.00%       9.216us       9.216us           0 b           0 b           0 b           0 b             1  
                                               aten::to         0.01%       2.100ms        11.20%        2.424s       1.381ms       0.000us         0.00%        2.177s       1.240ms         288 b           0 b      14.96 Gb           0 b          1755  
                                         aten::_to_copy         0.03%       6.124ms        11.19%        2.422s       4.105ms       0.000us         0.00%        2.177s       3.689ms         288 b           0 b      14.96 Gb           0 b           590  
                                            aten::copy_         0.05%      11.580ms        10.96%        2.372s       2.515ms        2.010s        83.20%        2.197s       2.330ms           0 b           0 b           0 b           0 b           943  
                                        cudaMemcpyAsync        10.40%        2.251s        10.41%        2.253s       3.966ms       0.000us         0.00%     186.662ms     328.631us           0 b           0 b           0 b           0 b           568  
                                         c10d::scatter_         0.05%      11.831ms         8.10%        1.753s       7.793ms       0.000us         0.00%     415.682ms       1.847ms           0 b           0 b           0 b           0 b           225  
                                           nccl:scatter         0.00%       0.000us             0        1.733s       7.704ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           225  
                                           Unrecognized         5.05%        1.094s         5.05%        1.094s     273.429ms     252.674us         0.01%     252.674us      63.168us           0 b           0 b           0 b           0 b             4  
                                  cudaFuncGetAttributes         0.02%       5.347ms         5.01%        1.084s      34.983ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b            31  
                                    cudaLaunchKernelExC         4.10%     888.408ms         4.10%     888.408ms       3.931ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           226  
                                             cudaMalloc         3.43%     741.823ms         3.43%     741.823ms     370.912ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             2  
                                            aten::empty         0.03%       6.861ms         3.17%     687.056ms     683.638us       0.000us         0.00%       0.000us       0.000us      29.92 Gb      29.92 Gb      19.46 Gb      19.46 Gb          1005  
                           cudaStreamCreateWithPriority         1.15%     247.919ms         1.15%     247.919ms       1.937ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           128  
                                          cudaHostAlloc         1.02%     220.517ms         1.02%     220.517ms      73.506ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             3  
                                  cudaStreamSynchronize         0.42%      90.714ms         0.42%      90.714ms     299.385us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           303  
                                    aten::empty_strided         0.05%      11.238ms         0.38%      81.788ms      97.599us       0.000us         0.00%       0.000us       0.000us         288 b         288 b      21.95 Gb      21.95 Gb           838  
                                           aten::detach         0.04%       9.651ms         0.33%      70.432ms      18.682us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          3770  
                                                 detach         0.25%      53.656ms         0.28%      60.689ms      18.463us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          3287  
                                       aten::contiguous         0.00%     355.579us         0.11%      23.851ms     186.336us       0.000us         0.00%      10.751ms      83.996us           0 b           0 b       4.50 Gb           0 b           128  
                                            aten::clone         0.00%     992.811us         0.11%      23.495ms     183.558us       0.000us         0.00%      10.751ms      83.996us           0 b           0 b       4.50 Gb           0 b           128  
                                       cudaLaunchKernel         0.01%       2.832ms         0.08%      17.413ms     136.042us       0.000us         0.00%     252.674us       1.974us           0 b           0 b           0 b           0 b           128  
                                         nccl:broadcast         0.00%       0.000us             0      15.654ms      15.654ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
                                            aten::chunk         0.01%       1.583ms         0.05%      10.431ms      46.359us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           225  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 21.648s
Self CUDA time total: 2.416s

Loading took 26.030614852905273 seconds

After

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.10it/s]
Model loading time: 7.50 seconds
Loading checkpoint shards:   0%|                                                                                                                                                                                                                                                                                                                                                                                       | 0/4 [00:00<?, ?it/s]-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                 model_load        61.72%        4.177s       100.00%        6.767s        6.767s       0.000us         0.00%     889.650ms     889.650ms           8 b     -14.96 Gb       7.97 Gb     -29.92 Gb             1  
                                aten::copy_        18.34%        1.241s        51.19%        3.464s       4.838ms     889.650ms       100.00%        1.069s       1.493ms           0 b      -2.25 Gb           0 b           0 b           716  
                                   aten::to         0.03%       1.912ms        32.41%        2.193s       1.776ms       0.000us         0.00%     889.650ms     720.364us         272 b           0 b       7.97 Gb           0 b          1235  
                             aten::_to_copy         0.07%       4.951ms        32.38%        2.191s       3.726ms       0.000us         0.00%     889.650ms       1.513ms         272 b           0 b       7.97 Gb           0 b           588  
                           aten::contiguous         0.01%     343.621us        16.36%        1.107s      17.300ms       0.000us         0.00%       0.000us       0.000us       2.25 Gb           0 b           0 b           0 b            64  
                                aten::clone         0.01%     804.528us        16.36%        1.107s      17.295ms       0.000us         0.00%       0.000us       0.000us       2.25 Gb           0 b           0 b           0 b            64  
                            cudaMemcpyAsync        13.52%     914.712ms        13.52%     914.712ms       3.133ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           292  
                                aten::empty         0.13%       8.888ms         4.27%     289.099ms     234.658us       0.000us         0.00%       0.000us       0.000us      32.17 Gb      32.17 Gb      29.92 Gb      29.92 Gb          1232  
                                 cudaMalloc         4.17%     282.526ms         4.17%     282.526ms     141.263ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             2  
                               aten::detach         0.13%       8.513ms         0.99%      67.127ms      20.219us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          3320  
                                     detach         0.75%      50.811ms         0.86%      58.530ms      20.631us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          2837  
                           _FromTorchTensor         0.33%      22.613ms         0.39%      26.480ms     117.691us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           225  
                        aten::empty_strided         0.15%      10.031ms         0.32%      21.610ms      35.368us       0.000us         0.00%       0.000us       0.000us         272 b         272 b       7.97 Gb       7.97 Gb           611  
                      cudaStreamSynchronize         0.15%      10.057ms         0.15%      10.057ms      34.442us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           292  
           cudaDeviceGetStreamPriorityRange         0.14%       9.143ms         0.14%       9.143ms       9.143ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1  
                                 aten::set_         0.08%       5.325ms         0.08%       5.325ms       9.055us       0.000us         0.00%       0.000us       0.000us     -14.96 Gb     -14.96 Gb           0 b           0 b           588  
                                 aten::view         0.07%       4.851ms         0.07%       4.851ms       6.004us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           808  
                              aten::view_as         0.02%       1.478ms         0.06%       3.867ms      17.187us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           225  
                                aten::slice         0.05%       3.387ms         0.06%       3.765ms       7.296us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           516  
                               aten::select         0.04%       2.869ms         0.05%       3.335ms      11.422us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           292  
                                 aten::item         0.01%     696.884us         0.04%       2.630ms       9.036us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           291  
                  aten::_local_scalar_dense         0.03%       1.933ms         0.03%       1.933ms       6.641us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           291  
                              aten::reshape         0.01%     638.532us         0.02%       1.426ms       4.899us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           291  
                           aten::as_strided         0.01%     892.183us         0.01%     892.183us       1.023us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           872  
                           aten::empty_like         0.00%     272.989us         0.01%     762.668us      11.917us       0.000us         0.00%       0.000us       0.000us       2.25 Gb           0 b           0 b           0 b            64  
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 6.767s
Self CUDA time total: 889.650ms

Loading took 7.50086522102356 seconds
Loading checkpoint shards:  25%|███████████████████████████████████████████████████████████████████████████████████████████▊                                                                                                                                                                                                                                                                                   | 1/4 [00:01<00:03,  1.22s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.10it/s]
Model loading time: 12.28 seconds
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                 model_load        73.85%        8.324s       100.00%       11.272s       11.272s       0.000us         0.00%     878.310ms     878.310ms           8 b     -14.96 Gb       7.97 Gb     -29.92 Gb             1  
                                aten::copy_        10.59%        1.194s        29.93%        3.373s       4.712ms     878.310ms       100.00%        1.059s       1.479ms           0 b      -2.25 Gb           0 b           0 b           716  
                                   aten::to         0.02%       1.881ms        19.64%        2.213s       1.792ms       0.000us         0.00%     878.310ms     711.182us         272 b           0 b       7.97 Gb           0 b          1235  
                             aten::_to_copy         0.04%       4.709ms        19.62%        2.211s       3.761ms       0.000us         0.00%     878.310ms       1.494ms         272 b           0 b       7.97 Gb           0 b           588  
                           aten::contiguous         0.00%     322.244us         9.53%        1.074s      16.785ms       0.000us         0.00%       0.000us       0.000us       2.25 Gb           0 b           0 b           0 b            64  
                                aten::clone         0.01%     835.340us         9.53%        1.074s      16.780ms       0.000us         0.00%       0.000us       0.000us       2.25 Gb           0 b           0 b           0 b            64  
                            cudaMemcpyAsync         8.02%     903.717ms         8.02%     903.717ms       3.095ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           292  
                                 cudaMalloc         6.24%     703.549ms         6.24%     703.549ms     351.775ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             2  
                                aten::empty         0.08%       8.860ms         5.51%     621.190ms     504.213us       0.000us         0.00%       0.000us       0.000us      32.17 Gb      32.17 Gb      29.92 Gb      29.92 Gb          1232  
                        aten::empty_strided         0.09%       9.693ms         0.90%     101.026ms     165.346us       0.000us         0.00%       0.000us       0.000us         272 b         272 b       7.97 Gb       7.97 Gb           611  
                               aten::detach         0.07%       8.352ms         0.57%      64.729ms      19.497us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          3320  
                                     detach         0.44%      49.138ms         0.50%      56.303ms      19.846us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          2837  
                           _FromTorchTensor         0.19%      21.822ms         0.23%      25.437ms     113.052us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           225  
                                   Resource         0.09%       9.790ms         0.09%       9.790ms       2.447ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             4  
                      cudaStreamSynchronize         0.09%       9.690ms         0.09%       9.690ms      33.185us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           292  
                                 aten::set_         0.05%       5.471ms         0.05%       5.471ms       9.305us       0.000us         0.00%       0.000us       0.000us     -14.96 Gb     -14.96 Gb           0 b           0 b           588  
                                 aten::view         0.04%       4.568ms         0.04%       4.568ms       5.653us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           808  
                                aten::slice         0.03%       3.369ms         0.03%       3.756ms       7.278us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           516  
                              aten::view_as         0.01%       1.432ms         0.03%       3.614ms      16.063us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           225  
                               aten::select         0.02%       2.765ms         0.03%       3.223ms      11.037us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           292  
                              aten::reshape         0.01%     655.310us         0.01%       1.423ms       4.891us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           291  
                                 aten::item         0.01%     708.561us         0.01%       1.216ms       4.180us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           291  
                           aten::as_strided         0.01%     895.889us         0.01%     895.889us       1.027us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           872  
                                 aten::ones         0.00%     256.346us         0.01%     781.781us      12.027us       0.000us         0.00%       0.000us       0.000us     520.00 Kb           0 b           0 b           0 b            65  
                           aten::empty_like         0.00%     264.930us         0.01%     722.534us      11.290us       0.000us         0.00%       0.000us       0.000us       2.25 Gb           0 b           0 b           0 b            64  
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 11.272s
Self CUDA time total: 878.310ms

Loading took 12.277393817901611 seconds

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker merged commit 1603018 into main Feb 26, 2025
25 checks passed
@ArthurZucker ArthurZucker deleted the safe-tensors branch February 26, 2025 19:12
@muellerzr muellerzr mentioned this pull request Feb 27, 2025
5 tasks
ArthurZucker pushed a commit that referenced this pull request Mar 1, 2025
* fix

* style

* better allocation

* fix

* fix

* style

* revert disk

* exit

* style

* return if nothing to cache

* dtensor guard

* fix regressiion

* fix regression

* fix

* fix
garrett361 pushed a commit to garrett361/transformers that referenced this pull request Mar 4, 2025
* fix

* style

* better allocation

* fix

* fix

* style

* revert disk

* exit

* style

* return if nothing to cache

* dtensor guard

* fix regressiion

* fix regression

* fix

* fix
garrett361 pushed a commit to garrett361/transformers that referenced this pull request Mar 4, 2025
* fix

* style

* better allocation

* fix

* fix

* style

* revert disk

* exit

* style

* return if nothing to cache

* dtensor guard

* fix regressiion

* fix regression

* fix

* fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants