Skip to content

Commit

Permalink
Merge pull request #1 from mackelab/tensor-construction-warnings
Browse files Browse the repository at this point in the history
Tensor construction warnings
  • Loading branch information
arturbekasov authored Apr 1, 2020
2 parents fb2cf10 + ffdf570 commit 6d4f5dd
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 52 deletions.
1 change: 0 additions & 1 deletion nflows/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,4 @@
from nflows.distributions.uniform import (
LotkaVolterraOscillating,
MG1Uniform,
TweakedUniform,
)
103 changes: 55 additions & 48 deletions nflows/transforms/standard.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,89 @@
"""Implementations of some standard transforms."""

from typing import Optional, Union, Tuple, Iterable
import warnings

import torch
from torch import Tensor

from nflows.transforms.base import Transform


class IdentityTransform(Transform):
"""Transform that leaves input unchanged."""

def forward(self, inputs, context=None):
batch_size = inputs.shape[0]
def forward(self, inputs: Tensor, context=Optional[Tensor]):
batch_size = inputs.size(0)
logabsdet = torch.zeros(batch_size)
return inputs, logabsdet

def inverse(self, inputs, context=None):
def inverse(self, inputs: Tensor, context=Optional[Tensor]):
return self(inputs, context)


class AffineScalarTransform(Transform):
"""Computes X = X * scale + shift, where scale and shift are scalars, and scale is non-zero."""
class PointwiseAffineTransform(Transform):
"""Forward transform X = X * scale + shift."""

def __init__(self, shift=None, scale=None):
def __init__(
self, shift: Union[Tensor, float] = 0.0, scale: Union[Tensor, float] = 1.0,
):
super().__init__()
shift, scale = map(torch.as_tensor, (shift, scale))

if shift is None and scale is None:
raise ValueError("At least one of scale and shift must be provided.")
if scale == 0.0:
raise ValueError("Scale cannot be zero.")
if not (scale > 0.0).all():
raise ValueError("Scale must be strictly positive.")

self.register_buffer(
"_shift", torch.tensor(shift if (shift is not None) else 0.0)
)
self.register_buffer(
"_scale", torch.tensor(scale if (scale is not None) else 1.0)
)
self.register_buffer("_shift", shift)
self.register_buffer("_scale", scale)

@property
def _log_scale(self):
return torch.log(torch.abs(self._scale))
def _log_scale(self) -> Tensor:
return torch.log(self._scale)

# XXX memoize?
def _batch_logabsdet(self, batch_shape: Iterable[int]) -> Tensor:
"""Return log abs det with input batch shape."""

if self._log_scale.numel() > 1:
return self._log_scale.expand(batch_shape).sum()
else:
# when log_scale is a scalar, we use n*log_scale, which is more
# numerically accurate than \sum_1^n log_scale.
return self._log_scale * torch.Size(batch_shape).numel()

def forward(self, inputs: Tensor, context=Optional[Tensor]) -> Tuple[Tensor]:
batch_size, *batch_shape = inputs.size()

def forward(self, inputs, context=None):
batch_size = inputs.shape[0]
num_dims = torch.prod(torch.tensor(inputs.shape[1:]), dtype=torch.float)
# RuntimeError here => shift/scale not broadcastable to input
outputs = inputs * self._scale + self._shift
logabsdet = torch.full([batch_size], self._log_scale * num_dims)
logabsdet = self._batch_logabsdet(batch_shape).expand(batch_size)

return outputs, logabsdet

def inverse(self, inputs, context=None):
batch_size = inputs.shape[0]
num_dims = torch.prod(torch.tensor(inputs.shape[1:]), dtype=torch.float)
def inverse(self, inputs: Tensor, context=Optional[Tensor]) -> Tuple[Tensor]:
batch_size, *batch_shape = inputs.size()
outputs = (inputs - self._shift) / self._scale
logabsdet = torch.full([batch_size], -self._log_scale * num_dims)
logabsdet = -self._batch_logabsdet(batch_shape).expand(batch_size)

return outputs, logabsdet


class AffineTransform(Transform):
def __init__(self, shift=None, scale=None):
super().__init__()
class AffineTransform(PointwiseAffineTransform):
def __init__(
self, shift: Union[Tensor, float] = 0.0, scale: Union[Tensor, float] = 1.0,
):

self.register_buffer(
"_shift", torch.tensor(shift if (shift is not None) else 0.0)
)
self.register_buffer(
"_scale", torch.tensor(scale if (scale is not None) else 1.0)
)
warnings.warn("Use PointwiseAffineTransform", DeprecationWarning)

@property
def _log_scale(self):
return torch.log(torch.abs(self._scale))
if shift is None:
shift = 0.0
warnings.warn(f"`shift=None` deprecated; default is {shift}")

def forward(self, inputs, context=None):
batch_size = inputs.shape[0]
outputs = inputs * self._scale + self._shift
logabsdet = self._log_scale.reshape(1, -1).repeat(batch_size, 1).sum(dim=-1)
return outputs, logabsdet
if scale is None:
scale = 1.0
warnings.warn(f"`scale=None` deprecated; default is {scale}.")

def inverse(self, inputs, context=None):
batch_size = inputs.shape[0]
outputs = (inputs - self._shift) / self._scale
logabsdet = -self._log_scale.reshape(1, -1).repeat(batch_size, 1).sum(dim=-1)
return outputs, logabsdet
super().__init__(shift, scale)


AffineScalarTransform = AffineTransform
5 changes: 3 additions & 2 deletions nflows/utils/torchutils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Various PyTorch utility functions."""

from typing import Union
import numpy as np
import torch

Expand Down Expand Up @@ -160,8 +161,8 @@ def get_temperature(max_value, bound=1 - 1e-3):
return temperature


def notinfnotnan(x):
return torch.all(~torch.isnan(x)) and torch.all(~torch.isinf(x))
def notinfnotnan(x: torch.Tensor) -> torch.Tensor:
return torch.isfinite(x).all()


def gaussian_kde_log_eval(samples, query):
Expand Down
2 changes: 1 addition & 1 deletion tests/transforms/coupling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def create_net(in_channels, out_channels):
batch_size = 10


class AffineTransformTest(TransformTest):
class AffineCouplingTransformTest(TransformTest):
shapes = [[20], [2, 4, 4]]

def test_forward(self):
Expand Down
1 change: 1 addition & 0 deletions tests/utils/torchutils_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for the PyTorch utility functions."""

import unittest
import numpy as np

import torch
import torchtestcase
Expand Down

0 comments on commit 6d4f5dd

Please sign in to comment.