Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update to add setting for fixing a seed before execution of f and g f… #4

Merged
merged 2 commits into from
Jan 23, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
add fix_random_seed keyword attribute when instantiating ReversibleBlock
  • Loading branch information
lucidrains committed Jan 23, 2020
commit 89cb849533c6e208dd4b5ccede63bc944ee5f05c
18 changes: 9 additions & 9 deletions revtorch/revtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import sys
import random

MAXINT = 2147483647

class ReversibleBlock(nn.Module):
'''
Elementary building block for building (partially) reversible architectures
Expand All @@ -19,17 +17,19 @@ class ReversibleBlock(nn.Module):
g_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shape
'''

def __init__(self, f_block, g_block, split_along_dim=1):
def __init__(self, f_block, g_block, split_along_dim=1, fix_random_seed = False):
super(ReversibleBlock, self).__init__()
self.f_block = f_block
self.g_block = g_block
self.split_along_dim = split_along_dim
self.fix_random_seed = fix_random_seed
self.random_seeds = {}

def set_random_seed(self, namespace, new = False):
def set_seed(self, namespace, new = False):
if not self.fix_random_seed:
return
if new:
self.random_seeds[namespace] = random.randint(0, MAXINT)
self.random_seeds[namespace] = random.randint(0, sys.maxsize)
torch.manual_seed(self.random_seeds[namespace])

def forward(self, x):
Expand All @@ -41,9 +41,9 @@ def forward(self, x):
x1, x2 = torch.chunk(x, 2, dim=self.split_along_dim)
y1, y2 = None, None
with torch.no_grad():
self.set_random_seed('f', new=True)
self.set_seed('f', new=True)
y1 = x1 + self.f_block(x2)
self.set_random_seed('g', new=True)
self.set_seed('g', new=True)
y2 = x2 + self.g_block(y1)

return torch.cat([y1, y2], dim=self.split_along_dim)
Expand Down Expand Up @@ -77,7 +77,7 @@ def backward_pass(self, y, dy):

# Ensures that PyTorch tracks the operations in a DAG
with torch.enable_grad():
self.set_random_seed('g')
self.set_seed('g')
gy1 = self.g_block(y1)

# Use autograd framework to differentiate the calculation. The
Expand All @@ -97,7 +97,7 @@ def backward_pass(self, y, dy):

with torch.enable_grad():
x2.requires_grad = True
self.set_random_seed('f')
self.set_seed('f')
fx2 = self.f_block(x2)

# Use autograd framework to differentiate the calculation. The
Expand Down