Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added mamba.py backend #30139

Merged
merged 19 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/models/mamba/configuration_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class MambaConfig(PretrainedConfig):
Whether or not to rescale `out_proj` weights when initializing.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the cache should be used.
use_mambapy (`bool`, *optional*, defaults to `True`):
Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not avaiable. If `True`, the mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.


Example:
Expand Down Expand Up @@ -123,6 +125,7 @@ def __init__(
time_step_floor=1e-4,
rescale_prenorm_residual=False,
use_cache=True,
use_mambapy=True,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -149,5 +152,6 @@ def __init__(
self.rescale_prenorm_residual = rescale_prenorm_residual
self.residual_in_fp32 = residual_in_fp32
self.use_cache = use_cache
self.use_mambapy = use_mambapy

super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
58 changes: 42 additions & 16 deletions src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,17 @@
add_start_docstrings_to_model_forward,
logging,
)
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available
from .configuration_mamba import MambaConfig


logger = logging.get_logger(__name__)

if is_mambapy_available():
from mambapy.pscan import pscan
else:
pscan = None

if is_mamba_ssm_available():
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
Expand Down Expand Up @@ -87,6 +92,8 @@ def __init__(self, config: MambaConfig, layer_idx: int):
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]

self.use_mambapy = config.use_mambapy

# projection of the input hidden states
self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
# selective projection used to make dt, B and C input dependant
Expand All @@ -105,11 +112,23 @@ def __init__(self, config: MambaConfig, layer_idx: int):
self.use_bias = config.use_bias

if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d"
)
if self.use_mambapy:
if is_mambapy_available():
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d"
)
else:
raise ImportError(
"use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py."
)
else:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
)

def cuda_kernels_forward(
self,
Expand Down Expand Up @@ -257,17 +276,24 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, ca
deltaB_u = discrete_B * hidden_states[:, :, :, None].float()

# 3.c perform the recurrence y ← SSM(A, B, C)(x)
scan_outputs = []
for i in range(seq_len):
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediate_size, ssm_state]
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediate_size, 1]
scan_outputs.append(scan_output[:, :, 0])
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len]
scan_output = scan_output + (hidden_states * self.D[None, :, None])
scan_output = (scan_output * self.act(gate))
if self.use_mambapy and self.training and cache_params is None:
hs = pscan(discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)) # [batch, seq_len, intermediate_size, ssm_state_size]

if cache_params is not None:
cache_params.update_ssm_state(self.layer_idx, ssm_state)
scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len]
scan_output = scan_output + hidden_states * self.D[None, :, None]
scan_output = scan_output * self.act(gate)
else:
scan_outputs = []
for i in range(seq_len):
ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
scan_outputs.append(scan_output[:, :, 0])
scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size]
scan_output = scan_output + (hidden_states * self.D[None, :, None])
scan_output = (scan_output * self.act(gate))

if cache_params is not None:
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)

# 4. Final linear projection
contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,12 @@ def is_causal_conv1d_available():
return False


def is_mambapy_available():
if is_torch_available():
return _is_package_available("mambapy")
return False


def is_torch_mps_available():
if is_torch_available():
import torch
Expand Down