Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Flax] Add Textual Inversion #880

Merged
merged 10 commits into from
Oct 26, 2022

Conversation

duongna21
Copy link
Contributor

@duongna21 duongna21 commented Oct 17, 2022

What does this PR do?

I suppose Flax examples should be available, so I have made an attempt to create them. This first PR is on textual inversion, which is mainly based on the PyTorch implementation except for the way we freeze the token embeddings (#855).

The script works well & the results look good (check out sd-concepts-library/flax-cat-toy-test) on large-RAM CPU. However, on GPU (V100 16GB) and TPU (v3-8), I got an OOM error that are related to constant folding. I’m not an expert in debugging XLA so it would be great if someone could take a look at it.

How to run

export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export DATA_DIR="path-to-dir-containing-images"

python textual_inversion_flax.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$DATA_DIR \
  --learnable_property="object" \
  --placeholder_token="<cat-toy>" --initializer_token="toy" \
  --resolution=512 \
  --train_batch_size=1 \
  --max_train_steps=3000 \
  --learning_rate=5.0e-04 --scale_lr \
  --output_dir="textual_inversion_cat"

Here are the logs:

On CPU (80GB RAM): Success (peak usage: 40GB RAM)

2022-10-17 02:03:10.312421: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 1s:

  transpose.1257 (displaying the full instruction incurs a runtime overhead. Raise your logging level to 4 or above).

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime.  XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2022-10-17 02:03:10.597724: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 1.285499784s
Constant folding an instruction is taking > 1s:

  transpose.1257 (displaying the full instruction incurs a runtime overhead. Raise your logging level to 4 or above).

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime.  XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2022-10-17 02:06:13.405402: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module pmap_train_step] Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************

On TPU v3-8: Error

[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/message_lite.cc:484] xla.HloModuleProto exceeded maximum protobuf size of 2GB: 3732619020
2022-10-17 14:13:23.584071: F external/org_tensorflow/tensorflow/compiler/xla/stream_executor/tpu/proto_helper.h:44] Check failed: proto.SerializeToArray(bytes, size) 
https://symbolize.stripped_domain/r/?trace=7f2759ab403b,7f2759ab40bf,7f262c4cf135,7f262c4cf5b4,7f263001de93,7f263000b689,7f2630000a52,7f262ef283b1,7f262ef3860e,7f262c82efc7,7f262c6cd4e1,7f262c6cda42,7f262c6a4edb,5f6928,906aff&map= 
*** SIGABRT received by PID 110427 (TID 110427) on cpu 25 from PID 110427; stack trace: ***
PC: @     0x7f2759ab403b  (unknown)  raise
    @     0x7f25c6594e74       1120  (unknown)
    @     0x7f2759ab40c0  162761584  (unknown)
    @     0x7f262c4cf136        432  stream_executor::tpu::SerializeProto<>()
    @     0x7f262c4cf5b5       3520  xla::(anonymous namespace)::TpuCompiler::RunHloPasses()
    @     0x7f263001de94        624  xla::Service::BuildExecutable()
    @     0x7f263000b68a       1152  xla::LocalService::CompileExecutables()
    @     0x7f2630000a53       2832  xla::LocalClient::Compile()
    @     0x7f262ef283b2        912  xla::PjRtStreamExecutorClient::Compile()
    @     0x7f262ef3860f       1264  xla::PjRtStreamExecutorClient::Compile()
    @     0x7f262c82efc8       1248  xla::PyClient::CompileMlir()
    @     0x7f262c6cd4e2       2080  pybind11::detail::argument_loader<>::call_impl<>()
    @     0x7f262c6cda43        224  pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
    @     0x7f262c6a4edc        816  pybind11::cpp_function::dispatcher()
    @           0x5f6929  (unknown)  PyCFunction_Call
    @           0x906b00  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f2759ab403b,7f25c6594e73,7f2759ab40bf,7f262c4cf135,7f262c4cf5b4,7f263001de93,7f263000b689,7f2630000a52,7f262ef283b1,7f262ef3860e,7f262c82efc7,7f262c6cd4e1,7f262c6cda42,7f262c6a4edb,5f6928,906aff&map=3e96c9a48e1e2529ef5f4f875ba7cd3d:7f25b6321000-7f25c68579c0 
