Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docsource/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ Care for numerical precision
Backward pass through :func:`~roma.utils.rotmat_geodesic_distance_naive` leads to unstable gradient estimations and produces *Not-a-Number* values for small angles,
whereas :func:`~roma.utils.rotmat_geodesic_distance_naive` is well-behaved, and returns *Not-a-Number* only for 0.0 angle where gradient is mathematically undefined.

However, both :func:`~roma.utils.rotmat_geodesic_distance_naive` and :func:`~roma.utils.rotmat_geodesic_distance` have low numerical precision around :math:`\theta=\pi`, in the same
way that :func:`~roma.utils.rotmat_geodesic_distance_naive` is unprecise around :math:`\theta=0`. For this reason, we also provide the function :func:`~roma.utils.rotmat_geodesic_distance_pi_stable`
which is numerically precise at :math:`\theta=\pi` too (i.e. precise for all :math:`\theta`), at the price of slower runtime performance around :math:`\theta=\pi`.

.. image:: rotmat_geodesic_distance_zero.svg

.. image:: rotmat_geodesic_distance_grads_zero.svg
Expand Down
76 changes: 76 additions & 0 deletions roma/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,82 @@ def rotmat_geodesic_distance_naive(R1, R2):
cos = rotmat_cosine_angle(R)
return torch.acos(torch.clamp(cos, -1.0, 1.0))

def _rotmat_geodesic_distance_atan2(R1, R2):
r"""
Returns the angular distance alpha between a pair of rotation matrices.
Based on the equalities :math:`\alpha=atan2(sin(\alpha), cos(\alpha))`,
:math:`sin(\alpha)=\frac{1}{2}|(R_{21}-R_{12}, R_{02}-R_{20}, R_{10}-R_{01})|_2` and :math:`cos(\alpha)=\frac{1}{2}(Trace(R)-1)`.

More precise than :func:`~roma.utils.rotmat_geodesic_distance` at :math:`\alpha=\pi`, while still precise for nearby rotations (:math:`\alpha=0`).

.. warning::
This function is significantly slower than :func:`~roma.utils.rotmat_geodesic_distance`
(approximately 3x slower on CUDA and 10x slower on CPU in this implementation, which is optimized for CUDA).
To remedy this, the function :func:`~roma.utils.rotmat_geodesic_distance_pi_stable` is preferred in most cases,
because it first calculates :math:`\alpha` using :func:`~roma.utils.rotmat_geodesic_distance` and then recalculates
:math:`\alpha` using this function if :math:`\alpha` is too close to :math:`\pi`.

Args:
R1, R2 (...x3x3 tensor): batch of 3x3 rotation matrices.
Returns:
batch of angles in radians (... tensor).
"""

# Equivalent to R = R2 @ R1.transpose(-1, -2), but faster when using CUDA.
# For a more CPU-friendly method, use R = R2 @ R1.transpose(-1, -2) instead.
r00 = (R2[..., 0, :] * R1[..., 0, :]).sum(dim=-1)
r11 = (R2[..., 1, :] * R1[..., 1, :]).sum(dim=-1)
r22 = (R2[..., 2, :] * R1[..., 2, :]).sum(dim=-1)
r21 = (R2[..., 2, :] * R1[..., 1, :]).sum(dim=-1)
r12 = (R2[..., 1, :] * R1[..., 2, :]).sum(dim=-1)
r02 = (R2[..., 0, :] * R1[..., 2, :]).sum(dim=-1)
r20 = (R2[..., 2, :] * R1[..., 0, :]).sum(dim=-1)
r10 = (R2[..., 1, :] * R1[..., 0, :]).sum(dim=-1)
r01 = (R2[..., 0, :] * R1[..., 1, :]).sum(dim=-1)

x = r21 - r12
y = r02 - r20
z = r10 - r01

sin = torch.sqrt(x * x + y * y + z * z)
cos = r00 + r11 + r22 - 1.0

return torch.atan2(sin, cos)

def rotmat_geodesic_distance_pi_stable(R1, R2, tol=1e-2):
r"""
Returns the angular distance alpha between a pair of rotation matrices.
Based on the equalities :math:`\alpha=atan2(sin(\alpha), cos(\alpha))`,
:math:`sin(\alpha)=\frac{1}{2}|(R_{21}-R_{12}, R_{02}-R_{20}, R_{10}-R_{01})|_2` and :math:`cos(\alpha)=\frac{1}{2}(Trace(R)-1)`.

More precise than :func:`~roma.utils.rotmat_geodesic_distance` at :math:`\alpha=\pi`, while still precise for nearby rotations where :math:`\alpha=0`.

For performance reasons, this method first calls :func:`~roma.utils.rotmat_geodesic_distance`. If the resulting :math:`\alpha` is too close
to :math:`\alpha=\pi` given the tolerance tol, i.e. if :math:`\alpha \in (\pi-tol, \pi+tol)`, it recalculates the value of :math:`\alpha`
using a more precise, but slower, method using :math:`\alpha=atan2(sin(\alpha), cos(\alpha))`.
Therefore, for almost all :math:`\alpha` this function produces the same result as :func:`~roma.utils.rotmat_geodesic_distance`
with only a minimal overhead cost in runtime performance. However, for :math:`\alpha` near :math:`\pi`, this function produces
more precise results, but it is approximately 3x slower using CUDA tensors and 10x slower using CPU tensors
for these :math:`\alpha`.

Args:
R1, R2 (...x3x3 tensor): batch of 3x3 rotation matrices.
Returns:
batch of angles in radians (... tensor).
"""
alpha = rotmat_geodesic_distance(R1, R2)

