Skip to content

Commit

Permalink
Merge pull request #1168 from gesen2egee/save_state_on_train_end
Browse files Browse the repository at this point in the history
Save state on train end
  • Loading branch information
kohya-ss authored Mar 20, 2024
2 parents 3b0db0f + d282c45 commit bf6cd4b
Show file tree
Hide file tree
Showing 9 changed files with 13 additions and 8 deletions.
2 changes: 1 addition & 1 deletion fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

accelerator.end_training()

if args.save_state and is_main_process:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)

del accelerator # この後メモリを使うのでこれは消す
Expand Down
5 changes: 5 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2938,6 +2938,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
action="store_true",
help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する",
)
parser.add_argument(
"--save_state_on_train_end",
action="store_true",
help="save training state additionally (including optimizer states etc.) on train end / optimizerなど学習状態も含めたstateを追加で保存する",
)
parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")

parser.add_argument("--train_batch_size", type=int, default=1, help="batch size for training / 学習時のバッチサイズ")
Expand Down
2 changes: 1 addition & 1 deletion sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

accelerator.end_training()

if args.save_state: # and is_main_process:
if args.save_state or args.save_state_on_train_end:
train_util.save_state_on_train_end(args, accelerator)

del accelerator # この後メモリを使うのでこれは消す
Expand Down
2 changes: 1 addition & 1 deletion sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def remove_model(old_ckpt_name):

accelerator.end_training()

if is_main_process and args.save_state:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)

if is_main_process:
Expand Down
2 changes: 1 addition & 1 deletion train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def remove_model(old_ckpt_name):

accelerator.end_training()

if is_main_process and args.save_state:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)

# del accelerator # この後メモリを使うのでこれは消す→printで使うので消さずにおく
Expand Down
2 changes: 1 addition & 1 deletion train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def train(args):

accelerator.end_training()

if args.save_state and is_main_process:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)

del accelerator # この後メモリを使うのでこれは消す
Expand Down
2 changes: 1 addition & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,7 @@ def remove_model(old_ckpt_name):

accelerator.end_training()

if is_main_process and args.save_state:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)

if is_main_process:
Expand Down
2 changes: 1 addition & 1 deletion train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ def remove_model(old_ckpt_name):

accelerator.end_training()

if args.save_state and is_main_process:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)

if is_main_process:
Expand Down
2 changes: 1 addition & 1 deletion train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def remove_model(old_ckpt_name):

accelerator.end_training()

if args.save_state and is_main_process:
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)

updated_embs = text_encoder.get_input_embeddings().weight[token_ids_XTI].data.detach().clone()
Expand Down

0 comments on commit bf6cd4b

Please sign in to comment.