From 327d8e5a6a716f41c085e9cd772f1b20f86a68e1 Mon Sep 17 00:00:00 2001 From: DN6 Date: Thu, 7 Mar 2024 16:38:16 +0530 Subject: [PATCH] update --- .../pipelines/stable_cascade/pipeline_stable_cascade.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py index 583ef3ee02c2..8a8d5b65e31a 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py @@ -19,7 +19,7 @@ from ...models import StableCascadeUNet from ...schedulers import DDPMWuerstchenScheduler -from ...utils import logging, replace_example_docstring +from ...utils import is_torch_version, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel @@ -361,6 +361,8 @@ def __call__( device = self._execution_device dtype = self.decoder.dtype self._guidance_scale = guidance_scale + if is_torch_version("<", "2.2.0") and dtype == torch.bfloat16: + raise ValueError("`StableCascadeDecoderPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype.") # 1. Check inputs. Raise error if not correct self.check_inputs(