mask = (alpha > math.pi - tol) & (alpha < math.pi + tol)

if not mask.any().item():
return alpha

if mask.all().item():
return _rotmat_geodesic_distance_atan2(R1, R2)

alpha = alpha.clone()
alpha[mask] = _rotmat_geodesic_distance_atan2(R1[mask], R2[mask])
return alpha

def unitquat_geodesic_distance(q1, q2):
r"""
Expand Down
62 changes: 46 additions & 16 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,31 @@ def test_rotmat_geodesic_distance(self):
alpha = (np.pi - 1e-5) * 2 * (torch.rand(batch_size, dtype=dtype)-0.5)
x = alpha[:,None] * axis
R = roma.rotvec_to_rotmat(x)

cosine = roma.rotmat_cosine_angle(R)
self.assertTrue(is_close(cosine, torch.cos(alpha)))

I = torch.eye(3, dtype=dtype)
M = roma.random_rotmat(batch_size, dtype=dtype)

geo_dist = roma.rotmat_geodesic_distance(R, I[None,:,:])
self.assertTrue(is_close(torch.abs(alpha), geo_dist))

# Left-invariance of the metric
geo_dist_bis = roma.rotmat_geodesic_distance(M @ R, M @ I[None,:,:])
self.assertTrue(is_close(geo_dist_bis, geo_dist))

# Right-invariance of the metric
geo_dist_ter = roma.rotmat_geodesic_distance(R @ M, I[None,:,:] @ M)
self.assertTrue(is_close(geo_dist_ter, geo_dist))

geo_dist_naive = roma.rotmat_geodesic_distance_naive(M @ R, M @ I[None,:,:])
self.assertTrue(is_close(torch.abs(alpha), geo_dist_naive))

for geodesic_distance_function in [roma.rotmat_geodesic_distance,
roma.rotmat_geodesic_distance_naive,
roma.rotmat_geodesic_distance_pi_stable,
roma.utils._rotmat_geodesic_distance_atan2]:

geo_dist = geodesic_distance_function(R, I[None,:,:])
self.assertTrue(is_close(torch.abs(alpha), geo_dist),
msg=f"{geodesic_distance_function.__name__} failed geodesic distance function")

# Left-invariance of the metric
geo_dist_bis = geodesic_distance_function(M @ R, M @ I[None,:,:])
self.assertTrue(is_close(geo_dist_bis, geo_dist),
msg=f"{geodesic_distance_function.__name__} failed left-invariance for the geodesic distance function")

# Right-invariance of the metric
geo_dist_ter = geodesic_distance_function(R @ M, I[None,:,:] @ M)
self.assertTrue(is_close(geo_dist_ter, geo_dist),
msg=f"{geodesic_distance_function.__name__} failed right-invariance for the geodesic distance function")

def test_other_geodesic_distance(self):
batch_size = 100
Expand All @@ -62,6 +67,31 @@ def test_other_geodesic_distance(self):
alpha_rotvec = roma.rotvec_geodesic_distance(rotvec1, rotvec2)
self.assertTrue(is_close(alpha_rotvec, alpha_q))

def test_rotmat_geodesic_distance_atan2(self):
# The rotmat_geodesic_distance_pi_stable also passes this test,
# but _rotmat_geodesic_distance_atan2 is its internal helper function
# which has the strongest preciseness guarantess.

# This test also passes with atol=1e-6 with all possible eps_end.
# roma.rotmat_geodesic_distance notably does not pass this test,
# as it stops working at eps_end=1e-2 for atol=1e-5.

batch_size = 100
eps_end = 1e-5
eps_start = eps_end * 10
atol = 1e-5

for dtype in (torch.float32, torch.float64):
axis = torch.nn.functional.normalize(torch.randn((batch_size,3), dtype=dtype), dim=-1)
alpha = torch.linspace(np.pi - eps_start, np.pi - eps_end, batch_size, dtype=dtype)
x = alpha[:,None] * axis
R = roma.rotvec_to_rotmat(x)
I = torch.eye(3, dtype=dtype)

geo_dist = roma.utils._rotmat_geodesic_distance_atan2(R, I[None,:,:])
self.assertTrue(torch.all(torch.abs(geo_dist - alpha) < atol),
msg=f"rotmat_geodesic_distance_pi_stable failed near pi")

def test_identity_quat(self):
q = roma.identity_quat()
self.assertTrue(q.shape == (4,))
Expand Down