Skip to content

Commit

Permalink
Expose res guidance scale, and save mem #24
Browse files Browse the repository at this point in the history
  • Loading branch information
dilithjay committed Apr 25, 2023
1 parent 50d7336 commit 3c13472
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 8 deletions.
10 changes: 7 additions & 3 deletions img_styler/image_prompt/control_net/canny2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def get_image_samples(
eta=0.0,
low_threshold=100,
high_threshold=200,
save_memory=True,
):
input_image = cv2.imread(input_img_path)
apply_canny = CannyDetector()
Expand All @@ -58,7 +59,8 @@ def get_image_samples(
seed = random.randint(0, 65535)
seed_everything(seed)

model.low_vram_shift(is_diffusing=False)
if save_memory:
model.low_vram_shift(is_diffusing=False)

cond = {
"c_concat": [control],
Expand All @@ -70,7 +72,8 @@ def get_image_samples(
}
shape = (4, H // 8, W // 8)

model.low_vram_shift(is_diffusing=True)
if save_memory:
model.low_vram_shift(is_diffusing=False)

model.control_scales = (
[strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)
Expand All @@ -86,7 +89,8 @@ def get_image_samples(
unconditional_conditioning=un_cond,
)

model.low_vram_shift(is_diffusing=False)
if save_memory:
model.low_vram_shift(is_diffusing=False)

x_samples = model.decode_first_stage(samples)
x_samples = (
Expand Down
12 changes: 8 additions & 4 deletions img_styler/image_prompt/control_net/scribble2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_image_samples(
strength=1.0,
scale=5.0,
eta=0.0,
save_memory=True,
):
input_image = cv2.imread(input_img_path)

Expand All @@ -53,7 +54,8 @@ def get_image_samples(
seed = random.randint(0, 65535)
seed_everything(seed)

model.low_vram_shift(is_diffusing=False)
if save_memory:
model.low_vram_shift(is_diffusing=False)

cond = {
"c_concat": [control],
Expand All @@ -65,12 +67,13 @@ def get_image_samples(
}
shape = (4, H // 8, W // 8)

model.low_vram_shift(is_diffusing=True)
if save_memory:
model.low_vram_shift(is_diffusing=False)

model.control_scales = (
[strength * (0.825 ** float(12 - i)) for i in range(13)] if guess_mode else ([strength] * 13)
) # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
samples, intermediates = ddim_sampler.sample(
samples, _ = ddim_sampler.sample(
ddim_steps,
num_samples,
shape,
Expand All @@ -81,7 +84,8 @@ def get_image_samples(
unconditional_conditioning=un_cond,
)

model.low_vram_shift(is_diffusing=False)
if save_memory:
model.low_vram_shift(is_diffusing=False)

x_samples = model.decode_first_stage(samples)
x_samples = (
Expand Down
23 changes: 22 additions & 1 deletion img_styler/ui/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,34 @@ async def update_processed_face(q: Q, save=False):
value=q.client.no_images,
tooltip="Number of image samples to generate.",
),
ui.slider(
name="prompt_resolution",
label="Image Resolution",
min=256,
max=768,
step=64,
value=q.client.prompt_resolution or 512,
),
ui.slider(
name="prompt_strength",
label="Control Strength",
min=0.0,
max=2.0,
step=0.01,
value=q.client.prompt_strength or 1.0,
),
ui.slider(
name="prompt_scale",
label="Guidance Scale",
min=0.1,
max=30.0,
step=0.1,
value=q.client.prompt_strength,
value=q.client.prompt_scale or 9.0,
),
ui.checkbox(
name="prompt_save_memory",
label="Save Memory",
value=q.client.prompt_save_memory or True
),
ui.textbox(
name="prompt_a",
Expand Down

0 comments on commit 3c13472

Please sign in to comment.