Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax] added broadcast_to_shape_from_left helper and Scheduler tests #864

Merged
merged 15 commits into from
Oct 25, 2022

Conversation

kashif
Copy link
Contributor

@kashif kashif commented Oct 17, 2022

instead of the while loop

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 17, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me - could we add one test?
Also did we make sure that the Flax pipeline still works (e.g. did we run the slow Flax tests once? )

@kashif
Copy link
Contributor Author

kashif commented Oct 17, 2022

yes let me add flax scheduler tests! good idea!

@kashif
Copy link
Contributor Author

kashif commented Oct 17, 2022

I will integrate #580 here

@kashif kashif changed the title [Flax] added broadcast_to_shape_from_left helper [Flax] added broadcast_to_shape_from_left helper and Scheduler tests Oct 19, 2022
@kashif
Copy link
Contributor Author

kashif commented Oct 19, 2022

@anton-l i have added some flax scheduler tests... whenever you get a chance i would not mind a review. Thanks!

@patrickvonplaten
Copy link
Contributor

Did we test the changes on a TPUv3-8 or TPUv2-8?

@patrickvonplaten
Copy link
Contributor

Happy to merge once confirmed everything works on TPUv3-8 and fast tests pass. Think the current failure of the fast tests is unrelated and has been fixed by Suraj here: #928

@kashif
Copy link
Contributor Author

kashif commented Oct 20, 2022

@patrickvonplaten thanks let me update my branch!

@pcuenca
Copy link
Member

pcuenca commented Oct 21, 2022

The following test passes in a TPU v3-8:

def test_stable_diffusion_v1_4(self):

But the others in the same class don't. For example, we are now getting (8, 1, 128, 128, 3) here:

assert images.shape == (8, 1, 64, 64, 3)
. The same thing happens in main, but it works in 7c22626.

I'm a bit puzzled, I'll take another look when I've had some rest.

@patrickvonplaten
Copy link
Contributor

Ok as this PR gives the same results as main merging this for now

@patrickvonplaten patrickvonplaten merged commit 240abdd into huggingface:main Oct 25, 2022
@kashif kashif deleted the jax-broadcast branch October 25, 2022 11:48
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…uggingface#864)

* added broadcast_to_shape_from_left helper

* initial tests

* fixed pndm tests

* shape required for pndm

* added require_flax

* fix style

* fix more imports

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants