Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Fix FalconMamba training issues due to incompatible kernels #33195

Merged
merged 8 commits into from
Sep 5, 2024

Conversation

younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Aug 29, 2024

What does this PR do?

Fixes training instability issues that users are facing with FalconMamba architecture. This is due to the fact that we did not upstreamed the modified kernels into mamba-ssm library.

To reproduce, you can try out this snippet:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, FalconMambaForCausalLM

model_id = "tiiuae/falcon-mamba-7b"
text = "Hello today"

model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
tok = AutoTokenizer.from_pretrained(model_id)

inputs = tok(text, return_tensors="pt").to(0)

with torch.no_grad():
    logits = torch.argmax(model(**inputs).logits, dim=-1)
    
print(tok.batch_decode(logits))

model.train()
lm_logits = model(**inputs).logits
next_token = torch.argmax(lm_logits, dim=-1)
    
print(tok.batch_decode(next_token))
loss = (1 - lm_logits).mean()
loss.backward()

That will output complete gibberish in the main branch. This is because we use the kernels that do not take into account the layer normalization of B, C and dt layers of Mamba.

There is an ongoing PR: state-spaces/mamba#543 - I also propose to upstream the changes here in case that PR will not get merged as it will require users to install mamba-ssm from a certain version

The added kernel is a copy paste of the existing kernel that depends on triton, causal-conv1d and einops, and is not imported by transformers at all unless users excplictly import it.

Fixes #33234

cc @ArthurZucker @molbap @vasqu

Comment on lines 15 to 28

import torch
import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd

from einops import rearrange, repeat

try:
import causal_conv1d_cuda
except ImportError:
causal_conv1d_cuda = None

import selective_scan_cuda
from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all these imports should be safe as this module is only used by FM only if mamba-ssm library is available. Einops, causal conv1d and selective scan cuda are automatically installed when mamba-ssm is installed

Comment on lines 111 to 120
def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
return_last_state=False):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
not considered in the backward pass.
"""
return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)


def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
Copy link
Contributor Author

@younesbelkada younesbelkada Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the methods above are a copy of the existing selective scan kernel: https://github.com/younesbelkada/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py

Comment on lines +255 to +266
if b_rms_weight is not None:
B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
if c_rms_weight is not None:
C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
if dt_rms_weight is not None:
delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
delta = rms_norm_forward(delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps)
delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main changes of the kernels happen here

@@ -131,6 +133,15 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int):
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
self.use_bias = config.use_bias

# Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here
self.register_buffer(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems in triton you cannot compute parameter-free RMS norm, therefore we create dummy non-learnable weights that are non persistent here

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, I'd just add some more checks on the imports and add the new rms norm to the slow path as well.

Not too familiar with triton so I'll trust that it works :p Seems like pretty simple forwarding and ignoring the grads on the rms params.

@@ -524,3 +524,31 @@ def test_batched_generation(self):
out = tok.batch_decode(out, skip_special_tokens=True)

self.assertListEqual(out, EXPECTED_OUTPUT)

def test_training_kernel(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be a slow test, no? Even with little tokens, I'd assume that it's still quite expensive.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed ! there is a decorator in the class FalconMambaIntegrationTests that should be propagated to all tests including this one so all should be good :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, that makes sense! I overlooked that ^^

Comment on lines +232 to +234
B = rms_forward(B, variance_epsilon=self.rms_eps)
C = rms_forward(C, variance_epsilon=self.rms_eps)
time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be propagated to slow path.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done !

Comment on lines 47 to 53
if is_mamba_ssm_available():
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update

from ...kernels.falcon_mamba import mamba_inner_fn
else:
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should check for causal conv1d here as well. It's been optional at some point and the mamba_inner_fn depends on it to be available.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nvm, I forgot that there is an all check later in the file so it's not necessary.

causal_conv1d_cuda = None

import selective_scan_cuda
from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For BC, the file has been renamed from v2+ to layer_norm.py but for versions below it was layernorm.py.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot ! just updated it

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thx for iterating 😄

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but TBH we'd rather wait a little bit for the PR to be merged (let's avoid maintaining the kernels ourselves) ! If not done end of the week we can merge!
Thanks again @vasqu 😉 good input!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, can you make sure the new slow tests passes ? Good to merge otherwise

# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's maybe add a comment saying all of this comes from xxx.com with the link to the file !

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done !

Comment on lines +138 to +142
"b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False
)
self.register_buffer(
"dt_rms", torch.nn.Parameter(torch.ones(self.intermediate_size), requires_grad=False), persistent=False
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

names are a bit horrendous but it's not really your fault!

@younesbelkada
Copy link
Contributor Author

I can confirm the new slow tests pass on my end (on single GPU it was failing due to OOM, I just changed it to multi-GPU and it fixed it)

Screenshot 2024-09-05 at 11 25 49 AM

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker merged commit 47b0964 into huggingface:main Sep 5, 2024
16 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks for the PR 🤗 Glad to see you around in the OS world!

@younesbelkada younesbelkada deleted the fix-fm-training branch September 5, 2024 10:08
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
…uggingface#33195)

* fix FM training kernels

* fix copies

* fix copies

* propagate to slow path

* make it BC

* add comment

* fix test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Falcon Mamba] Unexpected model output with use_cache=False and model.train()
4 participants