diff --git a/scripts/prompt_travel.py b/scripts/prompt_travel.py index bf758a9..0bbf518 100644 --- a/scripts/prompt_travel.py +++ b/scripts/prompt_travel.py @@ -31,14 +31,25 @@ from modules.processing import Processed, StableDiffusionProcessing as Processing, StableDiffusionProcessingTxt2Img as ProcessingTxt2Img, StableDiffusionProcessingImg2Img as ProcessingImg2Img from modules.images import resize_image from modules.sd_samplers_common import single_sample_to_image -from modules.prompt_parser import DictWithShape - -''' -DictWithShape { - 'crossattn': Tensor, - 'vector': Tensor, -} -''' + +try: + from modules.prompt_parser import DictWithShape +except ImportError: + ''' + DictWithShape { + 'crossattn': Tensor, + 'vector': Tensor, + } + ''' + class DictWithShape(dict): + def __init__(self, x, shape): + super().__init__() + self.update(x) + + @property + def shape(self): + return self["crossattn"].shape + Cond = Union[Tensor, DictWithShape] class Mode(Enum):