@@ -17,12 +17,15 @@ class InProjContainer(torch.nn.Module):
17
17
"""
18
18
19
19
def __init__ (
20
- self , query_proj : nn .Module , key_proj : nn .Module , value_proj : nn .Module
20
+ self ,
21
+ query_proj : nn .Module ,
22
+ key_proj : Optional [nn .Module ],
23
+ value_proj : Optional [nn .Module ],
21
24
):
22
25
super ().__init__ ()
23
26
self .query_proj = query_proj
24
- self .key_proj = key_proj
25
- self .value_proj = value_proj
27
+ self .key_proj = key_proj if key_proj is not None else query_proj
28
+ self .value_proj = value_proj if value_proj is not None else query_proj
26
29
27
30
def forward (
28
31
self ,
@@ -42,6 +45,7 @@ class MultiHeadDispatchConfig:
42
45
dim_key : Optional [int ]
43
46
dim_value : Optional [int ]
44
47
in_proj_container : Optional [InProjContainer ]
48
+ use_separate_proj_weight : Optional [bool ]
45
49
out_proj : Optional [nn .Module ]
46
50
47
51
@@ -68,6 +72,7 @@ def __init__(
68
72
dim_key : Optional [int ] = None ,
69
73
dim_value : Optional [int ] = None ,
70
74
in_proj_container : Optional [InProjContainer ] = None ,
75
+ use_separate_proj_weight : Optional [bool ] = False ,
71
76
out_proj : Optional [nn .Module ] = None ,
72
77
* args ,
73
78
** kwargs ,
@@ -97,8 +102,12 @@ def __init__(
97
102
query_proj = nn .Linear (
98
103
dim_model , dim_key , bias = False
99
104
), # NOTE: optional bias ?
100
- key_proj = nn .Linear (dim_model , dim_key , bias = False ),
101
- value_proj = nn .Linear (dim_model , dim_value , bias = False ),
105
+ key_proj = nn .Linear (dim_model , dim_key , bias = False )
106
+ if use_separate_proj_weight
107
+ else None ,
108
+ value_proj = nn .Linear (dim_model , dim_value , bias = False )
109
+ if use_separate_proj_weight
110
+ else None ,
102
111
)
103
112
)
104
113
0 commit comments