Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: 'weight' must be 2-D while training Flan-T5 models with stage 3 #2746

Closed
smitanannaware opened this issue Jan 25, 2023 · 25 comments
Assignees
Labels
bug Something isn't working training

Comments

@smitanannaware
Copy link

I am using Huggingface Seq2SeqTrainer for training Flan-T5-xl model with deepspeed stage 3.

trainer = Seq2SeqTrainer(
                #model_init = self.model_init,
                model=self.model,
                args=training_args,
                train_dataset=train_ds,
                eval_dataset=val_ds,
                tokenizer = self.tokenizer,
                data_collator=self.data_collator,
                compute_metrics=self.compute_metrics,
            )
        
trainer.train()

I am stuck on below error:

  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/transformers/trainer.py", line 1527, in train
    return inner_training_loop(
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/transformers/trainer.py", line 1773, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/transformers/trainer.py", line 2523, in training_step
    loss = self.compute_loss(model, inputs)
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/transformers/trainer.py", line 2555, in compute_loss
    outputs = model(**inputs)
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1158, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1111, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 1611, in forward
    encoder_outputs = self.encoder(
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 941, in forward
    inputs_embeds = self.embed_tokens(input_ids)
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/users/snannawa/.conda/envs/sn_torch/lib/python3.10/site-packages/torch/nn/functional.py", line 2210, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: 'weight' must be 2-D

The code works with Zero2 config but not working with Zero 3. I have tried a couple of settings but no luck.

{
    "fp16": {
        "enabled": true,
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_fp16_weights_on_model_save": true
    },

    "gradient_accumulation_steps": 8,
    "gradient_clipping": "auto",
    "steps_per_print": 10,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

Any help would be appreciated.

@williamberman
Copy link

Looks like same error popped up in diffusers using zero stage3 :) huggingface/diffusers#1865

@dumitrescustefan
Copy link

Don't know if this helps, but I get the same 2-D error with stage3 in a weird way: I use the datasets map function with the method of a class that contains a SentenceTranformer model. Basically, I want to augment my dataset before training, and when used with deepspeed it gives the 2-D error in the sentence transformer that has nothing to do with the model I'm actually training. Stage 2 seems to work okay.
I'm just beginning with deepspeed and probably don't understand how to use if fully, but maybe it helps with this issue.

@tohtana
Copy link
Contributor

tohtana commented Feb 8, 2023

Hello @smitanannaware, thank you for reporting.

According to the documentation of HuggingFace, you need to pass your deepspeed config file to TrainingArgument. Can you try the setting?

training_args = Seq2SeqTrainingArguments(
...
    deepspeed="ds_config.json"
)

I tried to train Flan-T5 using the code on this article.
The training diverged with FP16 as suggested in the article, but I didn't see the error with stage 3.

@djaym7
Copy link

djaym7 commented Mar 10, 2023

+1, getting same error

@tohtana
Copy link
Contributor

tohtana commented Mar 10, 2023

@djaym7 Thank you for your report!

Can you give us more details? Did you pass deepspeed argument to Seq2SeqTrainingArguments as shown in my comment?
Is it possible to share the entire code?

@djaym7
Copy link

djaym7 commented Mar 10, 2023

config is loaded from https://github.com/philschmid/deep-learning-pytorch-huggingface/blob/main/training/configs/ds_flan_t5_z3_config.json

training_args = TrainingArguments(
output_dir=f"./results/{question_name}_{output_dir_suffix}",
learning_rate=lr,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
# auto_find_batch_size=True,
num_train_epochs=epochs,
weight_decay=0.02,
warmup_steps=warmup_steps, #1epoch = 1530/16-- 95 steps
lr_scheduler_type= 'linear',
optim='adamw_torch',
evaluation_strategy='epoch',
# save_strategy='epoch',save_steps=eval_steps,
logging_steps=eval_steps,
eval_steps=eval_steps,
gradient_checkpointing=gradient_checkpointing,
# do_eval=False,
save_total_limit=2,
# load_best_model_at_end=True,
fp16=fp16,
# metric_for_best_model='f1',
gradient_accumulation_steps = gradient_accumulation_steps,
dataloader_num_workers = dataloader_num_workers,
sharded_ddp=sharded_ddp,
)
if deepspeed:
training_args.deepspeed = deepspeed_dict

    from transformers.deepspeed import HfTrainerDeepSpeedConfig
    training_args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(deepspeed_dict)

@woodyx218
Copy link

+1, getting same error

@tohtana
Copy link
Contributor

tohtana commented Mar 21, 2023

@djaym7 @woodyx218
Can you try the complete example on philschmid's blog?
He showed the complete code to train Flan-T5 using DeepSpeed. The code successfully worked in my environment.

@djaym7
Copy link

djaym7 commented Mar 21, 2023

Error is when using PEFT with Flan..

@yezifeiafei
Copy link

yezifeiafei commented Mar 27, 2023

Error is when using PEFT with Flan..

Hi @djaym7
I have the same problem, how did you fix it?

@djaym7
Copy link

djaym7 commented Mar 27, 2023

havent, using regular inference without deepspeed

@tohtana
Copy link
Contributor

tohtana commented Apr 12, 2023

Hi @djaym7,
I apologize for the delayed response.

I have tried to reproduce the problem using both deepspeed and PEFT (prefix tuning) but haven't seen the same error.
My code is available at https://github.com/tohtana/ds_repro_2746
You can set up the dataset using prepare_dataset.py and then run run_t5_ds_peft.sh.

I came across the error that you mentioned at huggingface/peft#168.
However, I found that the error happened regardless of whether I used deepspeed or not.
I could resolve it by setting False to both args.gradient_checkpointing and use_cache as you mentioned in the thread of the issue.

I didn't see an error after making these changes.
Can you let me know if I missed something?

The versions of peft, transformers, deepspeed were:

  • peft 0.3.0.dev0
  • deepspeed 0.8.3
  • transformers 4.28.0.dev0

@djaym7
Copy link

djaym7 commented Apr 12, 2023

Hi @djaym7, I apologize for the delayed response.

I have tried to reproduce the problem using both deepspeed and PEFT (prefix tuning) but haven't seen the same error. My code is available at https://github.com/tohtana/ds_repro_2746 You can set up the dataset using prepare_dataset.py and then run run_t5_ds_peft.sh.

I came across the error that you mentioned at huggingface/peft#168. However, I found that the error happened regardless of whether I used deepspeed or not. I could resolve it by setting False to both args.gradient_checkpointing and use_cache as you mentioned in the thread of the issue.

I didn't see an error after making these changes. Can you let me know if I missed something?

The versions of peft, transformers, deepspeed were:

  • peft 0.3.0.dev0
  • deepspeed 0.8.3
  • transformers 4.28.0.dev0

There's no error in training, error is in inference .. add following after training and there'll be error

for batch in tqdm(data_loader):
# need to push the data to device
with torch.no_grad():
outs = model.generate(input_ids=batch['input_ids'].to(device),
attention_mask=batch['attention_mask'].to(device),
max_new_tokens=128) # num_beams=8, early_stopping=True)

@tohtana
Copy link
Contributor

tohtana commented Apr 20, 2023

Hi @djaym7,

I added the folllowing code after trainer.train() in this example but didn't see the error.
Is it possible to share your code?

    device = torch.device("cuda")
    loader = torch.utils.data.DataLoader(eval_dataset, batch_size=args.per_device_eval_batch_size, shuffle=False, collate_fn=data_collator)
    for batch in loader:
        with torch.no_grad():
            outputs = model.generate(input_ids=batch['input_ids'].to(device),
                attention_mask=batch['attention_mask'].to(device), max_new_tokens=128) # num_beams=8, early_stopping=True)
            print(tokenizer.batch_decode(outputs, skip_special_tokens=True))

@djaym7
Copy link

djaym7 commented Apr 20, 2023

Actually, it comes from using bertScore
image

def evaluate(data_loader, model,tokenizer,print_samples=False,metric='bertscore_simple',device=None,**kwargs):
    """
    Compute scores given the predictions and gold labels
    """

    if device is not None:
        model = model.to(device)

    inputs,outputs, targets = [], [], []
    
    inputs_dat,outputs_dat = [], []

    for batch in tqdm(data_loader):
        # need to push the data to device
        if device is not None:
            batch['input_ids']=batch['input_ids'].to(model.device)
            batch['attention_mask']=batch['attention_mask'].to(model.device)
        
        with torch.no_grad():
            outs = model.generate(input_ids=batch['input_ids'], #
                                            attention_mask=batch['attention_mask'], 
                                            max_new_tokens=128,**kwargs)  # num_beams=8, early_stopping=True)


           
        dec = [tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
        labels = batch['labels']
        labels[labels==-100] = tokenizer.pad_token_id
        target = [tokenizer.decode(ids, skip_special_tokens=True) for ids in batch["labels"]]
        inp = [tokenizer.decode(ids, skip_special_tokens=True) for ids in batch["input_ids"]]

        inputs.extend(inp)
        outputs.extend(dec)
        targets.extend(target)
    
    if print_samples:
        print("\nPrint some results to check the sanity of generation method:", '\n', '-'*30)
        for i in [1, 5, 25, 42, 50, 4, 10, 35]:
            try:
                print(f'>>Input     : {inputs[i]}')
                print(f'>>Target    : {targets[i]}')
                print(f'>>Generation: {outputs[i]}\n\n')
            except UnicodeEncodeError:
                print('Unable to print due to the coding error')

        if 'input_ids_dat' in batch:
            print('\n\n On TARGET DOMAIN')

            for i in [1, 5, 25, 42, 50, 4, 10, 35]:
                try:
                    print(f'>>Input     : {inputs_dat[i]}')
                    print(f'>>Generation: {outputs_dat[i]}\n\n')
                except UnicodeEncodeError:
                    print('Unable to print due to the coding error')
            print()

    scores, all_labels, all_preds = compute_scores(outputs, targets,metric=metric)
    # results = {'scores': scores, 'labels': all_labels, 'preds': all_preds}
    
    scores['refs'] = all_labels
    scores['preds'] = all_preds
    
    
    scores['exact_match_metrics'] = compute_f1_scores(outputs,targets) 
    return scores#, all_labels, all_preds
def compute_f1_scores(pred_pt, gold_pt):
    """
    Function to compute F1 scores with pred and gold quads
    The input needs to be already processed
    """
    # number of true postive, gold standard, predictions
    accuracies = []
    for p,r in zip(pred_pt,gold_pt):
        if p==r:
            accuracies.append(1)
        else:
            accuracies.append(0)
        
        
    return {'accuracy':np.mean(accuracies)}


def compute_scores(pred_seqs, gold_seqs, metric='bertscore_simple'):
    """
    Compute model performance
    """
    scores={}
    assert len(pred_seqs) == len(gold_seqs)
    if 'bertscore' in metric and 'complex' in metric:
        bert_score = load_metric('bertscore')
        scores.update(bert_score.compute(predictions=pred_seqs, references=gold_seqs,model_type='bert-base-uncased' ))
        for sim in [0.5,0.6,0.7,0.8,0.9]:
            scores['accuracy_'+ str(sim)] = [1 if i>=sim else 0 for i in scores['f1'] ]
            scores['accuracy_'+ str(sim)+'_mean'] = np.mean(scores['accuracy_'+ str(sim)])

        scores['class_metrics'] = class_wise_metrics(scores,pred_seqs,gold_seqs)
        scores['class_length'] = Counter(gold_seqs)
        new_scores = {}
    if 'bertscore' in metric and 'simple' in metric:
        bert_score = load_metric('bertscore')
        scores.update(bert_score.compute(predictions=pred_seqs, references=gold_seqs,model_type='bert-base-uncased' ))
        for sim in [0.5,0.6,0.7,0.8,0.9]:
            scores['accuracy_'+ str(sim)] = [1 if i>=sim else 0 for i in scores['f1'] ]
            scores['accuracy_'+ str(sim)+'_mean'] = np.mean(scores['accuracy_'+ str(sim)])
    if 'rouge' in metric:
        bert_score = load_metric('rouge')
        scores.update(bert_score.compute(predictions=pred_seqs, references=gold_seqs ))
    
    
    return scores, gold_seqs, pred_seqs

@tohtana
Copy link
Contributor

tohtana commented Apr 20, 2023

@djaym7 Can we clarify what errors you have now? I see several different errors regarding this issue.

It would be helpful if you could give us the entire reproducing code.

@djaym7
Copy link

djaym7 commented Apr 20, 2023

The error you mentioned at huggingface/peft#168 is about both training and inference. Do you still have the errors? YES
Your latest error is in computing metrics. Do you have no issue for training and inference(generation) now? YES

To reproduce, add the evaluate function shared above after training the model. Error is posted above as well.

@tohtana
Copy link
Contributor

tohtana commented Apr 21, 2023

@djaym7

The error you mentioned at huggingface/peft#168 is about both training and inference. Do you still have the errors? YES

I am a bit confused. You wrote "There's no error in training, error is in inference" at #2746 (comment). Do you have an error with training or not?

I wrote an example of training/generation using DS and PEFT. I didn't fully test it but at least it didn't throw the error. What is the difference with your code?

To reproduce, add the evaluate function shared above after training the model. Error is posted above as well.

I think we need to make sure that we are doing the same for training/generation before further investigation.

@tohtana
Copy link
Contributor

tohtana commented May 12, 2023

Closing because we have no additional information.
Please feel free to reopen if the problem still exists.

@tohtana tohtana closed this as completed May 12, 2023
@xiangxu-google
Copy link

Faced the same issue when run inference using T5ForConditionalGeneration.from_pretrained() to load a pre-trained model.

Solution: use trainer.save_model() instead of model.save_pretrained() to save the pre-trained model.

@z7ye
Copy link

z7ye commented May 22, 2023

Facing same issue when using torch.jit.trace(model, example_inputs=dummy, strict=False) to save cerebras/Cerebras-GPT-111M pretrained model from huggingface. I dont see this error when not using deepspeed.

@nikolakopoulos
Copy link

I had the same issue. Thankfully, it went away when I upgraded to Deepspeed 0.9.5.

@allanj
Copy link

allanj commented Feb 27, 2024

Facing the same issue with DeepSpeed 0.13.4.

Training with PEFT: QLora + DeepSpeed Zero Stage 3, offload param and optimizer to CPU.
Model: LLaMA2

Training is fine.

After training, we merge_and_unload to model and perform inference, once we do inference we got this error:

  File "/root/code_sft/sft_main.py", line 461, in main
    test_pfm = evaluate(args, test_dataloader, model, mix_precision=mix_precision, tokenizer=tokenizer,
  File "/root/code_sft/sft_main.py", line 223, in evaluate
    generated_ids = module.generate(input_ids=feature["input_ids"],
  File "/miniconda/envs/py310/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/miniconda/envs/py310/lib/python3.10/site-packages/transformers/generation/utils.py", line 1474, in generate
  File "/miniconda/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/miniconda/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/miniconda/envs/py310/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1183, in forward
    outputs = self.model(
  File "/miniconda/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/miniconda/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/miniconda/envs/py310/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1027, in forward
    inputs_embeds = self.embed_tokens(input_ids)
  File "/miniconda/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/miniconda/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/miniconda/envs/py310/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/miniconda/envs/py310/lib/python3.10/site-packages/torch/nn/functional.py", line 2233, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: 'weight' must be 2-D

@AfrinDange
Copy link

I am facing the same issue with the DPR and GPT2 models. I am using the latest torch version to use FullyShardedDataParallel for distributed training.

The training works fine (regardless of the number of devices I use)
The inference only works when the world size = 1.
Otherwise, I am getting the same error.

@ggsdeath
Copy link

Facing the same issue with DeepSpeed 0.13.4.

Training with PEFT: QLora + DeepSpeed Zero Stage 3, offload param and optimizer to CPU. Model: LLaMA2

Training is fine.

After training, we merge_and_unload to model and perform inference, once we do inference we got this error:

  File "/root/code_sft/sft_main.py", line 461, in main
    test_pfm = evaluate(args, test_dataloader, model, mix_precision=mix_precision, tokenizer=tokenizer,
  File "/root/code_sft/sft_main.py", line 223, in evaluate
    generated_ids = module.generate(input_ids=feature["input_ids"],
  File "/miniconda/envs/py310/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/miniconda/envs/py310/lib/python3.10/site-packages/transformers/generation/utils.py", line 1474, in generate
  File "/miniconda/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/miniconda/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/miniconda/envs/py310/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1183, in forward
    outputs = self.model(
  File "/miniconda/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/miniconda/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/miniconda/envs/py310/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1027, in forward
    inputs_embeds = self.embed_tokens(input_ids)
  File "/miniconda/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/miniconda/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/miniconda/envs/py310/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 162, in forward
    return F.embedding(
  File "/miniconda/envs/py310/lib/python3.10/site-packages/torch/nn/functional.py", line 2233, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: 'weight' must be 2-D

have you solve this problem. It seems the solution mentioned doesn't work well for the issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests