Skip to content

Commit

Permalink
fix: unstable quality of image while multi-batch TencentQQGYLab#14
Browse files Browse the repository at this point in the history
  • Loading branch information
JettHu committed Apr 22, 2024
1 parent b68b9d1 commit a83bd37
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 5 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

## :star2: Changelog

- **[2024.4.22]** Fix unstable quality of image while multi-batch.
- **[2024.4.19]** Documenting nodes
- **[2024.4.19]** Initial repo

Expand Down
Binary file added assets/multi-batch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 11 additions & 5 deletions ella.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
from typing import Dict

import folder_paths
import torch
from comfy import model_management
from comfy.conds import CONDCrossAttn
from safetensors.torch import load_model

from .model import ELLA, T5TextEmbedder
Expand Down Expand Up @@ -37,7 +39,8 @@ def __init__(self, ella, model_sampling, positive, negative) -> None:
self.ella.to(self.dtype)
for i in range(len(self.embeds)):
for k in self.embeds[i]:
self.embeds[i][k].to(device=self.load_device, dtype=self.dtype)
self.embeds[i][k].to(dtype=self.dtype)
self.embeds[i][k] = CONDCrossAttn(self.embeds[i][k])

@property
def load_device(self):
Expand All @@ -47,10 +50,13 @@ def load_device(self):
def offload_device(self):
return model_management.text_encoder_offload_device()

def process_cond(self, embeds: Dict[str, CONDCrossAttn], batch_size, **kwargs):
return {k: v.process_cond(batch_size, self.load_device, **kwargs).cond for k, v in embeds.items()}

def prepare_conds(self):
self.ella.to(self.load_device)
cond = self.ella(torch.Tensor([999]).to(torch.int64), **self.embeds[0])
uncond = self.ella(torch.Tensor([999]).to(torch.int64), **self.embeds[1])
cond = self.ella(torch.Tensor([999]).to(torch.int64), **self.process_cond(self.embeds[0], 1))
uncond = self.ella(torch.Tensor([999]).to(torch.int64), **self.process_cond(self.embeds[1], 1))
self.ella.to(self.offload_device)
return cond, uncond

Expand All @@ -65,8 +71,8 @@ def __call__(self, apply_model, kwargs: dict):
self.ella.to(device=self.load_device)
for i in cond_or_uncond:
h = self.ella(
self.model_sampling.timestep(timestep_[i]),
**self.embeds[i],
self.model_sampling.timestep(timestep_[0]),
**self.process_cond(self.embeds[i], input_x.size(0) // len(cond_or_uncond)),
)
time_aware_encoder_hidden_states.append(h)
self.ella.to(self.offload_device)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
torch
safetensors
transformers
sentencepiece

0 comments on commit a83bd37

Please sign in to comment.