Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Circle Loss Function for Similarity/Metric Learning Tasks. #20452

Merged
merged 7 commits into from
Nov 6, 2024

Conversation

ma7555
Copy link
Contributor

@ma7555 ma7555 commented Nov 5, 2024

Overview

This PR adds the implementation of the Circle Loss function to Keras, addressing issue #20421. Circle Loss is designed for metric learning tasks, aiming to minimize within-class distances and maximize between-class distances in the embedding space.

  1. Normal Within-Batch Negative/Positive Pairs mining
    • Directly use the Circle class as the case with all other losses.
import keras

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

model = keras.Sequential()
model.add(keras.layers.InputLayer(shape=(32, 32, 3)))
model.add(keras.layers.Rescaling(1.0 / 255, offset=-1))
for i in range(3):
    model.add(
        keras.layers.Conv2D(
            32, (3, 3), padding="valid", activation="relu", name=f"conv_{i}"
        )
    )
    model.add(keras.layers.MaxPooling2D((2, 2)))
model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(64, activation=None))
model.add(keras.layers.UnitNormalization())

model.compile(
    optimizer="adam",
    loss=keras.src.losses.losses.Circle(),
)
model.fit(x_train, y_train, validation_data=(x_test, y_test), epochs=5, batch_size=2048)
Epoch 1/5
25/25 ━━━━━━━━━━━━━━━━━━━━ 10s 362ms/step - loss: 54.0521 - val_loss: 47.4749
Epoch 2/5
25/25 ━━━━━━━━━━━━━━━━━━━━ 9s 370ms/step - loss: 46.7501 - val_loss: 45.2312
Epoch 3/5
25/25 ━━━━━━━━━━━━━━━━━━━━ 9s 365ms/step - loss: 44.8289 - val_loss: 44.0905
Epoch 4/5
25/25 ━━━━━━━━━━━━━━━━━━━━ 9s 362ms/step - loss: 43.7766 - val_loss: 43.6335
Epoch 5/5
25/25 ━━━━━━━━━━━━━━━━━━━━ 9s 347ms/step - loss: 43.2757 - val_loss: 43.0075
  1. Support for Cross-Batch Memory (XBM) Training

    • Introduces ref_embeddings and ref_labels parameters to the loss function which defaults to y_pred and y_true if not passed.
    • Allows the use of embeddings and labels from a memory bank or previous batches.
    • I wrote an example of usage with XBM
  2. Fixes to logsumexp Function for Torch/JAX

    • keras.ops.logsumexp is numerically unstable in JAX and PyTorch backends when an entire row contains -inf values (e.g., when a sample has no positive or negative pairs), resulting in NaN outputs instead of the expected -inf. In native Jax/Torch, this operation returns the expected -inf but keras.ops does not use the native backend implementation. While you may have a valid reason for this that i am not aware of, I have modified the ops in the last commit to use the native operation since all tests pass. The test fails on numpy backend as it is left as it is waiting for the reviewer comment on which path to follow. numpy backend can be replaced with scipy.special.logsumexp

Testing

  • Correctness testing against pytorch implementation can be reviewed in colab

Notes:

  • This is a draft PR, I will still test the loss function in real large model training scenario after the draft is approved.
  • This loss function requires training on L2 normalized embeddings (like CosineSimilarity). Cosine Similarity handles this internally, but I did not follow this and added in the docs that the last layer should be UnitNormalization
  • remove_diagonal argument is kept as an argument in cases for advanced users to train across completely different memory bank. Like two different datasets, one used as query and the other as reference. In all other cases, this should be always True

@codecov-commenter
Copy link

codecov-commenter commented Nov 5, 2024

Codecov Report

Attention: Patch coverage is 96.82540% with 2 lines in your changes missing coverage. Please review.

Project coverage is 82.02%. Comparing base (272bb90) to head (192e6f7).
Report is 3 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/utils/numerical_utils.py 81.81% 0 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20452      +/-   ##
==========================================
+ Coverage   82.01%   82.02%   +0.01%     
==========================================
  Files         514      515       +1     
  Lines       47239    47326      +87     
  Branches     7413     7424      +11     
==========================================
+ Hits        38741    38821      +80     
- Misses       6704     6706       +2     
- Partials     1794     1799       +5     
Flag Coverage Δ
keras 81.88% <96.82%> (+0.01%) ⬆️
keras-jax 64.95% <92.06%> (+0.04%) ⬆️
keras-numpy 59.91% <92.06%> (+0.05%) ⬆️
keras-tensorflow 65.96% <90.47%> (+0.05%) ⬆️
keras-torch 64.87% <93.65%> (+0.02%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Please take a look at the numpy test failure.

keras/src/losses/losses.py Outdated Show resolved Hide resolved
keras/src/losses/losses.py Outdated Show resolved Hide resolved
keras/src/losses/losses.py Outdated Show resolved Hide resolved
)

circle_loss = ops.softplus(p_loss + n_loss)
backend.set_keras_mask(circle_loss, circle_loss > 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on how the loss is used, the mask might not be taken into account. How critical is it?

Copy link
Contributor Author

@ma7555 ma7555 Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This masking behaviour is used to mask samples from the batch that has no negative/positive pairs (a solo class with nothing to compare to). When this happens, you want to eleminate it from the loss as it has a loss value of 0 and not masking it can affect the sum_over_batch_size reduction (make the loss lower than reality). In pytorch, they use the AvgNonZeroReducer by taking the mean only of values that is not a zero.

I won't say it is critical because:

  1. Using a data sampler for pair generation like TFDataSampler or any similar sampler solves this (which is a standard in metric learning).
  2. Using a larger batch size.
  3. Not an issue if it happens every now and then.
  4. In tensorflow-similarity version, no masking is applied.
  5. If it happens all the time, it means that the data feeding sampler is the problem not the loss function itself.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update! To fix the code format check, run sh shell/format.sh. You may need to shorten text lines by hand in the docstrings.

@ma7555
Copy link
Contributor Author

ma7555 commented Nov 6, 2024

@fchollet looks good now, please hold till I test it in real model training scenario

@ma7555 ma7555 marked this pull request as draft November 6, 2024 09:22
@ma7555 ma7555 marked this pull request as ready for review November 6, 2024 16:31
@ma7555
Copy link
Contributor Author

ma7555 commented Nov 6, 2024

tested and looks fine. Kaggle Notebook
Ready to pull.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for the contribution!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Nov 6, 2024
@fchollet fchollet merged commit 57c3589 into keras-team:master Nov 6, 2024
6 checks passed
@google-ml-butler google-ml-butler bot removed ready to pull Ready to be merged into the codebase kokoro:force-run labels Nov 6, 2024
@ma7555 ma7555 deleted the circle_loss branch November 6, 2024 20:12
wang-xianghao pushed a commit to wang-xianghao/keras-dev that referenced this pull request Nov 20, 2024
…-team#20452)

* update keras/src/losses/__init__.py, losses.py, losses_test.py and numerical_utils.py

* ruff fixes

* hotfix for logsumexp numerical unstability with -inf values

* actual fix for logsumexp -inf unstability

* Add tests, fix numpy logsumexp, and update Circle Loss docstrings.

* run api_gen.sh
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants