Skip to content

Commit

Permalink
Fix bf16 support
Browse files Browse the repository at this point in the history
Fix scheduler.step call to latest
  • Loading branch information
nngokhale committed Jan 12, 2024
1 parent 6b4c5f3 commit e0dab26
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 3 deletions.
3 changes: 2 additions & 1 deletion examples/controlnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ python text_to_image_generation_canny.py \
--image_save_dir /tmp/controlnet_images \
--use_habana \
--use_hpu_graphs \
--gaudi_config Habana/stable-diffusion
--gaudi_config Habana/stable-diffusion \
--bf16
```

3 changes: 2 additions & 1 deletion examples/controlnet/text_to_image_generation_canny.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def main():

# Initialize the scheduler and the generation pipeline
scheduler = GaudiDDIMScheduler.from_pretrained(args.model_name_or_path, subfolder="scheduler")
controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
model_dtype = torch.bfloat16 if args.bf16 else None
controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path, torch_dtype=model_dtype)
kwargs = {
"scheduler": scheduler,
"use_habana": args.use_habana,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,7 @@ def __call__(

# compute the previous noisy sample x_t -> x_t-1
latents_batch = self.scheduler.step(
noise_pred, latents_batch, **extra_step_kwargs, return_dict=False
noise_pred, timestep, latents_batch, **extra_step_kwargs, return_dict=False
)[0]

if not self.use_hpu_graphs:
Expand Down

0 comments on commit e0dab26

Please sign in to comment.