diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py index 583ef3ee02c2..8a8d5b65e31a 100644 --- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py @@ -19,7 +19,7 @@ from ...models import StableCascadeUNet from ...schedulers import DDPMWuerstchenScheduler -from ...utils import logging, replace_example_docstring +from ...utils import is_torch_version, logging, replace_example_docstring from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel @@ -361,6 +361,8 @@ def __call__( device = self._execution_device dtype = self.decoder.dtype self._guidance_scale = guidance_scale + if is_torch_version("<", "2.2.0") and dtype == torch.bfloat16: + raise ValueError("`StableCascadeDecoderPipeline` requires torch>=2.2.0 when using `torch.bfloat16` dtype.") # 1. Check inputs. Raise error if not correct self.check_inputs(