-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Falcon Mamba] Unexpected model output with use_cache=False and model.train() #33234
Comments
#33195 should resolve this, the original kernels in the mamba-ssm package miss some additional rms normalization. |
Thanks for the link. I tried with that PR, the output changes/improves, but it is still not identical with the other three cases:
|
@jploski the output is not exactly the same but looks coherent compared to the previous results: 204 118.5 [ ]
627 118.0 [ new]
1842 118.0 [ type]
4307 117.5 [ gift]
241 117.0 [ a]
822 117.0 [ good]
914 117.0 [ great]
986 117.0 [ high]
1733 117.0 [ special]
829 116.5 [ very] vs
I'd suspect there are some numerical differences between the kernel and the non-kernel path which sums up and ultimately ends up not predicting exactly the same tokens. |
@younesbelkada A valid point, but perplexity over Wiki text seems significantly (~16%) worse for the train=True case. I evaluated using the following script:
With falcon-mamba-7b-instruct I'm getting PPL=7.3546 for train=False and PPL=8.5525 for train=True. With mamba-130m-hf there is only a slight difference - PPL=27.1124 for train=False and PPL=27.1174 for train=True. |
System Info
transformers
version: 4.45.0.dev0Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Observed (unexpected) output:
Expected behavior
I would expect the logits produced by forward pass to be the same regardless of whether model.train(True) has been invoked or use_cache is False or True. This holds true for mamba-130m-hf. However, when running the provided test script with falcon-mamba-7b-instruct, the output for the case model.train(True) and use_cache=False is different from the other outputs.
Practical relevance: discovered during ORPO training - where model.training is True and use_cache=False. The outputs and loss calculation during training do not match the outputs observed after training in evaluation mode.
The text was updated successfully, but these errors were encountered: