-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable PyTorch's FakeTensorMode for EulerDiscreteScheduler scheduler #7151
Enable PyTorch's FakeTensorMode for EulerDiscreteScheduler scheduler #7151
Conversation
This fix could also be applied to other schedulers. |
1228ebe
to
4ef59ed
Compare
@thiagocrepaldi |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
yes, will do now. thanks |
For other schedulers, we need a way to workaround the absence of torch.interp to fix lines such as Any ideas? Maybe use the following snippet instead? (from the link above) def interpolate(x: torch.Tensor, xp: torch.Tensor, fp: torch.Tensor) -> torch.Tensor:
"""One-dimensional linear interpolation for monotonically increasing sample
points.
Returns the one-dimensional piecewise linear interpolant to a function with
given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`.
Args:
x: the :math:`x`-coordinates at which to evaluate the interpolated
values.
xp: the :math:`x`-coordinates of the data points, must be increasing.
fp: the :math:`y`-coordinates of the data points, same length as `xp`.
Returns:
the interpolated values, same size as `x`.
"""
m = (fp[:,1:] - fp[:,:-1]) / (xp[:,1:] - xp[:,:-1]) #slope
b = fp[:, :-1] - (m.mul(xp[:, :-1]) )
indicies = torch.sum(torch.ge(x[:, :, None], xp[:, None, :]), -1) - 1 #torch.ge: x[i] >= xp[i] ? true: false
indicies = torch.clamp(indicies, 0, m.shape[-1] - 1)
line_idx = torch.linspace(0, indicies.shape[0], 1, device=indicies.device).to(torch.long)
line_idx = line_idx.expand(indicies.shape)
# idx = torch.cat([line_idx, indicies] , 0)
return m[line_idx, indicies].mul(x) + b[line_idx, indicies] |
2af35c9
to
56b29c9
Compare
56b29c9
to
15c9796
Compare
Hi folks, from the discussion at #7151 , do you think we can merge this one? |
b58b3e9
to
15c9796
Compare
PyTorch's FakeTensorMode does not support `.numpy()` or `numpy.array()` calls. This PR replaces `sigmas` numpy tensor by a PyTorch tensor equivalent Repro ```python with torch._subclasses.FakeTensorMode() as fake_mode, ONNXTorchPatcher(): fake_model = DiffusionPipeline.from_pretrained(model_name, low_cpu_mem_usage=False) ``` that otherwise would fail with `RuntimeError: .numpy() is not supported for tensor subclasses.`
15c9796
to
c5d6b68
Compare
thanks! merged |
PyTorch's FakeTensorMode does not support
.numpy()
ornumpy.array()
calls.This PR replaces
sigmas
numpy tensor by a PyTorch tensor equivalentRepro
that otherwise would fail with
RuntimeError: .numpy() is not supported for tensor subclasses.
Fixes #7152