""" UOT (Unbalanced Optimal Transport) 位姿估计 Demo ================================================= UOTHead 使用 Sinkhorn 非平衡最优传输进行特征匹配和位姿估计。 流程: 1. Cosine Cost Matrix: C = 1 - cosine_sim(feat1, feat2) 2. Sinkhorn Unbalanced OT: 迭代求解运输计划 T 3. Point Projection: project_kpts = T @ kpts2 / sum(T) 4. Weighted SVD: 从匹配点对估计刚体变换 R|t 关键参数: - epsilon: 熵正则化(控制运输计划的平滑度) - gamma: 质量正则化(允许部分匹配) - sinkhorn_iter: 5次迭代 """ import torch import numpy as np import matplotlib.pyplot as plt import matplotlib matplotlib.use('Agg') import sys import os sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from uot import UOTHead, sinkhorn_unbalanced, compute_rigid_transform OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output') os.makedirs(OUTPUT_DIR, exist_ok=True) def visualize_cost_matrix(): """可视化代价矩阵""" print('\n--- 代价矩阵 (Cost Matrix) ---') torch.manual_seed(42) # 模拟query和positive的150个关键点特征 feat1 = torch.randn(2, 150, 128) # query feat2 = torch.randn(2, 150, 128) # positive # 让部分特征相似(模拟真实闭环场景) # 前100个特征点有对应关系 feat2[:, :100] = feat1[:, :100] + 0.1 * torch.randn(2, 100, 128) # 计算cosine cost matrix feat1_norm = feat1 / (feat1.norm(dim=2, keepdim=True) + 1e-8) feat2_norm = feat2 / (feat2.norm(dim=2, keepdim=True) + 1e-8) C = 1.0 - torch.bmm(feat1_norm, feat2_norm.transpose(1, 2)) C_np = C[0].numpy() fig, axes = plt.subplots(1, 3, figsize=(18, 5)) im0 = axes[0].imshow(C_np, cmap='YlOrRd') axes[0].set_title('Cost Matrix C = 1 - cos_sim') axes[0].set_xlabel('Positive Point j') axes[0].set_ylabel('Query Point i') plt.colorbar(im0, ax=axes[0]) # 缩放看前30个点(有对应关系的) im1 = axes[1].imshow(C_np[:30, :30], cmap='YlOrRd') axes[1].set_title('Cost Matrix (前30×30)\n有模拟对应关系') axes[1].set_xlabel('Positive Point j') axes[1].set_ylabel('Query Point i') plt.colorbar(im1, ax=axes[1]) # 对角线cost分布 vs 非对角线 diag_cost = np.diag(C_np) off_diag = C_np[~np.eye(150, dtype=bool)] axes[2].hist(diag_cost, bins=30, alpha=0.6, label=f'对角线(匹配点)\nmean={diag_cost.mean():.3f}', color='green') axes[2].hist(off_diag, bins=30, alpha=0.6, label=f'非对角线\nmean={off_diag.mean():.3f}', color='gray') axes[2].set_title('Cost分布: 匹配 vs 非匹配') axes[2].set_xlabel('Cost') axes[2].legend(fontsize=8) plt.suptitle('UOT 代价矩阵分析', fontsize=14, fontweight='bold') plt.tight_layout() path = os.path.join(OUTPUT_DIR, 'uot_cost_matrix.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f' [保存] {path}') def visualize_sinkhorn(): """可视化Sinkhorn迭代过程""" print('\n--- Sinkhorn 迭代过程 ---') torch.manual_seed(42) # 构造有明显对应关系的特征 B, N, C = 1, 50, 128 feat1 = torch.randn(B, N, C) feat1_norm = feat1 / (feat1.norm(dim=2, keepdim=True) + 1e-8) # feat2是feat1的扰动版本 feat2 = feat1 + 0.15 * torch.randn(B, N, C) feat2_norm = feat2 / (feat2.norm(dim=2, keepdim=True) + 1e-8) C = 1.0 - torch.bmm(feat1_norm, feat2_norm.transpose(1, 2)) epsilon = torch.tensor([0.05]) gamma = torch.tensor([1.0]) # 逐步可视化Sinkhorn迭代 K = torch.exp(-C / epsilon) max_iter = 5 power = gamma / (gamma + epsilon + 1e-8) a = torch.ones((B, N, 1)) / N prob1 = torch.ones((B, N, 1)) / N prob2 = torch.ones((B, N, 1)) / N fig, axes = plt.subplots(2, 4, figsize=(18, 9)) # K (初始) K_np = K[0].numpy() im0 = axes[0, 0].imshow(K_np, cmap='YlOrRd') axes[0, 0].set_title('K (exp(-C/ε))\n迭代0') axes[0, 0].set_xlabel('Positive'); axes[0, 0].set_ylabel('Query') plt.colorbar(im0, ax=axes[0, 0]) for iteration in range(1, min(max_iter + 1, 7)): # Update b KTa = torch.bmm(K.transpose(1, 2), a) b = torch.pow(prob2 / (KTa + 1e-8), power) # Update a Kb = torch.bmm(K, b) a = torch.pow(prob1 / (Kb + 1e-8), power) T = torch.mul(torch.mul(a, K), b.transpose(1, 2)) T_np = T[0].numpy() ax = axes[(iteration) // 4, (iteration) % 4] im = ax.imshow(T_np, cmap='YlOrRd') ax.set_title(f'Transport Plan T\n迭代{iteration}') ax.set_xlabel('Positive'); ax.set_ylabel('Query') plt.colorbar(im, ax=ax) # 空余位置 for i in range(max_iter + 1, 8): axes[i // 4, i % 4].axis('off') plt.suptitle('Sinkhorn 非平衡最优传输迭代过程', fontsize=14, fontweight='bold') plt.tight_layout() path = os.path.join(OUTPUT_DIR, 'uot_sinkhorn_iterations.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f' [保存] {path}') def test_rigid_transform(): """测试刚体变换估计""" print('\n--- Weighted SVD 刚体变换估计 ---') torch.manual_seed(42) B, N = 2, 150 # 真实变换 angle = torch.tensor(0.5) # ~28.6度 R_true = torch.tensor([ [torch.cos(angle), -torch.sin(angle), 0], [torch.sin(angle), torch.cos(angle), 0], [0, 0, 1] ]).unsqueeze(0).repeat(B, 1, 1) t_true = torch.tensor([2.0, -1.0, 0.1]).unsqueeze(0).unsqueeze(-1).repeat(B, 1, 1) # query点云 pts1 = torch.randn(B, N, 3) * 20 # positive点云 = R * query + t + noise pts2 = R_true @ pts1.transpose(1, 2) + t_true pts2 = pts2.transpose(1, 2) + 0.3 * torch.randn(B, N, 3) # 模拟transport weights(前80个点匹配好,后70个匹配差) weights = torch.ones(B, N) weights[:, 80:] = 0.1 # 降低后70个点的权重 # 估计变换 transform = compute_rigid_transform(pts1, pts2, weights) # 评估 R_est = transform[:, :3, :3] t_est = transform[:, :3, 3] # 旋转误差 R_err = R_est @ R_true.transpose(1, 2) trace = torch.diagonal(R_err, dim1=1, dim2=2).sum(dim=1) angle_err = torch.acos(torch.clamp((trace - 1) / 2, -1, 1)) * 180 / np.pi # 平移误差 t_err = (t_est - t_true.squeeze(-1)).norm(dim=1) print(f'旋转误差: {angle_err[0].item():.2f}°') print(f'平移误差: {t_err[0].item():.3f}m') # 可视化 fig, axes = plt.subplots(1, 2, figsize=(14, 5)) # 3D点云(XY平面投影) pts1_2d = pts1[0, :, :2].numpy() pts2_2d = pts2[0, :, :2].numpy() # 投影点 pts1_transformed = (R_est[0] @ pts1[0].T + t_est[0].unsqueeze(-1)).T[:, :2].numpy() axes[0].scatter(pts1_2d[:, 0], pts1_2d[:, 1], c='blue', s=10, alpha=0.6, label='Query') axes[0].scatter(pts2_2d[:, 0], pts2_2d[:, 1], c='red', s=10, alpha=0.6, label='Positive') for i in range(min(20, N)): if weights[0, i] > 0.5: axes[0].plot([pts1_2d[i, 0], pts2_2d[i, 0]], [pts1_2d[i, 1], pts2_2d[i, 1]], 'gray', alpha=0.3, linewidth=0.5) axes[0].set_title('匹配点对 (蓝色→红色)') axes[0].set_xlabel('X (m)'); axes[0].set_ylabel('Y (m)') axes[0].legend(fontsize=8) axes[0].set_aspect('equal') # 变换后 axes[1].scatter(pts1_transformed[:, 0], pts1_transformed[:, 1], c='blue', s=10, alpha=0.6, label='Query (变换后)') axes[1].scatter(pts2_2d[:, 0], pts2_2d[:, 1], c='red', s=10, alpha=0.6, label='Positive (目标)') axes[1].set_title(f'变换后对比\n旋转误差:{angle_err[0].item():.2f}° 平移误差:{t_err[0].item():.3f}m') axes[1].set_xlabel('X (m)'); axes[1].set_ylabel('Y (m)') axes[1].legend(fontsize=8) axes[1].set_aspect('equal') plt.suptitle('Weighted SVD 刚体变换估计', fontsize=14, fontweight='bold') plt.tight_layout() path = os.path.join(OUTPUT_DIR, 'uot_rigid_transform.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f' [保存] {path}') def visualize_epsilon_gamma(): """可视化epsilon和gamma参数的影响""" print('\n--- epsilon/gamma 参数分析 ---') torch.manual_seed(42) N = 50 feat1 = torch.randn(1, N, 128) feat1_norm = feat1 / (feat1.norm(dim=2, keepdim=True) + 1e-8) feat2 = feat1 + 0.2 * torch.randn(1, N, 128) feat2_norm = feat2 / (feat2.norm(dim=2, keepdim=True) + 1e-8) epsilons = [0.01, 0.05, 0.1, 0.5] gammas = [0.1, 1.0, 10.0] fig, axes = plt.subplots(len(gammas), len(epsilons), figsize=(16, 12)) for gi, gamma in enumerate(gammas): for ei, eps in enumerate(epsilons): epsilon = torch.tensor([eps]) gam = torch.tensor([gamma]) T = sinkhorn_unbalanced( feat1_norm, feat2_norm, epsilon=epsilon, gamma=gam, max_iter=5, matrix='cosine' ) ax = axes[gi, ei] im = ax.imshow(T[0].numpy(), cmap='YlOrRd') ax.set_title(f'ε={eps}, γ={gamma}') ax.set_xlabel('Positive'); ax.set_ylabel('Query') plt.colorbar(im, ax=ax) plt.suptitle('epsilon (熵正则) 和 gamma (质量正则) 对 Transport Plan 的影响', fontsize=14, fontweight='bold') plt.tight_layout() path = os.path.join(OUTPUT_DIR, 'uot_epsilon_gamma.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f' [保存] {path}') print(""" 参数解释: - epsilon (ε): 熵正则化强度 - 小ε → Transport Plan更稀疏(hard matching) - 大ε → Transport Plan更平滑(soft matching) - gamma (γ): 质量正则化强度 - 小γ → 允许部分匹配(质量可增减) - 大γ → 要求质量守恒(所有点必须匹配) """) def analyze_parameters(): """参数量分析""" print('\n--- 参数量分析 ---') uot = UOTHead(nb_iter=5, name='original') total = sum(p.numel() for p in uot.parameters()) print(f'总参数量: {total} (仅 epsilon, gamma 两个可学习标量)') for name, param in uot.named_parameters(): print(f' {name}: {param.data.item():.4f}') def main(): print('=' * 60) print('UOT (Unbalanced Optimal Transport) 位姿估计可视化') print('=' * 60) analyze_parameters() visualize_cost_matrix() visualize_sinkhorn() test_rigid_transform() visualize_epsilon_gamma() print('\n' + '=' * 60) print('结构总结:') print('=' * 60) print(""" UOTHead (非平衡最优传输位姿估计): ┌──────────────────────────────────────────────────────┐ │ 输入: feat1(B,150,128), feat2(B,150,128) │ │ kpts1(B,150,3), kpts2(B,150,3) │ │ │ │ 1. Cost Matrix: C = 1 - cosine_sim(feat1, feat2) │ │ → (B, 150, 150) │ │ │ │ 2. Sinkhorn Unbalanced OT (迭代5次): │ │ K = exp(-C / epsilon) │ │ for i in range(5): │ │ b = (prob2 / Kᵀa)^(γ/(γ+ε)) │ │ a = (prob1 / Kb)^(γ/(γ+ε)) │ │ T = a ⊙ K ⊙ bᵀ │ │ → (B, 150, 150) 运输计划 │ │ │ │ 3. 投影: project_kpts = T @ kpts2 / ΣT │ │ → (B, 150, 3) query匹配点在positive空间的投影坐标 │ │ │ │ 4. Weighted SVD 刚体变换: │ │ - 加权中心化 │ │ - SVD分解协方差 │ │ - 输出 R(3×3), t(3×1) │ │ → transformation: (B, 3, 4) │ └──────────────────────────────────────────────────────┘ 为什么用Unbalanced OT(非平衡最优传输)? - 标准OT要求两个点集大小相同且质量守恒 - 实际场景:部分关键点在另一帧中可能被遮挡 - Unbalanced OT允许部分匹配,更鲁棒 两个可学习参数: - epsilon (ε): 熵正则化,exp(ε)+0.03 - gamma (γ): 质量正则化,exp(γ) """) print(f'\n所有可视化结果保存在: {OUTPUT_DIR}') if __name__ == '__main__': main()