Skip to content

Commit

Permalink
[feat][minor] Simplicial embeddings (#261)
Browse files Browse the repository at this point in the history
* Initial implementation, needs testing

adding the option in the encoder

Adding it to the microViT example

* code review, pulling in triton softmax if possible
  • Loading branch information
blefaudeux authored Apr 8, 2022
1 parent 23aaa58 commit fb7bbcb
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix some torchscriptability [#246]
- Fix FourierMix being compatible with AMP [#258]

### Added
- Simplicial Embeddings [#259]

## [0.0.10] - 2022-03-14
### Fixed
- Expose bias flag for feedforwards, same default as Timm [#220]
Expand Down
3 changes: 3 additions & 0 deletions HOWTO.md
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,9 @@ my_config = [
"activation": "relu",
"hidden_layer_multiplier": 4,
},
# Optional Simplicial Embeddings on the last encoder layer
# the temperature parameter is itself optional
"simplicial_embeddings": {"L": 4, "temperature": 0.5}
},
{
"reversible": False, # Optionally make these layers reversible, to save memory
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)*
- [Sine](xformers/components/positional_embedding/sine.py)
- [Vocabulary](xformers/components/positional_embedding/vocab.py)
- [Rotary](xformers/components/positional_embedding/rotary.py)
- [Simplicial](xformers/components/simplicial_embedding.py)

</p></details>

Expand Down
1 change: 1 addition & 0 deletions examples/microViT.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
"activation": "gelu",
"hidden_layer_multiplier": hidden_layer_multiplier,
},
# "simplicial_embeddings": {"L": n_head, "temperature": 2.0},
}
]

Expand Down
82 changes: 82 additions & 0 deletions tests/test_block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,85 @@ def test_embedding_projection():
input_mask = torch.randn(SEQ, dtype=torch.float, device=device)
input_mask[input_mask < 0.0] = -float("inf")
_ = block(inputs, input_mask=input_mask)


@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="This test requires a CUDA device"
)
def test_simplicial_embedding(
device: torch.device,
):
attention_config = {
"name": "scaled_dot_product",
"dropout": 0.1,
"causal": False,
"window_size": SEQ // 8 + 1,
"seq_len": SEQ,
"dim_model": MODEL,
"num_heads": 4,
}

multi_head_config = {
"num_heads": 4,
"dim_model": MODEL,
"residual_dropout": 0.1,
"attention": attention_config,
}

feedforward_config = {
"name": "MLP",
"dim_model": MODEL,
"dropout": DROPOUT,
"activation": "relu",
"hidden_layer_multiplier": 4,
}

position_encoding_config = {
"name": "sine",
"dim_model": MODEL,
"seq_len": SEQ,
"vocab_size": VOCAB_SIZE,
}

block_config = xFormerEncoderConfig(
dim_model=MODEL,
multi_head_config=multi_head_config,
feedforward_config=feedforward_config,
position_encoding_config=position_encoding_config,
layer_norm_style="pre",
reversible=False,
simplicial_embeddings={"L": 4},
)

# Test that the whole block can be instantiated
block = xFormerEncoderBlock.from_config(block_config).to(device)

# Check that the dimensions make sense, to a FW pass
inputs = torch.rand(BATCH, SEQ, device=device)
_ = block(inputs)

# Check that we support attention masking, at least interface wise (do not check correctness yet)
att_mask = torch.ones(SEQ, SEQ, dtype=torch.bool, device=device)
_ = block(inputs, att_mask=att_mask)

# Check that we support input masking, at least interface wise (do not check correctness yet)
input_mask = torch.randn(SEQ, dtype=torch.float, device=device)
input_mask[input_mask < 0.0] = -float("inf")
_ = block(inputs, input_mask=input_mask)

# Check that a faulty L is caught
block_config = xFormerEncoderConfig(
dim_model=MODEL,
multi_head_config=multi_head_config,
feedforward_config=feedforward_config,
position_encoding_config=position_encoding_config,
layer_norm_style="pre",
reversible=False,
simplicial_embeddings={"L": 3},
)

