diff --git a/pytorch3d/transforms/__init__.py b/pytorch3d/transforms/__init__.py index efa0d6310..81da49571 100644 --- a/pytorch3d/transforms/__init__.py +++ b/pytorch3d/transforms/__init__.py @@ -20,6 +20,7 @@ rotation_6d_to_matrix, standardize_quaternion, ) +from .se3 import se3_exp_map, se3_log_map from .so3 import ( so3_exponential_map, so3_exp_map, diff --git a/pytorch3d/transforms/se3.py b/pytorch3d/transforms/se3.py new file mode 100644 index 000000000..e14c720ad --- /dev/null +++ b/pytorch3d/transforms/se3.py @@ -0,0 +1,213 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import torch + +from .so3 import hat, _so3_exp_map, so3_log_map + + +def se3_exp_map(log_transform: torch.Tensor, eps: float = 1e-4) -> torch.Tensor: + """ + Convert a batch of logarithmic representations of SE(3) matrices `log_transform` + to a batch of 4x4 SE(3) matrices using the exponential map. + See e.g. [1], Sec 9.4.2. for more detailed description. + + A SE(3) matrix has the following form: + ``` + [ R 0 ] + [ T 1 ] , + ``` + where `R` is a 3x3 rotation matrix and `T` is a 3-D translation vector. + SE(3) matrices are commonly used to represent rigid motions or camera extrinsics. + + In the SE(3) logarithmic representation SE(3) matrices are + represented as 6-dimensional vectors `[log_translation | log_rotation]`, + i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`. + + The conversion from the 6D representation to a 4x4 SE(3) matrix `transform` + is done as follows: + ``` + transform = exp( [ hat(log_rotation) 0 ] + [ log_translation 1 ] ) , + ``` + where `exp` is the matrix exponential and `hat` is the Hat operator [2]. + + Note that for any `log_transform` with `0 <= ||log_rotation|| < 2pi` + (i.e. the rotation angle is between 0 and 2pi), the following identity holds: + ``` + se3_log_map(se3_exponential_map(log_transform)) == log_transform + ``` + + The conversion has a singularity around `||log(transform)|| = 0` + which is handled by clamping controlled with the `eps` argument. + + Args: + log_transform: Batch of vectors of shape `(minibatch, 6)`. + eps: A threshold for clipping the squared norm of the rotation logarithm + to avoid unstable gradients in the singular case. + + Returns: + Batch of transformation matrices of shape `(minibatch, 4, 4)`. + + Raises: + ValueError if `log_transform` is of incorrect shape. + + [1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf + [2] https://en.wikipedia.org/wiki/Hat_operator + """ + + if log_transform.ndim != 2 or log_transform.shape[1] != 6: + raise ValueError("Expected input to be of shape (N, 6).") + + N, _ = log_transform.shape + + log_translation = log_transform[..., :3] + log_rotation = log_transform[..., 3:] + + # rotation is an exponential map of log_rotation + ( + R, + rotation_angles, + log_rotation_hat, + log_rotation_hat_square, + ) = _so3_exp_map(log_rotation, eps=eps) + + # translation is V @ T + V = _se3_V_matrix( + log_rotation, + log_rotation_hat, + log_rotation_hat_square, + rotation_angles, + eps=eps, + ) + T = torch.bmm(V, log_translation[:, :, None])[:, :, 0] + + transform = torch.zeros( + N, 4, 4, dtype=log_transform.dtype, device=log_transform.device + ) + + transform[:, :3, :3] = R + transform[:, :3, 3] = T + transform[:, 3, 3] = 1.0 + + return transform.permute(0, 2, 1) + + +def se3_log_map( + transform: torch.Tensor, eps: float = 1e-4, cos_bound: float = 1e-4 +) -> torch.Tensor: + """ + Convert a batch of 4x4 transformation matrices `transform` + to a batch of 6-dimensional SE(3) logarithms of the SE(3) matrices. + See e.g. [1], Sec 9.4.2. for more detailed description. + + A SE(3) matrix has the following form: + ``` + [ R 0 ] + [ T 1 ] , + ``` + where `R` is an orthonormal 3x3 rotation matrix and `T` is a 3-D translation vector. + SE(3) matrices are commonly used to represent rigid motions or camera extrinsics. + + In the SE(3) logarithmic representation SE(3) matrices are + represented as 6-dimensional vectors `[log_translation | log_rotation]`, + i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`. + + The conversion from the 4x4 SE(3) matrix `transform` to the + 6D representation `log_transform = [log_translation | log_rotation]` + is done as follows: + ``` + log_transform = log(transform) + log_translation = log_transform[3, :3] + log_rotation = inv_hat(log_transform[:3, :3]) + ``` + where `log` is the matrix logarithm + and `inv_hat` is the inverse of the Hat operator [2]. + + Note that for any valid 4x4 `transform` matrix, the following identity holds: + ``` + se3_exp_map(se3_log_map(transform)) == transform + ``` + + The conversion has a singularity around `(transform=I)` which is handled + by clamping controlled with the `eps` and `cos_bound` arguments. + + Args: + transform: batch of SE(3) matrices of shape `(minibatch, 4, 4)`. + eps: A threshold for clipping the squared norm of the rotation logarithm + to avoid division by zero in the singular case. + cos_bound: Clamps the cosine of the rotation angle to + [-1 + cos_bound, 3 - cos_bound] to avoid non-finite outputs. + The non-finite outputs can be caused by passing small rotation angles + to the `acos` function in `so3_rotation_angle` of `so3_log_map`. + + Returns: + Batch of logarithms of input SE(3) matrices + of shape `(minibatch, 6)`. + + Raises: + ValueError if `transform` is of incorrect shape. + ValueError if `R` has an unexpected trace. + + [1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf + [2] https://en.wikipedia.org/wiki/Hat_operator + """ + + if transform.ndim != 3: + raise ValueError("Input tensor shape has to be (N, 4, 4).") + + N, dim1, dim2 = transform.shape + if dim1 != 4 or dim2 != 4: + raise ValueError("Input tensor shape has to be (N, 4, 4).") + + if not torch.allclose(transform[:, :3, 3], torch.zeros_like(transform[:, :3, 3])): + raise ValueError("All elements of `transform[:, :3, 3]` should be 0.") + + # log_rot is just so3_log_map of the upper left 3x3 block + R = transform[:, :3, :3].permute(0, 2, 1) + log_rotation = so3_log_map(R, eps=eps, cos_bound=cos_bound) + + # log_translation is V^-1 @ T + T = transform[:, 3, :3] + V = _se3_V_matrix(*_get_se3_V_input(log_rotation), eps=eps) + log_translation = torch.linalg.solve(V, T[:, :, None])[:, :, 0] + + return torch.cat((log_translation, log_rotation), dim=1) + + +def _se3_V_matrix( + log_rotation: torch.Tensor, + log_rotation_hat: torch.Tensor, + log_rotation_hat_square: torch.Tensor, + rotation_angles: torch.Tensor, + eps: float = 1e-4, +) -> torch.Tensor: + """ + A helper function that computes the "V" matrix from [1], Sec 9.4.2. + [1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf + """ + + V = ( + torch.eye(3, dtype=log_rotation.dtype, device=log_rotation.device)[None] + + log_rotation_hat + * ((1 - torch.cos(rotation_angles)) / (rotation_angles ** 2))[:, None, None] + + ( + log_rotation_hat_square + * ((rotation_angles - torch.sin(rotation_angles)) / (rotation_angles ** 3))[ + :, None, None + ] + ) + ) + + return V + + +def _get_se3_V_input(log_rotation: torch.Tensor, eps: float = 1e-4): + """ + A helper function that computes the input variables to the `_se3_V_matrix` + function. + """ + nrms = (log_rotation ** 2).sum(-1) + rotation_angles = torch.clamp(nrms, eps).sqrt() + log_rotation_hat = hat(log_rotation) + log_rotation_hat_square = torch.bmm(log_rotation_hat, log_rotation_hat) + return log_rotation, log_rotation_hat, log_rotation_hat_square, rotation_angles diff --git a/pytorch3d/transforms/so3.py b/pytorch3d/transforms/so3.py index 750a47b84..5de8ee8e8 100644 --- a/pytorch3d/transforms/so3.py +++ b/pytorch3d/transforms/so3.py @@ -14,6 +14,7 @@ def so3_relative_angle( R2: torch.Tensor, cos_angle: bool = False, cos_bound: float = 1e-4, + eps: float = 1e-4, ) -> torch.Tensor: """ Calculates the relative angle (in radians) between pairs of @@ -33,7 +34,8 @@ def so3_relative_angle( of the `acos` call. Note that the non-finite outputs/gradients are returned when the angle is requested (i.e. `cos_angle==False`) and the rotation angle is close to 0 or π. - + eps: Tolerance for the valid trace check of the relative rotation matrix + in `so3_rotation_angle`. Returns: Corresponding rotation angles of shape `(minibatch,)`. If `cos_angle==True`, returns the cosine of the angles. @@ -43,7 +45,7 @@ def so3_relative_angle( ValueError if `R1` or `R2` has an unexpected trace. """ R12 = torch.bmm(R1, R2.permute(0, 2, 1)) - return so3_rotation_angle(R12, cos_angle=cos_angle, cos_bound=cos_bound) + return so3_rotation_angle(R12, cos_angle=cos_angle, cos_bound=cos_bound, eps=eps) def so3_rotation_angle( diff --git a/tests/bm_se3.py b/tests/bm_se3.py new file mode 100644 index 000000000..d6e92e200 --- /dev/null +++ b/tests/bm_se3.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from fvcore.common.benchmark import benchmark +from test_se3 import TestSE3 + + +def bm_se3() -> None: + kwargs_list = [ + {"batch_size": 1}, + {"batch_size": 10}, + {"batch_size": 100}, + {"batch_size": 1000}, + ] + benchmark(TestSE3.se3_expmap, "SE3_EXP", kwargs_list, warmup_iters=1) + benchmark(TestSE3.se3_logmap, "SE3_LOG", kwargs_list, warmup_iters=1) + + +if __name__ == "__main__": + bm_se3() diff --git a/tests/test_se3.py b/tests/test_se3.py new file mode 100644 index 000000000..0338e9b81 --- /dev/null +++ b/tests/test_se3.py @@ -0,0 +1,324 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + + +import unittest + +import numpy as np +import torch +from common_testing import TestCaseMixin +from pytorch3d.transforms.rotation_conversions import random_rotations +from pytorch3d.transforms.se3 import se3_exp_map, se3_log_map +from pytorch3d.transforms.so3 import ( + so3_exp_map, + so3_log_map, + so3_rotation_angle, +) + + +class TestSE3(TestCaseMixin, unittest.TestCase): + precomputed_log_transform = torch.tensor( + [ + [0.1900, 2.1600, -0.1700, 0.8500, -1.9200, 0.6500], + [-0.6500, -0.8200, 0.5300, -1.2800, -1.6600, -0.3000], + [-0.0900, 0.2000, -1.1200, 1.8600, -0.7100, 0.6900], + [0.8000, -0.0300, 1.4900, -0.5200, -0.2500, 1.4700], + [-0.3300, -1.1600, 2.3600, -0.6900, 0.1800, -1.1800], + [-1.8000, -1.5800, 0.8400, 1.4200, 0.6500, 0.4300], + [-1.5900, 0.6200, 1.6900, -0.6600, 0.9400, 0.0800], + [0.0800, -0.1400, 0.3300, -0.5900, -1.0700, 0.1000], + [-0.3300, -0.5300, -0.8800, 0.3900, 0.1600, -0.2000], + [1.0100, -1.3500, -0.3500, -0.6400, 0.4500, -0.5400], + ], + dtype=torch.float32, + ) + + precomputed_transform = torch.tensor( + [ + [ + [-0.3496, -0.2966, 0.8887, 0.0000], + [-0.7755, 0.6239, -0.0968, 0.0000], + [-0.5258, -0.7230, -0.4481, 0.0000], + [-0.7392, 1.9119, 0.3122, 1.0000], + ], + [ + [0.0354, 0.5992, 0.7998, 0.0000], + [0.8413, 0.4141, -0.3475, 0.0000], + [-0.5395, 0.6852, -0.4894, 0.0000], + [-0.9902, -0.4840, 0.1226, 1.0000], + ], + [ + [0.6664, -0.1679, 0.7264, 0.0000], + [-0.7309, -0.3394, 0.5921, 0.0000], + [0.1471, -0.9255, -0.3489, 0.0000], + [-0.0815, 0.8719, -0.4516, 1.0000], + ], + [ + [0.1010, 0.9834, -0.1508, 0.0000], + [-0.8783, 0.0169, -0.4779, 0.0000], + [-0.4674, 0.1807, 0.8654, 0.0000], + [0.2375, 0.7043, 1.4159, 1.0000], + ], + [ + [0.3935, -0.8930, 0.2184, 0.0000], + [0.7873, 0.2047, -0.5817, 0.0000], + [0.4747, 0.4009, 0.7836, 0.0000], + [-0.3476, -0.0424, 2.5408, 1.0000], + ], + [ + [0.7572, 0.6342, -0.1567, 0.0000], + [0.1039, 0.1199, 0.9873, 0.0000], + [0.6449, -0.7638, 0.0249, 0.0000], + [-1.2885, -2.0666, -0.1137, 1.0000], + ], + [ + [0.6020, -0.2140, -0.7693, 0.0000], + [-0.3409, 0.8024, -0.4899, 0.0000], + [0.7221, 0.5572, 0.4101, 0.0000], + [-0.7550, 1.1928, 1.8480, 1.0000], + ], + [ + [0.4913, 0.3548, 0.7954, 0.0000], + [0.2013, 0.8423, -0.5000, 0.0000], + [-0.8474, 0.4058, 0.3424, 0.0000], + [-0.1003, -0.0406, 0.3295, 1.0000], + ], + [ + [0.9678, -0.1622, -0.1926, 0.0000], + [0.2235, 0.9057, 0.3603, 0.0000], + [0.1160, -0.3917, 0.9128, 0.0000], + [-0.4417, -0.3111, -0.9227, 1.0000], + ], + [ + [0.7710, -0.5957, -0.2250, 0.0000], + [0.3288, 0.6750, -0.6605, 0.0000], + [0.5454, 0.4352, 0.7163, 0.0000], + [0.5623, -1.5886, -0.0182, 1.0000], + ], + ], + dtype=torch.float32, + ) + + def setUp(self) -> None: + super().setUp() + torch.manual_seed(42) + np.random.seed(42) + + @staticmethod + def init_log_transform(batch_size: int = 10): + """ + Initialize a list of `batch_size` 6-dimensional vectors representing + randomly generated logarithms of SE(3) transforms. + """ + device = torch.device("cuda:0") + log_rot = torch.randn((batch_size, 6), dtype=torch.float32, device=device) + return log_rot + + @staticmethod + def init_transform(batch_size: int = 10): + """ + Initialize a list of `batch_size` 4x4 SE(3) transforms. + """ + device = torch.device("cuda:0") + transform = torch.zeros(batch_size, 4, 4, dtype=torch.float32, device=device) + transform[:, :3, :3] = random_rotations( + batch_size, dtype=torch.float32, device=device + ) + transform[:, 3, :3] = torch.randn( + (batch_size, 3), dtype=torch.float32, device=device + ) + transform[:, 3, 3] = 1.0 + return transform + + def test_se3_exp_output_format(self, batch_size: int = 100): + """ + Check that the output of `se3_exp_map` is a valid SE3 matrix. + """ + transform = se3_exp_map(TestSE3.init_log_transform(batch_size=batch_size)) + R = transform[:, :3, :3] + T = transform[:, 3, :3] + rest = transform[:, :, 3] + Rdet = R.det() + + # check det(R)==1 + self.assertClose(Rdet, torch.ones_like(Rdet), atol=1e-4) + + # check that the translation is a finite vector + self.assertTrue(torch.isfinite(T).all()) + + # check last column == [0,0,0,1] + last_col = rest.new_zeros(batch_size, 4) + last_col[:, -1] = 1.0 + self.assertClose(rest, last_col) + + def test_compare_with_precomputed(self): + """ + Compare the outputs against precomputed results. + """ + self.assertClose( + se3_log_map(self.precomputed_transform), + self.precomputed_log_transform, + atol=1e-4, + ) + self.assertClose( + self.precomputed_transform, + se3_exp_map(self.precomputed_log_transform), + atol=1e-4, + ) + + def test_se3_exp_singularity(self, batch_size: int = 100): + """ + Tests whether the `se3_exp_map` is robust to the input vectors + with low L2 norms, where the algorithm is numerically unstable. + """ + # generate random log-rotations with a tiny angle + log_rot = TestSE3.init_log_transform(batch_size=batch_size) + log_rot_small = log_rot * 1e-6 + log_rot_small.requires_grad = True + transforms = se3_exp_map(log_rot_small) + # tests whether all outputs are finite + self.assertTrue(torch.isfinite(transforms).all()) + # tests whether all gradients are finite and not None + loss = transforms.sum() + loss.backward() + self.assertIsNotNone(log_rot_small.grad) + self.assertTrue(torch.isfinite(log_rot_small.grad).all()) + + def test_se3_log_singularity(self, batch_size: int = 100): + """ + Tests whether the `se3_log_map` is robust to the input matrices + whose rotation angles and translations are close to the numerically + unstable region (i.e. matrices with low rotation angles + and 0 translation). + """ + # generate random rotations with a tiny angle + device = torch.device("cuda:0") + identity = torch.eye(3, device=device) + rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device) + r = [identity, rot180] + r.extend( + [ + torch.qr(identity + torch.randn_like(identity) * 1e-6)[0] + + float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-8 + # this adds random noise to the second half + # of the random orthogonal matrices to generate + # near-orthogonal matrices + for i in range(batch_size - 2) + ] + ) + r = torch.stack(r) + # tiny translations + t = torch.randn(batch_size, 3, dtype=r.dtype, device=device) * 1e-6 + # create the transform matrix + transform = torch.zeros(batch_size, 4, 4, dtype=torch.float32, device=device) + transform[:, :3, :3] = r + transform[:, 3, :3] = t + transform[:, 3, 3] = 1.0 + transform.requires_grad = True + # the log of the transform + log_transform = se3_log_map(transform, eps=1e-4, cos_bound=1e-4) + # tests whether all outputs are finite + self.assertTrue(torch.isfinite(log_transform).all()) + # tests whether all gradients are finite and not None + loss = log_transform.sum() + loss.backward() + self.assertIsNotNone(transform.grad) + self.assertTrue(torch.isfinite(transform.grad).all()) + + def test_se3_exp_zero_translation(self, batch_size: int = 100): + """ + Check that `se3_exp_map` with zero translation gives + the same result as corresponding `so3_exp_map`. + """ + log_transform = TestSE3.init_log_transform(batch_size=batch_size) + log_transform[:, :3] *= 0.0 + transform = se3_exp_map(log_transform, eps=1e-8) + transform_so3 = so3_exp_map(log_transform[:, 3:], eps=1e-8) + self.assertClose( + transform[:, :3, :3], transform_so3.permute(0, 2, 1), atol=1e-4 + ) + self.assertClose( + transform[:, 3, :3], torch.zeros_like(transform[:, :3, 3]), atol=1e-4 + ) + + def test_se3_log_zero_translation(self, batch_size: int = 100): + """ + Check that `se3_log_map` with zero translation gives + the same result as corresponding `so3_log_map`. + """ + transform = TestSE3.init_transform(batch_size=batch_size) + transform[:, 3, :3] *= 0.0 + log_transform = se3_log_map(transform, eps=1e-8, cos_bound=1e-4) + log_transform_so3 = so3_log_map(transform[:, :3, :3], eps=1e-8, cos_bound=1e-4) + self.assertClose(log_transform[:, 3:], -log_transform_so3, atol=1e-4) + self.assertClose( + log_transform[:, :3], torch.zeros_like(log_transform[:, :3]), atol=1e-4 + ) + + def test_se3_exp_to_log_to_exp(self, batch_size: int = 10000): + """ + Check that `se3_exp_map(se3_log_map(A))==A` for + a batch of randomly generated SE(3) matrices `A`. + """ + transform = TestSE3.init_transform(batch_size=batch_size) + # Limit test transforms to those not around the singularity where + # the rotation angle~=pi. + nonsingular = so3_rotation_angle(transform[:, :3, :3]) < 3.134 + transform = transform[nonsingular] + transform_ = se3_exp_map( + se3_log_map(transform, eps=1e-8, cos_bound=0.0), eps=1e-8 + ) + self.assertClose(transform, transform_, atol=0.02) + + def test_se3_log_to_exp_to_log(self, batch_size: int = 100): + """ + Check that `se3_log_map(se3_exp_map(log_transform))==log_transform` + for a randomly generated batch of SE(3) matrix logarithms `log_transform`. + """ + log_transform = TestSE3.init_log_transform(batch_size=batch_size) + log_transform_ = se3_log_map(se3_exp_map(log_transform, eps=1e-8), eps=1e-8) + self.assertClose(log_transform, log_transform_, atol=1e-1) + + def test_bad_se3_input_value_err(self): + """ + Tests whether `se3_exp_map` and `se3_log_map` correctly return + a ValueError if called with an argument of incorrect shape, or with + an tensor containing illegal values. + """ + device = torch.device("cuda:0") + + for size in ([5, 4], [3, 4, 5], [3, 5, 6]): + log_transform = torch.randn(size=size, device=device) + with self.assertRaises(ValueError): + se3_exp_map(log_transform) + + for size in ([5, 4], [3, 4, 5], [3, 5, 6], [2, 2, 3, 4]): + transform = torch.randn(size=size, device=device) + with self.assertRaises(ValueError): + se3_log_map(transform) + + # Test the case where transform[:, :, :3] != 0. + transform = torch.rand(size=[5, 4, 4], device=device) + 0.1 + with self.assertRaises(ValueError): + se3_log_map(transform) + + @staticmethod + def se3_expmap(batch_size: int = 10): + log_transform = TestSE3.init_log_transform(batch_size=batch_size) + torch.cuda.synchronize() + + def compute_transforms(): + se3_exp_map(log_transform) + torch.cuda.synchronize() + + return compute_transforms + + @staticmethod + def se3_logmap(batch_size: int = 10): + log_transform = TestSE3.init_transform(batch_size=batch_size) + torch.cuda.synchronize() + + def compute_logs(): + se3_log_map(log_transform) + torch.cuda.synchronize() + + return compute_logs