From 157e923c697f99508305b60ed416271507121f09 Mon Sep 17 00:00:00 2001 From: Alexandre TL Date: Mon, 8 Apr 2024 17:28:48 +0200 Subject: [PATCH 01/19] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 2a2830fb1001..287503a58138 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,7 @@ limitations under the License. +[fork for mamba.py backend] 🤗 Transformers provides thousands of pretrained models to perform tasks on different modalities such as text, vision, and audio. These models can be applied on: From 336e796660aa715d6bb999bb60448935650a6bf2 Mon Sep 17 00:00:00 2001 From: Alexandre TL Date: Mon, 8 Apr 2024 19:47:21 +0200 Subject: [PATCH 02/19] tests: forward ok --- mamba_tests.ipynb | 120 ++++++++++ .../models/mamba/configuration_mamba.py | 2 + .../models/mamba/modeling_mamba.py | 29 ++- src/transformers/models/mamba/pscan.py | 224 ++++++++++++++++++ 4 files changed, 367 insertions(+), 8 deletions(-) create mode 100644 mamba_tests.ipynb create mode 100644 src/transformers/models/mamba/pscan.py diff --git a/mamba_tests.ipynb b/mamba_tests.ipynb new file mode 100644 index 000000000000..08aee6a5d702 --- /dev/null +++ b/mamba_tests.ipynb @@ -0,0 +1,120 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/alexandretorres/miniconda3/envs/torch23/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], + "source": [ + "import torch\n", + "from src.transformers.models.mamba.configuration_mamba import MambaConfig\n", + "from src.transformers.models.mamba.modeling_mamba import MambaMixer" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d\n" + ] + } + ], + "source": [ + "torch.manual_seed(23456)\n", + "config1 = MambaConfig(use_mambapy=False, vocab_size=60, hidden_size=64, num_hidden_layers=4)\n", + "mixer1 = MambaMixer(config1, 0)\n", + "\n", + "torch.manual_seed(23456)\n", + "config2 = MambaConfig(use_mambapy=True, vocab_size=60, hidden_size=64, num_hidden_layers=4)\n", + "mixer2 = MambaMixer(config2, 0)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "x = torch.randn(2, 12, 64)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.allclose(mixer1(x), mixer2(x))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO : CHECK BACKWARD!!!!!!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "torch23", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 460c1f3b32ac..93a401dc9399 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -100,6 +100,7 @@ class MambaConfig(PretrainedConfig): def __init__( self, + use_mambapy, vocab_size=50280, hidden_size=768, state_size=16, @@ -125,6 +126,7 @@ def __init__( use_cache=True, **kwargs, ): + self.usemambapy = use_mambapy self.vocab_size = vocab_size self.hidden_size = hidden_size self.state_size = state_size diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 5edb28ad7416..1bf0fa6f684d 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -36,6 +36,7 @@ from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available from .configuration_mamba import MambaConfig +from .pscan import pscan logger = logging.get_logger(__name__) @@ -68,6 +69,7 @@ class MambaMixer(nn.Module): def __init__(self, config: MambaConfig, layer_idx: int): super().__init__() + self.usemambapy = config.usemambapy self.hidden_size = config.hidden_size self.ssm_state_size = config.state_size self.conv_kernel_size = config.conv_kernel @@ -257,14 +259,25 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, ca deltaB_u = discrete_B * hidden_states[:, :, :, None].float() # 3.c perform the recurrence y ← SSM(A, B, C)(x) - scan_outputs = [] - for i in range(seq_len): - ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediate_size, ssm_state] - scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediate_size, 1] - scan_outputs.append(scan_output[:, :, 0]) - scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len] - scan_output = scan_output + (hidden_states * self.D[None, :, None]) - scan_output = (scan_output * self.act(gate)) + if self.usemambapy: + hs = pscan(discrete_A, deltaB_u) # [batch, intermediate_size, seq_len, ssm_state_size] + + scan_output = (hs.transpose(1, 2) @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len] + scan_output = scan_output + hidden_states * self.D[None, :, None] + scan_output = scan_output * self.act(gate) + + # pas sur, todo + ssm_state = hs[:, -1] + + else: + scan_outputs = [] + for i in range(seq_len): + ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state] + scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1] + scan_outputs.append(scan_output[:, :, 0]) + scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size] + scan_output = scan_output + (hidden_states * self.D[None, :, None]) + scan_output = (scan_output * self.act(gate)) if cache_params is not None: cache_params.update_ssm_state(self.layer_idx, ssm_state) diff --git a/src/transformers/models/mamba/pscan.py b/src/transformers/models/mamba/pscan.py new file mode 100644 index 000000000000..e3c817360f54 --- /dev/null +++ b/src/transformers/models/mamba/pscan.py @@ -0,0 +1,224 @@ +import math + +import torch +import torch.nn.functional as F + +""" + +This file comes from https://github.com/alxndrTL/mamba.py. +It is an implementation of the parallel scan operation in PyTorch (Blelloch version). +Please see alxndrTL/mamba.py/docs/pscan.ipynb for a detailed explanation of what happens here. + +It has been slightly modified : here the input of pscan is supposed to be (B, D, L, N) to avoid +a transpose given the shapes in modeling_mamba.py (as opposed to (B, L, D, N) in mamba.py). + +""" + +def npo2(len): + """ + Returns the next power of 2 above len + """ + + return 2 ** math.ceil(math.log2(len)) + +def pad_npo2(X): + """ + Pads input length dim to the next power of 2 + + Args: + X : (B, D, L, N) + + Returns: + Y : (B, D, npo2(L), N) + """ + + len_npo2 = npo2(X.size(2)) + pad_tuple = (0, 0, 0, len_npo2 - X.size(2), 0, 0) + return F.pad(X, pad_tuple, "constant", 0) + +class PScan(torch.autograd.Function): + @staticmethod + def pscan(A, X): + # A : (B, D, L, N) + # X : (B, D, L, N) + + # modifies X in place by doing a parallel scan. + # more formally, X will be populated by these values : + # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0 + # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps) + + # only supports L that is a power of two (mainly for a clearer code) + + B, D, L, _ = A.size() + num_steps = int(math.log2(L)) + + # up sweep (last 2 steps unfolded) + Aa = A + Xa = X + for _ in range(num_steps-2): + T = Xa.size(2) + Aa = Aa.view(B, D, T//2, 2, -1) + Xa = Xa.view(B, D, T//2, 2, -1) + + Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0])) + Aa[:, :, :, 1].mul_(Aa[:, :, :, 0]) + + Aa = Aa[:, :, :, 1] + Xa = Xa[:, :, :, 1] + + # we have only 4, 2 or 1 nodes left + if Xa.size(2) == 4: + Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0])) + Aa[:, :, 1].mul_(Aa[:, :, 0]) + + Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1]))) + elif Xa.size(2) == 2: + Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0])) + return + else: + return + + # down sweep (first 2 steps unfolded) + Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)] + Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)] + Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1])) + Aa[:, :, 2].mul_(Aa[:, :, 1]) + + for k in range(num_steps-3, -1, -1): + Aa = A[:, :, 2**k-1:L:2**k] + Xa = X[:, :, 2**k-1:L:2**k] + + T = Xa.size(2) + Aa = Aa.view(B, D, T//2, 2, -1) + Xa = Xa.view(B, D, T//2, 2, -1) + + Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1])) + Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1]) + + @staticmethod + def pscan_rev(A, X): + # A : (B, D, L, N) + # X : (B, D, L, N) + + # the same function as above, but in reverse + # (if you flip the input, call pscan, then flip the output, you get what this function outputs) + # it is used in the backward pass + + # only supports L that is a power of two (mainly for a clearer code) + + B, D, L, _ = A.size() + num_steps = int(math.log2(L)) + + # up sweep (last 2 steps unfolded) + Aa = A + Xa = X + for _ in range(num_steps-2): + T = Xa.size(2) + Aa = Aa.view(B, D, T//2, 2, -1) + Xa = Xa.view(B, D, T//2, 2, -1) + + Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1])) + Aa[:, :, :, 0].mul_(Aa[:, :, :, 1]) + + Aa = Aa[:, :, :, 0] + Xa = Xa[:, :, :, 0] + + # we have only 4, 2 or 1 nodes left + if Xa.size(2) == 4: + Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3])) + Aa[:, :, 2].mul_(Aa[:, :, 3]) + + Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2])))) + elif Xa.size(2) == 2: + Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1])) + return + else: + return + + # down sweep (first 2 steps unfolded) + Aa = A[:, :, 0:L:2**(num_steps-2)] + Xa = X[:, :, 0:L:2**(num_steps-2)] + Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2])) + Aa[:, :, 1].mul_(Aa[:, :, 2]) + + for k in range(num_steps-3, -1, -1): + Aa = A[:, :, 0:L:2**k] + Xa = X[:, :, 0:L:2**k] + + T = Xa.size(2) + Aa = Aa.view(B, D, T//2, 2, -1) + Xa = Xa.view(B, D, T//2, 2, -1) + + Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0])) + Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0]) + + @staticmethod + def forward(ctx, A_in, X_in): + """ + Applies the parallel scan operation, as defined above. Returns a new tensor. + If you can, privilege sequence lengths that are powers of two. + + Args: + A_in : (B, D, L, N) + X_in : (B, D, L, N) + + Returns: + H : (B, D, L, N) + """ + + L = X_in.size(2) + + # cloning is requiered because of the in-place ops + if L == npo2(L): + A = A_in.clone() + X = X_in.clone() + else: + # pad tensors (and clone btw) + A = pad_npo2(A_in) # (B, D, npo2(L), N) + X = pad_npo2(X_in) # (B, D, npo2(L), N) + + # parallel scan (modifies X in-place) + PScan.pscan(A, X) + + ctx.save_for_backward(A_in, X) + + # slice [:, :, :L] (cut if there was padding) + return X[:, :, :L] + + @staticmethod + def backward(ctx, grad_output_in): + """ + Flows the gradient from the output to the input. Returns two new tensors. + + Args: + ctx : A_in : (B, D, L, N), X : (B, D, L, N) + grad_output_in : (B, D, L, N) + + Returns: + gradA : (B, D, L, N), gradX : (B, D, L, N) + """ + + A_in, X = ctx.saved_tensors + + L = grad_output_in.size(2) + + # cloning is requiered because of the in-place ops + if L == npo2(L): + grad_output = grad_output_in.clone() + # the next padding will clone A_in + else: + grad_output = pad_npo2(grad_output_in) # (B, D, npo2(L), N) + A_in = pad_npo2(A_in) # (B, D, npo2(L), N) + + # prepare tensors + A = torch.nn.functional.pad(A_in[:, :, 1:], (0, 0, 0, 1)) # (B, D, npo2(L), N) shift 1 to the left (see hand derivation) + + # reverse parallel scan (modifies grad_output in-place) + PScan.pscan_rev(A, grad_output) + + Q = torch.zeros_like(X) + Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:]) + + return Q[:, :, :L], grad_output[:, :, :L] + +pscan = PScan.apply From 5f8b1150f4237e7807218361c3b176b7c109a4ad Mon Sep 17 00:00:00 2001 From: Alexandre TL Date: Mon, 8 Apr 2024 21:18:03 +0200 Subject: [PATCH 03/19] backward test done --- mamba_tests.ipynb | 65 +++++++++++++++++++++++++---------------------- 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/mamba_tests.ipynb b/mamba_tests.ipynb index 08aee6a5d702..d09ad239a364 100644 --- a/mamba_tests.ipynb +++ b/mamba_tests.ipynb @@ -4,16 +4,7 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/alexandretorres/miniconda3/envs/torch23/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], + "outputs": [], "source": [ "import torch\n", "from src.transformers.models.mamba.configuration_mamba import MambaConfig\n", @@ -45,7 +36,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -54,40 +45,52 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 19, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n" + ] } ], "source": [ - "torch.allclose(mixer1(x), mixer2(x))" + "y1 = mixer1(x)\n", + "J_1 = y1.sum()\n", + "J_1.backward()\n", + "\n", + "y2 = mixer2(x)\n", + "J_2 = y2.sum()\n", + "J_2.backward()\n", + "\n", + "print(torch.allclose(y1, y2))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n" + ] + } + ], "source": [ - "# TODO : CHECK BACKWARD!!!!!!" + "gradients_same = True\n", + "for param1, param2 in zip(mixer1.parameters(), mixer2.parameters()):\n", + " if not torch.allclose(param1.grad, param2.grad, rtol=0.01):\n", + " gradients_same = False\n", + " break\n", + "\n", + "print(gradients_same)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null, From 09335d7f291aaf4c293c95380c788f5709206d42 Mon Sep 17 00:00:00 2001 From: Alexandre TL Date: Tue, 9 Apr 2024 10:04:50 +0200 Subject: [PATCH 04/19] done testing --- mamba_tests.ipynb | 24 ++-- mamba_tests_2.ipynb | 122 ++++++++++++++++++ .../models/mamba/configuration_mamba.py | 2 - .../models/mamba/modeling_mamba.py | 13 +- src/transformers/models/mamba/pscan.py | 2 +- 5 files changed, 137 insertions(+), 26 deletions(-) create mode 100644 mamba_tests_2.ipynb diff --git a/mamba_tests.ipynb b/mamba_tests.ipynb index d09ad239a364..32d035863a9b 100644 --- a/mamba_tests.ipynb +++ b/mamba_tests.ipynb @@ -8,7 +8,9 @@ "source": [ "import torch\n", "from src.transformers.models.mamba.configuration_mamba import MambaConfig\n", - "from src.transformers.models.mamba.modeling_mamba import MambaMixer" + "from src.transformers.models.mamba.modeling_mamba import MambaMixer\n", + "\n", + "from transformers.models.mamba.modeling_mamba import MambaMixer as MambaMixer_orig" ] }, { @@ -20,23 +22,24 @@ "name": "stderr", "output_type": "stream", "text": [ + "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d\n", "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d\n" ] } ], "source": [ "torch.manual_seed(23456)\n", - "config1 = MambaConfig(use_mambapy=False, vocab_size=60, hidden_size=64, num_hidden_layers=4)\n", - "mixer1 = MambaMixer(config1, 0)\n", + "config1 = MambaConfig(vocab_size=60, hidden_size=64, num_hidden_layers=4)\n", + "mixer1 = MambaMixer_orig(config1, 0)\n", "\n", "torch.manual_seed(23456)\n", - "config2 = MambaConfig(use_mambapy=True, vocab_size=60, hidden_size=64, num_hidden_layers=4)\n", + "config2 = MambaConfig(vocab_size=60, hidden_size=64, num_hidden_layers=4)\n", "mixer2 = MambaMixer(config2, 0)" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -45,7 +48,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -70,7 +73,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -90,13 +93,6 @@ "\n", "print(gradients_same)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/mamba_tests_2.ipynb b/mamba_tests_2.ipynb new file mode 100644 index 000000000000..f65db3993ae8 --- /dev/null +++ b/mamba_tests_2.ipynb @@ -0,0 +1,122 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "from transformers.models.mamba import MambaForCausalLM as MambaForCausalLM_orig, MambaConfig as MambaConfig_orig\n", + "\n", + "from src.transformers.models.mamba import MambaForCausalLM, MambaConfig" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d\n", + "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the mamba.py implementation for training. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d\n" + ] + } + ], + "source": [ + "torch.manual_seed(34567)\n", + "config = MambaConfig_orig(vocab_size=60, hidden_size=64, num_hidden_layers=4)\n", + "model_orig = MambaForCausalLM_orig(config)\n", + "\n", + "torch.manual_seed(34567)\n", + "config = MambaConfig(vocab_size=60, hidden_size=64, num_hidden_layers=4)\n", + "model = MambaForCausalLM(config)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "x = torch.randint(low=0, high=config.vocab_size-1, size=(16, 12))\n", + "\n", + "y_orig = model_orig(x)\n", + "J_orig = y_orig.logits.sum()\n", + "J_orig.backward()\n", + "\n", + "y = model(x)\n", + "J = y.logits.sum()\n", + "J.backward()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.allclose(y_orig.logits, y.logits, rtol=0.001)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n" + ] + } + ], + "source": [ + "gradients_same = True\n", + "for param1, param2 in zip(model_orig.parameters(), model.parameters()):\n", + " if not torch.allclose(param1.grad, param2.grad, rtol=0.01):\n", + " gradients_same = False\n", + " break\n", + "\n", + "print(gradients_same)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "torch23", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 93a401dc9399..460c1f3b32ac 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -100,7 +100,6 @@ class MambaConfig(PretrainedConfig): def __init__( self, - use_mambapy, vocab_size=50280, hidden_size=768, state_size=16, @@ -126,7 +125,6 @@ def __init__( use_cache=True, **kwargs, ): - self.usemambapy = use_mambapy self.vocab_size = vocab_size self.hidden_size = hidden_size self.state_size = state_size diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 1bf0fa6f684d..64eeae5bc4a5 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -69,7 +69,6 @@ class MambaMixer(nn.Module): def __init__(self, config: MambaConfig, layer_idx: int): super().__init__() - self.usemambapy = config.usemambapy self.hidden_size = config.hidden_size self.ssm_state_size = config.state_size self.conv_kernel_size = config.conv_kernel @@ -109,7 +108,7 @@ def __init__(self, config: MambaConfig, layer_idx: int): if not is_fast_path_available: logger.warning_once( "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" - " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " is None. Falling back to the mamba.py implementation for training. To install follow https://github.com/state-spaces/mamba/#installation and" " https://github.com/Dao-AILab/causal-conv1d" ) @@ -259,16 +258,12 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, ca deltaB_u = discrete_B * hidden_states[:, :, :, None].float() # 3.c perform the recurrence y ← SSM(A, B, C)(x) - if self.usemambapy: + if self.training and cache_params is None: hs = pscan(discrete_A, deltaB_u) # [batch, intermediate_size, seq_len, ssm_state_size] scan_output = (hs.transpose(1, 2) @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len] scan_output = scan_output + hidden_states * self.D[None, :, None] scan_output = scan_output * self.act(gate) - - # pas sur, todo - ssm_state = hs[:, -1] - else: scan_outputs = [] for i in range(seq_len): @@ -279,8 +274,8 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, ca scan_output = scan_output + (hidden_states * self.D[None, :, None]) scan_output = (scan_output * self.act(gate)) - if cache_params is not None: - cache_params.update_ssm_state(self.layer_idx, ssm_state) + if cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] diff --git a/src/transformers/models/mamba/pscan.py b/src/transformers/models/mamba/pscan.py index e3c817360f54..56019276a127 100644 --- a/src/transformers/models/mamba/pscan.py +++ b/src/transformers/models/mamba/pscan.py @@ -7,7 +7,7 @@ This file comes from https://github.com/alxndrTL/mamba.py. It is an implementation of the parallel scan operation in PyTorch (Blelloch version). -Please see alxndrTL/mamba.py/docs/pscan.ipynb for a detailed explanation of what happens here. +Please see https://github.com/alxndrTL/mamba.py/blob/main/docs/pscan.ipynb for a detailed explanation of what happens here. It has been slightly modified : here the input of pscan is supposed to be (B, D, L, N) to avoid a transpose given the shapes in modeling_mamba.py (as opposed to (B, L, D, N) in mamba.py). From e3bbe67c8c054aa7751e581d0f71d75b335b6ed7 Mon Sep 17 00:00:00 2001 From: Alexandre TL Date: Tue, 9 Apr 2024 10:09:40 +0200 Subject: [PATCH 05/19] removed check. scripts --- mamba_tests.ipynb | 119 ------------------------------------------ mamba_tests_2.ipynb | 122 -------------------------------------------- 2 files changed, 241 deletions(-) delete mode 100644 mamba_tests.ipynb delete mode 100644 mamba_tests_2.ipynb diff --git a/mamba_tests.ipynb b/mamba_tests.ipynb deleted file mode 100644 index 32d035863a9b..000000000000 --- a/mamba_tests.ipynb +++ /dev/null @@ -1,119 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "from src.transformers.models.mamba.configuration_mamba import MambaConfig\n", - "from src.transformers.models.mamba.modeling_mamba import MambaMixer\n", - "\n", - "from transformers.models.mamba.modeling_mamba import MambaMixer as MambaMixer_orig" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d\n", - "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d\n" - ] - } - ], - "source": [ - "torch.manual_seed(23456)\n", - "config1 = MambaConfig(vocab_size=60, hidden_size=64, num_hidden_layers=4)\n", - "mixer1 = MambaMixer_orig(config1, 0)\n", - "\n", - "torch.manual_seed(23456)\n", - "config2 = MambaConfig(vocab_size=60, hidden_size=64, num_hidden_layers=4)\n", - "mixer2 = MambaMixer(config2, 0)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "x = torch.randn(2, 12, 64)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n" - ] - } - ], - "source": [ - "y1 = mixer1(x)\n", - "J_1 = y1.sum()\n", - "J_1.backward()\n", - "\n", - "y2 = mixer2(x)\n", - "J_2 = y2.sum()\n", - "J_2.backward()\n", - "\n", - "print(torch.allclose(y1, y2))" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n" - ] - } - ], - "source": [ - "gradients_same = True\n", - "for param1, param2 in zip(mixer1.parameters(), mixer2.parameters()):\n", - " if not torch.allclose(param1.grad, param2.grad, rtol=0.01):\n", - " gradients_same = False\n", - " break\n", - "\n", - "print(gradients_same)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "torch23", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.2" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/mamba_tests_2.ipynb b/mamba_tests_2.ipynb deleted file mode 100644 index f65db3993ae8..000000000000 --- a/mamba_tests_2.ipynb +++ /dev/null @@ -1,122 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "\n", - "from transformers.models.mamba import MambaForCausalLM as MambaForCausalLM_orig, MambaConfig as MambaConfig_orig\n", - "\n", - "from src.transformers.models.mamba import MambaForCausalLM, MambaConfig" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d\n", - "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the mamba.py implementation for training. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d\n" - ] - } - ], - "source": [ - "torch.manual_seed(34567)\n", - "config = MambaConfig_orig(vocab_size=60, hidden_size=64, num_hidden_layers=4)\n", - "model_orig = MambaForCausalLM_orig(config)\n", - "\n", - "torch.manual_seed(34567)\n", - "config = MambaConfig(vocab_size=60, hidden_size=64, num_hidden_layers=4)\n", - "model = MambaForCausalLM(config)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "x = torch.randint(low=0, high=config.vocab_size-1, size=(16, 12))\n", - "\n", - "y_orig = model_orig(x)\n", - "J_orig = y_orig.logits.sum()\n", - "J_orig.backward()\n", - "\n", - "y = model(x)\n", - "J = y.logits.sum()\n", - "J.backward()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.allclose(y_orig.logits, y.logits, rtol=0.001)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n" - ] - } - ], - "source": [ - "gradients_same = True\n", - "for param1, param2 in zip(model_orig.parameters(), model.parameters()):\n", - " if not torch.allclose(param1.grad, param2.grad, rtol=0.01):\n", - " gradients_same = False\n", - " break\n", - "\n", - "print(gradients_same)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "torch23", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.2" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From db8f353c152bcaf28167da25186628c79fabbbb9 Mon Sep 17 00:00:00 2001 From: Alexandre TL Date: Tue, 9 Apr 2024 10:10:10 +0200 Subject: [PATCH 06/19] Update README.md --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 287503a58138..2a2830fb1001 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,6 @@ limitations under the License. -[fork for mamba.py backend] 🤗 Transformers provides thousands of pretrained models to perform tasks on different modalities such as text, vision, and audio. These models can be applied on: From 4596849365b00db40d76ffd84be6d4a4a77d9276 Mon Sep 17 00:00:00 2001 From: Alexandre TL Date: Tue, 9 Apr 2024 10:37:44 +0200 Subject: [PATCH 07/19] added use_mambapy arg --- src/transformers/models/mamba/configuration_mamba.py | 6 +++++- src/transformers/models/mamba/modeling_mamba.py | 6 ++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 460c1f3b32ac..52359bd8d7ec 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -79,8 +79,10 @@ class MambaConfig(PretrainedConfig): Whether or not to rescale `out_proj` weights when initializing. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the cache should be used. + use_mambapy (`bool`, *optional*, defaults to `True`): + Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not avaiable. If `True`, the mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited. - + Example: ```python @@ -123,6 +125,7 @@ def __init__( time_step_floor=1e-4, rescale_prenorm_residual=False, use_cache=True, + use_mambapy=True, **kwargs, ): self.vocab_size = vocab_size @@ -149,5 +152,6 @@ def __init__( self.rescale_prenorm_residual = rescale_prenorm_residual self.residual_in_fp32 = residual_in_fp32 self.use_cache = use_cache + self.use_mambapy = use_mambapy super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 64eeae5bc4a5..20ffe0e81b79 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -88,6 +88,8 @@ def __init__(self, config: MambaConfig, layer_idx: int): self.activation = config.hidden_act self.act = ACT2FN[config.hidden_act] + self.use_mambapy = config.use_mambapy + # projection of the input hidden states self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias) # selective projection used to make dt, B and C input dependant @@ -108,7 +110,7 @@ def __init__(self, config: MambaConfig, layer_idx: int): if not is_fast_path_available: logger.warning_once( "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" - " is None. Falling back to the mamba.py implementation for training. To install follow https://github.com/state-spaces/mamba/#installation and" + " is None. Falling back to the implementation determined by the argument config `use_mambapy` for training. To install follow https://github.com/state-spaces/mamba/#installation and" " https://github.com/Dao-AILab/causal-conv1d" ) @@ -258,7 +260,7 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, ca deltaB_u = discrete_B * hidden_states[:, :, :, None].float() # 3.c perform the recurrence y ← SSM(A, B, C)(x) - if self.training and cache_params is None: + if self.use_mambapy and self.training and cache_params is None: hs = pscan(discrete_A, deltaB_u) # [batch, intermediate_size, seq_len, ssm_state_size] scan_output = (hs.transpose(1, 2) @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len] From b3caa0cf6bab6713ca7806974115ef4300b4168b Mon Sep 17 00:00:00 2001 From: Alexandre TL Date: Tue, 9 Apr 2024 10:39:29 +0200 Subject: [PATCH 08/19] fixed typo in warning --- src/transformers/models/mamba/modeling_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 20ffe0e81b79..b4b8ad0ec398 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -109,7 +109,7 @@ def __init__(self, config: MambaConfig, layer_idx: int): if not is_fast_path_available: logger.warning_once( - "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" " is None. Falling back to the implementation determined by the argument config `use_mambapy` for training. To install follow https://github.com/state-spaces/mamba/#installation and" " https://github.com/Dao-AILab/causal-conv1d" ) From 722d57d453e80b41f4cd515906d3d13717941a35 Mon Sep 17 00:00:00 2001 From: Alexandre TL Date: Thu, 27 Jun 2024 14:27:00 +0200 Subject: [PATCH 09/19] protected imports w/ mambapy package --- .../models/mamba/modeling_mamba.py | 23 ++++++++++++------- src/transformers/utils/import_utils.py | 6 +++++ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index b4b8ad0ec398..8eb7b74a2f13 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -33,11 +33,9 @@ add_start_docstrings_to_model_forward, logging, ) -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available +from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available from .configuration_mamba import MambaConfig -from .pscan import pscan - logger = logging.get_logger(__name__) if is_mamba_ssm_available(): @@ -55,6 +53,11 @@ (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) +if is_mambapy_available(): + from mambapy.pscan import pscan +else: + pscan = None + _CHECKPOINT_FOR_DOC = "state-spaces/mamba-130m-hf" _CONFIG_FOR_DOC = "MambaConfig" @@ -108,11 +111,15 @@ def __init__(self, config: MambaConfig, layer_idx: int): self.use_bias = config.use_bias if not is_fast_path_available: - logger.warning_once( - "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" - " is None. Falling back to the implementation determined by the argument config `use_mambapy` for training. To install follow https://github.com/state-spaces/mamba/#installation and" - " https://github.com/Dao-AILab/causal-conv1d" - ) + if self.use_mambapy: + assert is_mambapy_available, "use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py." + logger.warning_once("The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d") + else: + logger.warning_once("The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py.") def cuda_kernels_forward( self, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index bd14dd8cd753..e7256ce317e4 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -394,6 +394,12 @@ def is_causal_conv1d_available(): return False +def is_mambapy_available(): + if is_torch_available(): + return _is_package_available("mambapy") + return False + + def is_torch_mps_available(): if is_torch_available(): import torch From 51dcc1e271d3ed625b8c6f392673a412e4e1ae04 Mon Sep 17 00:00:00 2001 From: Alexandre TL Date: Thu, 18 Jul 2024 11:20:59 +0200 Subject: [PATCH 10/19] delete pscan.py + raise rather than assert --- .../models/mamba/modeling_mamba.py | 10 +- src/transformers/models/mamba/pscan.py | 224 ------------------ 2 files changed, 6 insertions(+), 228 deletions(-) delete mode 100644 src/transformers/models/mamba/pscan.py diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 8eb7b74a2f13..ef27606f744f 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -112,10 +112,12 @@ def __init__(self, config: MambaConfig, layer_idx: int): if not is_fast_path_available: if self.use_mambapy: - assert is_mambapy_available, "use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py." - logger.warning_once("The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" - " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and" - " https://github.com/Dao-AILab/causal-conv1d") + if is_mambapy_available(): + logger.warning_once("The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d") + else: + raise ImportError("use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py.") else: logger.warning_once("The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and" diff --git a/src/transformers/models/mamba/pscan.py b/src/transformers/models/mamba/pscan.py deleted file mode 100644 index 56019276a127..000000000000 --- a/src/transformers/models/mamba/pscan.py +++ /dev/null @@ -1,224 +0,0 @@ -import math - -import torch -import torch.nn.functional as F - -""" - -This file comes from https://github.com/alxndrTL/mamba.py. -It is an implementation of the parallel scan operation in PyTorch (Blelloch version). -Please see https://github.com/alxndrTL/mamba.py/blob/main/docs/pscan.ipynb for a detailed explanation of what happens here. - -It has been slightly modified : here the input of pscan is supposed to be (B, D, L, N) to avoid -a transpose given the shapes in modeling_mamba.py (as opposed to (B, L, D, N) in mamba.py). - -""" - -def npo2(len): - """ - Returns the next power of 2 above len - """ - - return 2 ** math.ceil(math.log2(len)) - -def pad_npo2(X): - """ - Pads input length dim to the next power of 2 - - Args: - X : (B, D, L, N) - - Returns: - Y : (B, D, npo2(L), N) - """ - - len_npo2 = npo2(X.size(2)) - pad_tuple = (0, 0, 0, len_npo2 - X.size(2), 0, 0) - return F.pad(X, pad_tuple, "constant", 0) - -class PScan(torch.autograd.Function): - @staticmethod - def pscan(A, X): - # A : (B, D, L, N) - # X : (B, D, L, N) - - # modifies X in place by doing a parallel scan. - # more formally, X will be populated by these values : - # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0 - # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps) - - # only supports L that is a power of two (mainly for a clearer code) - - B, D, L, _ = A.size() - num_steps = int(math.log2(L)) - - # up sweep (last 2 steps unfolded) - Aa = A - Xa = X - for _ in range(num_steps-2): - T = Xa.size(2) - Aa = Aa.view(B, D, T//2, 2, -1) - Xa = Xa.view(B, D, T//2, 2, -1) - - Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0])) - Aa[:, :, :, 1].mul_(Aa[:, :, :, 0]) - - Aa = Aa[:, :, :, 1] - Xa = Xa[:, :, :, 1] - - # we have only 4, 2 or 1 nodes left - if Xa.size(2) == 4: - Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0])) - Aa[:, :, 1].mul_(Aa[:, :, 0]) - - Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1]))) - elif Xa.size(2) == 2: - Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0])) - return - else: - return - - # down sweep (first 2 steps unfolded) - Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)] - Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)] - Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1])) - Aa[:, :, 2].mul_(Aa[:, :, 1]) - - for k in range(num_steps-3, -1, -1): - Aa = A[:, :, 2**k-1:L:2**k] - Xa = X[:, :, 2**k-1:L:2**k] - - T = Xa.size(2) - Aa = Aa.view(B, D, T//2, 2, -1) - Xa = Xa.view(B, D, T//2, 2, -1) - - Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1])) - Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1]) - - @staticmethod - def pscan_rev(A, X): - # A : (B, D, L, N) - # X : (B, D, L, N) - - # the same function as above, but in reverse - # (if you flip the input, call pscan, then flip the output, you get what this function outputs) - # it is used in the backward pass - - # only supports L that is a power of two (mainly for a clearer code) - - B, D, L, _ = A.size() - num_steps = int(math.log2(L)) - - # up sweep (last 2 steps unfolded) - Aa = A - Xa = X - for _ in range(num_steps-2): - T = Xa.size(2) - Aa = Aa.view(B, D, T//2, 2, -1) - Xa = Xa.view(B, D, T//2, 2, -1) - - Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1])) - Aa[:, :, :, 0].mul_(Aa[:, :, :, 1]) - - Aa = Aa[:, :, :, 0] - Xa = Xa[:, :, :, 0] - - # we have only 4, 2 or 1 nodes left - if Xa.size(2) == 4: - Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3])) - Aa[:, :, 2].mul_(Aa[:, :, 3]) - - Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2])))) - elif Xa.size(2) == 2: - Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1])) - return - else: - return - - # down sweep (first 2 steps unfolded) - Aa = A[:, :, 0:L:2**(num_steps-2)] - Xa = X[:, :, 0:L:2**(num_steps-2)] - Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2])) - Aa[:, :, 1].mul_(Aa[:, :, 2]) - - for k in range(num_steps-3, -1, -1): - Aa = A[:, :, 0:L:2**k] - Xa = X[:, :, 0:L:2**k] - - T = Xa.size(2) - Aa = Aa.view(B, D, T//2, 2, -1) - Xa = Xa.view(B, D, T//2, 2, -1) - - Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0])) - Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0]) - - @staticmethod - def forward(ctx, A_in, X_in): - """ - Applies the parallel scan operation, as defined above. Returns a new tensor. - If you can, privilege sequence lengths that are powers of two. - - Args: - A_in : (B, D, L, N) - X_in : (B, D, L, N) - - Returns: - H : (B, D, L, N) - """ - - L = X_in.size(2) - - # cloning is requiered because of the in-place ops - if L == npo2(L): - A = A_in.clone() - X = X_in.clone() - else: - # pad tensors (and clone btw) - A = pad_npo2(A_in) # (B, D, npo2(L), N) - X = pad_npo2(X_in) # (B, D, npo2(L), N) - - # parallel scan (modifies X in-place) - PScan.pscan(A, X) - - ctx.save_for_backward(A_in, X) - - # slice [:, :, :L] (cut if there was padding) - return X[:, :, :L] - - @staticmethod - def backward(ctx, grad_output_in): - """ - Flows the gradient from the output to the input. Returns two new tensors. - - Args: - ctx : A_in : (B, D, L, N), X : (B, D, L, N) - grad_output_in : (B, D, L, N) - - Returns: - gradA : (B, D, L, N), gradX : (B, D, L, N) - """ - - A_in, X = ctx.saved_tensors - - L = grad_output_in.size(2) - - # cloning is requiered because of the in-place ops - if L == npo2(L): - grad_output = grad_output_in.clone() - # the next padding will clone A_in - else: - grad_output = pad_npo2(grad_output_in) # (B, D, npo2(L), N) - A_in = pad_npo2(A_in) # (B, D, npo2(L), N) - - # prepare tensors - A = torch.nn.functional.pad(A_in[:, :, 1:], (0, 0, 0, 1)) # (B, D, npo2(L), N) shift 1 to the left (see hand derivation) - - # reverse parallel scan (modifies grad_output in-place) - PScan.pscan_rev(A, grad_output) - - Q = torch.zeros_like(X) - Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:]) - - return Q[:, :, :L], grad_output[:, :, :L] - -pscan = PScan.apply From 4529d8ef98d238406038a7f0faf27c99c7a3c98c Mon Sep 17 00:00:00 2001 From: Alexandre TL Date: Thu, 18 Jul 2024 13:56:17 +0200 Subject: [PATCH 11/19] Update import_utils.py --- src/transformers/utils/import_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index e7256ce317e4..b00f3c6f54a0 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -396,6 +396,8 @@ def is_causal_conv1d_available(): def is_mambapy_available(): if is_torch_available(): + import torch + return _is_package_available("mambapy") return False From 8a81408d70c2cedcc1b8a2025e69c79af1a954ab Mon Sep 17 00:00:00 2001 From: Alexandre TL Date: Fri, 19 Jul 2024 18:20:31 +0200 Subject: [PATCH 12/19] fix whitespaces and unused import --- src/transformers/models/mamba/configuration_mamba.py | 2 +- src/transformers/utils/import_utils.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 52359bd8d7ec..d6735134c2df 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -82,7 +82,7 @@ class MambaConfig(PretrainedConfig): use_mambapy (`bool`, *optional*, defaults to `True`): Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not avaiable. If `True`, the mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited. - + Example: ```python diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index b00f3c6f54a0..c38655cba572 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -395,9 +395,7 @@ def is_causal_conv1d_available(): def is_mambapy_available(): - if is_torch_available(): - import torch - + if is_torch_available(): return _is_package_available("mambapy") return False From d0c809e1b6fa5108bf01635485b3b370ccec119d Mon Sep 17 00:00:00 2001 From: Alexandre TL Date: Fri, 19 Jul 2024 18:22:46 +0200 Subject: [PATCH 13/19] trailing whitespace + import block unformatted --- src/transformers/models/mamba/modeling_mamba.py | 10 +++++----- src/transformers/utils/import_utils.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index ef27606f744f..805fb3c29541 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -38,6 +38,11 @@ logger = logging.get_logger(__name__) +if is_mambapy_available(): + from mambapy.pscan import pscan +else: + pscan = None + if is_mamba_ssm_available(): from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn from mamba_ssm.ops.triton.selective_state_update import selective_state_update @@ -53,11 +58,6 @@ (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) -if is_mambapy_available(): - from mambapy.pscan import pscan -else: - pscan = None - _CHECKPOINT_FOR_DOC = "state-spaces/mamba-130m-hf" _CONFIG_FOR_DOC = "MambaConfig" diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index c38655cba572..e7256ce317e4 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -395,7 +395,7 @@ def is_causal_conv1d_available(): def is_mambapy_available(): - if is_torch_available(): + if is_torch_available(): return _is_package_available("mambapy") return False From 0bf2cc03e960227213deebc817c6c3fb47f4a024 Mon Sep 17 00:00:00 2001 From: Alexandre TL Date: Fri, 19 Jul 2024 18:26:19 +0200 Subject: [PATCH 14/19] Update modeling_mamba.py --- src/transformers/models/mamba/modeling_mamba.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 805fb3c29541..285ddf180b25 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -36,6 +36,7 @@ from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available from .configuration_mamba import MambaConfig + logger = logging.get_logger(__name__) if is_mambapy_available(): From 4d29292ebb9de79d6490cd45bc8e3fba95fec849 Mon Sep 17 00:00:00 2001 From: Alexandre TL Date: Mon, 22 Jul 2024 09:21:27 +0200 Subject: [PATCH 15/19] transpose before pscan --- src/transformers/models/mamba/modeling_mamba.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 285ddf180b25..09ca0803d2b0 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -271,9 +271,9 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, ca # 3.c perform the recurrence y ← SSM(A, B, C)(x) if self.use_mambapy and self.training and cache_params is None: - hs = pscan(discrete_A, deltaB_u) # [batch, intermediate_size, seq_len, ssm_state_size] + hs = pscan(discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)) # [batch, intermediate_size, seq_len, ssm_state_size] - scan_output = (hs.transpose(1, 2) @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len] + scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len] scan_output = scan_output + hidden_states * self.D[None, :, None] scan_output = scan_output * self.act(gate) else: From f3281a484e9c7b1235388efa8f6f7b4668466e7d Mon Sep 17 00:00:00 2001 From: Alexandre TL Date: Mon, 22 Jul 2024 09:22:01 +0200 Subject: [PATCH 16/19] shape comment --- src/transformers/models/mamba/modeling_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 09ca0803d2b0..180027fba585 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -271,7 +271,7 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, ca # 3.c perform the recurrence y ← SSM(A, B, C)(x) if self.use_mambapy and self.training and cache_params is None: - hs = pscan(discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)) # [batch, intermediate_size, seq_len, ssm_state_size] + hs = pscan(discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)) # [batch, seq_len, intermediate_size, ssm_state_size] scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len] scan_output = scan_output + hidden_states * self.D[None, :, None] From a9957aa1fd8fefd5588cf2896eada7e7eab65e00 Mon Sep 17 00:00:00 2001 From: Alexandre TL Date: Tue, 23 Jul 2024 10:11:30 +0200 Subject: [PATCH 17/19] ran make style --- .../models/mamba/modeling_mamba.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 180027fba585..50c0f9ebe4a5 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -114,15 +114,21 @@ def __init__(self, config: MambaConfig, layer_idx: int): if not is_fast_path_available: if self.use_mambapy: if is_mambapy_available(): - logger.warning_once("The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" - " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and" - " https://github.com/Dao-AILab/causal-conv1d") + logger.warning_once( + "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) else: - raise ImportError("use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py.") + raise ImportError( + "use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py." + ) else: - logger.warning_once("The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" - " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and" - " https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py.") + logger.warning_once( + "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py." + ) def cuda_kernels_forward( self, From 7181b1948a5213f36645a2922d1640112c1aeca1 Mon Sep 17 00:00:00 2001 From: Alexandre TL Date: Tue, 23 Jul 2024 11:36:19 +0200 Subject: [PATCH 18/19] use_mambapy=False by default Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/models/mamba/configuration_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index d6735134c2df..03a7f992d36f 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -125,7 +125,7 @@ def __init__( time_step_floor=1e-4, rescale_prenorm_residual=False, use_cache=True, - use_mambapy=True, + use_mambapy=False, **kwargs, ): self.vocab_size = vocab_size From 00d4173cce95279c0280fefdc087b3e647670d97 Mon Sep 17 00:00:00 2001 From: Alexandre TL Date: Tue, 23 Jul 2024 11:50:45 +0200 Subject: [PATCH 19/19] ran make fix-copies --- src/transformers/models/mamba/configuration_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 03a7f992d36f..89f08dd3cd32 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -79,7 +79,7 @@ class MambaConfig(PretrainedConfig): Whether or not to rescale `out_proj` weights when initializing. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the cache should be used. - use_mambapy (`bool`, *optional*, defaults to `True`): + use_mambapy (`bool`, *optional*, defaults to `False`): Determines the fallback strategy during training if the CUDA-based official implementation of Mamba is not avaiable. If `True`, the mamba.py implementation is used. If `False`, the naive and slower implementation is used. Consider switching to the naive version if memory is limited.