Skip to content

Commit

Permalink
patch
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 28, 2024
1 parent 717e398 commit 23a2959
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 7 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'tab-transformer-pytorch',
packages = find_packages(),
version = '0.4.0',
version = '0.4.1',
license='MIT',
description = 'Tab Transformer - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
5 changes: 2 additions & 3 deletions tab_transformer_pytorch/ft_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,10 @@ def forward(self, x, return_attn = False):
x = self.expand_streams(x)

for attn, ff in self.layers:
attn_out, post_softmax_attn = attn(x)
x, post_softmax_attn = attn(x)
post_softmax_attns.append(post_softmax_attn)

x = attn_out + x
x = ff(x) + x
x = ff(x)

x = self.reduce_streams(x)

Expand Down
5 changes: 2 additions & 3 deletions tab_transformer_pytorch/tab_transformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,10 @@ def forward(self, x, return_attn = False):
x = self.expand_streams(x)

for attn, ff in self.layers:
attn_out, post_softmax_attn = attn(x)
x, post_softmax_attn = attn(x)
post_softmax_attns.append(post_softmax_attn)

x = x + attn_out
x = ff(x) + x
x = ff(x)

x = self.reduce_streams(x)

Expand Down

0 comments on commit 23a2959

Please sign in to comment.