E1017 14:13:23.680544  110427 coredump_hook.cc:395] RAW: Remote crash data gathering hook invoked.
E1017 14:13:23.680559  110427 coredump_hook.cc:441] RAW: Skipping coredump since rlimit was 0 at process start.
E1017 14:13:23.680570  110427 client.cc:243] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E1017 14:13:23.680578  110427 coredump_hook.cc:502] RAW: Sending fingerprint to remote end.
E1017 14:13:23.680591  110427 coredump_socket.cc:120] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E1017 14:13:23.680608  110427 coredump_hook.cc:506] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E1017 14:13:23.680613  110427 coredump_hook.cc:580] RAW: Discarding core.
E1017 14:13:23.987094  110427 process_state.cc:775] RAW: Raising signal 6 with default behavior
Aborted (core dumped)

On Tesla V100 (16GB): Error

2022-10-17 15:00:23.861175: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 1s:

  reverse.18070 (displaying the full instruction incurs a runtime overhead. Raise your logging level to 4 or above).

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime.  XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2022-10-17 15:00:24.029558: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:133] The operation took 1.170920914s
Constant folding an instruction is taking > 1s:

  reverse.18070 (displaying the full instruction incurs a runtime overhead. Raise your logging level to 4 or above).

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime.  XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
Killed

Who can review?

cc @patrickvonplaten @patil-suraj

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 17, 2022

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor

Very cool idea @duongna21 - cc'ing @patil-suraj here for a review :-)

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this example, very cool!

From a first look, it looks good, will try this on a TPU v3/v2 and then do a detailed review :)

@duongna21
Copy link
Contributor Author

duongna21 commented Oct 25, 2022

@patil-suraj I found and fixed the bug related to constant folding. It turns out that I forgot to replicate the params of vae and unet across devices. Now the script works well on Tesla V100 (16GB) and TPU v3-8.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patil-suraj I leave it up to you to merge the PR if you're happy with it

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work, just tried this ona v3-8 and it works great! I just left one commnet, until the PR in transformers is merged, we could load the clip model directly from it's repo.

Also, would be awesome, if you could add a section in readme on how to use this script. Then it should be good to merge :)

Comment on lines +462 to +470
def zero_grads():
# from https://github.com/deepmind/optax/issues/159#issuecomment-896459491
def init_fn(_):
return ()

def update_fn(updates, state, params=None):
return jax.tree_util.tree_map(jnp.zeros_like, updates), ()

return optax.GradientTransformation(init_fn, update_fn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool!

Co-authored-by: Suraj Patil <surajp815@gmail.com>
@duongna21 duongna21 changed the title Add Textual inversion Flax [Flax] Add Textual Inversion Oct 26, 2022
@duongna21
Copy link
Contributor Author

duongna21 commented Oct 26, 2022

Amazing work, just tried this ona v3-8 and it works great! I just left one commnet, until the PR in transformers is merged, we could load the clip model directly from it's repo.

Also, would be awesome, if you could add a section in readme on how to use this script. Then it should be good to merge :)

@patil-suraj Thanks for the review. Addressed your comments. Check it out!

@patil-suraj
Copy link
Contributor

patil-suraj commented Oct 26, 2022

Thanks a lot for updating the readme and the awesome contribution, merging!

Let's announce it tomorrow :)

@patil-suraj patil-suraj merged commit a23ad87 into huggingface:main Oct 26, 2022
@duongna21 duongna21 deleted the add-flax-textual-inversion branch October 26, 2022 22:22
@duongna21 duongna21 mentioned this pull request Nov 3, 2022
PhaneeshB pushed a commit to nod-ai/diffusers that referenced this pull request Mar 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants