Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pipiline] Wuerstchen v3 aka Stable Cascasde pipeline #6487

Merged
merged 123 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
123 commits
Select commit Hold shift + click to select a range
6185da3
initial diffNext v3
kashif Oct 22, 2023
6fd8639
move to v3 folder
kashif Nov 11, 2023
86e2bcd
imports
kashif Nov 11, 2023
e77db10
dry up the unets
kashif Nov 12, 2023
644dc5d
Merge branch 'main' into wuerstchen-v3
kashif Jan 8, 2024
1380b95
no switch_level
kashif Jan 9, 2024
2bca122
fix init
kashif Jan 9, 2024
d4d0bc1
add switch_level tp config
kashif Jan 9, 2024
0db9e4d
Fixed some things
dome272 Jan 11, 2024
87e5577
Added pooled text embeddings
dome272 Jan 11, 2024
38f9f35
Initial work on adding image encoder
dome272 Jan 12, 2024
dc3f47e
changes from @dome272
kashif Jan 15, 2024
5c6635f
Stuff for the image encoder processing and variable naming in decoder
dome272 Jan 16, 2024
3d41b2a
fix arg name
kashif Jan 19, 2024
2012e71
inference fixes
dome272 Jan 21, 2024
add164a
inference fixes
dome272 Jan 21, 2024
f6035c6
Merge branch 'main' into wuerstchen-v3
kashif Jan 29, 2024
edbd76b
default TimestepBlock without conds
kashif Jan 29, 2024
c5326fa
c_skip=0 by default
kashif Jan 29, 2024
228f98c
fix bfloat16 to cpu
kashif Jan 29, 2024
b1e6db3
use config
kashif Jan 29, 2024
0fb4bf8
undo temp change
kashif Jan 30, 2024
834baba
fix gen_c_embeddings args
kashif Jan 30, 2024
fc361d2
change text encoding
dome272 Feb 2, 2024
7632707
text encoding
dome272 Feb 2, 2024
bef887a
undo print
kashif Feb 2, 2024
0816469
Merge branch 'main' into wuerstchen-v3
kashif Feb 2, 2024
b1413d5
undo .gitignore change
kashif Feb 2, 2024
ae5967b
Allow WuerstchenV3PriorPipeline to use the base DDPM & DDIM schedulers
pabloppp Feb 2, 2024
979ea12
use WuerstchenV3Unet in both pipelines
kashif Feb 4, 2024
dc24bb4
Merge branch 'main' into wuerstchen-v3
kashif Feb 4, 2024
966cdbc
fix imports
kashif Feb 4, 2024
d9a71df
initial failing tests
kashif Feb 4, 2024
af02b68
cleanup
kashif Feb 4, 2024
e962671
use scheduler.timesterps
kashif Feb 4, 2024
a1ecef2
some fixes to the tests, still not fully working
pabloppp Feb 7, 2024
331d0d3
fix tests
kashif Feb 7, 2024
7452985
fix prior tests
kashif Feb 7, 2024
c0bb4ca
add dropout to the model_kwargs
kashif Feb 8, 2024
e01bc49
more tests passing
kashif Feb 10, 2024
17fed8c
update expected_slice
kashif Feb 10, 2024
733ec02
initial rename
kashif Feb 10, 2024
021c3e2
rename tests
kashif Feb 10, 2024
b2c615f
rename class names
kashif Feb 10, 2024
3d5328e
make fix-copies
kashif Feb 10, 2024
33a1af8
initial docs
kashif Feb 10, 2024
a7040a2
autodocs
kashif Feb 10, 2024
8882633
typos
kashif Feb 10, 2024
e63a312
fix arg docs
kashif Feb 10, 2024
cdeb5da
add text_encoder info
kashif Feb 10, 2024
72b87e7
combined pipeline has optional image arg
kashif Feb 10, 2024
d929cdf
Merge branch 'main' into wuerstchen-v3
sayakpaul Feb 12, 2024
c883cb2
fix documentation
sayakpaul Feb 12, 2024
66a17e1
Merge branch 'main' into wuerstchen-v3
kashif Feb 12, 2024
a3dc213
Merge branch 'main' into wuerstchen-v3
kashif Feb 13, 2024
33b70f4
Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade…
kashif Feb 15, 2024
cc10c29
Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade…
kashif Feb 15, 2024
6f5ed3d
Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade…
kashif Feb 15, 2024
bf3a972
Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade…
kashif Feb 15, 2024
5634ef3
Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
kashif Feb 15, 2024
3cf4c1b
Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade…
kashif Feb 15, 2024
b5e2ca9
use self.config
kashif Feb 15, 2024
9b525fd
Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade…
kashif Feb 15, 2024
60efc49
c_in -> in_channels
kashif Feb 15, 2024
cbd0775
removed kwargs from unet's forward
kashif Feb 15, 2024
c1f72e3
Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
kashif Feb 15, 2024
7cb3838
Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
kashif Feb 15, 2024
b3a80f7
remove older callback api
kashif Feb 15, 2024
519805f
removed kwargs and fixed decoder guidance > 1
kashif Feb 18, 2024
7698bf6
decoder takes emeds
kashif Feb 18, 2024
88633a9
check and use image_embeds
kashif Feb 18, 2024
d68207b
fixed all but one decoder test
kashif Feb 19, 2024
143df09
fix decoder tests
kashif Feb 19, 2024
2483df2
Merge branch 'main' into wuerstchen-v3
kashif Feb 21, 2024
f4a788b
Merge branch 'main' into wuerstchen-v3
kashif Feb 22, 2024
169db20
update callback api
kashif Feb 22, 2024
3cb0ec1
fix some more combined tests
kashif Feb 22, 2024
84f4f3d
push combined pipeline
kashif Feb 22, 2024
4f69a51
initial docs
kashif Feb 23, 2024
7dcbdc6
fix doc_string
kashif Feb 23, 2024
4f5dffb
update combined api
kashif Feb 23, 2024
4753b99
no test_callback_inputs test for combined pipeline
kashif Feb 26, 2024
adec75f
add optional components
kashif Feb 26, 2024
2e877d2
fix ordering of components
kashif Feb 26, 2024
e956f3e
fix combined tests
kashif Feb 26, 2024
f18ff23
update convert script
kashif Feb 26, 2024
3ff5120
Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade…
kashif Feb 27, 2024
979fed0
Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade…
kashif Feb 27, 2024
72cf605
Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade…
kashif Feb 27, 2024
b2e0f06
fix imports
kashif Feb 27, 2024
4c33b8a
move effnet out of deniosing loop
kashif Feb 27, 2024
9785210
prompt_embeds_pooled only when doing guidance
kashif Feb 27, 2024
25ecc81
Fix repeat shape
99991 Feb 29, 2024
1b171b6
Merge pull request #2 from 99991/wuerstchen-v3
kashif Feb 29, 2024
4914e04
move StableCascadeUnet to models/unets/
kashif Feb 29, 2024
8c2e479
more descriptive names
kashif Feb 29, 2024
2d1f438
Merge branch 'main' into wuerstchen-v3
kashif Feb 29, 2024
871387e
converted when numpy()
kashif Feb 29, 2024
85fb15c
StableCascadePriorPipelineOutput docs
kashif Feb 29, 2024
72249af
rename StableCascadeUNet
kashif Mar 1, 2024
6767b29
add slow tests
kashif Mar 2, 2024
cb7f47c
fix slow tests
kashif Mar 2, 2024
7ff8828
Merge branch 'main' into wuerstchen-v3
kashif Mar 4, 2024
748ab08
update
DN6 Mar 5, 2024
3ad7516
update
DN6 Mar 5, 2024
e7434ff
updated model_path
kashif Mar 5, 2024
ac716ab
add args for weights
kashif Mar 5, 2024
13e9812
set push_to_hub to false
kashif Mar 5, 2024
b6d3b6f
update
DN6 Mar 5, 2024
a07623f
update
DN6 Mar 5, 2024
a487d16
Merge branch 'wuerstchen-v3' of https://github.com/kashif/diffusers i…
DN6 Mar 5, 2024
c6a5537
update
DN6 Mar 5, 2024
e505de1
update
DN6 Mar 5, 2024
a2a5060
update
DN6 Mar 5, 2024
3326dee
update
DN6 Mar 5, 2024
11eac5f
update
DN6 Mar 5, 2024
2c226cd
update
DN6 Mar 5, 2024
d3e8cef
update
DN6 Mar 6, 2024
df5ed03
update
DN6 Mar 6, 2024
8e74e09
update
DN6 Mar 6, 2024
c1cd769
update
DN6 Mar 6, 2024
ceedcc4
update
DN6 Mar 6, 2024
8dd88f0
update
DN6 Mar 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 176 additions & 0 deletions scripts/convert_wuerstchen3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Run this script to convert the Wuerstchen V3 model weights to a diffusers pipeline.

