参考文章链接:
知乎:最适合深度学习的三维旋转表示
pytorch旋转矩阵转四元数及各种旋转表示方式之间的转换实现代码
CVPR19的论文《On the Continuity of Rotation Representations in Neural Networks》。
这篇文章分析比较了旋转矩阵,欧拉角,四元数等常见三维旋转表示对神经网络训练的影响,并提出了一种适合深度学习的 6D三维旋转表示方法。
不连续(Continiuous)的旋转表示不是好的旋转表示
右图中表示的是用几何圆表示的旋转一周,左图是用数值表示的旋转一周,为方便理解,我补充了图中红字
右图从45度到-45度是连续的,但是在左图的数值反映是不连续的,【0–π/4】和【7π/4–2π】
这个图中的表达的意思是从使用上边左图的数值表示的旋转是不连续的,这样的表示是不好的
1. 绕x轴旋转90°
angle = 90/180*math.pi
transform_matrix_x = np.array([
[1, 0, 0],
[0, math.cos(angle), -math.sin(angle)],
[0, math.sin(angle), math.cos(angle)]])#绕x轴
2.绕y轴旋转90°
angle = 90/180*math.pi
transform_matrix_y = np.array([
[math.cos(angle), 0, math.sin(angle)],
[0, 1, 0],
[-math.sin(angle), 0, math.cos(angle)]])#绕y轴
2.绕z轴旋转90°
angle = 90/180*math.pi
transform_matrix_z = np.array([
[math.cos(angle), -math.sin(angle), 0],
[math.sin(angle), math.cos(angle), 0],
[0, 0, 1]])#绕z轴
旋转矩阵转6D表示
很容易理解就是取3X3旋转矩阵的前两行的6个数
def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
"""
Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
by dropping the last row. Note that 6D representation is not unique.
Args:
matrix: batch of rotation matrices of size (*, 3, 3)
Returns:
6D rotation representation, of size (*, 6)
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
6D表示恢复旋转矩阵
施密特正交化
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
"""
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
using Gram--Schmidt orthogonalization per Section B of [1].
Args:
d6: 6D rotation representation, of size (*, 6)
Returns:
batch of rotation matrices of size (*, 3, 3)
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
a1, a2 = d6[..., :3], d6[..., 3:]
b1 = F.normalize(a1, dim=-1)
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1#归一化后b1内积为1,表达式中省去
b2 = F.normalize(b2, dim=-1)
b3 = torch.cross(b1, b2, dim=-1)#第三个正交向量垂直前两个正交向量所在的面(叉乘)
return torch.stack((b1, b2, b3), dim=-2)
完整代码
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import functools
from typing import Optional
import torch
import torch.nn.functional as F
"""
The transformation matrices returned from the functions in this file assume
the points on which the transformation will be applied are column vectors.
i.e. the R matrix is structured as
R = [
[Rxx, Rxy, Rxz],
[Ryx, Ryy, Ryz],
[Rzx, Rzy, Rzz],
] # (3, 3)
This matrix can be applied to column vectors by post multiplication
by the points e.g.
points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
transformed_points = R * points
To apply the same matrix to points which are row vectors, the R matrix
can be transposed and pre multiplied by the points:
e.g.
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
transformed_points = points * R.transpose(1, 0)
"""
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
"""
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
using Gram--Schmidt orthogonalization per Section B of [1].
Args:
d6: 6D rotation representation, of size (*, 6)
Returns:
batch of rotation matrices of size (*, 3, 3)
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
a1, a2 = d6[..., :3], d6[..., 3:]
b1 = F.normalize(a1, dim=-1)
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
b2 = F.normalize(b2, dim=-1)
b3 = torch.cross(b1, b2, dim=-1)
return torch.stack((b1, b2, b3), dim=-2)
def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
"""
Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
by dropping the last row. Note that 6D representation is not unique.
Args:
matrix: batch of rotation matrices of size (*, 3, 3)
Returns:
6D rotation representation, of size (*, 6)
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
import math
import numpy as np
from torch.nn import functional as F
# 初始化一个旋转角度
angle = 90/180*math.pi
# 创建一个三维坐标变换矩阵
transform_matrix_z = np.array([
[math.cos(angle), -math.sin(angle), 0],
[math.sin(angle), math.cos(angle), 0],
[0, 0, 1]])#绕z轴
transform_matrix_x = np.array([
[1, 0, 0],
[0, math.cos(angle), -math.sin(angle)],
[0, math.sin(angle), math.cos(angle)]])#绕x轴
transform_matrix_y = np.array([
[math.cos(angle), 0, math.sin(angle)],
[0, 1, 0],
[-math.sin(angle), 0, math.cos(angle)]])#绕y轴
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure(figsize=(10, 8))
ax = Axes3D(fig)
ax.scatter3D(0,0,0,s=5)
ax.text3D(0,0,0,"(0,0,0)")
ax.text3D(5,0,0,"X")
ax.text3D(0,5,0,"Y")
ax.text3D(0,0,5,"Z")
ax.plot3D([0,5], [0,0], [0,0], marker='>', linewidth=2)
ax.plot3D([0,0], [0,5], [0,0], marker='>', linewidth=2)
ax.plot3D([0,0], [0,0], [0,5], marker='>', linewidth=2)
z = np.linspace(0, 4, 200)
x = z * np.sin(20 * z)
y = z * np.cos(20 * z)
A = np.array([list(x), list(y), list(z)])
'''旋转方向'''
R = transform_matrix_x
R_ = np.array(rotation_6d_to_matrix(matrix_to_rotation_6d(torch.tensor(R))))
print("R", R)
print("R_", R_)
A1 = np.matmul(R_, A)
ax.plot3D(A[0], A[1], A[2], marker='o')
ax.plot3D(A1[0], A1[1], A1[2], marker='>', linewidth=2)
plt.rcParams.update({
'font.family': 'Times New Roman'})
plt.rcParams.update({
'font.weight': 'normal'})
plt.rcParams.update({
'font.size': 20})
# Tweaking display region and labels
# ax.set_xlim(-5, 5)
# ax.set_ylim(-5, 5)
# ax.set_zlim(-5, 5)
ax.set_xlabel('X axis')
ax.set_ylabel('Y axis')
ax.set_zlabel('Z axis')
plt.show()