Skip to content

Commit

Permalink
Make sure do_eval works without do_train (#100)
Browse files Browse the repository at this point in the history
* make sure evals work with do_train=False

* increase eval_num_samples
  • Loading branch information
farzadab authored Aug 27, 2024
1 parent 638a7a6 commit 44eebcc
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion ultravox/training/configs/meta_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ max_audio_duration_secs: 16

val_num_samples: 64
val_steps: 1000
eval_num_samples: 1024
eval_num_samples: 2000
eval_max_new_tokens: 32
eval_num_procs: 16

Expand Down
25 changes: 13 additions & 12 deletions ultravox/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def main() -> None:
local_rank = int(os.environ.get("LOCAL_RANK", 0))
is_master = local_rank == 0

if args.do_train:
train(args)
train(args)

if args.do_eval and is_master:
gc.collect()
Expand Down Expand Up @@ -296,22 +295,24 @@ def train(args: config_base.TrainConfig):
),
)

# Training loop
logging.info("Starting training...")
t_start = datetime.now()
logging.info(f"train start time: {t_start}")
if args.val_steps:
trainer.evaluate()
trainer.train()
if args.do_train:
# Training loop
logging.info("Starting training...")
t_start = datetime.now()
logging.info(f"train start time: {t_start}")
if args.val_steps:
trainer.evaluate()
trainer.train()
t_end = datetime.now()
logging.info(f"train end time: {t_end}")
logging.info(f"elapsed: {t_end - t_start}")

if is_master:
# Saving the model using pipeline to ensure its code is saved
pipeline = ultravox_pipeline.UltravoxPipeline(
model, tokenizer=text_tokenizer, device=device
)
pipeline.save_pretrained(args.output_dir)
t_end = datetime.now()
logging.info(f"train end time: {t_end}")
logging.info(f"elapsed: {t_end - t_start}")


def evaluate(args: config_base.TrainConfig):
Expand Down

0 comments on commit 44eebcc

Please sign in to comment.