From fb7bbcb7162e324af731815d9b170f59f0afafd2 Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Fri, 8 Apr 2022 16:44:28 -0400 Subject: [PATCH] [feat][minor] Simplicial embeddings (#261) * Initial implementation, needs testing adding the option in the encoder Adding it to the microViT example * code review, pulling in triton softmax if possible --- CHANGELOG.md | 3 + HOWTO.md | 3 + README.md | 1 + examples/microViT.py | 1 + tests/test_block_factory.py | 82 +++++++++++++++++++++ xformers/components/simplicial_embedding.py | 75 +++++++++++++++++++ xformers/factory/block_factory.py | 15 ++++ 7 files changed, 180 insertions(+) create mode 100644 xformers/components/simplicial_embedding.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ec3952d54b..ef46383efe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix some torchscriptability [#246] - Fix FourierMix being compatible with AMP [#258] +### Added +- Simplicial Embeddings [#259] + ## [0.0.10] - 2022-03-14 ### Fixed - Expose bias flag for feedforwards, same default as Timm [#220] diff --git a/HOWTO.md b/HOWTO.md index 8de549d96c..55cddb1561 100644 --- a/HOWTO.md +++ b/HOWTO.md @@ -509,6 +509,9 @@ my_config = [ "activation": "relu", "hidden_layer_multiplier": 4, }, + # Optional Simplicial Embeddings on the last encoder layer + # the temperature parameter is itself optional + "simplicial_embeddings": {"L": 4, "temperature": 0.5} }, { "reversible": False, # Optionally make these layers reversible, to save memory diff --git a/README.md b/README.md index 02db93d31b..1fd409de70 100644 --- a/README.md +++ b/README.md @@ -163,6 +163,7 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)* - [Sine](xformers/components/positional_embedding/sine.py) - [Vocabulary](xformers/components/positional_embedding/vocab.py) - [Rotary](xformers/components/positional_embedding/rotary.py) +- [Simplicial](xformers/components/simplicial_embedding.py)

diff --git a/examples/microViT.py b/examples/microViT.py index dfa0dced2f..879d906e83 100644 --- a/examples/microViT.py +++ b/examples/microViT.py @@ -86,6 +86,7 @@ def __init__( "activation": "gelu", "hidden_layer_multiplier": hidden_layer_multiplier, }, + # "simplicial_embeddings": {"L": n_head, "temperature": 2.0}, } ] diff --git a/tests/test_block_factory.py b/tests/test_block_factory.py index 7b5450faaa..391038d3eb 100644 --- a/tests/test_block_factory.py +++ b/tests/test_block_factory.py @@ -310,3 +310,85 @@ def test_embedding_projection(): input_mask = torch.randn(SEQ, dtype=torch.float, device=device) input_mask[input_mask < 0.0] = -float("inf") _ = block(inputs, input_mask=input_mask) + + +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.skipif( + not torch.cuda.is_available(), reason="This test requires a CUDA device" +) +def test_simplicial_embedding( + device: torch.device, +): + attention_config = { + "name": "scaled_dot_product", + "dropout": 0.1, + "causal": False, + "window_size": SEQ // 8 + 1, + "seq_len": SEQ, + "dim_model": MODEL, + "num_heads": 4, + } + + multi_head_config = { + "num_heads": 4, + "dim_model": MODEL, + "residual_dropout": 0.1, + "attention": attention_config, + } + + feedforward_config = { + "name": "MLP", + "dim_model": MODEL, + "dropout": DROPOUT, + "activation": "relu", + "hidden_layer_multiplier": 4, + } + + position_encoding_config = { + "name": "sine", + "dim_model": MODEL, + "seq_len": SEQ, + "vocab_size": VOCAB_SIZE, + } + + block_config = xFormerEncoderConfig( + dim_model=MODEL, + multi_head_config=multi_head_config, + feedforward_config=feedforward_config, + position_encoding_config=position_encoding_config, + layer_norm_style="pre", + reversible=False, + simplicial_embeddings={"L": 4}, + ) + + # Test that the whole block can be instantiated + block = xFormerEncoderBlock.from_config(block_config).to(device) + + # Check that the dimensions make sense, to a FW pass + inputs = torch.rand(BATCH, SEQ, device=device) + _ = block(inputs) + + # Check that we support attention masking, at least interface wise (do not check correctness yet) + att_mask = torch.ones(SEQ, SEQ, dtype=torch.bool, device=device) + _ = block(inputs, att_mask=att_mask) + + # Check that we support input masking, at least interface wise (do not check correctness yet) + input_mask = torch.randn(SEQ, dtype=torch.float, device=device) + input_mask[input_mask < 0.0] = -float("inf") + _ = block(inputs, input_mask=input_mask) + + # Check that a faulty L is caught + block_config = xFormerEncoderConfig( + dim_model=MODEL, + multi_head_config=multi_head_config, + feedforward_config=feedforward_config, + position_encoding_config=position_encoding_config, + layer_norm_style="pre", + reversible=False, + simplicial_embeddings={"L": 3}, + ) + + # Test that the whole block can be instantiated + with pytest.raises(AssertionError): + block = xFormerEncoderBlock.from_config(block_config).to(device) + _ = block(inputs) diff --git a/xformers/components/simplicial_embedding.py b/xformers/components/simplicial_embedding.py new file mode 100644 index 0000000000..7864ca9473 --- /dev/null +++ b/xformers/components/simplicial_embedding.py @@ -0,0 +1,75 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import asdict, dataclass +from typing import Optional, Type, TypeVar + +import torch + +from xformers import _is_triton_available + +Self = TypeVar("Self", bound="SimplicialEmbedding") + + +if _is_triton_available: + from xformers.triton.softmax import softmax as triton_softmax + + +@dataclass +class SimplicialEmbeddingConfig: + L: int + temperature: float + + +class SimplicialEmbedding(torch.nn.Module): + """ + An implementation of the "Simplicial Embeddings"_, as proposed by Lavoie et. al + + Arguments: + - L: the number of embedding chunks + - temperature: optional scaling parameter for the softmax operation. + A small (<1.) temperature will lead to a sparse representation (up to one-hot), + while a large (>1.) temperature will make the vector more uniform + + _"Simplicial Embeddings": https://arxiv.org/pdf/2204.00616.pdf + """ + + def __init__(self, L: int, temperature: Optional[float] = None) -> None: + super().__init__() + self.L = L + self.temperature = temperature + + def forward(self, x: torch.Tensor) -> torch.Tensor: + assert ( + x.shape[-1] % self.L == 0 + ), f"The embedding dimension {x.shape[-1]} is not divisible by the chosen L parameter {self.L}" + + # Seperate the input tensor into V chunks + B, C, E = x.shape + V = E // self.L + + Vs = x.reshape(B, C, self.L, V) + + # Softmax normalize them, with the proposed temperature + # This is done over the last dimension, so only within Vs + if self.temperature is not None: + Vs /= self.temperature + + if _is_triton_available: + Vs = triton_softmax( + Vs, mask=None, causal=False + ) # the softmax is on the last dimension + else: + Vs = torch.nn.functional.softmax(Vs, dim=-1) + + # Concatenate back and return + return Vs.reshape(B, C, E) + + @classmethod + def from_config(cls: Type[Self], config: SimplicialEmbeddingConfig) -> Self: + # Generate the class inputs from the config + fields = asdict(config) + + return cls(**fields) diff --git a/xformers/factory/block_factory.py b/xformers/factory/block_factory.py index 662b6d60d7..14bd9941f4 100644 --- a/xformers/factory/block_factory.py +++ b/xformers/factory/block_factory.py @@ -30,6 +30,7 @@ build_positional_embedding, ) from xformers.components.residual import get_deepnorm_coefficients +from xformers.components.simplicial_embedding import SimplicialEmbedding from xformers.utils import generate_matching_config @@ -183,6 +184,7 @@ class xFormerEncoderConfig(xFormerBlockConfig): multi_head_config: Dict[str, Any] use_triton: bool + simplicial_embeddings: Optional[Dict[str, Any]] def __init__( self, @@ -192,6 +194,7 @@ def __init__( position_encoding_config: Optional[Dict[str, Any]] = None, layer_norm_style: str = "post", use_triton: bool = True, + simplicial_embeddings: Optional[Dict[str, Any]] = None, **kwargs, ): # Convenience, fill in duplicated field @@ -224,6 +227,7 @@ def __init__( self.multi_head_config = multi_head_config self.use_triton = use_triton + self.simplicial_embeddings = simplicial_embeddings @dataclass(init=False) @@ -351,6 +355,13 @@ def __init__(self, config: xFormerEncoderConfig, **kwargs): ): self.wrap_ff = PostNorm(config.dim_model, self.wrap_ff) + # Simplicial embeddings are only used if specified, and on the last layer + self.simplicial_embedding: Optional[SimplicialEmbedding] = None + if config.simplicial_embeddings is not None and config.layer_position.is_last(): + self.simplicial_embedding = SimplicialEmbedding( + **config.simplicial_embeddings + ) + @classmethod def from_config(cls, config: xFormerEncoderConfig): return cls(config) @@ -395,6 +406,10 @@ def forward( x = self.wrap_att(inputs=[q, k, v], att_mask=att_mask) x = self.wrap_ff(inputs=[x]) + # Optional simplicial embeddings + if self.simplicial_embedding is not None: + x = self.simplicial_embedding(x) + return x