Skip to content

Commit

Permalink
Rework Wuerstchen notebook implementation (#1433)
Browse files Browse the repository at this point in the history
Signed-off-by: Ilya Trushkin <ilya.trushkin@intel.com>
  • Loading branch information
itrushkin authored Nov 7, 2023
1 parent 541a11c commit 854ae3f
Showing 1 changed file with 35 additions and 171 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -266,53 +266,26 @@
"## Convert the model to OpenVINO IR [$\\Uparrow$](#Table-of-content:)\n",
"Main model components:\n",
"- Prior stage: create low-dimensional latent space representation of the image using text-conditional LDM\n",
"- Decoder stage: using representation from Stage C, produce a latent image in latent space of higher dimensionality using another LDM and using VQGAN-decoder, decode the latent image to yield a full-resolution output image\n",
"- Decoder stage: using representation from Prior Stage, produce a latent image in latent space of higher dimensionality using another LDM and using VQGAN-decoder, decode the latent image to yield a full-resolution output image\n",
"\n",
"First, let's extract required pipeline components (PyTorch modules, tokenizers and schedulers) to free up the memory taken by the loaded pipeline. The pipeline consists of 2 sub-pipelines: Prior pipeline accessed by `prior_pipe` property, and Decoder Pipeline accessed by `decoder_pipe` property."
"The pipeline consists of 2 sub-pipelines: Prior pipeline accessed by `prior_pipe` property, and Decoder Pipeline accessed by `decoder_pipe` property."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2392ac9a-ae49-4a21-85a3-1656157f07a4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2239"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# Prior pipeline\n",
"prior_text_encoder = pipeline.prior_text_encoder\n",
"prior_text_encoder.eval()\n",
"prior_prior = pipeline.prior_prior\n",
"prior_prior.eval()\n",
"pipeline.prior_text_encoder.eval()\n",
"pipeline.prior_prior.eval()\n",
"\n",
"# Decoder pipeline\n",
"decoder = pipeline.decoder\n",
"decoder.eval()\n",
"text_encoder = pipeline.text_encoder\n",
"text_encoder.eval()\n",
"vqgan = pipeline.vqgan\n",
"vqgan.eval()\n",
"\n",
"# Pipelines tokenizers\n",
"tokenizer = pipeline.tokenizer\n",
"prior_tokenizer = pipeline.prior_tokenizer\n",
"\n",
"# Pipelines schedulers\n",
"scheduler = pipeline.scheduler\n",
"prior_scheduler = pipeline.prior_scheduler\n",
"\n",
"del pipeline\n",
"gc.collect()"
"pipeline.decoder.eval()\n",
"pipeline.text_encoder.eval()\n",
"pipeline.vqgan.eval();"
]
},
{
Expand Down Expand Up @@ -376,15 +349,16 @@
],
"source": [
"convert(\n",
" prior_text_encoder,\n",
" pipeline.prior_text_encoder,\n",
" PRIOR_TEXT_ENCODER_PATH,\n",
" example_input={\n",
" \"input_ids\": torch.zeros(1, 77, dtype=torch.int32),\n",
" \"attention_mask\": torch.zeros(1, 77),\n",
" },\n",
" input={\"input_ids\": ((1, 77),), \"attention_mask\": ((1, 77),)},\n",
")\n",
"del prior_text_encoder\n",
"del pipeline.prior_text_encoder\n",
"del pipeline.prior_pipe.text_encoder\n",
"gc.collect()"
]
},
Expand Down Expand Up @@ -415,12 +389,13 @@
],
"source": [
"convert(\n",
" prior_prior,\n",
" pipeline.prior_prior,\n",
" PRIOR_PRIOR_PATH,\n",
" example_input=[torch.zeros(2, 16, 24, 24), torch.zeros(2), torch.zeros(2, 77, 1280)],\n",
" input=[((2, 16, 24, 24),), ((2),), ((2, 77, 1280),)],\n",
")\n",
"del prior_prior\n",
"del pipeline.prior_prior\n",
"del pipeline.prior_pipe.prior\n",
"gc.collect()"
]
},
Expand Down Expand Up @@ -458,7 +433,7 @@
],
"source": [
"convert(\n",
" decoder,\n",
" pipeline.decoder,\n",
" DECODER_PATH,\n",
" example_input={\n",
" \"x\": torch.zeros(1, 4, 256, 256),\n",
Expand All @@ -473,7 +448,8 @@
" \"clip\": ((1, 77, 1024),),\n",
" },\n",
")\n",
"del decoder\n",
"del pipeline.decoder\n",
"del pipeline.decoder_pipe.decoder\n",
"gc.collect()"
]
},
Expand Down Expand Up @@ -504,15 +480,16 @@
],
"source": [
"convert(\n",
" text_encoder,\n",
" pipeline.text_encoder,\n",
" TEXT_ENCODER_PATH,\n",
" example_input={\n",
" \"input_ids\": torch.zeros(1, 77, dtype=torch.int32),\n",
" \"attention_mask\": torch.zeros(1, 77),\n",
" },\n",
" input={\"input_ids\": ((1, 77),), \"attention_mask\": ((1, 77),)},\n",
")\n",
"del text_encoder\n",
"del pipeline.text_encoder\n",
"del pipeline.decoder_pipe.text_encoder\n",
"gc.collect()"
]
},
Expand Down Expand Up @@ -559,12 +536,13 @@
],
"source": [
"convert(\n",
" VqganDecoderWrapper(vqgan),\n",
" VqganDecoderWrapper(pipeline.vqgan),\n",
" VQGAN_PATH,\n",
" VQGAN_PATH,\n",
" example_input=torch.zeros(1, 4, 256, 256),\n",
" input=(1, 4, 256, 256),\n",
")\n",
"del vqgan\n",
"del pipeline.decoder_pipe.vqgan\n",
"gc.collect()"
]
},
Expand Down Expand Up @@ -754,7 +732,7 @@
"id": "395897d0-568c-4387-878b-50afc9d5c2d4",
"metadata": {},
"source": [
"And define sub-pipeline classes for both Prior and Decoder pipelines."
"And insert wrappers instances in the pipeline:"
]
},
{
Expand All @@ -764,106 +742,12 @@
"metadata": {},
"outputs": [],
"source": [
"class OVWuerstchenPriorPipeline:\n",
" _execution_device = torch.device(\"cpu\") # accessed in original workflow\n",
" config = namedtuple(\n",
" \"OVWuerstchenPriorPipelineConfig\", [\"resolution_multiple\", \"latent_mean\", \"latent_std\"]\n",
" )(42.67, 42, 1) # accessed in the original workflow\n",
"\n",
" def __init__(self, text_encoder, prior, tokenizer, scheduler):\n",
" self.text_encoder = TextEncoderWrapper(text_encoder)\n",
" self.prior = PriorPriorWrapper(prior)\n",
" self.tokenizer = tokenizer\n",
" self.scheduler = scheduler\n",
"\n",
"\n",
"OVWuerstchenPriorPipeline.__call__ = diffusers.pipelines.WuerstchenPriorPipeline.__call__\n",
"\n",
"# Methods below accessed in __call__\n",
"OVWuerstchenPriorPipeline.check_inputs = diffusers.pipelines.WuerstchenPriorPipeline.check_inputs\n",
"OVWuerstchenPriorPipeline.encode_prompt = diffusers.pipelines.WuerstchenPriorPipeline.encode_prompt\n",
"OVWuerstchenPriorPipeline.prepare_latents = (\n",
" diffusers.pipelines.WuerstchenPriorPipeline.prepare_latents\n",
")\n",
"OVWuerstchenPriorPipeline.progress_bar = diffusers.pipelines.WuerstchenPriorPipeline.progress_bar"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "6da9d343-286d-4448-b3f4-40cacb7d4eb6",
"metadata": {},
"outputs": [],
"source": [
"class OVWuerstchenDecoderPipeline:\n",
" _execution_device = torch.device(\"cpu\") # accessed in the original workflow\n",
" config = namedtuple(\"OVWuerstchenDecoderPipelineConfig\", \"latent_dim_scale\")(10.67) # accessed in the original workflow\n",
"\n",
" def __init__(self, decoder, text_encoder, vqgan, tokenizer, scheduler):\n",
" self.decoder = DecoderWrapper(decoder)\n",
" self.text_encoder = TextEncoderWrapper(text_encoder)\n",
" self.vqgan = VqganWrapper(vqgan)\n",
" self.tokenizer = tokenizer\n",
" self.scheduler = scheduler\n",
"\n",
"\n",
"OVWuerstchenDecoderPipeline.__call__ = diffusers.pipelines.WuerstchenDecoderPipeline.__call__\n",
"\n",
"# Methods below accessed in __call__\n",
"OVWuerstchenDecoderPipeline.check_inputs = (\n",
" diffusers.pipelines.WuerstchenDecoderPipeline.check_inputs\n",
")\n",
"OVWuerstchenDecoderPipeline.encode_prompt = (\n",
" diffusers.pipelines.WuerstchenDecoderPipeline.encode_prompt\n",
")\n",
"OVWuerstchenDecoderPipeline.prepare_latents = (\n",
" diffusers.pipelines.WuerstchenDecoderPipeline.prepare_latents\n",
")\n",
"OVWuerstchenDecoderPipeline.progress_bar = (\n",
" diffusers.pipelines.WuerstchenDecoderPipeline.progress_bar\n",
")\n",
"OVWuerstchenDecoderPipeline.numpy_to_pil = (\n",
" lambda _, images: diffusers.pipelines.WuerstchenDecoderPipeline.numpy_to_pil(images)\n",
")"
]
},
{
"cell_type": "markdown",
"id": "75e4bc8f-5fcd-4814-b124-65d5767f90c5",
"metadata": {},
"source": [
"Finally, combine pipelines together."
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "9b9d5a46-0668-4a4b-8696-ad3c85a1ad79",
"metadata": {},
"outputs": [],
"source": [
"class OVWuerstchenCombinedPipeline:\n",
" def __init__(\n",
" self,\n",
" prior_text_encoder,\n",
" prior_prior,\n",
" decoder,\n",
" text_encoder,\n",
" vqgan,\n",
" tokenizer,\n",
" scheduler,\n",
" prior_tokenizer,\n",
" prior_scheduler,\n",
" ):\n",
" self.prior_pipe = OVWuerstchenPriorPipeline(\n",
" prior_text_encoder, prior_prior, prior_tokenizer, prior_scheduler\n",
" )\n",
" self.decoder_pipe = OVWuerstchenDecoderPipeline(\n",
" decoder, text_encoder, vqgan, tokenizer, scheduler\n",
" )\n",
"\n",
"pipeline.prior_pipe.text_encoder = TextEncoderWrapper(ov_prior_text_encoder)\n",
"pipeline.prior_pipe.prior = PriorPriorWrapper(ov_prior_prior)\n",
"\n",
"OVWuerstchenCombinedPipeline.__call__ = diffusers.pipelines.WuerstchenCombinedPipeline.__call__"
"pipeline.decoder_pipe.decoder = DecoderWrapper(ov_decoder)\n",
"pipeline.decoder_pipe.text_encoder = TextEncoderWrapper(ov_text_encoder)\n",
"pipeline.decoder_pipe.vqgan = VqganWrapper(ov_vqgan)"
]
},
{
Expand All @@ -876,27 +760,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"id": "b2526fc1-203b-414a-bc69-fb7559ed11c0",
"metadata": {},
"outputs": [],
"source": [
"ov_pipe = OVWuerstchenCombinedPipeline(\n",
" ov_prior_text_encoder,\n",
" ov_prior_prior,\n",
" ov_decoder,\n",
" ov_text_encoder,\n",
" ov_vqgan,\n",
" tokenizer,\n",
" scheduler,\n",
" prior_tokenizer,\n",
" prior_scheduler,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 28,
"id": "9f6836ec-fca5-45b5-ba43-c6a5b5249628",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -934,7 +798,7 @@
"negative_prompt = \"\"\n",
"num_images_per_prompt = 1\n",
"\n",
"output = ov_pipe(\n",
"output = pipeline(\n",
" prompt=caption,\n",
" height=1024,\n",
" width=1024,\n",
Expand All @@ -948,7 +812,7 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 29,
"id": "6c4bfba0-9277-4e40-88a8-616e96bd5646",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -981,14 +845,14 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 30,
"id": "b6f8114e-9e7c-4a1f-a486-0f7eb26c9ded",
"metadata": {},
"outputs": [],
"source": [
"def generate(caption, negative_prompt, prior_guidance_scale, seed):\n",
" generator = torch.Generator().manual_seed(seed)\n",
" image = ov_pipe(\n",
" image = pipeline(\n",
" prompt=caption,\n",
" height=1024,\n",
" width=1024,\n",
Expand All @@ -1004,7 +868,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 31,
"id": "e9139efe-ab86-455a-ad84-eb7ce5e40ed7",
"metadata": {},
"outputs": [],
Expand Down

0 comments on commit 854ae3f

Please sign in to comment.