-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Implements the SE3 logarithm and exponential maps. (this is a second part of the split of D23326429) Outputs of `bm_se3`: ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- SE3_EXP_1 738 885 678 SE3_EXP_10 717 877 698 SE3_EXP_100 718 847 697 SE3_EXP_1000 729 1181 686 -------------------------------------------------------------------------------- Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- SE3_LOG_1 1451 2267 345 SE3_LOG_10 2185 2453 229 SE3_LOG_100 2217 2448 226 SE3_LOG_1000 2455 2599 204 -------------------------------------------------------------------------------- ``` Reviewed By: patricklabatut Differential Revision: D27852557 fbshipit-source-id: e42ccc9cfffe780e9cad24129de15624ae818472
- Loading branch information
1 parent
9f14e82
commit b2ac265
Showing
5 changed files
with
561 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
|
||
import torch | ||
|
||
from .so3 import hat, _so3_exp_map, so3_log_map | ||
|
||
|
||
def se3_exp_map(log_transform: torch.Tensor, eps: float = 1e-4) -> torch.Tensor: | ||
""" | ||
Convert a batch of logarithmic representations of SE(3) matrices `log_transform` | ||
to a batch of 4x4 SE(3) matrices using the exponential map. | ||
See e.g. [1], Sec 9.4.2. for more detailed description. | ||
A SE(3) matrix has the following form: | ||
``` | ||
[ R 0 ] | ||
[ T 1 ] , | ||
``` | ||
where `R` is a 3x3 rotation matrix and `T` is a 3-D translation vector. | ||
SE(3) matrices are commonly used to represent rigid motions or camera extrinsics. | ||
In the SE(3) logarithmic representation SE(3) matrices are | ||
represented as 6-dimensional vectors `[log_translation | log_rotation]`, | ||
i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`. | ||
The conversion from the 6D representation to a 4x4 SE(3) matrix `transform` | ||
is done as follows: | ||
``` | ||
transform = exp( [ hat(log_rotation) 0 ] | ||
[ log_translation 1 ] ) , | ||
``` | ||
where `exp` is the matrix exponential and `hat` is the Hat operator [2]. | ||
Note that for any `log_transform` with `0 <= ||log_rotation|| < 2pi` | ||
(i.e. the rotation angle is between 0 and 2pi), the following identity holds: | ||
``` | ||
se3_log_map(se3_exponential_map(log_transform)) == log_transform | ||
``` | ||
The conversion has a singularity around `||log(transform)|| = 0` | ||
which is handled by clamping controlled with the `eps` argument. | ||
Args: | ||
log_transform: Batch of vectors of shape `(minibatch, 6)`. | ||
eps: A threshold for clipping the squared norm of the rotation logarithm | ||
to avoid unstable gradients in the singular case. | ||
Returns: | ||
Batch of transformation matrices of shape `(minibatch, 4, 4)`. | ||
Raises: | ||
ValueError if `log_transform` is of incorrect shape. | ||
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf | ||
[2] https://en.wikipedia.org/wiki/Hat_operator | ||
""" | ||
|
||
if log_transform.ndim != 2 or log_transform.shape[1] != 6: | ||
raise ValueError("Expected input to be of shape (N, 6).") | ||
|
||
N, _ = log_transform.shape | ||
|
||
log_translation = log_transform[..., :3] | ||
log_rotation = log_transform[..., 3:] | ||
|
||
# rotation is an exponential map of log_rotation | ||
( | ||
R, | ||
rotation_angles, | ||
log_rotation_hat, | ||
log_rotation_hat_square, | ||
) = _so3_exp_map(log_rotation, eps=eps) | ||
|
||
# translation is V @ T | ||
V = _se3_V_matrix( | ||
log_rotation, | ||
log_rotation_hat, | ||
log_rotation_hat_square, | ||
rotation_angles, | ||
eps=eps, | ||
) | ||
T = torch.bmm(V, log_translation[:, :, None])[:, :, 0] | ||
|
||
transform = torch.zeros( | ||
N, 4, 4, dtype=log_transform.dtype, device=log_transform.device | ||
) | ||
|
||
transform[:, :3, :3] = R | ||
transform[:, :3, 3] = T | ||
transform[:, 3, 3] = 1.0 | ||
|
||
return transform.permute(0, 2, 1) | ||
|
||
|
||
def se3_log_map( | ||
transform: torch.Tensor, eps: float = 1e-4, cos_bound: float = 1e-4 | ||
) -> torch.Tensor: | ||
""" | ||
Convert a batch of 4x4 transformation matrices `transform` | ||
to a batch of 6-dimensional SE(3) logarithms of the SE(3) matrices. | ||
See e.g. [1], Sec 9.4.2. for more detailed description. | ||
A SE(3) matrix has the following form: | ||
``` | ||
[ R 0 ] | ||
[ T 1 ] , | ||
``` | ||
where `R` is an orthonormal 3x3 rotation matrix and `T` is a 3-D translation vector. | ||
SE(3) matrices are commonly used to represent rigid motions or camera extrinsics. | ||
In the SE(3) logarithmic representation SE(3) matrices are | ||
represented as 6-dimensional vectors `[log_translation | log_rotation]`, | ||
i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`. | ||
The conversion from the 4x4 SE(3) matrix `transform` to the | ||
6D representation `log_transform = [log_translation | log_rotation]` | ||
is done as follows: | ||
``` | ||
log_transform = log(transform) | ||
log_translation = log_transform[3, :3] | ||
log_rotation = inv_hat(log_transform[:3, :3]) | ||
``` | ||
where `log` is the matrix logarithm | ||
and `inv_hat` is the inverse of the Hat operator [2]. | ||
Note that for any valid 4x4 `transform` matrix, the following identity holds: | ||
``` | ||
se3_exp_map(se3_log_map(transform)) == transform | ||
``` | ||
The conversion has a singularity around `(transform=I)` which is handled | ||
by clamping controlled with the `eps` and `cos_bound` arguments. | ||
Args: | ||
transform: batch of SE(3) matrices of shape `(minibatch, 4, 4)`. | ||
eps: A threshold for clipping the squared norm of the rotation logarithm | ||
to avoid division by zero in the singular case. | ||
cos_bound: Clamps the cosine of the rotation angle to | ||
[-1 + cos_bound, 3 - cos_bound] to avoid non-finite outputs. | ||
The non-finite outputs can be caused by passing small rotation angles | ||
to the `acos` function in `so3_rotation_angle` of `so3_log_map`. | ||
Returns: | ||
Batch of logarithms of input SE(3) matrices | ||
of shape `(minibatch, 6)`. | ||
Raises: | ||
ValueError if `transform` is of incorrect shape. | ||
ValueError if `R` has an unexpected trace. | ||
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf | ||
[2] https://en.wikipedia.org/wiki/Hat_operator | ||
""" | ||
|
||
if transform.ndim != 3: | ||
raise ValueError("Input tensor shape has to be (N, 4, 4).") | ||
|
||
N, dim1, dim2 = transform.shape | ||
if dim1 != 4 or dim2 != 4: | ||
raise ValueError("Input tensor shape has to be (N, 4, 4).") | ||
|
||
if not torch.allclose(transform[:, :3, 3], torch.zeros_like(transform[:, :3, 3])): | ||
raise ValueError("All elements of `transform[:, :3, 3]` should be 0.") | ||
|
||
# log_rot is just so3_log_map of the upper left 3x3 block | ||
R = transform[:, :3, :3].permute(0, 2, 1) | ||
log_rotation = so3_log_map(R, eps=eps, cos_bound=cos_bound) | ||
|
||
# log_translation is V^-1 @ T | ||
T = transform[:, 3, :3] | ||
V = _se3_V_matrix(*_get_se3_V_input(log_rotation), eps=eps) | ||
log_translation = torch.linalg.solve(V, T[:, :, None])[:, :, 0] | ||
|
||
return torch.cat((log_translation, log_rotation), dim=1) | ||
|
||
|
||
def _se3_V_matrix( | ||
log_rotation: torch.Tensor, | ||
log_rotation_hat: torch.Tensor, | ||
log_rotation_hat_square: torch.Tensor, | ||
rotation_angles: torch.Tensor, | ||
eps: float = 1e-4, | ||
) -> torch.Tensor: | ||
""" | ||
A helper function that computes the "V" matrix from [1], Sec 9.4.2. | ||
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf | ||
""" | ||
|
||
V = ( | ||
torch.eye(3, dtype=log_rotation.dtype, device=log_rotation.device)[None] | ||
+ log_rotation_hat | ||
* ((1 - torch.cos(rotation_angles)) / (rotation_angles ** 2))[:, None, None] | ||
+ ( | ||
log_rotation_hat_square | ||
* ((rotation_angles - torch.sin(rotation_angles)) / (rotation_angles ** 3))[ | ||
:, None, None | ||
] | ||
) | ||
) | ||
|
||
return V | ||
|
||
|
||
def _get_se3_V_input(log_rotation: torch.Tensor, eps: float = 1e-4): | ||
""" | ||
A helper function that computes the input variables to the `_se3_V_matrix` | ||
function. | ||
""" | ||
nrms = (log_rotation ** 2).sum(-1) | ||
rotation_angles = torch.clamp(nrms, eps).sqrt() | ||
log_rotation_hat = hat(log_rotation) | ||
log_rotation_hat_square = torch.bmm(log_rotation_hat, log_rotation_hat) | ||
return log_rotation, log_rotation_hat, log_rotation_hat_square, rotation_angles |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
|
||
from fvcore.common.benchmark import benchmark | ||
from test_se3 import TestSE3 | ||
|
||
|
||
def bm_se3() -> None: | ||
kwargs_list = [ | ||
{"batch_size": 1}, | ||
{"batch_size": 10}, | ||
{"batch_size": 100}, | ||
{"batch_size": 1000}, | ||
] | ||
benchmark(TestSE3.se3_expmap, "SE3_EXP", kwargs_list, warmup_iters=1) | ||
benchmark(TestSE3.se3_logmap, "SE3_LOG", kwargs_list, warmup_iters=1) | ||
|
||
|
||
if __name__ == "__main__": | ||
bm_se3() |
Oops, something went wrong.