import torch
from transformers import (
AutoTokenizer,
CLIPConfig,
CLIPImageProcessor,
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
)

# from vqgan import VQModel
from diffusers import (
DDPMWuerstchenScheduler,
WuerstchenV3DecoderPipeline,
WuerstchenV3PriorPipeline,
)
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.pipelines.wuerstchen3 import WuerstchenV3Unet


device = "cpu"

# set paths to model weights
model_path = "../Wuerstchen"
prior_checkpoint_path = f"{model_path}/v1.pt"
decoder_checkpoint_path = f"{model_path}/base_120k.pt"


# Clip Text encoder and tokenizer
config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
config.text_config.projection_dim = config.projection_dim
text_encoder = CLIPTextModelWithProjection.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config
)
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")

# image processor
feature_extractor = CLIPImageProcessor()
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")

# Prior
orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
state_dict = {}
for key in orig_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else:
state_dict[key] = orig_state_dict[key]
prior_model = WuerstchenV3Unet(
c_in=16,
c_out=16,
c_r=64,
patch_size=1,
c_cond=2048,
c_hidden=[2048, 2048],
nhead=[32, 32],
blocks=[[8, 24], [24, 8]],
block_repeat=[[1, 1], [1, 1]],
level_config=["CTA", "CTA"],
c_clip_text=1280,
c_clip_text_pooled=1280,
c_clip_img=768,
c_clip_seq=4,
kernel_size=3,
dropout=[0.1, 0.1],
self_attn=True,
t_conds=["sca", "crp"],
switch_level=[False],
).to(device)
prior_model.load_state_dict(state_dict)

