网络测试和学习demo

This commit is contained in:
cyy_mac
2026-05-09 17:03:40 +08:00
parent edbe8fdbf9
commit 78298e56f1
9 changed files with 2868 additions and 0 deletions

View File

@@ -0,0 +1,356 @@
"""
UOT (Unbalanced Optimal Transport) 位姿估计 Demo
=================================================
UOTHead 使用 Sinkhorn 非平衡最优传输进行特征匹配和位姿估计。
流程:
1. Cosine Cost Matrix: C = 1 - cosine_sim(feat1, feat2)
2. Sinkhorn Unbalanced OT: 迭代求解运输计划 T
3. Point Projection: project_kpts = T @ kpts2 / sum(T)
4. Weighted SVD: 从匹配点对估计刚体变换 R|t
关键参数:
- epsilon: 熵正则化(控制运输计划的平滑度)
- gamma: 质量正则化(允许部分匹配)
- sinkhorn_iter: 5次迭代
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from uot import UOTHead, sinkhorn_unbalanced, compute_rigid_transform
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output')
os.makedirs(OUTPUT_DIR, exist_ok=True)
def visualize_cost_matrix():
"""可视化代价矩阵"""
print('\n--- 代价矩阵 (Cost Matrix) ---')
torch.manual_seed(42)
# 模拟query和positive的150个关键点特征
feat1 = torch.randn(2, 150, 128) # query
feat2 = torch.randn(2, 150, 128) # positive
# 让部分特征相似(模拟真实闭环场景)
# 前100个特征点有对应关系
feat2[:, :100] = feat1[:, :100] + 0.1 * torch.randn(2, 100, 128)
# 计算cosine cost matrix
feat1_norm = feat1 / (feat1.norm(dim=2, keepdim=True) + 1e-8)
feat2_norm = feat2 / (feat2.norm(dim=2, keepdim=True) + 1e-8)
C = 1.0 - torch.bmm(feat1_norm, feat2_norm.transpose(1, 2))
C_np = C[0].numpy()
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
im0 = axes[0].imshow(C_np, cmap='YlOrRd')
axes[0].set_title('Cost Matrix C = 1 - cos_sim')
axes[0].set_xlabel('Positive Point j')
axes[0].set_ylabel('Query Point i')
plt.colorbar(im0, ax=axes[0])
# 缩放看前30个点有对应关系的
im1 = axes[1].imshow(C_np[:30, :30], cmap='YlOrRd')
axes[1].set_title('Cost Matrix (前30×30)\n有模拟对应关系')
axes[1].set_xlabel('Positive Point j')
axes[1].set_ylabel('Query Point i')
plt.colorbar(im1, ax=axes[1])
# 对角线cost分布 vs 非对角线
diag_cost = np.diag(C_np)
off_diag = C_np[~np.eye(150, dtype=bool)]
axes[2].hist(diag_cost, bins=30, alpha=0.6, label=f'对角线(匹配点)\nmean={diag_cost.mean():.3f}',
color='green')
axes[2].hist(off_diag, bins=30, alpha=0.6, label=f'非对角线\nmean={off_diag.mean():.3f}',
color='gray')
axes[2].set_title('Cost分布: 匹配 vs 非匹配')
axes[2].set_xlabel('Cost')
axes[2].legend(fontsize=8)
plt.suptitle('UOT 代价矩阵分析', fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'uot_cost_matrix.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def visualize_sinkhorn():
"""可视化Sinkhorn迭代过程"""
print('\n--- Sinkhorn 迭代过程 ---')
torch.manual_seed(42)
# 构造有明显对应关系的特征
B, N, C = 1, 50, 128
feat1 = torch.randn(B, N, C)
feat1_norm = feat1 / (feat1.norm(dim=2, keepdim=True) + 1e-8)
# feat2是feat1的扰动版本
feat2 = feat1 + 0.15 * torch.randn(B, N, C)
feat2_norm = feat2 / (feat2.norm(dim=2, keepdim=True) + 1e-8)
C = 1.0 - torch.bmm(feat1_norm, feat2_norm.transpose(1, 2))
epsilon = torch.tensor([0.05])
gamma = torch.tensor([1.0])
# 逐步可视化Sinkhorn迭代
K = torch.exp(-C / epsilon)
max_iter = 5
power = gamma / (gamma + epsilon + 1e-8)
a = torch.ones((B, N, 1)) / N
prob1 = torch.ones((B, N, 1)) / N
prob2 = torch.ones((B, N, 1)) / N
fig, axes = plt.subplots(2, 4, figsize=(18, 9))
# K (初始)
K_np = K[0].numpy()
im0 = axes[0, 0].imshow(K_np, cmap='YlOrRd')
axes[0, 0].set_title('K (exp(-C/ε))\n迭代0')
axes[0, 0].set_xlabel('Positive'); axes[0, 0].set_ylabel('Query')
plt.colorbar(im0, ax=axes[0, 0])
for iteration in range(1, min(max_iter + 1, 7)):
# Update b
KTa = torch.bmm(K.transpose(1, 2), a)
b = torch.pow(prob2 / (KTa + 1e-8), power)
# Update a
Kb = torch.bmm(K, b)
a = torch.pow(prob1 / (Kb + 1e-8), power)
T = torch.mul(torch.mul(a, K), b.transpose(1, 2))
T_np = T[0].numpy()
ax = axes[(iteration) // 4, (iteration) % 4]
im = ax.imshow(T_np, cmap='YlOrRd')
ax.set_title(f'Transport Plan T\n迭代{iteration}')
ax.set_xlabel('Positive'); ax.set_ylabel('Query')
plt.colorbar(im, ax=ax)
# 空余位置
for i in range(max_iter + 1, 8):
axes[i // 4, i % 4].axis('off')
plt.suptitle('Sinkhorn 非平衡最优传输迭代过程', fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'uot_sinkhorn_iterations.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def test_rigid_transform():
"""测试刚体变换估计"""
print('\n--- Weighted SVD 刚体变换估计 ---')
torch.manual_seed(42)
B, N = 2, 150
# 真实变换
angle = torch.tensor(0.5) # ~28.6度
R_true = torch.tensor([
[torch.cos(angle), -torch.sin(angle), 0],
[torch.sin(angle), torch.cos(angle), 0],
[0, 0, 1]
]).unsqueeze(0).repeat(B, 1, 1)
t_true = torch.tensor([2.0, -1.0, 0.1]).unsqueeze(0).unsqueeze(-1).repeat(B, 1, 1)
# query点云
pts1 = torch.randn(B, N, 3) * 20
# positive点云 = R * query + t + noise
pts2 = R_true @ pts1.transpose(1, 2) + t_true
pts2 = pts2.transpose(1, 2) + 0.3 * torch.randn(B, N, 3)
# 模拟transport weights前80个点匹配好后70个匹配差
weights = torch.ones(B, N)
weights[:, 80:] = 0.1 # 降低后70个点的权重
# 估计变换
transform = compute_rigid_transform(pts1, pts2, weights)
# 评估
R_est = transform[:, :3, :3]
t_est = transform[:, :3, 3]
# 旋转误差
R_err = R_est @ R_true.transpose(1, 2)
trace = torch.diagonal(R_err, dim1=1, dim2=2).sum(dim=1)
angle_err = torch.acos(torch.clamp((trace - 1) / 2, -1, 1)) * 180 / np.pi
# 平移误差
t_err = (t_est - t_true.squeeze(-1)).norm(dim=1)
print(f'旋转误差: {angle_err[0].item():.2f}°')
print(f'平移误差: {t_err[0].item():.3f}m')
# 可视化
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# 3D点云XY平面投影
pts1_2d = pts1[0, :, :2].numpy()
pts2_2d = pts2[0, :, :2].numpy()
# 投影点
pts1_transformed = (R_est[0] @ pts1[0].T + t_est[0].unsqueeze(-1)).T[:, :2].numpy()
axes[0].scatter(pts1_2d[:, 0], pts1_2d[:, 1], c='blue', s=10, alpha=0.6, label='Query')
axes[0].scatter(pts2_2d[:, 0], pts2_2d[:, 1], c='red', s=10, alpha=0.6, label='Positive')
for i in range(min(20, N)):
if weights[0, i] > 0.5:
axes[0].plot([pts1_2d[i, 0], pts2_2d[i, 0]],
[pts1_2d[i, 1], pts2_2d[i, 1]],
'gray', alpha=0.3, linewidth=0.5)
axes[0].set_title('匹配点对 (蓝色→红色)')
axes[0].set_xlabel('X (m)'); axes[0].set_ylabel('Y (m)')
axes[0].legend(fontsize=8)
axes[0].set_aspect('equal')
# 变换后
axes[1].scatter(pts1_transformed[:, 0], pts1_transformed[:, 1],
c='blue', s=10, alpha=0.6, label='Query (变换后)')
axes[1].scatter(pts2_2d[:, 0], pts2_2d[:, 1],
c='red', s=10, alpha=0.6, label='Positive (目标)')
axes[1].set_title(f'变换后对比\n旋转误差:{angle_err[0].item():.2f}° 平移误差:{t_err[0].item():.3f}m')
axes[1].set_xlabel('X (m)'); axes[1].set_ylabel('Y (m)')
axes[1].legend(fontsize=8)
axes[1].set_aspect('equal')
plt.suptitle('Weighted SVD 刚体变换估计', fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'uot_rigid_transform.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def visualize_epsilon_gamma():
"""可视化epsilon和gamma参数的影响"""
print('\n--- epsilon/gamma 参数分析 ---')
torch.manual_seed(42)
N = 50
feat1 = torch.randn(1, N, 128)
feat1_norm = feat1 / (feat1.norm(dim=2, keepdim=True) + 1e-8)
feat2 = feat1 + 0.2 * torch.randn(1, N, 128)
feat2_norm = feat2 / (feat2.norm(dim=2, keepdim=True) + 1e-8)
epsilons = [0.01, 0.05, 0.1, 0.5]
gammas = [0.1, 1.0, 10.0]
fig, axes = plt.subplots(len(gammas), len(epsilons), figsize=(16, 12))
for gi, gamma in enumerate(gammas):
for ei, eps in enumerate(epsilons):
epsilon = torch.tensor([eps])
gam = torch.tensor([gamma])
T = sinkhorn_unbalanced(
feat1_norm, feat2_norm,
epsilon=epsilon, gamma=gam,
max_iter=5, matrix='cosine'
)
ax = axes[gi, ei]
im = ax.imshow(T[0].numpy(), cmap='YlOrRd')
ax.set_title(f'ε={eps}, γ={gamma}')
ax.set_xlabel('Positive'); ax.set_ylabel('Query')
plt.colorbar(im, ax=ax)
plt.suptitle('epsilon (熵正则) 和 gamma (质量正则) 对 Transport Plan 的影响',
fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'uot_epsilon_gamma.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
print("""
参数解释:
- epsilon (ε): 熵正则化强度
- 小ε → Transport Plan更稀疏hard matching
- 大ε → Transport Plan更平滑soft matching
- gamma (γ): 质量正则化强度
- 小γ → 允许部分匹配(质量可增减)
- 大γ → 要求质量守恒(所有点必须匹配)
""")
def analyze_parameters():
"""参数量分析"""
print('\n--- 参数量分析 ---')
uot = UOTHead(nb_iter=5, name='original')
total = sum(p.numel() for p in uot.parameters())
print(f'总参数量: {total} (仅 epsilon, gamma 两个可学习标量)')
for name, param in uot.named_parameters():
print(f' {name}: {param.data.item():.4f}')
def main():
print('=' * 60)
print('UOT (Unbalanced Optimal Transport) 位姿估计可视化')
print('=' * 60)
analyze_parameters()
visualize_cost_matrix()
visualize_sinkhorn()
test_rigid_transform()
visualize_epsilon_gamma()
print('\n' + '=' * 60)
print('结构总结:')
print('=' * 60)
print("""
UOTHead (非平衡最优传输位姿估计):
┌──────────────────────────────────────────────────────┐
│ 输入: feat1(B,150,128), feat2(B,150,128) │
│ kpts1(B,150,3), kpts2(B,150,3) │
│ │
│ 1. Cost Matrix: C = 1 - cosine_sim(feat1, feat2) │
│ → (B, 150, 150) │
│ │
│ 2. Sinkhorn Unbalanced OT (迭代5次): │
│ K = exp(-C / epsilon) │
│ for i in range(5): │
│ b = (prob2 / Kᵀa)^(γ/(γ+ε)) │
│ a = (prob1 / Kb)^(γ/(γ+ε)) │
│ T = a ⊙ K ⊙ bᵀ │
│ → (B, 150, 150) 运输计划 │
│ │
│ 3. 投影: project_kpts = T @ kpts2 / ΣT │
│ → (B, 150, 3) query匹配点在positive空间的投影坐标 │
│ │
│ 4. Weighted SVD 刚体变换: │
│ - 加权中心化 │
│ - SVD分解协方差 │
│ - 输出 R(3×3), t(3×1) │
│ → transformation: (B, 3, 4) │
└──────────────────────────────────────────────────────┘
为什么用Unbalanced OT非平衡最优传输
- 标准OT要求两个点集大小相同且质量守恒
- 实际场景:部分关键点在另一帧中可能被遮挡
- Unbalanced OT允许部分匹配更鲁棒
两个可学习参数:
- epsilon (ε): 熵正则化exp(ε)+0.03
- gamma (γ): 质量正则化exp(γ)
""")
print(f'\n所有可视化结果保存在: {OUTPUT_DIR}')
if __name__ == '__main__':
main()