Skip to content

Commit

Permalink
fix: c_crossattn was placed on the wrong device in `fp32-text-enc + f…
Browse files Browse the repository at this point in the history
…orce-fp32`

[issues TencentQQGYLab#10](TencentQQGYLab#10)
  • Loading branch information
JettHu committed Apr 20, 2024
1 parent a9f23c5 commit b68b9d1
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion ella.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __call__(self, apply_model, kwargs: dict):
timestep_ = kwargs["timestep"]
c = kwargs["c"]
cond_or_uncond = kwargs["cond_or_uncond"] # [0|1]
_device = c["c_crossattn"].device

time_aware_encoder_hidden_states = []
self.ella.to(device=self.load_device)
Expand All @@ -70,7 +71,7 @@ def __call__(self, apply_model, kwargs: dict):
time_aware_encoder_hidden_states.append(h)
self.ella.to(self.offload_device)

c["c_crossattn"] = torch.cat(time_aware_encoder_hidden_states, dim=0)
c["c_crossattn"] = torch.cat(time_aware_encoder_hidden_states, dim=0).to(_device)

return apply_model(input_x, timestep_, **c)

Expand Down

0 comments on commit b68b9d1

Please sign in to comment.