# scheduler for prior and decoder
scheduler = DDPMWuerstchenScheduler()

# Prior pipeline
prior_pipeline = WuerstchenV3PriorPipeline(
prior=prior_model,
tokenizer=tokenizer,
text_encoder=text_encoder,
image_encoder=image_encoder,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
prior_pipeline.save_pretrained("wuerstchenV3-prior")

# Decoder
orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
state_dict = {}
for key in orig_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
# rename clip_mapper to clip_txt_pooled_mapper
elif key.endswith("clip_mapper.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
elif key.endswith("clip_mapper.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
else:
state_dict[key] = orig_state_dict[key]
decoder = WuerstchenV3Unet(
c_in=4,
c_out=4,
c_r=64,
patch_size=2,
c_cond=1280,
c_hidden=[320, 640, 1280, 1280],
nhead=[-1, -1, 20, 20],
blocks=[[2, 6, 28, 6], [6, 28, 6, 2]],
block_repeat=[[1, 1, 1, 1], [3, 3, 2, 2]],
level_config=["CT", "CT", "CTA", "CTA"],
c_clip_text_pooled=1280,
c_clip_seq=4,
c_effnet=16,
c_pixels=3,
kernel_size=3,
dropout=[0, 0, 0.1, 0.1],
self_attn=True,
t_conds=["sca"],
).to(device)
decoder.load_state_dict(state_dict)

# VQGAN from V2
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")

# Decoder pipeline
decoder_pipeline = WuerstchenV3DecoderPipeline(
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
)
decoder_pipeline.save_pretrained("wuerstchenV3")

# TODO
# # Wuerstchen pipeline
# wuerstchen_pipeline = WuerstchenCombinedPipeline(
# # Decoder
# text_encoder=gen_text_encoder,
# tokenizer=gen_tokenizer,
# decoder=decoder,
# scheduler=scheduler,
# vqgan=vqmodel,
# # Prior
# prior_tokenizer=tokenizer,
# prior_text_encoder=text_encoder,
# prior=prior_model,
# prior_scheduler=scheduler,
# )
# wuerstchen_pipeline.save_pretrained("warp-ai/WuerstchenCombinedPipeline")
6 changes: 6 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,9 @@
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
"WuerstchenV3CombinedPipeline",
"WuerstchenV3DecoderPipeline",
"WuerstchenV3PriorPipeline",
]
)

Expand Down Expand Up @@ -670,6 +673,9 @@
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
WuerstchenV3CombinedPipeline,
WuerstchenV3DecoderPipeline,
WuerstchenV3PriorPipeline,
)

