From 314d3de42ba3fef37c950e76d2bf837287b4b39d Mon Sep 17 00:00:00 2001 From: Qiang Zhang Date: Fri, 16 Feb 2024 12:05:33 -0800 Subject: [PATCH] implement MaMMUT (#520) Summary: Pull Request resolved: https://github.com/facebookresearch/multimodal/pull/520 Implement MaMMUT, mostly based on current CoCa code as well as https://github.com/lucidrains/MaMMUT-pytorch. Reviewed By: ebsmothers Differential Revision: D52823194 fbshipit-source-id: f5cbe59188c50def1f2fb40b9bd7c9d7d359b864 --- torchmultimodal/modules/layers/transformer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchmultimodal/modules/layers/transformer.py b/torchmultimodal/modules/layers/transformer.py index e2eea49e..f377879d 100644 --- a/torchmultimodal/modules/layers/transformer.py +++ b/torchmultimodal/modules/layers/transformer.py @@ -553,6 +553,8 @@ class TransformerDecoder(nn.Module): If None, K and V are assumed to have dimension d_model. Defaults to None. final_layer_norm_eps (Optional[float]): epsilon used in final layer norm. Defaults to None (no final layer norm). + cross_attention_interval: interval layers to apply cross attention. Not used if + use_cross_attention = False """ def __init__( @@ -568,6 +570,7 @@ def __init__( use_cross_attention: bool = True, dim_kv: Optional[int] = None, final_layer_norm_eps: Optional[float] = None, + cross_attention_interval: int = 1, ): super().__init__() self.layer = nn.ModuleList( @@ -580,7 +583,7 @@ def __init__( activation, layer_norm_eps, norm_first, - use_cross_attention, + use_cross_attention and (i % cross_attention_interval == 0), dim_kv, ) for i in range(n_layer)