From 31ae0c6cede6c08245c8ca13cc54a79612b5a7a3 Mon Sep 17 00:00:00 2001 From: Hampus Serneke Date: Mon, 4 May 2026 16:15:03 +0200 Subject: [PATCH 1/2] Add roma.rotmat_geodesic_distance_pi_stable --- docsource/source/index.rst | 4 ++ roma/utils.py | 76 ++++++++++++++++++++++++++++++++++++++ test/test_utils.py | 26 +++++++++++++ 3 files changed, 106 insertions(+) diff --git a/docsource/source/index.rst b/docsource/source/index.rst index e23d4ec..c103c0e 100644 --- a/docsource/source/index.rst +++ b/docsource/source/index.rst @@ -175,6 +175,10 @@ Care for numerical precision Backward pass through :func:`~roma.utils.rotmat_geodesic_distance_naive` leads to unstable gradient estimations and produces *Not-a-Number* values for small angles, whereas :func:`~roma.utils.rotmat_geodesic_distance_naive` is well-behaved, and returns *Not-a-Number* only for 0.0 angle where gradient is mathematically undefined. + However, both :func:`~roma.utils.rotmat_geodesic_distance_naive` and :func:`~roma.utils.rotmat_geodesic_distance` have low numerical precision around :math:`\theta=\pi`, in the same + way that :func:`~roma.utils.rotmat_geodesic_distance_naive` is unprecise around :math:`\theta=0`. For this reason, we also provide the function :func:`~roma.utils.rotmat_geodesic_distance_pi_stable` + which is numerically precise at :math:`\theta=\pi` too (i.e. precise for all :math:`\theta`), at the price of slower runtime performance around :math:`\theta=\pi`. + .. image:: rotmat_geodesic_distance_zero.svg .. image:: rotmat_geodesic_distance_grads_zero.svg diff --git a/roma/utils.py b/roma/utils.py index 7217f60..28a45da 100644 --- a/roma/utils.py +++ b/roma/utils.py @@ -164,6 +164,82 @@ def rotmat_geodesic_distance_naive(R1, R2): cos = rotmat_cosine_angle(R) return torch.acos(torch.clamp(cos, -1.0, 1.0)) +def _rotmat_geodesic_distance_atan2(R1, R2): + r""" + Returns the angular distance alpha between a pair of rotation matrices. + Based on the equalities :math:`\alpha=atan2(sin(\alpha), cos(\alpha))`, + :math:`sin(\alpha)=\frac{1}{2}|(R_{21}-R_{12}, R_{02}-R_{20}, R_{10}-R_{01})|_2` and :math:`cos(\alpha)=\frac{1}{2}(Trace(R)-1)`. + + More precise than :func:`~roma.utils.rotmat_geodesic_distance` at :math:`\alpha=\pi`, while still precise for nearby rotations (:math:`\alpha=0`). + + .. warning:: + This function is significantly slower than :func:`~roma.utils.rotmat_geodesic_distance` + (approximately 3x slower on CUDA and 10x slower on CPU in this implementation, which is optimized for CUDA). + To remedy this, the function :func:`~roma.utils.rotmat_geodesic_distance_pi_stable` is preferred in most cases, + because it first calculates :math:`\alpha` using :func:`~roma.utils.rotmat_geodesic_distance` and then recalculates + :math:`\alpha` using this function if :math:`\alpha` is too close to :math:`\pi`. + + Args: + R1, R2 (...x3x3 tensor): batch of 3x3 rotation matrices. + Returns: + batch of angles in radians (... tensor). + """ + + # Equivalent to R = R2 @ R1.transpose(-1, -2), but faster when using CUDA. + # For a more CPU-friendly method, use R = R2 @ R1.transpose(-1, -2) instead. + r00 = (R2[..., 0, :] * R1[..., 0, :]).sum(dim=-1) + r11 = (R2[..., 1, :] * R1[..., 1, :]).sum(dim=-1) + r22 = (R2[..., 2, :] * R1[..., 2, :]).sum(dim=-1) + r21 = (R2[..., 2, :] * R1[..., 1, :]).sum(dim=-1) + r12 = (R2[..., 1, :] * R1[..., 2, :]).sum(dim=-1) + r02 = (R2[..., 0, :] * R1[..., 2, :]).sum(dim=-1) + r20 = (R2[..., 2, :] * R1[..., 0, :]).sum(dim=-1) + r10 = (R2[..., 1, :] * R1[..., 0, :]).sum(dim=-1) + r01 = (R2[..., 0, :] * R1[..., 1, :]).sum(dim=-1) + + x = r21 - r12 + y = r02 - r20 + z = r10 - r01 + + sin = torch.sqrt(x * x + y * y + z * z) + cos = r00 + r11 + r22 - 1.0 + + return torch.atan2(sin, cos) + +def rotmat_geodesic_distance_pi_stable(R1, R2, tol=1e-2): + r""" + Returns the angular distance alpha between a pair of rotation matrices. + Based on the equalities :math:`\alpha=atan2(sin(\alpha), cos(\alpha))`, + :math:`sin(\alpha)=\frac{1}{2}|(R_{21}-R_{12}, R_{02}-R_{20}, R_{10}-R_{01})|_2` and :math:`cos(\alpha)=\frac{1}{2}(Trace(R)-1)`. + + More precise than :func:`~roma.utils.rotmat_geodesic_distance` at :math:`\alpha=\pi`, while still precise for nearby rotations where :math:`\alpha=0`. + + For performance reasons, this method first calls :func:`~roma.utils.rotmat_geodesic_distance`. If the resulting :math:`\alpha` is too close + to :math:`\alpha=\pi` given the tolerance tol, i.e. if :math:`\alpha \in (\pi-tol, \pi+tol)`, it recalculates the value of :math:`\alpha` + using a more precise, but slower, method using :math:`\alpha=atan2(sin(\alpha), cos(\alpha))`. + Therefore, for almost all :math:`\alpha` this function produces the same result as :func:`~roma.utils.rotmat_geodesic_distance` + with only a minimal overhead cost in runtime performance. However, for :math:`\alpha` near :math:`\pi`, this function produces + more precise results, but it is approximately 3x slower using CUDA tensors and 10x slower using CPU tensors + for these :math:`\alpha`. + + Args: + R1, R2 (...x3x3 tensor): batch of 3x3 rotation matrices. + Returns: + batch of angles in radians (... tensor). + """ + alpha = rotmat_geodesic_distance(R1, R2) + + mask = (alpha > math.pi - tol) & (alpha < math.pi + tol) + + if not mask.any().item(): + return alpha + + if mask.all().item(): + return _rotmat_geodesic_distance_atan2(R1, R2) + + alpha = alpha.clone() + alpha[mask] = _rotmat_geodesic_distance_atan2(R1[mask], R2[mask]) + return alpha def unitquat_geodesic_distance(q1, q2): r""" diff --git a/test/test_utils.py b/test/test_utils.py index 7b015d0..71c0359 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -6,6 +6,7 @@ import numpy as np import roma import roma.internal +from roma import utils from test.utils import is_close class TestUtils(unittest.TestCase): @@ -62,6 +63,31 @@ def test_other_geodesic_distance(self): alpha_rotvec = roma.rotvec_geodesic_distance(rotvec1, rotvec2) self.assertTrue(is_close(alpha_rotvec, alpha_q)) + def test_rotmat_geodesic_distance_atan2(self): + # The rotmat_geodesic_distance_pi_stable also passes this test, + # but _rotmat_geodesic_distance_atan2 is its internal helper function + # which has the strongest preciseness guarantess. + + # This test also passes with atol=1e-6 with all possible eps_end. + # roma.rotmat_geodesic_distance notably does not pass this test, + # as it stops working at eps_end=1e-2 for atol=1e-5. + + batch_size = 100 + eps_end = 1e-5 + eps_start = eps_end * 10 + atol = 1e-5 + + for dtype in (torch.float32, torch.float64): + axis = torch.nn.functional.normalize(torch.randn((batch_size,3), dtype=dtype), dim=-1) + alpha = torch.linspace(np.pi - eps_start, np.pi - eps_end, batch_size, dtype=dtype) + x = alpha[:,None] * axis + R = roma.rotvec_to_rotmat(x) + I = torch.eye(3, dtype=dtype) + + geo_dist = roma.utils._rotmat_geodesic_distance_atan2(R, I[None,:,:]) + self.assertTrue(torch.all(torch.abs(geo_dist - alpha) < atol), + msg=f"rotmat_geodesic_distance_pi_stable failed near pi") + def test_identity_quat(self): q = roma.identity_quat() self.assertTrue(q.shape == (4,)) From 15c99dba6b75865c3418bc15916cc2b7294aa197 Mon Sep 17 00:00:00 2001 From: Hampus Serneke Date: Mon, 4 May 2026 16:24:11 +0200 Subject: [PATCH 2/2] Add and refactor test_rotmat_geodesic_distance Include the same tests for rotmat_geodesic_distance_pi_stable as rotmat_geodesic_distance. Note that this refactoring also tests the naive method for right invariance and the ordinary test; previously it was only tested for left invariance. --- test/test_utils.py | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 71c0359..b74bba8 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -6,7 +6,6 @@ import numpy as np import roma import roma.internal -from roma import utils from test.utils import is_close class TestUtils(unittest.TestCase): @@ -26,26 +25,31 @@ def test_rotmat_geodesic_distance(self): alpha = (np.pi - 1e-5) * 2 * (torch.rand(batch_size, dtype=dtype)-0.5) x = alpha[:,None] * axis R = roma.rotvec_to_rotmat(x) - + cosine = roma.rotmat_cosine_angle(R) self.assertTrue(is_close(cosine, torch.cos(alpha))) - + I = torch.eye(3, dtype=dtype) M = roma.random_rotmat(batch_size, dtype=dtype) - - geo_dist = roma.rotmat_geodesic_distance(R, I[None,:,:]) - self.assertTrue(is_close(torch.abs(alpha), geo_dist)) - - # Left-invariance of the metric - geo_dist_bis = roma.rotmat_geodesic_distance(M @ R, M @ I[None,:,:]) - self.assertTrue(is_close(geo_dist_bis, geo_dist)) - - # Right-invariance of the metric - geo_dist_ter = roma.rotmat_geodesic_distance(R @ M, I[None,:,:] @ M) - self.assertTrue(is_close(geo_dist_ter, geo_dist)) - - geo_dist_naive = roma.rotmat_geodesic_distance_naive(M @ R, M @ I[None,:,:]) - self.assertTrue(is_close(torch.abs(alpha), geo_dist_naive)) + + for geodesic_distance_function in [roma.rotmat_geodesic_distance, + roma.rotmat_geodesic_distance_naive, + roma.rotmat_geodesic_distance_pi_stable, + roma.utils._rotmat_geodesic_distance_atan2]: + + geo_dist = geodesic_distance_function(R, I[None,:,:]) + self.assertTrue(is_close(torch.abs(alpha), geo_dist), + msg=f"{geodesic_distance_function.__name__} failed geodesic distance function") + + # Left-invariance of the metric + geo_dist_bis = geodesic_distance_function(M @ R, M @ I[None,:,:]) + self.assertTrue(is_close(geo_dist_bis, geo_dist), + msg=f"{geodesic_distance_function.__name__} failed left-invariance for the geodesic distance function") + + # Right-invariance of the metric + geo_dist_ter = geodesic_distance_function(R @ M, I[None,:,:] @ M) + self.assertTrue(is_close(geo_dist_ter, geo_dist), + msg=f"{geodesic_distance_function.__name__} failed right-invariance for the geodesic distance function") def test_other_geodesic_distance(self): batch_size = 100