try:
Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,11 @@
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
]
_import_structure["wuerstchen3"] = [
"WuerstchenV3CombinedPipeline",
"WuerstchenV3DecoderPipeline",
"WuerstchenV3PriorPipeline",
]
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -471,6 +476,11 @@
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
)
from .wuerstchen3 import (
WuerstchenV3CombinedPipeline,
WuerstchenV3DecoderPipeline,
WuerstchenV3PriorPipeline,
)

try:
if not is_onnx_available():
Expand Down
19 changes: 15 additions & 4 deletions src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,20 @@ def forward(self, x):


kashif marked this conversation as resolved.
Show resolved Hide resolved
class TimestepBlock(nn.Module):
def __init__(self, c, c_timestep):
DN6 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, c, c_timestep, conds=[]):
super().__init__()
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.mapper = linear_cls(c_timestep, c * 2)
self.conds = conds
for cname in conds:
setattr(self, f"mapper_{cname}", linear_cls(c_timestep, c * 2))

def forward(self, x, t):
a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
t = t.chunk(len(self.conds) + 1, dim=1)
a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
for i, c in enumerate(self.conds):
ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
a, b = a + ac, b + bc
return x * (1 + a) + b


Expand All @@ -49,10 +56,14 @@ def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear

self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
self.depthwise = conv_cls(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential(
linear_cls(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), linear_cls(c * 4, c)
linear_cls(c + c_skip, c * 4),
yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved
nn.GELU(),
GlobalResponseNorm(c * 4),
nn.Dropout(dropout),
linear_cls(c * 4, c),
)

def forward(self, x, x_skip=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def forward(self, x, r, effnet, clip=None, x_cat=None, eps=1e-3, return_noise=Tr


class ResBlockStageB(nn.Module):
def __init__(self, c, c_skip=None, kernel_size=3, dropout=0.0):
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
super().__init__()
self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,9 @@ def __call__(
latents = self.vqgan.config.scale_factor * latents
images = self.vqgan.decode(latents).sample.clamp(0, 1)
if output_type == "np":
images = images.permute(0, 2, 3, 1).cpu().numpy()
images = images.permute(0, 2, 3, 1).float().cpu().numpy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't that done automatically when converting to numpy?

elif output_type == "pil":
images = images.permute(0, 2, 3, 1).cpu().numpy()
images = images.permute(0, 2, 3, 1).float().cpu().numpy()
images = self.numpy_to_pil(images)
else:
images = latents
Expand Down
52 changes: 52 additions & 0 deletions src/diffusers/pipelines/wuerstchen3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import TYPE_CHECKING

from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)


_dummy_objects = {}
_import_structure = {}

try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects

_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["modeling_wuerstchen3_common"] = ["WuerstchenV3Unet"]
_import_structure["pipeline_wuerstchen3"] = ["WuerstchenV3DecoderPipeline"]
_import_structure["pipeline_wuerstchen3_combined"] = ["WuerstchenV3CombinedPipeline"]
_import_structure["pipeline_wuerstchen3_prior"] = ["WuerstchenV3PriorPipeline"]


if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modeling_wuerstchen3_common import WuerstchenV3Unet
from .pipeline_wuerstchen3 import WuerstchenV3DecoderPipeline
from .pipeline_wuerstchen3_combined import WuerstchenV3CombinedPipeline
from .pipeline_wuerstchen3_prior import WuerstchenV3PriorPipeline
else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)

for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
Loading
Loading