Source code for rofunc.utils.robolab.coord.transform_tensor

#  Copyright (C) 2024, Junjia Liu
# 
#  This file is part of Rofunc.
# 
#  Rofunc is licensed under the GNU General Public License v3.0.
#  You may use, distribute, and modify this code under the terms of the GPL-3.0.
# 
#  Additional Terms for Commercial Use:
#  Commercial use requires sharing 50% of net profits with the copyright holder.
#  Financial reports and regular payments must be provided as agreed in writing.
#  Non-compliance results in revocation of commercial rights.
# 
#  For more details, see <https://www.gnu.org/licenses/>.
#  Contact: skylark0924@gmail.com

"""
Coordinate transformation functions with tensor support
----------------------------------------------------------
| This module provides functions to convert between different coordinate systems with tensor support.
| Note

1. Quaternions ix+jy+kz+w are represented as [x, y, z, w].
2. Euler angles are represented as [roll, pitch, yaw], in radians. The rotation order is ZYX.
3. Rotation matrices are represented as (3, 3).
4. Homogeneous matrices are represented as (4, 4).
"""

import numpy as np
import torch

# epsilon for testing whether a number is close to zero
_EPS = torch.finfo(torch.float32).eps * 4.0


[docs]def check_pos_tensor(pos): """ Check if the input position is valid. :param pos: (batch, 3) or (3, ) :return: position >>> check_pos_tensor([0, 0, 0]) tensor([[0., 0., 0.]]) >>> check_pos_tensor([[0, 0, 0]]) tensor([[0., 0., 0.]]) >>> check_pos_tensor(np.array([0, 0, 0])) tensor([[0., 0., 0.]]) """ pos = torch.tensor(pos, dtype=torch.float32) if len(pos.shape) == 1: pos = pos.unsqueeze(0) assert pos.shape[-1] == 3, "The last dimension of the input tensor should be 3." return pos
[docs]def check_quat_tensor(quat): """ Check if the input quat is normalized. :param quat: (batch, 4) or (4, ) :return: normalized quat >>> check_quat_tensor([0, 5, 0, 1]) tensor([[0.0000, 0.9806, 0.0000, 0.1961]]) >>> check_quat_tensor([[0, 2, 0, 1]]) tensor([[0.0000, 0.8944, 0.0000, 0.4472]]) >>> check_quat_tensor(np.array([1, 5, 5.435, 1])) tensor([[0.1330, 0.6650, 0.7228, 0.1330]]) """ quat = torch.tensor(quat, dtype=torch.float32) if len(quat.shape) == 1: quat = quat.unsqueeze(0) assert quat.shape[-1] == 4, "The last dimension of the input tensor should be 4." norm = torch.norm(quat, dim=-1, keepdim=True) if torch.any(norm == 0): raise ValueError( f"The input quat is invalid. The index of the invalid quat is {torch.where(norm == 0)[0]}") quat = quat / (norm + _EPS) return quat
[docs]def check_rot_matrix_tensor(rot_matrix): """ Check if the input rotation matrix is valid, orthogonal, and normalize it if necessary. :param rot_matrix: Input rotation matrix :return: Validated and normalized rotation matrix >>> from rofunc.utils.robolab.coord.transform import random_rot_matrix >>> rot_matrix = random_rot_matrix() * 3 >>> torch.allclose(check_rot_matrix_tensor(rot_matrix) * 3, torch.tensor(rot_matrix, dtype=torch.float32)) True """ rot_matrix = torch.tensor(rot_matrix, dtype=torch.float32) if len(rot_matrix.shape) == 2: rot_matrix = rot_matrix.unsqueeze(0) # Check if the matrix is square if rot_matrix.shape[-1] != rot_matrix.shape[-2]: raise ValueError("Input matrix is not square.") # # Check orthogonality: R^T * R should be equal to the identity matrix # identity_matrix = torch.eye(rot_matrix.shape[-1]) # matrix_product = torch.matmul(rot_matrix.transpose(-1, -2), rot_matrix) # if not torch.allclose(matrix_product, identity_matrix): # raise ValueError("Input matrix is not orthogonal.") # Normalize the matrix if necessary normalized_rot_matrix = rot_matrix column_norms = torch.norm(normalized_rot_matrix, dim=-2) if not torch.allclose(column_norms, torch.ones_like(column_norms)): normalized_rot_matrix = normalized_rot_matrix / column_norms.unsqueeze(-1) return normalized_rot_matrix
[docs]def check_euler_tensor(euler): """ Check if the input euler angles are valid. :param euler: (batch, 3) or (3, ) :return: euler angles >>> check_euler_tensor([1.57, 0, 0]) tensor([[1.5700, 0.0000, 0.0000]]) >>> check_euler_tensor([[0, 0, 0]]) tensor([[0., 0., 0.]]) >>> check_euler_tensor(np.array([0, 0, 0])) tensor([[0., 0., 0.]]) """ euler = torch.tensor(euler, dtype=torch.float32) if len(euler.shape) == 1: euler = euler.unsqueeze(0) assert euler.shape[-1] == 3, "The last dimension of the input tensor should be 3." return euler
[docs]def random_quat_tensor(batch_size, rand=None): """ Return uniform random unit quat. :param batch_size: Batch size :param rand: Random number generator (optional) :return: Random unit quat, [x, y, z, w] >>> torch.allclose(torch.norm(random_quat_tensor(100), dim=-1), torch.ones(100)) True >>> rand_quat = random_quat_tensor(100) >>> torch.allclose(check_quat_tensor(rand_quat), rand_quat) True """ if rand is None: rand = torch.rand random_values = rand(batch_size, 3) r1 = torch.sqrt(1 - random_values[:, 0]) r2 = torch.sqrt(random_values[:, 0]) pi2 = 2 * torch.pi t1 = pi2 * random_values[:, 1] t2 = pi2 * random_values[:, 2] x = r1 * torch.sin(t1) y = r1 * torch.cos(t1) z = r2 * torch.sin(t2) w = r2 * torch.cos(t2) quat = torch.stack([x, y, z, w], dim=-1) return quat
[docs]def random_rot_matrix_tensor(batch_size, rand=None): """ Generate random rotation matrix. quat = [x, y, z, w]. :param batch_size: Batch size :param rand: Random number generator (optional) :return: Random rotation matrix >>> rand_rot_matrix = random_rot_matrix_tensor(100) >>> torch.allclose(rand_rot_matrix.det(), torch.ones(100)) True >>> torch.allclose(check_rot_matrix_tensor(rand_rot_matrix), rand_rot_matrix) True >>> from rofunc.utils.robolab.coord.transform import check_rot_matrix >>> torch.allclose(torch.tensor(check_rot_matrix(rand_rot_matrix[0]), dtype=torch.float32), rand_rot_matrix[0]) True """ if rand is None: rand = torch.rand quat = random_quat_tensor(batch_size, rand) wx, wy, wz, ww = quat[:, 0], quat[:, 1], quat[:, 2], quat[:, 3] xx = wx * wx xy = wx * wy xz = wx * wz xw = wx * ww yy = wy * wy yz = wy * wz yw = wy * ww zz = wz * wz zw = wz * ww rot_matrix = torch.stack( [ 1 - 2 * (yy + zz), 2 * (xy - zw), 2 * (xz + yw), 2 * (xy + zw), 1 - 2 * (xx + zz), 2 * (yz - xw), 2 * (xz - yw), 2 * (yz + xw), 1 - 2 * (xx + yy), ], dim=1, ).view(batch_size, 3, 3) return rot_matrix
[docs]def quat_from_rot_matrix_tensor(rot_matrix): """ Convert rotation matrix to quat. [x, y, z, w] :param rot_matrix: :return: quat, [x, y, z, w] >>> quat_from_rot_matrix_tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) tensor([[0., 0., 0., 1.]]) >>> quat_from_rot_matrix_tensor([[0.9362934, -0.2896295, 0.1986693], [0.3129918, 0.9447025, -0.0978434], [-0.1593451, 0.1537920, 0.9751703]]) tensor([[0.0641, 0.0912, 0.1534, 0.9819]]) >>> rand_rot_matrix = random_rot_matrix_tensor(100) >>> torch.allclose(check_rot_matrix_tensor(rand_rot_matrix), rot_matrix_from_quat_tensor(quat_from_rot_matrix_tensor(rand_rot_matrix)), rtol=1e-03, atol=1e-03) True """ rot_matrix = check_rot_matrix_tensor(rot_matrix) trace = rot_matrix[:, 0, 0] + rot_matrix[:, 1, 1] + rot_matrix[:, 2, 2] r = torch.sqrt(1 + trace) qw = 0.5 * r qx = (rot_matrix[:, 2, 1] - rot_matrix[:, 1, 2]) / (2 * r) qy = (rot_matrix[:, 0, 2] - rot_matrix[:, 2, 0]) / (2 * r) qz = (rot_matrix[:, 1, 0] - rot_matrix[:, 0, 1]) / (2 * r) quat = torch.stack([qx, qy, qz, qw], dim=-1) return quat
[docs]def quat_from_euler_tensor(euler): """ Convert euler angles to quat. The rotation order is ZYX. :param euler: (batch, 3) or (3, ), [roll, pitch, yaw], the rotation order is ZYX. :return: quat, [x, y, z, w] >>> quat_from_euler_tensor([0, 0, 0]) tensor([[0., 0., 0., 1.]]) >>> quat_from_euler_tensor([[0, 0, 0]]) tensor([[0., 0., 0., 1.]]) >>> quat_from_euler_tensor(np.array([0, 0, 0])) tensor([[0., 0., 0., 1.]]) >>> quat_from_euler_tensor([[0, 1.23, 0.57], [0.5, 0.3, 0.7], [0.1, 0.2, 0.3]]) tensor([[-0.1622, 0.5537, 0.2296, 0.7838], [ 0.1801, 0.2199, 0.2938, 0.9126], [ 0.0343, 0.1060, 0.1436, 0.9833]]) """ euler = check_euler_tensor(euler) roll = euler[:, 0] pitch = euler[:, 1] yaw = euler[:, 2] cy = torch.cos(yaw * 0.5) sy = torch.sin(yaw * 0.5) cp = torch.cos(pitch * 0.5) sp = torch.sin(pitch * 0.5) cr = torch.cos(roll * 0.5) sr = torch.sin(roll * 0.5) qx = sr * cp * cy - cr * sp * sy qy = cr * sp * cy + sr * cp * sy qz = cr * cp * sy - sr * sp * cy qw = cr * cp * cy + sr * sp * sy quat = torch.stack([qx, qy, qz, qw], dim=-1) return quat
[docs]def rot_matrix_from_quat_tensor(quat): """ Convert quat to rotation matrix. :param quat: [x, y, z, w] :return: >>> rot_matrix_from_quat_tensor([0, 0, 0, 1]) tensor([[[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]]) >>> rot_matrix_from_quat_tensor([[0.06146124, 0, 0, 0.99810947], [0.2794439, 0.0521324, 0.3632374, 0.8872722]]) tensor([[[ 1.0000, 0.0000, 0.0000], [ 0.0000, 0.9924, -0.1227], [ 0.0000, 0.1227, 0.9924]], <BLANKLINE> [[ 0.7307, -0.6154, 0.2955], [ 0.6737, 0.5799, -0.4580], [ 0.1105, 0.5338, 0.8384]]]) """ quat = check_quat_tensor(quat) x, y, z, w = quat[:, 0], quat[:, 1], quat[:, 2], quat[:, 3] r11 = 1 - 2 * (y ** 2 + z ** 2) r12 = 2 * (x * y - z * w) r13 = 2 * (x * z + y * w) r21 = 2 * (x * y + z * w) r22 = 1 - 2 * (x ** 2 + z ** 2) r23 = 2 * (y * z - x * w) r31 = 2 * (x * z - y * w) r32 = 2 * (y * z + x * w) r33 = 1 - 2 * (x ** 2 + y ** 2) rot_matrix = torch.stack([ torch.stack([r11, r12, r13], dim=-1), torch.stack([r21, r22, r23], dim=-1), torch.stack([r31, r32, r33], dim=-1) ], dim=-2) return check_rot_matrix_tensor(rot_matrix)
[docs]def rot_matrix_from_euler_tensor(euler): """ Convert euler angles to rotation matrix. :param euler: (batch, 3) or (3, ), [roll, pitch, yaw] in radian :return: Rotation matrix >>> rot_matrix_from_euler_tensor([0, 0, 0]) tensor([[[1., 0., 0.], [0., 1., 0.], [-0., 0., 1.]]]) >>> rot_matrix_from_euler_tensor([[0.5, 0.3, 0.7], [1.33, 0.2, -0.03]]) tensor([[[ 0.7307, -0.4570, 0.5072], [ 0.6154, 0.7625, -0.1996], [-0.2955, 0.4580, 0.8384]], <BLANKLINE> [[ 0.9796, 0.2000, 0.0182], [-0.0294, 0.2326, -0.9721], [-0.1987, 0.9518, 0.2337]]]) """ euler = check_euler_tensor(euler) roll = euler[:, 0] pitch = euler[:, 1] yaw = euler[:, 2] cos_r = torch.cos(roll) sin_r = torch.sin(roll) cos_p = torch.cos(pitch) sin_p = torch.sin(pitch) cos_y = torch.cos(yaw) sin_y = torch.sin(yaw) batch_size = euler.size(0) device = euler.device rot_matrix = torch.zeros((batch_size, 3, 3), device=device) rot_matrix[:, 0, 0] = cos_y * cos_p rot_matrix[:, 0, 1] = cos_y * sin_p * sin_r - sin_y * cos_r rot_matrix[:, 0, 2] = cos_y * sin_p * cos_r + sin_y * sin_r rot_matrix[:, 1, 0] = sin_y * cos_p rot_matrix[:, 1, 1] = sin_y * sin_p * sin_r + cos_y * cos_r rot_matrix[:, 1, 2] = sin_y * sin_p * cos_r - cos_y * sin_r rot_matrix[:, 2, 0] = -sin_p rot_matrix[:, 2, 1] = cos_p * sin_r rot_matrix[:, 2, 2] = cos_p * cos_r return rot_matrix
[docs]def euler_from_quat_tensor(quat): """ Convert quat to euler angles. :param quat: [x, y, z, w] :return: euler angles, [roll, pitch, yaw] in radian >>> euler_from_quat_tensor(torch.tensor([[0, 0, 0, 1.]])) tensor([[0., 0., 0.]]) >>> euler_from_quat_tensor(torch.tensor([[0.06146124, 0, 0, 0.99810947], [0.2794439, 0.0521324, 0.3632374, 0.8872722]])) tensor([[ 0.1230, 0.0000, 0.0000], [ 0.5669, -0.1107, 0.7449]]) """ quat = check_quat_tensor(quat) qx = quat[:, 0] qy = quat[:, 1] qz = quat[:, 2] qw = quat[:, 3] roll = torch.atan2(2 * (qw * qx + qy * qz), 1 - 2 * (qx * qx + qy * qy)) pitch = torch.asin(torch.clamp(2 * (qw * qy - qz * qx), -1, 1)) yaw = torch.atan2(2 * (qw * qz + qx * qy), 1 - 2 * (qy * qy + qz * qz)) euler = torch.stack([roll, pitch, yaw], dim=-1) return euler
[docs]def euler_from_rot_matrix_tensor(rot_matrix): """ Convert rotation matrix to euler angles. :param rot_matrix: :return: >>> euler_from_rot_matrix_tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) tensor([[0., -0., 0.]]) >>> euler_from_rot_matrix_tensor([[[ 0.7307, -0.4570, 0.5072], [ 0.6154, 0.7625, -0.1996], [-0.2955, 0.4580, 0.8384]], [[ 0.9796, 0.2000, 0.0182], [-0.0294, 0.2326, -0.9721], [-0.1987, 0.9518, 0.2337]]]) tensor([[ 0.5000, 0.3000, 0.6999], [ 1.3300, 0.2000, -0.0300]]) """ rot_matrix = check_rot_matrix_tensor(rot_matrix) r11, r12, r13 = rot_matrix[:, 0, 0], rot_matrix[:, 0, 1], rot_matrix[:, 0, 2] r21, r22, r23 = rot_matrix[:, 1, 0], rot_matrix[:, 1, 1], rot_matrix[:, 1, 2] r31, r32, r33 = rot_matrix[:, 2, 0], rot_matrix[:, 2, 1], rot_matrix[:, 2, 2] pitch = -torch.asin(r31) roll = torch.atan2(r32 / torch.cos(pitch), r33 / torch.cos(pitch)) yaw = torch.atan2(r21 / torch.cos(pitch), r11 / torch.cos(pitch)) euler = torch.stack([roll, pitch, yaw], dim=-1) return euler
[docs]def homo_matrix_from_quat_tensor(quat, pos=None): """ Convert quat and pos to homogeneous matrix :param quat: :param pos: :return: """ quat = check_quat_tensor(quat) if pos is not None: pos = check_pos_tensor(pos) assert quat.shape[0] == pos.shape[0] else: pos = torch.zeros((quat.shape[0], 3)) batch_size = quat.shape[0] device = quat.device homo_matrix = torch.zeros((batch_size, 4, 4), device=device) rot_matrix = rot_matrix_from_quat_tensor(quat) homo_matrix[:, :3, :3] = rot_matrix homo_matrix[:, :3, 3] = pos homo_matrix[:, 3, 3] = 1 return homo_matrix