-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Circle Loss Function for Similarity/Metric Learning Tasks. #20452
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #20452 +/- ##
==========================================
+ Coverage 82.01% 82.02% +0.01%
==========================================
Files 514 515 +1
Lines 47239 47326 +87
Branches 7413 7424 +11
==========================================
+ Hits 38741 38821 +80
- Misses 6704 6706 +2
- Partials 1794 1799 +5
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! Please take a look at the numpy test failure.
) | ||
|
||
circle_loss = ops.softplus(p_loss + n_loss) | ||
backend.set_keras_mask(circle_loss, circle_loss > 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depending on how the loss is used, the mask might not be taken into account. How critical is it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This masking behaviour is used to mask samples from the batch that has no negative/positive pairs (a solo class with nothing to compare to). When this happens, you want to eleminate it from the loss as it has a loss value of 0 and not masking it can affect the sum_over_batch_size
reduction (make the loss lower than reality). In pytorch, they use the AvgNonZeroReducer by taking the mean only of values that is not a zero.
I won't say it is critical because:
- Using a data sampler for pair generation like TFDataSampler or any similar sampler solves this (which is a standard in metric learning).
- Using a larger batch size.
- Not an issue if it happens every now and then.
- In tensorflow-similarity version, no masking is applied.
- If it happens all the time, it means that the data feeding sampler is the problem not the loss function itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update! To fix the code format check, run sh shell/format.sh
. You may need to shorten text lines by hand in the docstrings.
@fchollet looks good now, please hold till I test it in real model training scenario |
tested and looks fine. Kaggle Notebook |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for the contribution!
…-team#20452) * update keras/src/losses/__init__.py, losses.py, losses_test.py and numerical_utils.py * ruff fixes * hotfix for logsumexp numerical unstability with -inf values * actual fix for logsumexp -inf unstability * Add tests, fix numpy logsumexp, and update Circle Loss docstrings. * run api_gen.sh
Overview
This PR adds the implementation of the Circle Loss function to Keras, addressing issue #20421. Circle Loss is designed for metric learning tasks, aiming to minimize within-class distances and maximize between-class distances in the embedding space.
Support for Cross-Batch Memory (XBM) Training
ref_embeddings
andref_labels
parameters to the loss function which defaults to y_pred and y_true if not passed.Fixes to
logsumexp
Function for Torch/JAX-inf
values (e.g., when a sample has no positive or negative pairs), resulting in NaN outputs instead of the expected-inf
. In native Jax/Torch, this operation returns the expected-inf
but keras.ops does not use the native backend implementation. While you may have a valid reason for this that i am not aware of, I have modified the ops in the last commit to use the native operation since all tests pass. The test fails on numpy backend as it is left as it is waiting for the reviewer comment on which path to follow. numpy backend can be replaced withscipy.special.logsumexp
Testing
Notes:
remove_diagonal
argument is kept as an argument in cases for advanced users to train across completely different memory bank. Like two different datasets, one used as query and the other as reference. In all other cases, this should be alwaysTrue