Skip to content

Commit 036213b

Browse files
Fix sft trainer when args is None (#1295)
* fix sft trainer when args is None * add test * fix
1 parent 6042596 commit 036213b

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

tests/test_sft_trainer.py

+19
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,25 @@ def test_sft_trainer_with_model_neftune(self):
709709
_ = trainer.model(torch.LongTensor([[1, 0, 1]]).to(device))
710710
self.assertTrue(len(trainer.model.get_input_embeddings()._forward_hooks) == 0)
711711

712+
@require_peft
713+
def test_peft_sft_trainer_str(self):
714+
peft_config = LoraConfig(
715+
r=16,
716+
lora_alpha=32,
717+
lora_dropout=0.05,
718+
bias="none",
719+
task_type="CAUSAL_LM",
720+
)
721+
722+
_ = SFTTrainer(
723+
model=self.model_id,
724+
args=None,
725+
train_dataset=self.train_dataset,
726+
eval_dataset=self.eval_dataset,
727+
peft_config=peft_config,
728+
packing=True,
729+
)
730+
712731
@require_peft
713732
def test_peft_sft_trainer(self):
714733
with tempfile.TemporaryDirectory() as tmp_dir:

trl/trainer/sft_trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def make_inputs_require_grad(module, input, output):
209209
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
210210

211211
model = get_peft_model(model, peft_config)
212-
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
212+
if args is not None and args.bf16 and getattr(model, "is_loaded_in_4bit", False):
213213
peft_module_casting_to_bf16(model)
214214

215215
if tokenizer is None:

0 commit comments

Comments
 (0)