Skip to content

Commit

Permalink
using the fused dropout in the FusedMLP
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Nov 9, 2021
1 parent 5843434 commit cf69993
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
16 changes: 9 additions & 7 deletions xformers/components/feedforward/fused_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

if torch.cuda.is_available():
try:
from xformers.triton import FusedLinear
from xformers.triton import FusedDropoutBias, FusedLinear

@dataclass
class FusedMlpConfig(FeedforwardConfig):
Expand All @@ -39,8 +39,8 @@ def __init__(
dropout: float,
activation: Activation,
hidden_layer_multiplier: int,
*args,
**kwargs,
*_,
**__,
):
super().__init__()

Expand All @@ -51,11 +51,13 @@ def __init__(
in_features=dim_model,
out_features=hidden_layer_multiplier * dim_model,
activation=activation,
bias=True,
bias=False,
),
nn.Dropout(dropout),
nn.Linear(hidden_layer_multiplier * dim_model, dim_model),
nn.Dropout(dropout),
FusedDropoutBias(dropout, hidden_layer_multiplier * dim_model),
nn.Linear(
hidden_layer_multiplier * dim_model, dim_model, bias=False
),
FusedDropoutBias(dropout, hidden_layer_multiplier * dim_model),
)
self.requires_cuda = True

Expand Down
3 changes: 2 additions & 1 deletion xformers/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
_triton_available = torch.cuda.is_available()
if _triton_available:
try:
from .dropout import dropout # noqa
from .dropout import FusedDropoutBias, dropout # noqa
from .fused_linear_layer import FusedLinear # noqa
from .layer_norm import FusedLayerNorm, layer_norm # noqa
from .softmax import log_softmax, softmax # noqa
Expand All @@ -18,6 +18,7 @@
"dropout",
"softmax",
"log_softmax",
"FusedDropoutBias",
"FusedLinear",
"FusedLayerNorm",
"layer_norm",
Expand Down
10 changes: 10 additions & 0 deletions xformers/triton/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,13 @@ def dropout(x: torch.Tensor, p: float, bias: Optional[torch.Tensor] = None):
return _dropout.apply(x, p, bias)

return x + bias if bias is not None else x


class FusedDropoutBias(torch.nn.Module):
def __init__(self, p: float, bias_shape: Optional[int]) -> None:
super().__init__()
self.p = p
self.bias = torch.zeros(bias_shape) if bias_shape is not None else None

def forward(self, x: torch.Tensor) -> torch.Tensor:
return dropout(x, self.p, self.bias)

0 comments on commit cf69993

Please sign in to comment.