diff --git a/ella.py b/ella.py index e6abeed..0843c7c 100644 --- a/ella.py +++ b/ella.py @@ -59,6 +59,7 @@ def __call__(self, apply_model, kwargs: dict): timestep_ = kwargs["timestep"] c = kwargs["c"] cond_or_uncond = kwargs["cond_or_uncond"] # [0|1] + _device = c["c_crossattn"].device time_aware_encoder_hidden_states = [] self.ella.to(device=self.load_device) @@ -70,7 +71,7 @@ def __call__(self, apply_model, kwargs: dict): time_aware_encoder_hidden_states.append(h) self.ella.to(self.offload_device) - c["c_crossattn"] = torch.cat(time_aware_encoder_hidden_states, dim=0) + c["c_crossattn"] = torch.cat(time_aware_encoder_hidden_states, dim=0).to(_device) return apply_model(input_x, timestep_, **c)