Skip to content

Commit 281e734

Browse files
authored
[feat][MHA] Expose sharing the projections, default to sharing them -same as pytorch- (facebookresearch#201)
1 parent 285a10a commit 281e734

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

xformers/components/multi_head_dispatch.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@ class InProjContainer(torch.nn.Module):
1717
"""
1818

1919
def __init__(
20-
self, query_proj: nn.Module, key_proj: nn.Module, value_proj: nn.Module
20+
self,
21+
query_proj: nn.Module,
22+
key_proj: Optional[nn.Module],
23+
value_proj: Optional[nn.Module],
2124
):
2225
super().__init__()
2326
self.query_proj = query_proj
24-
self.key_proj = key_proj
25-
self.value_proj = value_proj
27+
self.key_proj = key_proj if key_proj is not None else query_proj
28+
self.value_proj = value_proj if value_proj is not None else query_proj
2629

2730
def forward(
2831
self,
@@ -42,6 +45,7 @@ class MultiHeadDispatchConfig:
4245
dim_key: Optional[int]
4346
dim_value: Optional[int]
4447
in_proj_container: Optional[InProjContainer]
48+
use_separate_proj_weight: Optional[bool]
4549
out_proj: Optional[nn.Module]
4650

4751

@@ -68,6 +72,7 @@ def __init__(
6872
dim_key: Optional[int] = None,
6973
dim_value: Optional[int] = None,
7074
in_proj_container: Optional[InProjContainer] = None,
75+
use_separate_proj_weight: Optional[bool] = False,
7176
out_proj: Optional[nn.Module] = None,
7277
*args,
7378
**kwargs,
@@ -97,8 +102,12 @@ def __init__(
97102
query_proj=nn.Linear(
98103
dim_model, dim_key, bias=False
99104
), # NOTE: optional bias ?
100-
key_proj=nn.Linear(dim_model, dim_key, bias=False),
101-
value_proj=nn.Linear(dim_model, dim_value, bias=False),
105+
key_proj=nn.Linear(dim_model, dim_key, bias=False)
106+
if use_separate_proj_weight
107+
else None,
108+
value_proj=nn.Linear(dim_model, dim_value, bias=False)
109+
if use_separate_proj_weight
110+
else None,
102111
)
103112
)
104113

0 commit comments

Comments
 (0)