From e153c4fc373e357c495d05ab742223b29965989f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 24 Mar 2021 14:22:10 -0700 Subject: [PATCH] allow for one to pass in images of any dimension, provided height and width is divisible by block size --- README.md | 1 - halonet_pytorch/halonet_pytorch.py | 25 +++++++++---------------- setup.py | 2 +- 3 files changed, 10 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index cfe9d2b..cc7cc86 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,6 @@ from halonet_pytorch import HaloAttention attn = HaloAttention( dim = 512, # dimension of feature map - fmap_size = 32, # feature map height and width block_size = 8, # neighborhood block size (feature map must be divisible by this) halo_size = 4, # halo size (block receptive field) dim_head = 64, # dimension of each head diff --git a/halonet_pytorch/halonet_pytorch.py b/halonet_pytorch/halonet_pytorch.py index a289d24..b4db0db 100644 --- a/halonet_pytorch/halonet_pytorch.py +++ b/halonet_pytorch/halonet_pytorch.py @@ -47,15 +47,13 @@ class RelPosEmb(nn.Module): def __init__( self, block_size, - fmap_size, + rel_size, dim_head ): super().__init__() - fmap_size = pair(fmap_size) - height, width = fmap_size + height = width = rel_size scale = dim_head ** -0.5 - self.fmap_size = fmap_size self.block_size = block_size self.rel_height = nn.Parameter(torch.randn(height * 2 - 1, dim_head) * scale) self.rel_width = nn.Parameter(torch.randn(width * 2 - 1, dim_head) * scale) @@ -79,14 +77,12 @@ def __init__( self, *, dim, - fmap_size, block_size, halo_size, dim_head = 64, heads = 8 ): super().__init__() - assert fmap_size % block_size == 0, 'feature map height or width must be divisible by block size' assert halo_size > 0, 'halo size must be greater than 0' self.dim = dim @@ -100,7 +96,7 @@ def __init__( self.rel_pos_emb = RelPosEmb( block_size = block_size, - fmap_size = block_size + (halo_size * 2), + rel_size = block_size + (halo_size * 2), dim_head = dim_head ) @@ -108,16 +104,9 @@ def __init__( self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) self.to_out = nn.Linear(inner_dim, dim) - # prepare a mask for removing attention to padding, cached for performance - - mask = torch.ones(1, 1, fmap_size, fmap_size) - mask = F.unfold(mask, kernel_size = block_size + (halo_size * 2), stride = block_size, padding = halo_size) - mask = repeat(mask, 'b j i -> (b i h) j', h = heads) - self.register_buffer('mask', mask == 0) - def forward(self, x): b, c, h, w, block, halo, heads, device = *x.shape, self.block_size, self.halo_size, self.heads, x.device - assert h == w, 'dimensions of fmap must be same on both sides, for now' + assert h % block == 0 and w % block == 0, 'fmap dimensions must be divisible by the block size' assert c == self.dim, f'channels for input ({c}) does not equal to the correct dimension ({self.dim})' # get block neighborhoods, and prepare a halo-ed version (blocks with padding) for deriving key values @@ -150,7 +139,11 @@ def forward(self, x): # mask out padding (in the paper, they claim to not need masks, but what about padding?) - mask = repeat(self.mask, 'h j -> (b h) () j', b = b) + mask = torch.ones(1, 1, h, w, device = device) + mask = F.unfold(mask, kernel_size = block + (halo * 2), stride = block, padding = halo) + mask = repeat(mask, '() j i -> (b i h) () j', b = b, h = heads) + mask = mask.bool() + max_neg_value = -torch.finfo(sim.dtype).max sim.masked_fill_(mask, max_neg_value) diff --git a/setup.py b/setup.py index 66a2a6a..445e756 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'halonet-pytorch', packages = find_packages(), - version = '0.0.3', + version = '0.0.4', license='MIT', description = 'HaloNet - Pytorch', author = 'Phil Wang',