Skip to content

Commit

Permalink
Update textual_inversion.py (#6952)
Browse files Browse the repository at this point in the history
* Update textual_inversion.py

* Apply suggestions from code review

* Update textual_inversion.py

* Update textual_inversion.py

* Update textual_inversion.py

* Update textual_inversion.py

* Update examples/textual_inversion/textual_inversion.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update textual_inversion.py

* styling

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
  • Loading branch information
Bhavay-2001 and sayakpaul authored Feb 16, 2024
1 parent 104afbc commit 777063e
Showing 1 changed file with 20 additions and 21 deletions.
41 changes: 20 additions & 21 deletions examples/textual_inversion/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available


Expand Down Expand Up @@ -84,32 +85,30 @@
logger = get_logger(__name__)


def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None):
def save_model_card(repo_id: str, images: list = None, base_model: str = None, repo_folder: str = None):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"

yaml = f"""
---
license: creativeml-openrail-m
base_model: {base_model}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
- textual_inversion
inference: true
---
"""
model_card = f"""
if images is not None:
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
model_description = f"""
# Textual inversion text2image fine-tuning - {repo_id}
These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n
{img_str}
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=base_model,
model_description=model_description,
inference=True,
)

tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "textual_inversion"]
model_card = populate_model_card(model_card, tags=tags)

model_card.save(os.path.join(repo_folder, "README.md"))


def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
Expand Down

0 comments on commit 777063e

Please sign in to comment.