-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flax] Add Textual Inversion #880
[Flax] Add Textual Inversion #880
Conversation
The documentation is not available anymore as the PR was closed or merged. |
Very cool idea @duongna21 - cc'ing @patil-suraj here for a review :-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding this example, very cool!
From a first look, it looks good, will try this on a TPU v3/v2 and then do a detailed review :)
@patil-suraj I found and fixed the bug related to constant folding. It turns out that I forgot to replicate the params of vae and unet across devices. Now the script works well on Tesla V100 (16GB) and TPU v3-8. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patil-suraj I leave it up to you to merge the PR if you're happy with it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing work, just tried this ona v3-8 and it works great! I just left one commnet, until the PR in transformers is merged, we could load the clip model directly from it's repo.
Also, would be awesome, if you could add a section in readme on how to use this script. Then it should be good to merge :)
def zero_grads(): | ||
# from https://github.com/deepmind/optax/issues/159#issuecomment-896459491 | ||
def init_fn(_): | ||
return () | ||
|
||
def update_fn(updates, state, params=None): | ||
return jax.tree_util.tree_map(jnp.zeros_like, updates), () | ||
|
||
return optax.GradientTransformation(init_fn, update_fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool!
Co-authored-by: Suraj Patil <surajp815@gmail.com>
@patil-suraj Thanks for the review. Addressed your comments. Check it out! |
Thanks a lot for updating the readme and the awesome contribution, merging! Let's announce it tomorrow :) |
What does this PR do?
I suppose Flax examples should be available, so I have made an attempt to create them. This first PR is on textual inversion, which is mainly based on the PyTorch implementation except for the way we freeze the token embeddings (#855).
The script works well & the results look good (check out
sd-concepts-library/flax-cat-toy-test
) on large-RAM CPU. However, on GPU (V100 16GB) and TPU (v3-8), I got an OOM error that are related to constant folding. I’m not an expert in debugging XLA so it would be great if someone could take a look at it.How to run
Here are the logs:
On CPU (80GB RAM): Success (peak usage: 40GB RAM)
On TPU v3-8: Error
On Tesla V100 (16GB): Error
Who can review?
cc @patrickvonplaten @patil-suraj