网络测试和学习demo
This commit is contained in:
356
network_learning/06_uot_demo.py
Normal file
356
network_learning/06_uot_demo.py
Normal file
@@ -0,0 +1,356 @@
|
||||
"""
|
||||
UOT (Unbalanced Optimal Transport) 位姿估计 Demo
|
||||
=================================================
|
||||
UOTHead 使用 Sinkhorn 非平衡最优传输进行特征匹配和位姿估计。
|
||||
|
||||
流程:
|
||||
1. Cosine Cost Matrix: C = 1 - cosine_sim(feat1, feat2)
|
||||
2. Sinkhorn Unbalanced OT: 迭代求解运输计划 T
|
||||
3. Point Projection: project_kpts = T @ kpts2 / sum(T)
|
||||
4. Weighted SVD: 从匹配点对估计刚体变换 R|t
|
||||
|
||||
关键参数:
|
||||
- epsilon: 熵正则化(控制运输计划的平滑度)
|
||||
- gamma: 质量正则化(允许部分匹配)
|
||||
- sinkhorn_iter: 5次迭代
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from uot import UOTHead, sinkhorn_unbalanced, compute_rigid_transform
|
||||
|
||||
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output')
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
|
||||
def visualize_cost_matrix():
|
||||
"""可视化代价矩阵"""
|
||||
print('\n--- 代价矩阵 (Cost Matrix) ---')
|
||||
|
||||
torch.manual_seed(42)
|
||||
# 模拟query和positive的150个关键点特征
|
||||
feat1 = torch.randn(2, 150, 128) # query
|
||||
feat2 = torch.randn(2, 150, 128) # positive
|
||||
|
||||
# 让部分特征相似(模拟真实闭环场景)
|
||||
# 前100个特征点有对应关系
|
||||
feat2[:, :100] = feat1[:, :100] + 0.1 * torch.randn(2, 100, 128)
|
||||
|
||||
# 计算cosine cost matrix
|
||||
feat1_norm = feat1 / (feat1.norm(dim=2, keepdim=True) + 1e-8)
|
||||
feat2_norm = feat2 / (feat2.norm(dim=2, keepdim=True) + 1e-8)
|
||||
C = 1.0 - torch.bmm(feat1_norm, feat2_norm.transpose(1, 2))
|
||||
|
||||
C_np = C[0].numpy()
|
||||
|
||||
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
|
||||
|
||||
im0 = axes[0].imshow(C_np, cmap='YlOrRd')
|
||||
axes[0].set_title('Cost Matrix C = 1 - cos_sim')
|
||||
axes[0].set_xlabel('Positive Point j')
|
||||
axes[0].set_ylabel('Query Point i')
|
||||
plt.colorbar(im0, ax=axes[0])
|
||||
|
||||
# 缩放看前30个点(有对应关系的)
|
||||
im1 = axes[1].imshow(C_np[:30, :30], cmap='YlOrRd')
|
||||
axes[1].set_title('Cost Matrix (前30×30)\n有模拟对应关系')
|
||||
axes[1].set_xlabel('Positive Point j')
|
||||
axes[1].set_ylabel('Query Point i')
|
||||
plt.colorbar(im1, ax=axes[1])
|
||||
|
||||
# 对角线cost分布 vs 非对角线
|
||||
diag_cost = np.diag(C_np)
|
||||
off_diag = C_np[~np.eye(150, dtype=bool)]
|
||||
|
||||
axes[2].hist(diag_cost, bins=30, alpha=0.6, label=f'对角线(匹配点)\nmean={diag_cost.mean():.3f}',
|
||||
color='green')
|
||||
axes[2].hist(off_diag, bins=30, alpha=0.6, label=f'非对角线\nmean={off_diag.mean():.3f}',
|
||||
color='gray')
|
||||
axes[2].set_title('Cost分布: 匹配 vs 非匹配')
|
||||
axes[2].set_xlabel('Cost')
|
||||
axes[2].legend(fontsize=8)
|
||||
|
||||
plt.suptitle('UOT 代价矩阵分析', fontsize=14, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
path = os.path.join(OUTPUT_DIR, 'uot_cost_matrix.png')
|
||||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f' [保存] {path}')
|
||||
|
||||
|
||||
def visualize_sinkhorn():
|
||||
"""可视化Sinkhorn迭代过程"""
|
||||
print('\n--- Sinkhorn 迭代过程 ---')
|
||||
|
||||
torch.manual_seed(42)
|
||||
# 构造有明显对应关系的特征
|
||||
B, N, C = 1, 50, 128
|
||||
feat1 = torch.randn(B, N, C)
|
||||
feat1_norm = feat1 / (feat1.norm(dim=2, keepdim=True) + 1e-8)
|
||||
|
||||
# feat2是feat1的扰动版本
|
||||
feat2 = feat1 + 0.15 * torch.randn(B, N, C)
|
||||
feat2_norm = feat2 / (feat2.norm(dim=2, keepdim=True) + 1e-8)
|
||||
|
||||
C = 1.0 - torch.bmm(feat1_norm, feat2_norm.transpose(1, 2))
|
||||
|
||||
epsilon = torch.tensor([0.05])
|
||||
gamma = torch.tensor([1.0])
|
||||
|
||||
# 逐步可视化Sinkhorn迭代
|
||||
K = torch.exp(-C / epsilon)
|
||||
max_iter = 5
|
||||
power = gamma / (gamma + epsilon + 1e-8)
|
||||
|
||||
a = torch.ones((B, N, 1)) / N
|
||||
prob1 = torch.ones((B, N, 1)) / N
|
||||
prob2 = torch.ones((B, N, 1)) / N
|
||||
|
||||
fig, axes = plt.subplots(2, 4, figsize=(18, 9))
|
||||
|
||||
# K (初始)
|
||||
K_np = K[0].numpy()
|
||||
im0 = axes[0, 0].imshow(K_np, cmap='YlOrRd')
|
||||
axes[0, 0].set_title('K (exp(-C/ε))\n迭代0')
|
||||
axes[0, 0].set_xlabel('Positive'); axes[0, 0].set_ylabel('Query')
|
||||
plt.colorbar(im0, ax=axes[0, 0])
|
||||
|
||||
for iteration in range(1, min(max_iter + 1, 7)):
|
||||
# Update b
|
||||
KTa = torch.bmm(K.transpose(1, 2), a)
|
||||
b = torch.pow(prob2 / (KTa + 1e-8), power)
|
||||
# Update a
|
||||
Kb = torch.bmm(K, b)
|
||||
a = torch.pow(prob1 / (Kb + 1e-8), power)
|
||||
|
||||
T = torch.mul(torch.mul(a, K), b.transpose(1, 2))
|
||||
T_np = T[0].numpy()
|
||||
|
||||
ax = axes[(iteration) // 4, (iteration) % 4]
|
||||
im = ax.imshow(T_np, cmap='YlOrRd')
|
||||
ax.set_title(f'Transport Plan T\n迭代{iteration}')
|
||||
ax.set_xlabel('Positive'); ax.set_ylabel('Query')
|
||||
plt.colorbar(im, ax=ax)
|
||||
|
||||
# 空余位置
|
||||
for i in range(max_iter + 1, 8):
|
||||
axes[i // 4, i % 4].axis('off')
|
||||
|
||||
plt.suptitle('Sinkhorn 非平衡最优传输迭代过程', fontsize=14, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
path = os.path.join(OUTPUT_DIR, 'uot_sinkhorn_iterations.png')
|
||||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f' [保存] {path}')
|
||||
|
||||
|
||||
def test_rigid_transform():
|
||||
"""测试刚体变换估计"""
|
||||
print('\n--- Weighted SVD 刚体变换估计 ---')
|
||||
|
||||
torch.manual_seed(42)
|
||||
B, N = 2, 150
|
||||
|
||||
# 真实变换
|
||||
angle = torch.tensor(0.5) # ~28.6度
|
||||
R_true = torch.tensor([
|
||||
[torch.cos(angle), -torch.sin(angle), 0],
|
||||
[torch.sin(angle), torch.cos(angle), 0],
|
||||
[0, 0, 1]
|
||||
]).unsqueeze(0).repeat(B, 1, 1)
|
||||
t_true = torch.tensor([2.0, -1.0, 0.1]).unsqueeze(0).unsqueeze(-1).repeat(B, 1, 1)
|
||||
|
||||
# query点云
|
||||
pts1 = torch.randn(B, N, 3) * 20
|
||||
|
||||
# positive点云 = R * query + t + noise
|
||||
pts2 = R_true @ pts1.transpose(1, 2) + t_true
|
||||
pts2 = pts2.transpose(1, 2) + 0.3 * torch.randn(B, N, 3)
|
||||
|
||||
# 模拟transport weights(前80个点匹配好,后70个匹配差)
|
||||
weights = torch.ones(B, N)
|
||||
weights[:, 80:] = 0.1 # 降低后70个点的权重
|
||||
|
||||
# 估计变换
|
||||
transform = compute_rigid_transform(pts1, pts2, weights)
|
||||
|
||||
# 评估
|
||||
R_est = transform[:, :3, :3]
|
||||
t_est = transform[:, :3, 3]
|
||||
|
||||
# 旋转误差
|
||||
R_err = R_est @ R_true.transpose(1, 2)
|
||||
trace = torch.diagonal(R_err, dim1=1, dim2=2).sum(dim=1)
|
||||
angle_err = torch.acos(torch.clamp((trace - 1) / 2, -1, 1)) * 180 / np.pi
|
||||
|
||||
# 平移误差
|
||||
t_err = (t_est - t_true.squeeze(-1)).norm(dim=1)
|
||||
|
||||
print(f'旋转误差: {angle_err[0].item():.2f}°')
|
||||
print(f'平移误差: {t_err[0].item():.3f}m')
|
||||
|
||||
# 可视化
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||||
|
||||
# 3D点云(XY平面投影)
|
||||
pts1_2d = pts1[0, :, :2].numpy()
|
||||
pts2_2d = pts2[0, :, :2].numpy()
|
||||
|
||||
# 投影点
|
||||
pts1_transformed = (R_est[0] @ pts1[0].T + t_est[0].unsqueeze(-1)).T[:, :2].numpy()
|
||||
|
||||
axes[0].scatter(pts1_2d[:, 0], pts1_2d[:, 1], c='blue', s=10, alpha=0.6, label='Query')
|
||||
axes[0].scatter(pts2_2d[:, 0], pts2_2d[:, 1], c='red', s=10, alpha=0.6, label='Positive')
|
||||
for i in range(min(20, N)):
|
||||
if weights[0, i] > 0.5:
|
||||
axes[0].plot([pts1_2d[i, 0], pts2_2d[i, 0]],
|
||||
[pts1_2d[i, 1], pts2_2d[i, 1]],
|
||||
'gray', alpha=0.3, linewidth=0.5)
|
||||
axes[0].set_title('匹配点对 (蓝色→红色)')
|
||||
axes[0].set_xlabel('X (m)'); axes[0].set_ylabel('Y (m)')
|
||||
axes[0].legend(fontsize=8)
|
||||
axes[0].set_aspect('equal')
|
||||
|
||||
# 变换后
|
||||
axes[1].scatter(pts1_transformed[:, 0], pts1_transformed[:, 1],
|
||||
c='blue', s=10, alpha=0.6, label='Query (变换后)')
|
||||
axes[1].scatter(pts2_2d[:, 0], pts2_2d[:, 1],
|
||||
c='red', s=10, alpha=0.6, label='Positive (目标)')
|
||||
axes[1].set_title(f'变换后对比\n旋转误差:{angle_err[0].item():.2f}° 平移误差:{t_err[0].item():.3f}m')
|
||||
axes[1].set_xlabel('X (m)'); axes[1].set_ylabel('Y (m)')
|
||||
axes[1].legend(fontsize=8)
|
||||
axes[1].set_aspect('equal')
|
||||
|
||||
plt.suptitle('Weighted SVD 刚体变换估计', fontsize=14, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
path = os.path.join(OUTPUT_DIR, 'uot_rigid_transform.png')
|
||||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f' [保存] {path}')
|
||||
|
||||
|
||||
def visualize_epsilon_gamma():
|
||||
"""可视化epsilon和gamma参数的影响"""
|
||||
print('\n--- epsilon/gamma 参数分析 ---')
|
||||
|
||||
torch.manual_seed(42)
|
||||
N = 50
|
||||
feat1 = torch.randn(1, N, 128)
|
||||
feat1_norm = feat1 / (feat1.norm(dim=2, keepdim=True) + 1e-8)
|
||||
feat2 = feat1 + 0.2 * torch.randn(1, N, 128)
|
||||
feat2_norm = feat2 / (feat2.norm(dim=2, keepdim=True) + 1e-8)
|
||||
|
||||
epsilons = [0.01, 0.05, 0.1, 0.5]
|
||||
gammas = [0.1, 1.0, 10.0]
|
||||
|
||||
fig, axes = plt.subplots(len(gammas), len(epsilons), figsize=(16, 12))
|
||||
|
||||
for gi, gamma in enumerate(gammas):
|
||||
for ei, eps in enumerate(epsilons):
|
||||
epsilon = torch.tensor([eps])
|
||||
gam = torch.tensor([gamma])
|
||||
T = sinkhorn_unbalanced(
|
||||
feat1_norm, feat2_norm,
|
||||
epsilon=epsilon, gamma=gam,
|
||||
max_iter=5, matrix='cosine'
|
||||
)
|
||||
ax = axes[gi, ei]
|
||||
im = ax.imshow(T[0].numpy(), cmap='YlOrRd')
|
||||
ax.set_title(f'ε={eps}, γ={gamma}')
|
||||
ax.set_xlabel('Positive'); ax.set_ylabel('Query')
|
||||
plt.colorbar(im, ax=ax)
|
||||
|
||||
plt.suptitle('epsilon (熵正则) 和 gamma (质量正则) 对 Transport Plan 的影响',
|
||||
fontsize=14, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
path = os.path.join(OUTPUT_DIR, 'uot_epsilon_gamma.png')
|
||||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f' [保存] {path}')
|
||||
|
||||
print("""
|
||||
参数解释:
|
||||
- epsilon (ε): 熵正则化强度
|
||||
- 小ε → Transport Plan更稀疏(hard matching)
|
||||
- 大ε → Transport Plan更平滑(soft matching)
|
||||
- gamma (γ): 质量正则化强度
|
||||
- 小γ → 允许部分匹配(质量可增减)
|
||||
- 大γ → 要求质量守恒(所有点必须匹配)
|
||||
""")
|
||||
|
||||
|
||||
def analyze_parameters():
|
||||
"""参数量分析"""
|
||||
print('\n--- 参数量分析 ---')
|
||||
uot = UOTHead(nb_iter=5, name='original')
|
||||
total = sum(p.numel() for p in uot.parameters())
|
||||
print(f'总参数量: {total} (仅 epsilon, gamma 两个可学习标量)')
|
||||
for name, param in uot.named_parameters():
|
||||
print(f' {name}: {param.data.item():.4f}')
|
||||
|
||||
|
||||
def main():
|
||||
print('=' * 60)
|
||||
print('UOT (Unbalanced Optimal Transport) 位姿估计可视化')
|
||||
print('=' * 60)
|
||||
|
||||
analyze_parameters()
|
||||
visualize_cost_matrix()
|
||||
visualize_sinkhorn()
|
||||
test_rigid_transform()
|
||||
visualize_epsilon_gamma()
|
||||
|
||||
print('\n' + '=' * 60)
|
||||
print('结构总结:')
|
||||
print('=' * 60)
|
||||
print("""
|
||||
UOTHead (非平衡最优传输位姿估计):
|
||||
|
||||
┌──────────────────────────────────────────────────────┐
|
||||
│ 输入: feat1(B,150,128), feat2(B,150,128) │
|
||||
│ kpts1(B,150,3), kpts2(B,150,3) │
|
||||
│ │
|
||||
│ 1. Cost Matrix: C = 1 - cosine_sim(feat1, feat2) │
|
||||
│ → (B, 150, 150) │
|
||||
│ │
|
||||
│ 2. Sinkhorn Unbalanced OT (迭代5次): │
|
||||
│ K = exp(-C / epsilon) │
|
||||
│ for i in range(5): │
|
||||
│ b = (prob2 / Kᵀa)^(γ/(γ+ε)) │
|
||||
│ a = (prob1 / Kb)^(γ/(γ+ε)) │
|
||||
│ T = a ⊙ K ⊙ bᵀ │
|
||||
│ → (B, 150, 150) 运输计划 │
|
||||
│ │
|
||||
│ 3. 投影: project_kpts = T @ kpts2 / ΣT │
|
||||
│ → (B, 150, 3) query匹配点在positive空间的投影坐标 │
|
||||
│ │
|
||||
│ 4. Weighted SVD 刚体变换: │
|
||||
│ - 加权中心化 │
|
||||
│ - SVD分解协方差 │
|
||||
│ - 输出 R(3×3), t(3×1) │
|
||||
│ → transformation: (B, 3, 4) │
|
||||
└──────────────────────────────────────────────────────┘
|
||||
|
||||
为什么用Unbalanced OT(非平衡最优传输)?
|
||||
- 标准OT要求两个点集大小相同且质量守恒
|
||||
- 实际场景:部分关键点在另一帧中可能被遮挡
|
||||
- Unbalanced OT允许部分匹配,更鲁棒
|
||||
|
||||
两个可学习参数:
|
||||
- epsilon (ε): 熵正则化,exp(ε)+0.03
|
||||
- gamma (γ): 质量正则化,exp(γ)
|
||||
""")
|
||||
|
||||
print(f'\n所有可视化结果保存在: {OUTPUT_DIR}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user