-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Critical] Very high loss rate at first few tokens (classifier free guidance not working) #80
Comments
This sounds critical indeed. Hopefully it's an easy fix. |
I think I've resolved this issue by tokenizing the text and insert it at the start of the codes and add a special token to indicate the start of the mesh tokens. I tested it on a smaller dataset but it seems to be working! |
@MarcusLoppe That is fantastic! Have you posted the fix somewhere? |
Not yet, my current way is bit hacky and requires bit of a rewrite to properly implement. I'm currently verifying the solution on bit bigger dataset and will hammer out all the possible bugs. |
@MarcusLoppe hey Marcus, thanks for identifying this issue have you tried turning off CFG? if you haven't, one thing to try is simply turning off CFG for the first few tokens. i think i've come across papers where they studied at which steps CFG is even effective also try turning off CFG and do greedy sampling and see what happens. if that doesn't work, there is some big issue |
With CFG you mean classifier-free guidance? Not sure how I would go about that, do you mean setting cond_drop_prob to 0.0? The issue lies with when the transformer has a empty sequence and only the text embedding to go from. The text embedding doesn't seem to help very much so it doesn't know what token to pick, hence the huge loss at the start. |
@MarcusLoppe oh, maybe it is already turned off so CFG is turned on by setting if you haven't been using |
@MarcusLoppe oh crap, do i only append the start token for the fine transformer?? 🤦 yes you are correct, it is never conditioned then for the first set of fine tokens |
thank you, this is a real issue then. i'll add cross attention to the fine transformer later today edit: another cheap solution would be to project text embeddings, pool it, and just use that as the initial sos token |
Awesome, however my 'fix' seems to be working however. The downside is that it needs a bigger vocab which slows the training bit but the stronger relationship between the mesh tokens and the text seems to be working :)
I had some issues with proving the context to the fine-decoder since the vector changes shapes but you might be able to solve it. However I tried removing the gateloop and fine-decoder so the main decoder is the last layer, but unfortunately it had the same issue. |
@MarcusLoppe yup, your way is also legit 😄 you have a bright future Marcus! finding this issue, the analysis, coming up with your own solution; none of that is trivial |
Thank you very much 😄 Although it took a while I think I've learned one or two things on the way 😄
I don't think the cross-attention will be enough, as per my last reply i removed the fine-decoder and gateloop and had the same issue. If you think about the multimodal generative models they never start from token 0. For a example the vision models has a prompt with a specific request from the user. I think projecting the text embeddings might be the better way in this case. |
@MarcusLoppe yup i went with the pooled text embedding summed to the sos token for now let me know if that fixes things (or not) 🤞 |
this was really my fault for designing the initial architecture incorrectly the sos token should be on the coarse transformer |
Awesome! I'll check it out 🚀 However with the last x-transformers update I'm getting the error below. And the dim_head in meshgpt isn't being passed correctly as it should be: "attn_dim_head " -> 1057 assert len(kwargs) == 0, f'unrecognized kwargs passed in {kwargs.keys()}' AssertionError: unrecognized kwargs passed in dict_keys(['dim_head', 'num_mem_kv']) |
@MarcusLoppe ah yes, those should have |
@MarcusLoppe ok, i threw in cross attention conditioning for fine transformer in 1.2.3 if that doesn't work, i'll just spend some time refactoring so the start token is in the coarse transformer. that would work for sure, as it is equivalent to your solution, except the text tokens do not undergo self attention |
@MarcusLoppe thanks for running the experiments! |
The loss rate improved much better over the epochs, however it had some downside. Unfortunately it did not work :( |
@MarcusLoppe ah, thank you ok, final try will have to save this for late next week if it doesn't work |
It worked better, here is the result of training it on 39 models with unique labels, however you can still see a spike in the start of the sequence meaning that it might not be resolved. Using my method I managed to get these results below, it manages to generate quite complex objects. I've also experiment with using 3 tokens per triangle and the autoencoder seems to be working, however it makes the training progression for the transformer slower. But considering that VRAM requirement for training on 800 triangle meshes would go from 22GB to 9GB and half the generation time, I think that is something worth exploring. However I think that the autoencoder could also benefit from getting the text embeddings, I tried to pass it as the context in the linear attention layer but since it requires the the same shape as the quantized input it won't accept it nor I think it would be very VRAM friendly to duplicate the text embedding to the number of faces. |
@MarcusLoppe that is much better (an order of magnitude)! thank you for the experiments Marcus! i know how to improve it (can add multiple sos to give the attention more surface area) |
@MarcusLoppe i'll get back to this later this week 🙏 |
@MarcusLoppe oh, the sos token has already been moved to the coarse transformer in the latest commit. that's where the improvement you are seeing is coming from |
Oh awesome, however the loss got very low (0.006) for these results, for the bigger datasets the loss gets to about 0.01 until it needs like 1000 epochs to reach similar loss. So some further improvements would be nice! 😄 |
@MarcusLoppe yup, we can try multiple sos tokens, then if that doesn't work, i'll build in the option to use prepended text embeddings (so like the solution you came up with, additional sos excised or pooled before fine transformer) and yes, text embedding aware autoencoder is achievable! in fact, the original soundstream paper did this |
I don't have the figures for them but I tried 16 and got bad results. I know that setting up the sos tokens before the decoder and then inserting after the cross attention will create some sort of learnable relationship and I assume that the tokens change with loss. So is it possible to reshape (with any nn) the text embeddings to the dim size and then inserting them at the start of the sequence and then a special token? |
ok, I'll try one more thing, but if that doesn't work, let's go for your text prefix solution |
Little bit off topic but I trained a 1 quantize auto-encoder and transformer and good results. It was a little slower progression but I got about 0.03 loss with the transformer. So that is a big win, halfing the sequence length and reducing vram requirement from 22 GB to 8 GB in training (800 faces) |
Hi again. Here is some failed results:
I was wonder if even the decoder cross-attention layer could handle it alone but with just the decoder layer couldn't handle any part of the sequence. The best result I got was with the commit below, however It may just be luck and not a consistent behaviour. The linear attention method had similar results but without the slowness of adding cross-attention to the fine-decoder. Training many many epochs using add cross attention based text conditioning for fine transformer too Linear layer with 4 sos tokens
Token 1: Correct = 237, Incorrect = 118 |
@MarcusLoppe thank you Marcus! 🙏 will get a few solutions in soon i also realized i wasn't caching the cross attention key / values correctly 🤦 will also get that fixed this morning |
Awesome! 😄 It works kinda good when the dataset is small (<500) , I don't think it's the model size since it can remember 10k models if its prompted with a few tokens. Btw let me know if I'm doing something wrong but during my testing I just call forward_on_codes and get the logits and get the token by argmax. |
Hey again, So I've noticed some strange behaviour with the cross attention num_mem_kv that might help you resolve the issue. However using the commit with the fine-decoder cross-attention I found the results below. This made it possible to generate mesh from token 0 since it seems to be hitting the correct tokens, however as you can see the mesh is hardly smooth but at least it's selecting the correct first token! I'm currently training to see if using x5 augmentation of the same dataset will yield any better results since it might be more robust. I also tested fine depth either to 4 or 8 but the effect worsen the performance, same goes with increasing the attn_num_mem_kv to 16. I also tested using 16 cross_attn_num_mem_kv on all the other solutions you've posted but there was no noticeable changes. Commit: 5ef6cbf
|
Hey, @lucidrains
Plus a few other tricks. I wouldn't say this issue is resolved since using a dataset with 1k unique labels, during the generation it will steer towards the most average mesh model according the the text embeddings, you can see this average effect in the second image (cond scale helps sometimes, setting it too high will turn the mesh into a blob). Possible issue / accidental featureI'm not sure if it's a problem but since I add the sos_token before the main decoder and then adding the text embedding pooling afterwards, it will results in 2 tokens with 'value' is added and with the padding it will be 12 tokens. The results is that 1 token will be replaced/lost due to the right shift since the 2 tokens are added and only the sos_token is removed. This is just a guess but maybe since the output is over a longer sequence window during (12 tokens in the future instead of 6), it might help with the inference since during training it outputs what it thinks might be after the EOS token. However this output is cut off and doesn't affect the loss so I'm not sure if it matters, I also increased the padding so it's 18 tokens but the performance degraded). Multi-token predictionI've been trying to understand how the transformer train and at the end there is always 1 extra face (6 tokens) and then the sequence is cut of so it's 5 tokens remaining. I'm guessing this is done for the autoregression and the EOS token. Here is some samples after training 15 epochs on the first 12 tokens on 2.8k meshes with 1000 labels:
500 labels with 10 models for each label- 2k codebook, number of quantizers: 2 1000 labels with 5 models for each label- 2k codebook, number of quantizers: 2 100 labels with 25 models for each label- 16k codebook, number of quantizers: 1 |
@MarcusLoppe thanks Marcus for the detailed report and glad to hear you found a solution! i'll keep chipping away at it to strengthen conditioning next up is to probably add adaptive layer/rms normalization to |
@MarcusLoppe you should see a slight speed bump on inference now, as i now cache the key / values correctly for cross attention hope i didn't break anything! |
Lovely :) I'll test the FILM normalization method and let you know. However the sos_token isn't quite how I implemented it, I've had more success in just leaving the sos_token in without unpacking it.
I tried explaining it before with my tests but I might have not been clear enough. Here is the implementation I've used
I'll give it a go :) So changing the below will made the 1 quantizer generation work. |
thanks for reporting the rounding down issue! and yes, i can cleanup the multiple sos tokens code if not needed. however, by setting just 1 sos token, it should be equivalent to what you deem the best working commit |
Just tested it on the first 12 tokens and using the FiLM + mean have worse performance plus it's giving me nan loss. Although I'm not sure if it would help since the issue might be pooling the mean. Lets say 1000's of text embeddings which all are unique, the cross-attention will receive them in their original state but then the fine decoder will get the average of each embedding as a additional token. |
@MarcusLoppe thanks for running it, made a few changes yea, can definitely try attention pooling, which is a step up from mean pool |
Okay, some updates.
CLIP seems to work better then T5 & BGE on longer sentences and contain more nuanced information.
So after fixing these changes I tried again and had much better success :) The results are very good, however there is some issues such as the test using x5 models per label, I have a very had time to generate 3 sets of distinct rows of the same furniture labels. I trained on a dataset using 775 labels with 5x examples each (3.8k meshes), first tests was only trained on 60 tokens total, then the latest one I trained on the full 1500 token sequences. I tested using x10 examples but that training run requires more time to get a accurate picture of it's performance. Using FiLMT5, trained on 60 tokens:
BGE, trained on 60 tokens:
BGE trained on 1500 tokens:
Without FiLMT5, trained on 60 tokens:
BGE, trained on 60 tokens:
BGE trained on 1500 tokens:
775 labels with 10x examples each (7.7k meshes),
Renders: With FILM |
@MarcusLoppe awesome, i think after adding adaptive layernorms to that will surely be enough |
@lucidrains I have this idea that due to the cross-attention to a text embedding which have a many to many relationship with tokens, if it instead just used the cross-attention to the sequence itself it will have more or else one to many relationship. I was wondering if using something like below would work? The multiple sos tokens are just kept for the main decoder but isn't for the rest of the network and there is no Q-Former architecture that takes the text embeddings and encodes the information to it. |
@MarcusLoppe oh, i don't even know what the q-former architecture is haha i'll have to read it later this week, but it sounds like just a cross attention based recompression, similar to perceiver resampler just got the adaptive layernorm conditioning into the repo! i think we can safely close this issue |
@MarcusLoppe we can chat about q-former in the discussions tab |
@MarcusLoppe oh yes, the qformer architecture is in vogue. bytedance recently used it for their vqvae to compress images even further. will explore this in the coming month for sure! |
Would it be possible to explore this sooner? :) Or maybe provide a hint on how to do this? However I'm still bit unsure if the cross-attention is the best way since the I had some trouble with using it for 10k labels. |
@MarcusLoppe is he/she a phd or ms student? if so, you and him/her should be able to work together and implement it, could even make for a short paper or i can take a look at it, but probably not for another month or two |
I think he's a PHD student, he applied for the compute a while ago and was granted it. It's not for a thesis or graded paper but perhaps a technical report. I'm happily and interested it in implementing it myself but with many of SOTA things it might be above my head. |
@MarcusLoppe ok, if he's a phd student, you two should definitely be able to work it out from the code already available |
@MarcusLoppe i'm not sure without spending a day reading the paper, but it looks to me they are simply using appended "query" tokens, which is similar to memory / register tokens in the literature. they simply concat it to the sequence and then attend to everything, and slice it out. it is similar to the sos tokens we've been playing around with, except it isn't autoregressive |
@MarcusLoppe ask your collaborator! he should know if he is in the field |
I've read bit further, you might be right and I'm not understanding your terminology. My understanding it that they train a autoencoder (tokenizer) and only use 32 tokens to represent the image. I'm not quite sure if it's applicable to this project, I played bit around with using the sos tokens,however I got worse results. I was thinking that maybe the issue isn't that the text embeddings are too weak but maybe the cross-attention will messes it up a bit. |
@MarcusLoppe you should def chat with your collaborator (instead of me) since you'll be training the next model together he will probably be more up-to-date with mesh research too, as he is following it full time |
@lucidrains
This is a issue I'm having a while, the cross-attention is very weak at the start of the sequence.
When the transformer starts with no tokens it will relay on the cross-attention but unfortunately the cross-attention doesn't work for the first token(s).
Proof
To prove this I trained a dataset of 500 models that have unique text embeddings and no augmentations, then I only took the first 6 tokens of the mesh and train on that.
After training for 8hrs, it's still stuck at 1.03 loss.
Without fixing this issue, the auto-regression without a prompt of tokens will never work.
This problem has been ongoing for a while but I thought it was a issue of training and using a model that has been trained on the first few tokens would resolve this. However that isn't the case.
Real-life example
To highlight the issue, I trained a model on the 13k dataset then removed all the augmentation copies and removed models with duplicate labels.
If I provide it with the first 2 tokens as a prompt it will autocomplete without no problem and no visual issues, however if i provide it with 1 or 0 tokens it fails completely.
Checked the logits
I investigated this further and checked the logits when it generated the first token, the probability for correct token was at the 9th most probable token.
I tried to implement a beam search with beam width of 5 but since the first token has such a low probability, it would require a lot of beams which probably will work but this seems like a brute force solution isn't very good.
It may work to do a beam search of 20 and then kill of the solutions which seems to have a low probability/entropy, but this seems like a bandage solution that might not work with scaling up meshgpt.
Why is this a problem?
The first tokens are very important for the generation since it's a domino effect, if it gets the incorrect token at the start, the generation will fail since it relays to much on the sequence to auto-correct.
It's like if the sentence is "Dog can be a happy animal" and it predicts "Human" as the first token, it won't be able to auto-correct since sentence is already messed up and the chances it will auto-correct to "Human got a dog which can be a happy animal" is extremely hard.
Possible solution
Since the cross-attention is used only on the "big" decoder, can it also be implemented for the fine decoder?
Attempts to fix:
This has been a problem for a long time and I've mentioned in the issues threads as a note so I'm creating a issue for it since it really prevents me from releasing fine-tuned models.
I got a model ready to go that can predict 13k models but since the first tokens make the autoregressive generation makes it impossible, I've not released it yet.
Here is some images over the loss:
The text was updated successfully, but these errors were encountered: