diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 3fe52a498044..134b45a73b08 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -217,7 +217,16 @@ def step( prev_sample = sample + derivative * dt device = model_output.device if torch.is_tensor(model_output) else "cpu" - noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) + if str(device) == "mps": + # randn does not work reproducibly on mps + noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to( + device + ) + else: + noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to( + device + ) + prev_sample = prev_sample + noise * sigma_up if not return_dict: diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 93aeb8cc3865..6425072ac3c9 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -214,7 +214,16 @@ def step( gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 device = model_output.device if torch.is_tensor(model_output) else "cpu" - noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device) + if str(device) == "mps": + # randn does not work reproducibly on mps + noise = torch.randn(model_output.shape, dtype=model_output.dtype, device="cpu", generator=generator).to( + device + ) + else: + noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to( + device + ) + eps = noise * s_noise sigma_hat = sigma * (gamma + 1)