Skip to content

Commit

Permalink
Use different seed for num_images_per_prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
Anghellia committed Aug 22, 2024
1 parent f499380 commit 5d6544f
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
4 changes: 3 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def main(args):
else:
image = None

xflux_pipeline = XFluxPipeline(args.model_type, args.device, args.offload, args.seed)
xflux_pipeline = XFluxPipeline(args.model_type, args.device, args.offload)
if args.use_lora:
print('load lora:', args.lora_local_path, args.lora_repo_id, args.lora_name)
xflux_pipeline.set_lora(args.lora_local_path, args.lora_repo_id, args.lora_name, args.lora_weight)
Expand All @@ -121,6 +121,7 @@ def main(args):
height=args.height,
guidance=args.guidance,
num_steps=args.num_steps,
seed=args.seed,
true_gs=args.true_gs,
neg_prompt=args.neg_prompt,
timestep_to_start_cfg=args.timestep_to_start_cfg,
Expand All @@ -129,6 +130,7 @@ def main(args):
os.mkdir(args.save_path)
ind = len(os.listdir(args.save_path))
result.save(os.path.join(args.save_path, f"result_{ind}.png"))
args.seed = args.seed + 1


if __name__ == "__main__":
Expand Down
12 changes: 6 additions & 6 deletions src/flux/xflux_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@


class XFluxPipeline:
def __init__(self, model_type, device, offload: bool = False, seed: int = None):
def __init__(self, model_type, device, offload: bool = False):
self.device = torch.device(device)
self.offload = offload
self.seed = seed
self.model_type = model_type

self.clip = load_clip(self.device)
Expand Down Expand Up @@ -78,6 +77,7 @@ def __call__(self,
height: int = 512,
guidance: float = 4,
num_steps: int = 50,
seed: int = 123456789,
true_gs = 3,
neg_prompt: str = '',
timestep_to_start_cfg: int = 0,
Expand All @@ -93,20 +93,20 @@ def __call__(self,
controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(torch.bfloat16).to(self.device)

return self.forward(prompt, width, height, guidance, num_steps, controlnet_image,
return self.forward(prompt, width, height, guidance, num_steps, seed, controlnet_image,
timestep_to_start_cfg=timestep_to_start_cfg, true_gs=true_gs, neg_prompt=neg_prompt)

def forward(self, prompt, width, height, guidance, num_steps, controlnet_image=None, timestep_to_start_cfg=0, true_gs=3, neg_prompt=""):
def forward(self, prompt, width, height, guidance, num_steps, seed, controlnet_image=None, timestep_to_start_cfg=0, true_gs=3, neg_prompt=""):
x = get_noise(
1, height, width, device=self.device,
dtype=torch.bfloat16, seed=self.seed
dtype=torch.bfloat16, seed=seed
)
timesteps = get_schedule(
num_steps,
(width // 8) * (height // 8) // (16 * 16),
shift=True,
)
torch.manual_seed(self.seed)
torch.manual_seed(seed)
with torch.no_grad():
if self.offload:
self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
Expand Down

0 comments on commit 5d6544f

Please sign in to comment.