# Test that the whole block can be instantiated
with pytest.raises(AssertionError):
block = xFormerEncoderBlock.from_config(block_config).to(device)
_ = block(inputs)
75 changes: 75 additions & 0 deletions xformers/components/simplicial_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import asdict, dataclass
from typing import Optional, Type, TypeVar

import torch

from xformers import _is_triton_available

Self = TypeVar("Self", bound="SimplicialEmbedding")


if _is_triton_available:
from xformers.triton.softmax import softmax as triton_softmax


@dataclass
class SimplicialEmbeddingConfig:
L: int
temperature: float


class SimplicialEmbedding(torch.nn.Module):
"""
An implementation of the "Simplicial Embeddings"_, as proposed by Lavoie et. al
Arguments:
- L: the number of embedding chunks
- temperature: optional scaling parameter for the softmax operation.
A small (<1.) temperature will lead to a sparse representation (up to one-hot),
while a large (>1.) temperature will make the vector more uniform
_"Simplicial Embeddings": https://arxiv.org/pdf/2204.00616.pdf
"""

def __init__(self, L: int, temperature: Optional[float] = None) -> None:
super().__init__()
self.L = L
self.temperature = temperature

def forward(self, x: torch.Tensor) -> torch.Tensor:
assert (
x.shape[-1] % self.L == 0
), f"The embedding dimension {x.shape[-1]} is not divisible by the chosen L parameter {self.L}"

# Seperate the input tensor into V chunks
B, C, E = x.shape
V = E // self.L

Vs = x.reshape(B, C, self.L, V)

# Softmax normalize them, with the proposed temperature
# This is done over the last dimension, so only within Vs
if self.temperature is not None:
Vs /= self.temperature

if _is_triton_available:
Vs = triton_softmax(
Vs, mask=None, causal=False
) # the softmax is on the last dimension
else:
Vs = torch.nn.functional.softmax(Vs, dim=-1)

# Concatenate back and return
return Vs.reshape(B, C, E)

@classmethod
def from_config(cls: Type[Self], config: SimplicialEmbeddingConfig) -> Self:
# Generate the class inputs from the config
fields = asdict(config)

return cls(**fields)
15 changes: 15 additions & 0 deletions xformers/factory/block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
build_positional_embedding,
)
from xformers.components.residual import get_deepnorm_coefficients
from xformers.components.simplicial_embedding import SimplicialEmbedding
from xformers.utils import generate_matching_config


Expand Down Expand Up @@ -183,6 +184,7 @@ class xFormerEncoderConfig(xFormerBlockConfig):

multi_head_config: Dict[str, Any]
use_triton: bool
simplicial_embeddings: Optional[Dict[str, Any]]

def __init__(
self,
Expand All @@ -192,6 +194,7 @@ def __init__(
position_encoding_config: Optional[Dict[str, Any]] = None,
layer_norm_style: str = "post",
use_triton: bool = True,
simplicial_embeddings: Optional[Dict[str, Any]] = None,
**kwargs,
):
# Convenience, fill in duplicated field
Expand Down Expand Up @@ -224,6 +227,7 @@ def __init__(

self.multi_head_config = multi_head_config
self.use_triton = use_triton
self.simplicial_embeddings = simplicial_embeddings


@dataclass(init=False)
Expand Down Expand Up @@ -351,6 +355,13 @@ def __init__(self, config: xFormerEncoderConfig, **kwargs):
):
self.wrap_ff = PostNorm(config.dim_model, self.wrap_ff)

# Simplicial embeddings are only used if specified, and on the last layer
self.simplicial_embedding: Optional[SimplicialEmbedding] = None
if config.simplicial_embeddings is not None and config.layer_position.is_last():
self.simplicial_embedding = SimplicialEmbedding(
**config.simplicial_embeddings
)

@classmethod
def from_config(cls, config: xFormerEncoderConfig):
return cls(config)
Expand Down Expand Up @@ -395,6 +406,10 @@ def forward(
x = self.wrap_att(inputs=[q, k, v], att_mask=att_mask)
x = self.wrap_ff(inputs=[x])

# Optional simplicial embeddings
if self.simplicial_embedding is not None:
x = self.simplicial_embedding(x)

return x


Expand Down

0 comments on commit fb7bbcb

Please sign in to comment.