Skip to content

Commit

Permalink
Rather use torch builtins numel and as_tensor.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alvaro Tejero-Cantero committed Mar 30, 2020
1 parent a1b26d3 commit ffdf570
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 40 deletions.
17 changes: 8 additions & 9 deletions nflows/transforms/standard.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""Implementations of some standard transforms."""

from typing import Optional, Union
from typing import Optional, Union, Tuple, Iterable
import warnings

import torch
from torch import Tensor

from nflows.transforms.base import Transform
from nflows.utils.torchutils import ensure_tensor, numel


class IdentityTransform(Transform):
Expand All @@ -29,7 +28,7 @@ def __init__(
self, shift: Union[Tensor, float] = 0.0, scale: Union[Tensor, float] = 1.0,
):
super().__init__()
shift, scale = map(ensure_tensor, (shift, scale))
shift, scale = map(torch.as_tensor, (shift, scale))

if not (scale > 0.0).all():
raise ValueError("Scale must be strictly positive.")
Expand All @@ -38,21 +37,21 @@ def __init__(
self.register_buffer("_scale", scale)

@property
def _log_scale(self):
def _log_scale(self) -> Tensor:
return torch.log(self._scale)

# XXX memoize?
def _batch_logabsdet(self, batch_shape: torch.Size):
def _batch_logabsdet(self, batch_shape: Iterable[int]) -> Tensor:
"""Return log abs det with input batch shape."""

if numel(self._log_scale) > 1:
if self._log_scale.numel() > 1:
return self._log_scale.expand(batch_shape).sum()
else:
# when log_scale is a scalar, we use n*log_scale, which is more
# numerically accurate than \sum_1^n log_scale.
return self._log_scale * numel(batch_shape)
return self._log_scale * torch.Size(batch_shape).numel()

def forward(self, inputs: Tensor, context=Optional[Tensor]):
def forward(self, inputs: Tensor, context=Optional[Tensor]) -> Tuple[Tensor]:
batch_size, *batch_shape = inputs.size()

# RuntimeError here => shift/scale not broadcastable to input
Expand All @@ -61,7 +60,7 @@ def forward(self, inputs: Tensor, context=Optional[Tensor]):

return outputs, logabsdet

def inverse(self, inputs: Tensor, context=Optional[Tensor]):
def inverse(self, inputs: Tensor, context=Optional[Tensor]) -> Tuple[Tensor]:
batch_size, *batch_shape = inputs.size()
outputs = (inputs - self._shift) / self._scale
logabsdet = -self._batch_logabsdet(batch_shape).expand(batch_size)
Expand Down
22 changes: 0 additions & 22 deletions nflows/utils/torchutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,25 +175,3 @@ def gaussian_kde_log_eval(samples, query):
d = -np.log(N) - (D / 2) * np.log(2 * np.pi) - D * np.log(std)
c += d
return torch.logsumexp(c, dim=-1)


def ensure_tensor(arg):
"""Return argument cast into a tensor if it's not one already."""

if not isinstance(arg, torch.Tensor):
return torch.tensor(arg)

return arg


def numel(t: Union[torch.Tensor, torch.Size]) -> int:
"""Return number of elements given a tensor or its size.
Args:
t: a Tensor or tensor's Size.
"""

if isinstance(t, torch.Tensor):
t = t.size()

return int(np.prod(t))
9 changes: 0 additions & 9 deletions tests/utils/torchutils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,6 @@ def test_searchsorted_arbitrary_shape(self):
idx = torchutils.searchsorted(bin_locations, inputs)
self.assertEqual(idx.shape, inputs.shape)

def test_ensure_tensor(self):
a_tensor = torch.randn([1, 2])
an_array = np.array([1, 2, 3])
a_scalar = 2.0

assert isinstance(torchutils.ensure_tensor(a_tensor), torch.Tensor)
assert isinstance(torchutils.ensure_tensor(an_array), torch.Tensor)
assert isinstance(torchutils.ensure_tensor(a_scalar), torch.Tensor)


if __name__ == "__main__":
unittest.main()

0 comments on commit ffdf570

Please sign in to comment.