From 89cb849533c6e208dd4b5ccede63bc944ee5f05c Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 23 Jan 2020 11:03:36 -0800 Subject: [PATCH] add fix_random_seed keyword attribute when instantiating ReversibleBlock --- revtorch/revtorch.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/revtorch/revtorch.py b/revtorch/revtorch.py index 82f20cf..3a5d24c 100644 --- a/revtorch/revtorch.py +++ b/revtorch/revtorch.py @@ -4,8 +4,6 @@ import sys import random -MAXINT = 2147483647 - class ReversibleBlock(nn.Module): ''' Elementary building block for building (partially) reversible architectures @@ -19,17 +17,19 @@ class ReversibleBlock(nn.Module): g_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shape ''' - def __init__(self, f_block, g_block, split_along_dim=1): + def __init__(self, f_block, g_block, split_along_dim=1, fix_random_seed = False): super(ReversibleBlock, self).__init__() self.f_block = f_block self.g_block = g_block self.split_along_dim = split_along_dim + self.fix_random_seed = fix_random_seed + self.random_seeds = {} - def set_random_seed(self, namespace, new = False): + def set_seed(self, namespace, new = False): if not self.fix_random_seed: return if new: - self.random_seeds[namespace] = random.randint(0, MAXINT) + self.random_seeds[namespace] = random.randint(0, sys.maxsize) torch.manual_seed(self.random_seeds[namespace]) def forward(self, x): @@ -41,9 +41,9 @@ def forward(self, x): x1, x2 = torch.chunk(x, 2, dim=self.split_along_dim) y1, y2 = None, None with torch.no_grad(): - self.set_random_seed('f', new=True) + self.set_seed('f', new=True) y1 = x1 + self.f_block(x2) - self.set_random_seed('g', new=True) + self.set_seed('g', new=True) y2 = x2 + self.g_block(y1) return torch.cat([y1, y2], dim=self.split_along_dim) @@ -77,7 +77,7 @@ def backward_pass(self, y, dy): # Ensures that PyTorch tracks the operations in a DAG with torch.enable_grad(): - self.set_random_seed('g') + self.set_seed('g') gy1 = self.g_block(y1) # Use autograd framework to differentiate the calculation. The @@ -97,7 +97,7 @@ def backward_pass(self, y, dy): with torch.enable_grad(): x2.requires_grad = True - self.set_random_seed('f') + self.set_seed('f') fx2 = self.f_block(x2) # Use autograd framework to differentiate the calculation. The