Skip to content

Commit

Permalink
add fix_random_seed keyword attribute when instantiating ReversibleBlock
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 23, 2020
1 parent 0cb733b commit 89cb849
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions revtorch/revtorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import sys
import random

MAXINT = 2147483647

class ReversibleBlock(nn.Module):
'''
Elementary building block for building (partially) reversible architectures
Expand All @@ -19,17 +17,19 @@ class ReversibleBlock(nn.Module):
g_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shape
'''

def __init__(self, f_block, g_block, split_along_dim=1):
def __init__(self, f_block, g_block, split_along_dim=1, fix_random_seed = False):
super(ReversibleBlock, self).__init__()
self.f_block = f_block
self.g_block = g_block
self.split_along_dim = split_along_dim
self.fix_random_seed = fix_random_seed
self.random_seeds = {}

def set_random_seed(self, namespace, new = False):
def set_seed(self, namespace, new = False):
if not self.fix_random_seed:
return
if new:
self.random_seeds[namespace] = random.randint(0, MAXINT)
self.random_seeds[namespace] = random.randint(0, sys.maxsize)
torch.manual_seed(self.random_seeds[namespace])

def forward(self, x):
Expand All @@ -41,9 +41,9 @@ def forward(self, x):
x1, x2 = torch.chunk(x, 2, dim=self.split_along_dim)
y1, y2 = None, None
with torch.no_grad():
self.set_random_seed('f', new=True)
self.set_seed('f', new=True)
y1 = x1 + self.f_block(x2)
self.set_random_seed('g', new=True)
self.set_seed('g', new=True)
y2 = x2 + self.g_block(y1)

return torch.cat([y1, y2], dim=self.split_along_dim)
Expand Down Expand Up @@ -77,7 +77,7 @@ def backward_pass(self, y, dy):

# Ensures that PyTorch tracks the operations in a DAG
with torch.enable_grad():
self.set_random_seed('g')
self.set_seed('g')
gy1 = self.g_block(y1)

# Use autograd framework to differentiate the calculation. The
Expand All @@ -97,7 +97,7 @@ def backward_pass(self, y, dy):

with torch.enable_grad():
x2.requires_grad = True
self.set_random_seed('f')
self.set_seed('f')
fx2 = self.f_block(x2)

# Use autograd framework to differentiate the calculation. The
Expand Down

0 comments on commit 89cb849

Please sign in to comment.