-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix: Fix FalconMamba
training issues due to incompatible kernels
#33195
Conversation
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch.cuda.amp import custom_bwd, custom_fwd | ||
|
||
from einops import rearrange, repeat | ||
|
||
try: | ||
import causal_conv1d_cuda | ||
except ImportError: | ||
causal_conv1d_cuda = None | ||
|
||
import selective_scan_cuda | ||
from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all these imports should be safe as this module is only used by FM only if mamba-ssm
library is available. Einops, causal conv1d and selective scan cuda are automatically installed when mamba-ssm
is installed
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, | ||
return_last_state=False): | ||
"""if return_last_state is True, returns (out, last_state) | ||
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is | ||
not considered in the backward pass. | ||
""" | ||
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state) | ||
|
||
|
||
def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the methods above are a copy of the existing selective scan kernel: https://github.com/younesbelkada/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py
if b_rms_weight is not None: | ||
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous() | ||
B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps) | ||
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous() | ||
if c_rms_weight is not None: | ||
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous() | ||
C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps) | ||
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous() | ||
if dt_rms_weight is not None: | ||
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous() | ||
delta = rms_norm_forward(delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps) | ||
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main changes of the kernels happen here
@@ -131,6 +133,15 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int): | |||
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) | |||
self.use_bias = config.use_bias | |||
|
|||
# Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here | |||
self.register_buffer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems in triton you cannot compute parameter-free RMS norm, therefore we create dummy non-learnable weights that are non persistent here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM, I'd just add some more checks on the imports and add the new rms norm to the slow path as well.
Not too familiar with triton so I'll trust that it works :p Seems like pretty simple forwarding and ignoring the grads on the rms params.
@@ -524,3 +524,31 @@ def test_batched_generation(self): | |||
out = tok.batch_decode(out, skip_special_tokens=True) | |||
|
|||
self.assertListEqual(out, EXPECTED_OUTPUT) | |||
|
|||
def test_training_kernel(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be a slow test, no? Even with little tokens, I'd assume that it's still quite expensive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed ! there is a decorator in the class FalconMambaIntegrationTests
that should be propagated to all tests including this one so all should be good :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, that makes sense! I overlooked that ^^
B = rms_forward(B, variance_epsilon=self.rms_eps) | ||
C = rms_forward(C, variance_epsilon=self.rms_eps) | ||
time_step = rms_forward(time_step, variance_epsilon=self.rms_eps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be propagated to slow path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done !
if is_mamba_ssm_available(): | ||
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn | ||
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn | ||
from mamba_ssm.ops.triton.selective_state_update import selective_state_update | ||
|
||
from ...kernels.falcon_mamba import mamba_inner_fn | ||
else: | ||
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should check for causal conv1d here as well. It's been optional at some point and the mamba_inner_fn
depends on it to be available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nvm, I forgot that there is an all
check later in the file so it's not necessary.
causal_conv1d_cuda = None | ||
|
||
import selective_scan_cuda | ||
from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For BC, the file has been renamed from v2+ to layer_norm.py
but for versions below it was layernorm.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks a lot ! just updated it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thx for iterating 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but TBH we'd rather wait a little bit for the PR to be merged (let's avoid maintaining the kernels ourselves) ! If not done end of the week we can merge!
Thanks again @vasqu 😉 good input!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, can you make sure the new slow tests passes ? Good to merge otherwise
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's maybe add a comment saying all of this comes from xxx.com with the link to the file !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done !
"b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False | ||
) | ||
self.register_buffer( | ||
"dt_rms", torch.nn.Parameter(torch.ones(self.intermediate_size), requires_grad=False), persistent=False | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
names are a bit horrendous but it's not really your fault!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Thanks for the PR 🤗 Glad to see you around in the OS world! |
…uggingface#33195) * fix FM training kernels * fix copies * fix copies * propagate to slow path * make it BC * add comment * fix test
What does this PR do?
Fixes training instability issues that users are facing with FalconMamba architecture. This is due to the fact that we did not upstreamed the modified kernels into mamba-ssm library.
To reproduce, you can try out this snippet:
That will output complete gibberish in the main branch. This is because we use the kernels that do not take into account the layer normalization of B, C and dt layers of Mamba.
There is an ongoing PR: state-spaces/mamba#543 - I also propose to upstream the changes here in case that PR will not get merged as it will require users to install mamba-ssm from a certain version
The added kernel is a copy paste of the existing kernel that depends on triton, causal-conv1d and einops, and is not imported by transformers at all unless users excplictly import it.
Fixes #33234
cc @ArthurZucker @molbap @vasqu