forked from bayesiains/nflows
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from mackelab/tensor-construction-warnings
Tensor construction warnings
- Loading branch information
Showing
5 changed files
with
60 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,5 +9,4 @@ | |
from nflows.distributions.uniform import ( | ||
LotkaVolterraOscillating, | ||
MG1Uniform, | ||
TweakedUniform, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,82 +1,89 @@ | ||
"""Implementations of some standard transforms.""" | ||
|
||
from typing import Optional, Union, Tuple, Iterable | ||
import warnings | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from nflows.transforms.base import Transform | ||
|
||
|
||
class IdentityTransform(Transform): | ||
"""Transform that leaves input unchanged.""" | ||
|
||
def forward(self, inputs, context=None): | ||
batch_size = inputs.shape[0] | ||
def forward(self, inputs: Tensor, context=Optional[Tensor]): | ||
batch_size = inputs.size(0) | ||
logabsdet = torch.zeros(batch_size) | ||
return inputs, logabsdet | ||
|
||
def inverse(self, inputs, context=None): | ||
def inverse(self, inputs: Tensor, context=Optional[Tensor]): | ||
return self(inputs, context) | ||
|
||
|
||
class AffineScalarTransform(Transform): | ||
"""Computes X = X * scale + shift, where scale and shift are scalars, and scale is non-zero.""" | ||
class PointwiseAffineTransform(Transform): | ||
"""Forward transform X = X * scale + shift.""" | ||
|
||
def __init__(self, shift=None, scale=None): | ||
def __init__( | ||
self, shift: Union[Tensor, float] = 0.0, scale: Union[Tensor, float] = 1.0, | ||
): | ||
super().__init__() | ||
shift, scale = map(torch.as_tensor, (shift, scale)) | ||
|
||
if shift is None and scale is None: | ||
raise ValueError("At least one of scale and shift must be provided.") | ||
if scale == 0.0: | ||
raise ValueError("Scale cannot be zero.") | ||
if not (scale > 0.0).all(): | ||
raise ValueError("Scale must be strictly positive.") | ||
|
||
self.register_buffer( | ||
"_shift", torch.tensor(shift if (shift is not None) else 0.0) | ||
) | ||
self.register_buffer( | ||
"_scale", torch.tensor(scale if (scale is not None) else 1.0) | ||
) | ||
self.register_buffer("_shift", shift) | ||
self.register_buffer("_scale", scale) | ||
|
||
@property | ||
def _log_scale(self): | ||
return torch.log(torch.abs(self._scale)) | ||
def _log_scale(self) -> Tensor: | ||
return torch.log(self._scale) | ||
|
||
# XXX memoize? | ||
def _batch_logabsdet(self, batch_shape: Iterable[int]) -> Tensor: | ||
"""Return log abs det with input batch shape.""" | ||
|
||
if self._log_scale.numel() > 1: | ||
return self._log_scale.expand(batch_shape).sum() | ||
else: | ||
# when log_scale is a scalar, we use n*log_scale, which is more | ||
# numerically accurate than \sum_1^n log_scale. | ||
return self._log_scale * torch.Size(batch_shape).numel() | ||
|
||
def forward(self, inputs: Tensor, context=Optional[Tensor]) -> Tuple[Tensor]: | ||
batch_size, *batch_shape = inputs.size() | ||
|
||
def forward(self, inputs, context=None): | ||
batch_size = inputs.shape[0] | ||
num_dims = torch.prod(torch.tensor(inputs.shape[1:]), dtype=torch.float) | ||
# RuntimeError here => shift/scale not broadcastable to input | ||
outputs = inputs * self._scale + self._shift | ||
logabsdet = torch.full([batch_size], self._log_scale * num_dims) | ||
logabsdet = self._batch_logabsdet(batch_shape).expand(batch_size) | ||
|
||
return outputs, logabsdet | ||
|
||
def inverse(self, inputs, context=None): | ||
batch_size = inputs.shape[0] | ||
num_dims = torch.prod(torch.tensor(inputs.shape[1:]), dtype=torch.float) | ||
def inverse(self, inputs: Tensor, context=Optional[Tensor]) -> Tuple[Tensor]: | ||
batch_size, *batch_shape = inputs.size() | ||
outputs = (inputs - self._shift) / self._scale | ||
logabsdet = torch.full([batch_size], -self._log_scale * num_dims) | ||
logabsdet = -self._batch_logabsdet(batch_shape).expand(batch_size) | ||
|
||
return outputs, logabsdet | ||
|
||
|
||
class AffineTransform(Transform): | ||
def __init__(self, shift=None, scale=None): | ||
super().__init__() | ||
class AffineTransform(PointwiseAffineTransform): | ||
def __init__( | ||
self, shift: Union[Tensor, float] = 0.0, scale: Union[Tensor, float] = 1.0, | ||
): | ||
|
||
self.register_buffer( | ||
"_shift", torch.tensor(shift if (shift is not None) else 0.0) | ||
) | ||
self.register_buffer( | ||
"_scale", torch.tensor(scale if (scale is not None) else 1.0) | ||
) | ||
warnings.warn("Use PointwiseAffineTransform", DeprecationWarning) | ||
|
||
@property | ||
def _log_scale(self): | ||
return torch.log(torch.abs(self._scale)) | ||
if shift is None: | ||
shift = 0.0 | ||
warnings.warn(f"`shift=None` deprecated; default is {shift}") | ||
|
||
def forward(self, inputs, context=None): | ||
batch_size = inputs.shape[0] | ||
outputs = inputs * self._scale + self._shift | ||
logabsdet = self._log_scale.reshape(1, -1).repeat(batch_size, 1).sum(dim=-1) | ||
return outputs, logabsdet | ||
if scale is None: | ||
scale = 1.0 | ||
warnings.warn(f"`scale=None` deprecated; default is {scale}.") | ||
|
||
def inverse(self, inputs, context=None): | ||
batch_size = inputs.shape[0] | ||
outputs = (inputs - self._shift) / self._scale | ||
logabsdet = -self._log_scale.reshape(1, -1).repeat(batch_size, 1).sum(dim=-1) | ||
return outputs, logabsdet | ||
super().__init__(shift, scale) | ||
|
||
|
||
AffineScalarTransform = AffineTransform |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters