357 lines
13 KiB
Python
357 lines
13 KiB
Python
"""
|
||
UOT (Unbalanced Optimal Transport) 位姿估计 Demo
|
||
=================================================
|
||
UOTHead 使用 Sinkhorn 非平衡最优传输进行特征匹配和位姿估计。
|
||
|
||
流程:
|
||
1. Cosine Cost Matrix: C = 1 - cosine_sim(feat1, feat2)
|
||
2. Sinkhorn Unbalanced OT: 迭代求解运输计划 T
|
||
3. Point Projection: project_kpts = T @ kpts2 / sum(T)
|
||
4. Weighted SVD: 从匹配点对估计刚体变换 R|t
|
||
|
||
关键参数:
|
||
- epsilon: 熵正则化(控制运输计划的平滑度)
|
||
- gamma: 质量正则化(允许部分匹配)
|
||
- sinkhorn_iter: 5次迭代
|
||
"""
|
||
|
||
import torch
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib
|
||
matplotlib.use('Agg')
|
||
|
||
import sys
|
||
import os
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from uot import UOTHead, sinkhorn_unbalanced, compute_rigid_transform
|
||
|
||
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output')
|
||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||
|
||
|
||
def visualize_cost_matrix():
|
||
"""可视化代价矩阵"""
|
||
print('\n--- 代价矩阵 (Cost Matrix) ---')
|
||
|
||
torch.manual_seed(42)
|
||
# 模拟query和positive的150个关键点特征
|
||
feat1 = torch.randn(2, 150, 128) # query
|
||
feat2 = torch.randn(2, 150, 128) # positive
|
||
|
||
# 让部分特征相似(模拟真实闭环场景)
|
||
# 前100个特征点有对应关系
|
||
feat2[:, :100] = feat1[:, :100] + 0.1 * torch.randn(2, 100, 128)
|
||
|
||
# 计算cosine cost matrix
|
||
feat1_norm = feat1 / (feat1.norm(dim=2, keepdim=True) + 1e-8)
|
||
feat2_norm = feat2 / (feat2.norm(dim=2, keepdim=True) + 1e-8)
|
||
C = 1.0 - torch.bmm(feat1_norm, feat2_norm.transpose(1, 2))
|
||
|
||
C_np = C[0].numpy()
|
||
|
||
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
|
||
|
||
im0 = axes[0].imshow(C_np, cmap='YlOrRd')
|
||
axes[0].set_title('Cost Matrix C = 1 - cos_sim')
|
||
axes[0].set_xlabel('Positive Point j')
|
||
axes[0].set_ylabel('Query Point i')
|
||
plt.colorbar(im0, ax=axes[0])
|
||
|
||
# 缩放看前30个点(有对应关系的)
|
||
im1 = axes[1].imshow(C_np[:30, :30], cmap='YlOrRd')
|
||
axes[1].set_title('Cost Matrix (前30×30)\n有模拟对应关系')
|
||
axes[1].set_xlabel('Positive Point j')
|
||
axes[1].set_ylabel('Query Point i')
|
||
plt.colorbar(im1, ax=axes[1])
|
||
|
||
# 对角线cost分布 vs 非对角线
|
||
diag_cost = np.diag(C_np)
|
||
off_diag = C_np[~np.eye(150, dtype=bool)]
|
||
|
||
axes[2].hist(diag_cost, bins=30, alpha=0.6, label=f'对角线(匹配点)\nmean={diag_cost.mean():.3f}',
|
||
color='green')
|
||
axes[2].hist(off_diag, bins=30, alpha=0.6, label=f'非对角线\nmean={off_diag.mean():.3f}',
|
||
color='gray')
|
||
axes[2].set_title('Cost分布: 匹配 vs 非匹配')
|
||
axes[2].set_xlabel('Cost')
|
||
axes[2].legend(fontsize=8)
|
||
|
||
plt.suptitle('UOT 代价矩阵分析', fontsize=14, fontweight='bold')
|
||
plt.tight_layout()
|
||
path = os.path.join(OUTPUT_DIR, 'uot_cost_matrix.png')
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f' [保存] {path}')
|
||
|
||
|
||
def visualize_sinkhorn():
|
||
"""可视化Sinkhorn迭代过程"""
|
||
print('\n--- Sinkhorn 迭代过程 ---')
|
||
|
||
torch.manual_seed(42)
|
||
# 构造有明显对应关系的特征
|
||
B, N, C = 1, 50, 128
|
||
feat1 = torch.randn(B, N, C)
|
||
feat1_norm = feat1 / (feat1.norm(dim=2, keepdim=True) + 1e-8)
|
||
|
||
# feat2是feat1的扰动版本
|
||
feat2 = feat1 + 0.15 * torch.randn(B, N, C)
|
||
feat2_norm = feat2 / (feat2.norm(dim=2, keepdim=True) + 1e-8)
|
||
|
||
C = 1.0 - torch.bmm(feat1_norm, feat2_norm.transpose(1, 2))
|
||
|
||
epsilon = torch.tensor([0.05])
|
||
gamma = torch.tensor([1.0])
|
||
|
||
# 逐步可视化Sinkhorn迭代
|
||
K = torch.exp(-C / epsilon)
|
||
max_iter = 5
|
||
power = gamma / (gamma + epsilon + 1e-8)
|
||
|
||
a = torch.ones((B, N, 1)) / N
|
||
prob1 = torch.ones((B, N, 1)) / N
|
||
prob2 = torch.ones((B, N, 1)) / N
|
||
|
||
fig, axes = plt.subplots(2, 4, figsize=(18, 9))
|
||
|
||
# K (初始)
|
||
K_np = K[0].numpy()
|
||
im0 = axes[0, 0].imshow(K_np, cmap='YlOrRd')
|
||
axes[0, 0].set_title('K (exp(-C/ε))\n迭代0')
|
||
axes[0, 0].set_xlabel('Positive'); axes[0, 0].set_ylabel('Query')
|
||
plt.colorbar(im0, ax=axes[0, 0])
|
||
|
||
for iteration in range(1, min(max_iter + 1, 7)):
|
||
# Update b
|
||
KTa = torch.bmm(K.transpose(1, 2), a)
|
||
b = torch.pow(prob2 / (KTa + 1e-8), power)
|
||
# Update a
|
||
Kb = torch.bmm(K, b)
|
||
a = torch.pow(prob1 / (Kb + 1e-8), power)
|
||
|
||
T = torch.mul(torch.mul(a, K), b.transpose(1, 2))
|
||
T_np = T[0].numpy()
|
||
|
||
ax = axes[(iteration) // 4, (iteration) % 4]
|
||
im = ax.imshow(T_np, cmap='YlOrRd')
|
||
ax.set_title(f'Transport Plan T\n迭代{iteration}')
|
||
ax.set_xlabel('Positive'); ax.set_ylabel('Query')
|
||
plt.colorbar(im, ax=ax)
|
||
|
||
# 空余位置
|
||
for i in range(max_iter + 1, 8):
|
||
axes[i // 4, i % 4].axis('off')
|
||
|
||
plt.suptitle('Sinkhorn 非平衡最优传输迭代过程', fontsize=14, fontweight='bold')
|
||
plt.tight_layout()
|
||
path = os.path.join(OUTPUT_DIR, 'uot_sinkhorn_iterations.png')
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f' [保存] {path}')
|
||
|
||
|
||
def test_rigid_transform():
|
||
"""测试刚体变换估计"""
|
||
print('\n--- Weighted SVD 刚体变换估计 ---')
|
||
|
||
torch.manual_seed(42)
|
||
B, N = 2, 150
|
||
|
||
# 真实变换
|
||
angle = torch.tensor(0.5) # ~28.6度
|
||
R_true = torch.tensor([
|
||
[torch.cos(angle), -torch.sin(angle), 0],
|
||
[torch.sin(angle), torch.cos(angle), 0],
|
||
[0, 0, 1]
|
||
]).unsqueeze(0).repeat(B, 1, 1)
|
||
t_true = torch.tensor([2.0, -1.0, 0.1]).unsqueeze(0).unsqueeze(-1).repeat(B, 1, 1)
|
||
|
||
# query点云
|
||
pts1 = torch.randn(B, N, 3) * 20
|
||
|
||
# positive点云 = R * query + t + noise
|
||
pts2 = R_true @ pts1.transpose(1, 2) + t_true
|
||
pts2 = pts2.transpose(1, 2) + 0.3 * torch.randn(B, N, 3)
|
||
|
||
# 模拟transport weights(前80个点匹配好,后70个匹配差)
|
||
weights = torch.ones(B, N)
|
||
weights[:, 80:] = 0.1 # 降低后70个点的权重
|
||
|
||
# 估计变换
|
||
transform = compute_rigid_transform(pts1, pts2, weights)
|
||
|
||
# 评估
|
||
R_est = transform[:, :3, :3]
|
||
t_est = transform[:, :3, 3]
|
||
|
||
# 旋转误差
|
||
R_err = R_est @ R_true.transpose(1, 2)
|
||
trace = torch.diagonal(R_err, dim1=1, dim2=2).sum(dim=1)
|
||
angle_err = torch.acos(torch.clamp((trace - 1) / 2, -1, 1)) * 180 / np.pi
|
||
|
||
# 平移误差
|
||
t_err = (t_est - t_true.squeeze(-1)).norm(dim=1)
|
||
|
||
print(f'旋转误差: {angle_err[0].item():.2f}°')
|
||
print(f'平移误差: {t_err[0].item():.3f}m')
|
||
|
||
# 可视化
|
||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||
|
||
# 3D点云(XY平面投影)
|
||
pts1_2d = pts1[0, :, :2].numpy()
|
||
pts2_2d = pts2[0, :, :2].numpy()
|
||
|
||
# 投影点
|
||
pts1_transformed = (R_est[0] @ pts1[0].T + t_est[0].unsqueeze(-1)).T[:, :2].numpy()
|
||
|
||
axes[0].scatter(pts1_2d[:, 0], pts1_2d[:, 1], c='blue', s=10, alpha=0.6, label='Query')
|
||
axes[0].scatter(pts2_2d[:, 0], pts2_2d[:, 1], c='red', s=10, alpha=0.6, label='Positive')
|
||
for i in range(min(20, N)):
|
||
if weights[0, i] > 0.5:
|
||
axes[0].plot([pts1_2d[i, 0], pts2_2d[i, 0]],
|
||
[pts1_2d[i, 1], pts2_2d[i, 1]],
|
||
'gray', alpha=0.3, linewidth=0.5)
|
||
axes[0].set_title('匹配点对 (蓝色→红色)')
|
||
axes[0].set_xlabel('X (m)'); axes[0].set_ylabel('Y (m)')
|
||
axes[0].legend(fontsize=8)
|
||
axes[0].set_aspect('equal')
|
||
|
||
# 变换后
|
||
axes[1].scatter(pts1_transformed[:, 0], pts1_transformed[:, 1],
|
||
c='blue', s=10, alpha=0.6, label='Query (变换后)')
|
||
axes[1].scatter(pts2_2d[:, 0], pts2_2d[:, 1],
|
||
c='red', s=10, alpha=0.6, label='Positive (目标)')
|
||
axes[1].set_title(f'变换后对比\n旋转误差:{angle_err[0].item():.2f}° 平移误差:{t_err[0].item():.3f}m')
|
||
axes[1].set_xlabel('X (m)'); axes[1].set_ylabel('Y (m)')
|
||
axes[1].legend(fontsize=8)
|
||
axes[1].set_aspect('equal')
|
||
|
||
plt.suptitle('Weighted SVD 刚体变换估计', fontsize=14, fontweight='bold')
|
||
plt.tight_layout()
|
||
path = os.path.join(OUTPUT_DIR, 'uot_rigid_transform.png')
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f' [保存] {path}')
|
||
|
||
|
||
def visualize_epsilon_gamma():
|
||
"""可视化epsilon和gamma参数的影响"""
|
||
print('\n--- epsilon/gamma 参数分析 ---')
|
||
|
||
torch.manual_seed(42)
|
||
N = 50
|
||
feat1 = torch.randn(1, N, 128)
|
||
feat1_norm = feat1 / (feat1.norm(dim=2, keepdim=True) + 1e-8)
|
||
feat2 = feat1 + 0.2 * torch.randn(1, N, 128)
|
||
feat2_norm = feat2 / (feat2.norm(dim=2, keepdim=True) + 1e-8)
|
||
|
||
epsilons = [0.01, 0.05, 0.1, 0.5]
|
||
gammas = [0.1, 1.0, 10.0]
|
||
|
||
fig, axes = plt.subplots(len(gammas), len(epsilons), figsize=(16, 12))
|
||
|
||
for gi, gamma in enumerate(gammas):
|
||
for ei, eps in enumerate(epsilons):
|
||
epsilon = torch.tensor([eps])
|
||
gam = torch.tensor([gamma])
|
||
T = sinkhorn_unbalanced(
|
||
feat1_norm, feat2_norm,
|
||
epsilon=epsilon, gamma=gam,
|
||
max_iter=5, matrix='cosine'
|
||
)
|
||
ax = axes[gi, ei]
|
||
im = ax.imshow(T[0].numpy(), cmap='YlOrRd')
|
||
ax.set_title(f'ε={eps}, γ={gamma}')
|
||
ax.set_xlabel('Positive'); ax.set_ylabel('Query')
|
||
plt.colorbar(im, ax=ax)
|
||
|
||
plt.suptitle('epsilon (熵正则) 和 gamma (质量正则) 对 Transport Plan 的影响',
|
||
fontsize=14, fontweight='bold')
|
||
plt.tight_layout()
|
||
path = os.path.join(OUTPUT_DIR, 'uot_epsilon_gamma.png')
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f' [保存] {path}')
|
||
|
||
print("""
|
||
参数解释:
|
||
- epsilon (ε): 熵正则化强度
|
||
- 小ε → Transport Plan更稀疏(hard matching)
|
||
- 大ε → Transport Plan更平滑(soft matching)
|
||
- gamma (γ): 质量正则化强度
|
||
- 小γ → 允许部分匹配(质量可增减)
|
||
- 大γ → 要求质量守恒(所有点必须匹配)
|
||
""")
|
||
|
||
|
||
def analyze_parameters():
|
||
"""参数量分析"""
|
||
print('\n--- 参数量分析 ---')
|
||
uot = UOTHead(nb_iter=5, name='original')
|
||
total = sum(p.numel() for p in uot.parameters())
|
||
print(f'总参数量: {total} (仅 epsilon, gamma 两个可学习标量)')
|
||
for name, param in uot.named_parameters():
|
||
print(f' {name}: {param.data.item():.4f}')
|
||
|
||
|
||
def main():
|
||
print('=' * 60)
|
||
print('UOT (Unbalanced Optimal Transport) 位姿估计可视化')
|
||
print('=' * 60)
|
||
|
||
analyze_parameters()
|
||
visualize_cost_matrix()
|
||
visualize_sinkhorn()
|
||
test_rigid_transform()
|
||
visualize_epsilon_gamma()
|
||
|
||
print('\n' + '=' * 60)
|
||
print('结构总结:')
|
||
print('=' * 60)
|
||
print("""
|
||
UOTHead (非平衡最优传输位姿估计):
|
||
|
||
┌──────────────────────────────────────────────────────┐
|
||
│ 输入: feat1(B,150,128), feat2(B,150,128) │
|
||
│ kpts1(B,150,3), kpts2(B,150,3) │
|
||
│ │
|
||
│ 1. Cost Matrix: C = 1 - cosine_sim(feat1, feat2) │
|
||
│ → (B, 150, 150) │
|
||
│ │
|
||
│ 2. Sinkhorn Unbalanced OT (迭代5次): │
|
||
│ K = exp(-C / epsilon) │
|
||
│ for i in range(5): │
|
||
│ b = (prob2 / Kᵀa)^(γ/(γ+ε)) │
|
||
│ a = (prob1 / Kb)^(γ/(γ+ε)) │
|
||
│ T = a ⊙ K ⊙ bᵀ │
|
||
│ → (B, 150, 150) 运输计划 │
|
||
│ │
|
||
│ 3. 投影: project_kpts = T @ kpts2 / ΣT │
|
||
│ → (B, 150, 3) query匹配点在positive空间的投影坐标 │
|
||
│ │
|
||
│ 4. Weighted SVD 刚体变换: │
|
||
│ - 加权中心化 │
|
||
│ - SVD分解协方差 │
|
||
│ - 输出 R(3×3), t(3×1) │
|
||
│ → transformation: (B, 3, 4) │
|
||
└──────────────────────────────────────────────────────┘
|
||
|
||
为什么用Unbalanced OT(非平衡最优传输)?
|
||
- 标准OT要求两个点集大小相同且质量守恒
|
||
- 实际场景:部分关键点在另一帧中可能被遮挡
|
||
- Unbalanced OT允许部分匹配,更鲁棒
|
||
|
||
两个可学习参数:
|
||
- epsilon (ε): 熵正则化,exp(ε)+0.03
|
||
- gamma (γ): 质量正则化,exp(γ)
|
||
""")
|
||
|
||
print(f'\n所有可视化结果保存在: {OUTPUT_DIR}')
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|