-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implementation of masked autoencoder as a monai network #7598
Conversation
Signed-off-by: Lucas Robinet <robinet.lucas@iuct-oncopole.fr>
Signed-off-by: Lucas Robinet <robinet.lucas@iuct-oncopole.fr>
Taking a look |
Please forgive the amount of time that it has taken to come back to you. We needed to review the situation with regards to licensing. Unfortunately, we can't accept this PR in its present form. As the code has been adapted from github.com/facebookresearch/mae and that code is released under the CC BY-NC 4.0 license, it would cause significant complications for MONAI, which is released under Apache 2.0. Although we appreciate the time and effort that has gone into this PR, any network implementation either needs to be licensed under a compatible license, or needs to be a clean-sheet implementation. If you are still interested in implementing and submitting a masked autoencoder, either from clean-sheet or as an adaptation from a compatible license, please submit a new PR. |
…8152) This follows a previous PR (#7598). In the previous PR, the official implementation was under a non-compatible license. This is a clean-sheet implementation I developed. The code is fairly straightforward, involving a transformer, encoder, and decoder. The primary changes are in how masks are selected and how patches are organized as they pass through the model. In the official masked autoencoder implementation, noise is first generated and then sorted twice using `torch.argsort`. This rearranges the tokens and identifies which ones are retained, ultimately selecting only a subset of the shuffled indices. In our implementation, we use `torch.multinomial` to generate mask indices, followed by simple boolean indexing to manage the sub-selection of patches for encoding and the reordering with mask tokens in the decoder. **Let me know if you need a detailed, line-by-line explanation of the new code, including how it works and how it differs from the previous version.** ### Description Implementation of the Masked Autoencoder as described in the paper: [Masked Autoencoders Are Scalable Vision Learners](https://arxiv.org/pdf/2111.06377.pdf) from Kaiming et al. Its effectiveness has already been demonstrated in the literature for medical tasks in the paper [Self Pre-training with Masked Autoencoders for Medical Image Classification and Segmentation](https://arxiv.org/abs/2203.05573). The PR contains the architecture and associated unit tests. **Note:** The output includes the prediction, which is a tensor of size: ($BS$, $N_{tokens}$, $D$), and the associated mask ($BS$, $N_{tokens}$). The mask is used to apply loss only to masked patches, but I'm not sure it's the “best” output format, what do you think? ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Lucas Robinet <robinet.lucas@iuct-oncopole.fr> Signed-off-by: Lucas Robinet <luca.robinet@gmail.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Description
Implementation of the Masked Autoencoder as described in the paper: Masked Autoencoders Are Scalable Vision Learners from Kaiming et al.
Masked autoencoders are a type of deep learning model that learn to reconstruct input data from partially masked or obscured versions of that data. They are structured to first "mask" or remove parts of the input image, and then attempt to reconstruct the missing pieces based solely on the available (unmasked) information.
This allows us to be much faster with fewer tokens being passed through the encoder. In addition, most representation learning methods are based on augmentations or different views of the same image and the quality of the representations depends heavily on the augmentations in question, which can be restrictive in the context of medical imaging. Masked autoencoders go some way towards overcoming this problem.
Its effectiveness has already been demonstrated in the literature for medical tasks in the paper Self Pre-training with Masked Autoencoders for Medical Image Classification and Segmentation.
The PR concerns the implementation of this method and the associated tests (note currently the tests take 48seconds to pass that might be a bit long tell me).
So far 2D training on CIFAR data
went well
and 3d training on BraTS2021 data yields this
Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.