Files
fusion_LCD/network_learning/06_uot_demo.py
2026-05-09 17:03:40 +08:00

357 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
UOT (Unbalanced Optimal Transport) 位姿估计 Demo
=================================================
UOTHead 使用 Sinkhorn 非平衡最优传输进行特征匹配和位姿估计。
流程:
1. Cosine Cost Matrix: C = 1 - cosine_sim(feat1, feat2)
2. Sinkhorn Unbalanced OT: 迭代求解运输计划 T
3. Point Projection: project_kpts = T @ kpts2 / sum(T)
4. Weighted SVD: 从匹配点对估计刚体变换 R|t
关键参数:
- epsilon: 熵正则化(控制运输计划的平滑度)
- gamma: 质量正则化(允许部分匹配)
- sinkhorn_iter: 5次迭代
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from uot import UOTHead, sinkhorn_unbalanced, compute_rigid_transform
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output')
os.makedirs(OUTPUT_DIR, exist_ok=True)
def visualize_cost_matrix():
"""可视化代价矩阵"""
print('\n--- 代价矩阵 (Cost Matrix) ---')
torch.manual_seed(42)
# 模拟query和positive的150个关键点特征
feat1 = torch.randn(2, 150, 128) # query
feat2 = torch.randn(2, 150, 128) # positive
# 让部分特征相似(模拟真实闭环场景)
# 前100个特征点有对应关系
feat2[:, :100] = feat1[:, :100] + 0.1 * torch.randn(2, 100, 128)
# 计算cosine cost matrix
feat1_norm = feat1 / (feat1.norm(dim=2, keepdim=True) + 1e-8)
feat2_norm = feat2 / (feat2.norm(dim=2, keepdim=True) + 1e-8)
C = 1.0 - torch.bmm(feat1_norm, feat2_norm.transpose(1, 2))
C_np = C[0].numpy()
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
im0 = axes[0].imshow(C_np, cmap='YlOrRd')
axes[0].set_title('Cost Matrix C = 1 - cos_sim')
axes[0].set_xlabel('Positive Point j')
axes[0].set_ylabel('Query Point i')
plt.colorbar(im0, ax=axes[0])
# 缩放看前30个点有对应关系的
im1 = axes[1].imshow(C_np[:30, :30], cmap='YlOrRd')
axes[1].set_title('Cost Matrix (前30×30)\n有模拟对应关系')
axes[1].set_xlabel('Positive Point j')
axes[1].set_ylabel('Query Point i')
plt.colorbar(im1, ax=axes[1])
# 对角线cost分布 vs 非对角线
diag_cost = np.diag(C_np)
off_diag = C_np[~np.eye(150, dtype=bool)]
axes[2].hist(diag_cost, bins=30, alpha=0.6, label=f'对角线(匹配点)\nmean={diag_cost.mean():.3f}',
color='green')
axes[2].hist(off_diag, bins=30, alpha=0.6, label=f'非对角线\nmean={off_diag.mean():.3f}',
color='gray')
axes[2].set_title('Cost分布: 匹配 vs 非匹配')
axes[2].set_xlabel('Cost')
axes[2].legend(fontsize=8)
plt.suptitle('UOT 代价矩阵分析', fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'uot_cost_matrix.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def visualize_sinkhorn():
"""可视化Sinkhorn迭代过程"""
print('\n--- Sinkhorn 迭代过程 ---')
torch.manual_seed(42)
# 构造有明显对应关系的特征
B, N, C = 1, 50, 128
feat1 = torch.randn(B, N, C)
feat1_norm = feat1 / (feat1.norm(dim=2, keepdim=True) + 1e-8)
# feat2是feat1的扰动版本
feat2 = feat1 + 0.15 * torch.randn(B, N, C)
feat2_norm = feat2 / (feat2.norm(dim=2, keepdim=True) + 1e-8)
C = 1.0 - torch.bmm(feat1_norm, feat2_norm.transpose(1, 2))
epsilon = torch.tensor([0.05])
gamma = torch.tensor([1.0])
# 逐步可视化Sinkhorn迭代
K = torch.exp(-C / epsilon)
max_iter = 5
power = gamma / (gamma + epsilon + 1e-8)
a = torch.ones((B, N, 1)) / N
prob1 = torch.ones((B, N, 1)) / N
prob2 = torch.ones((B, N, 1)) / N
fig, axes = plt.subplots(2, 4, figsize=(18, 9))
# K (初始)
K_np = K[0].numpy()
im0 = axes[0, 0].imshow(K_np, cmap='YlOrRd')
axes[0, 0].set_title('K (exp(-C/ε))\n迭代0')
axes[0, 0].set_xlabel('Positive'); axes[0, 0].set_ylabel('Query')
plt.colorbar(im0, ax=axes[0, 0])
for iteration in range(1, min(max_iter + 1, 7)):
# Update b
KTa = torch.bmm(K.transpose(1, 2), a)
b = torch.pow(prob2 / (KTa + 1e-8), power)
# Update a
Kb = torch.bmm(K, b)
a = torch.pow(prob1 / (Kb + 1e-8), power)
T = torch.mul(torch.mul(a, K), b.transpose(1, 2))
T_np = T[0].numpy()
ax = axes[(iteration) // 4, (iteration) % 4]
im = ax.imshow(T_np, cmap='YlOrRd')
ax.set_title(f'Transport Plan T\n迭代{iteration}')
ax.set_xlabel('Positive'); ax.set_ylabel('Query')
plt.colorbar(im, ax=ax)
# 空余位置
for i in range(max_iter + 1, 8):
axes[i // 4, i % 4].axis('off')
plt.suptitle('Sinkhorn 非平衡最优传输迭代过程', fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'uot_sinkhorn_iterations.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def test_rigid_transform():
"""测试刚体变换估计"""
print('\n--- Weighted SVD 刚体变换估计 ---')
torch.manual_seed(42)
B, N = 2, 150
# 真实变换
angle = torch.tensor(0.5) # ~28.6度
R_true = torch.tensor([
[torch.cos(angle), -torch.sin(angle), 0],
[torch.sin(angle), torch.cos(angle), 0],
[0, 0, 1]
]).unsqueeze(0).repeat(B, 1, 1)
t_true = torch.tensor([2.0, -1.0, 0.1]).unsqueeze(0).unsqueeze(-1).repeat(B, 1, 1)
# query点云
pts1 = torch.randn(B, N, 3) * 20
# positive点云 = R * query + t + noise
pts2 = R_true @ pts1.transpose(1, 2) + t_true
pts2 = pts2.transpose(1, 2) + 0.3 * torch.randn(B, N, 3)
# 模拟transport weights前80个点匹配好后70个匹配差
weights = torch.ones(B, N)
weights[:, 80:] = 0.1 # 降低后70个点的权重
# 估计变换
transform = compute_rigid_transform(pts1, pts2, weights)
# 评估
R_est = transform[:, :3, :3]
t_est = transform[:, :3, 3]
# 旋转误差
R_err = R_est @ R_true.transpose(1, 2)
trace = torch.diagonal(R_err, dim1=1, dim2=2).sum(dim=1)
angle_err = torch.acos(torch.clamp((trace - 1) / 2, -1, 1)) * 180 / np.pi
# 平移误差
t_err = (t_est - t_true.squeeze(-1)).norm(dim=1)
print(f'旋转误差: {angle_err[0].item():.2f}°')
print(f'平移误差: {t_err[0].item():.3f}m')
# 可视化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# 3D点云XY平面投影
pts1_2d = pts1[0, :, :2].numpy()
pts2_2d = pts2[0, :, :2].numpy()
# 投影点
pts1_transformed = (R_est[0] @ pts1[0].T + t_est[0].unsqueeze(-1)).T[:, :2].numpy()
axes[0].scatter(pts1_2d[:, 0], pts1_2d[:, 1], c='blue', s=10, alpha=0.6, label='Query')
axes[0].scatter(pts2_2d[:, 0], pts2_2d[:, 1], c='red', s=10, alpha=0.6, label='Positive')
for i in range(min(20, N)):
if weights[0, i] > 0.5:
axes[0].plot([pts1_2d[i, 0], pts2_2d[i, 0]],
[pts1_2d[i, 1], pts2_2d[i, 1]],
'gray', alpha=0.3, linewidth=0.5)
axes[0].set_title('匹配点对 (蓝色→红色)')
axes[0].set_xlabel('X (m)'); axes[0].set_ylabel('Y (m)')
axes[0].legend(fontsize=8)
axes[0].set_aspect('equal')
# 变换后
axes[1].scatter(pts1_transformed[:, 0], pts1_transformed[:, 1],
c='blue', s=10, alpha=0.6, label='Query (变换后)')
axes[1].scatter(pts2_2d[:, 0], pts2_2d[:, 1],
c='red', s=10, alpha=0.6, label='Positive (目标)')
axes[1].set_title(f'变换后对比\n旋转误差:{angle_err[0].item():.2f}° 平移误差:{t_err[0].item():.3f}m')
axes[1].set_xlabel('X (m)'); axes[1].set_ylabel('Y (m)')
axes[1].legend(fontsize=8)
axes[1].set_aspect('equal')
plt.suptitle('Weighted SVD 刚体变换估计', fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'uot_rigid_transform.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def visualize_epsilon_gamma():
"""可视化epsilon和gamma参数的影响"""
print('\n--- epsilon/gamma 参数分析 ---')
torch.manual_seed(42)
N = 50
feat1 = torch.randn(1, N, 128)
feat1_norm = feat1 / (feat1.norm(dim=2, keepdim=True) + 1e-8)
feat2 = feat1 + 0.2 * torch.randn(1, N, 128)
feat2_norm = feat2 / (feat2.norm(dim=2, keepdim=True) + 1e-8)
epsilons = [0.01, 0.05, 0.1, 0.5]
gammas = [0.1, 1.0, 10.0]
fig, axes = plt.subplots(len(gammas), len(epsilons), figsize=(16, 12))
for gi, gamma in enumerate(gammas):
for ei, eps in enumerate(epsilons):
epsilon = torch.tensor([eps])
gam = torch.tensor([gamma])
T = sinkhorn_unbalanced(
feat1_norm, feat2_norm,
epsilon=epsilon, gamma=gam,
max_iter=5, matrix='cosine'
)
ax = axes[gi, ei]
im = ax.imshow(T[0].numpy(), cmap='YlOrRd')
ax.set_title(f'ε={eps}, γ={gamma}')
ax.set_xlabel('Positive'); ax.set_ylabel('Query')
plt.colorbar(im, ax=ax)
plt.suptitle('epsilon (熵正则) 和 gamma (质量正则) 对 Transport Plan 的影响',
fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'uot_epsilon_gamma.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
print("""
参数解释:
- epsilon (ε): 熵正则化强度
- 小ε → Transport Plan更稀疏hard matching
- 大ε → Transport Plan更平滑soft matching
- gamma (γ): 质量正则化强度
- 小γ → 允许部分匹配(质量可增减)
- 大γ → 要求质量守恒(所有点必须匹配)
""")
def analyze_parameters():
"""参数量分析"""
print('\n--- 参数量分析 ---')
uot = UOTHead(nb_iter=5, name='original')
total = sum(p.numel() for p in uot.parameters())
print(f'总参数量: {total} (仅 epsilon, gamma 两个可学习标量)')
for name, param in uot.named_parameters():
print(f' {name}: {param.data.item():.4f}')
def main():
print('=' * 60)
print('UOT (Unbalanced Optimal Transport) 位姿估计可视化')
print('=' * 60)
analyze_parameters()
visualize_cost_matrix()
visualize_sinkhorn()
test_rigid_transform()
visualize_epsilon_gamma()
print('\n' + '=' * 60)
print('结构总结:')
print('=' * 60)
print("""
UOTHead (非平衡最优传输位姿估计):
┌──────────────────────────────────────────────────────┐
│ 输入: feat1(B,150,128), feat2(B,150,128) │
│ kpts1(B,150,3), kpts2(B,150,3) │
│ │
│ 1. Cost Matrix: C = 1 - cosine_sim(feat1, feat2) │
│ → (B, 150, 150) │
│ │
│ 2. Sinkhorn Unbalanced OT (迭代5次): │
│ K = exp(-C / epsilon) │
│ for i in range(5): │
│ b = (prob2 / Kᵀa)^(γ/(γ+ε)) │
│ a = (prob1 / Kb)^(γ/(γ+ε)) │
│ T = a ⊙ K ⊙ bᵀ │
│ → (B, 150, 150) 运输计划 │
│ │
│ 3. 投影: project_kpts = T @ kpts2 / ΣT │
│ → (B, 150, 3) query匹配点在positive空间的投影坐标 │
│ │
│ 4. Weighted SVD 刚体变换: │
│ - 加权中心化 │
│ - SVD分解协方差 │
│ - 输出 R(3×3), t(3×1) │
│ → transformation: (B, 3, 4) │
└──────────────────────────────────────────────────────┘
为什么用Unbalanced OT非平衡最优传输
- 标准OT要求两个点集大小相同且质量守恒
- 实际场景:部分关键点在另一帧中可能被遮挡
- Unbalanced OT允许部分匹配更鲁棒
两个可学习参数:
- epsilon (ε): 熵正则化exp(ε)+0.03
- gamma (γ): 质量正则化exp(γ)
""")
print(f'\n所有可视化结果保存在: {OUTPUT_DIR}')
if __name__ == '__main__':
main()