Files
fusion_LCD/network_learning/03_converter_demo.py
2026-05-09 17:03:40 +08:00

231 lines
8.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Converter 跨模态特征转换器 Demo
================================
Converter 是跨模态融合的核心组件,负责在不同模态之间转换特征:
- cvt_bev: 图像特征 → BEV空间特征
- cvt_img: BEV特征 → 图像空间特征
结构:
Self-Attention (MHA) + Conv1d瓶颈残差块
输入: (B, 128, N) N个特征点
输出: (B, 128, N) 转换后的特征
作用: 使两个模态的特征在同一个空间中对齐,便于后续匹配和融合
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from net import Converter
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output')
os.makedirs(OUTPUT_DIR, exist_ok=True)
def visualize_feature_similarity(fea_before, fea_after, title, save_name):
"""可视化特征转换前后的相似度矩阵"""
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
# 转换前特征相似度
fea_before_norm = fea_before / (fea_before.norm(dim=1, keepdim=True) + 1e-8)
sim_before = (fea_before_norm[0].T @ fea_before_norm[0]).detach().numpy()
im0 = axes[0, 0].imshow(sim_before, cmap='RdYlBu_r', vmin=-1, vmax=1)
axes[0, 0].set_title('转换前 特征相似度矩阵')
axes[0, 0].set_xlabel('Point j'); axes[0, 0].set_ylabel('Point i')
plt.colorbar(im0, ax=axes[0, 0])
# 转换后特征相似度
fea_after_norm = fea_after / (fea_after.norm(dim=1, keepdim=True) + 1e-8)
sim_after = (fea_after_norm[0].T @ fea_after_norm[0]).detach().numpy()
im1 = axes[0, 1].imshow(sim_after, cmap='RdYlBu_r', vmin=-1, vmax=1)
axes[0, 1].set_title('转换后 特征相似度矩阵')
axes[0, 1].set_xlabel('Point j'); axes[0, 1].set_ylabel('Point i')
plt.colorbar(im1, ax=axes[0, 1])
# 差异
im2 = axes[0, 2].imshow(np.abs(sim_after - sim_before), cmap='YlOrRd')
axes[0, 2].set_title('相似度变化 |差值|')
axes[0, 2].set_xlabel('Point j'); axes[0, 2].set_ylabel('Point i')
plt.colorbar(im2, ax=axes[0, 2])
# 特征值分布 before
vals_before = fea_before[0].detach().numpy().flatten()
axes[1, 0].hist(vals_before, bins=50, color='steelblue', edgecolor='white', alpha=0.7)
axes[1, 0].set_title('转换前 特征值分布')
axes[1, 0].set_xlabel('Feature Value')
# 特征值分布 after
vals_after = fea_after[0].detach().numpy().flatten()
axes[1, 1].hist(vals_after, bins=50, color='coral', edgecolor='white', alpha=0.7)
axes[1, 1].set_title('转换后 特征值分布')
axes[1, 1].set_xlabel('Feature Value')
# 重叠对比
axes[1, 2].hist(vals_before, bins=50, color='steelblue', edgecolor='white',
alpha=0.5, label='Before')
axes[1, 2].hist(vals_after, bins=50, color='coral', edgecolor='white',
alpha=0.5, label='After')
axes[1, 2].set_title('分布对比')
axes[1, 2].legend()
plt.suptitle(title, fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, save_name)
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def visualize_attention(converter, fea_input):
"""提取并可视化Self-Attention权重"""
b, c, n = fea_input.shape
x1 = fea_input.permute(0, 2, 1) # B, N, C
# 手动计算attention权重
with torch.no_grad():
q = converter.mha.w_q(x1)
k = converter.mha.w_k(x1)
weights = torch.nn.functional.softmax(
torch.matmul(q, k.transpose(-2, -1)) / (converter.mha.d_model ** 0.5),
dim=-1
)
# 可视化前几个点的attention
n_show = 6
n = min(n, weights.shape[1])
fig, axes = plt.subplots(2, 3, figsize=(16, 10))
for idx in range(min(n_show, n)):
ax = axes[idx // 3, idx % 3]
ax.bar(range(min(n, 50)), weights[0, idx, :min(n, 50)].detach().numpy(),
color='steelblue', width=1.0)
ax.set_title(f'Query Point {idx} 的 Attention')
ax.set_xlabel('Key Point')
ax.set_ylabel('Weight')
ax.axhline(y=1.0 / n, color='red', linestyle='--', alpha=0.5, label=f'平均={1/n:.3f}')
ax.legend(fontsize=8)
for idx in range(n_show, 6):
axes[idx // 3, idx % 3].axis('off')
plt.suptitle('Converter Self-Attention 权重分析', fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'converter_attention.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def test_cross_modal_convert():
"""测试跨模态转换模拟图像特征→BEV特征转换"""
print('\n--- 跨模态转换测试 ---')
converter_bev = Converter(in_c=128)
converter_img = Converter(in_c=128)
# 模拟两个模态的特征
# 图像空间特征 (从图像特征图采样的N个点)
torch.manual_seed(42)
fea_img_space = torch.randn(2, 128, 100) # B=2, C=128, N=100
# BEV空间特征
fea_bev_space = torch.randn(2, 128, 100)
with torch.no_grad():
# 图像→BEV: 将图像空间特征转换到BEV空间
fea_to_bev = converter_bev(fea_img_space)
# BEV→图像: 将BEV空间特征转换到图像空间
fea_to_img = converter_img(fea_bev_space)
print(f'图像空间特征输入: {fea_img_space.shape}')
print(f'→ cvt_bev 转换后: {fea_to_bev.shape}')
print(f'BEV空间特征输入: {fea_bev_space.shape}')
print(f'→ cvt_img 转换后: {fea_to_img.shape}')
# 可视化转换前后
visualize_feature_similarity(
fea_img_space, fea_to_bev,
'cvt_bev: 图像特征 → BEV空间',
'converter_img_to_bev.png'
)
visualize_feature_similarity(
fea_bev_space, fea_to_img,
'cvt_img: BEV特征 → 图像空间',
'converter_bev_to_img.png'
)
# 可视化attention
visualize_attention(converter_bev, fea_img_space)
def analyze_architecture():
"""分析Converter结构"""
print('\n--- Converter 架构分析 ---')
converter = Converter(in_c=128)
total = sum(p.numel() for p in converter.parameters())
print(f'总参数量: {total:,} ({total / 1e3:.1f}K)')
for name, module in converter.named_children():
params = sum(p.numel() for p in module.parameters())
print(f' {name:15s}: {params:>10,} params')
# 详细结构
print("""
Converter 内部结构:
┌──────────────────────────────────────────┐
│ 输入 x: (B, 128, N) │
│ │ │
│ ┌─────┴─────┐ │
│ │ 路径1: MHA │ 路径2: Conv1d瓶颈块 │
│ │ Self-Attn │ Conv1d(128→32→128) │
│ │ x → x2 │ x → x3 │
│ └─────┬─────┘ │
│ │ │
│ concat([x2, x3]) → Conv1d(256→128) │
│ │ │
│ 输出: (B, 128, N) │
└──────────────────────────────────────────┘
MHA (多头自注意力):
- d_model=128, num_heads=4
- Q,K,V → 点积attention → FFN
- 捕捉特征点之间的全局关系
Conv1d瓶颈块:
- 128→32→16→32→128→128 (bottleneck)
- 逐点卷积,提取通道间的非线性关系
两条路径互补:
- MHA: 全局上下文建模
- Conv1d: 局部特征变换
- 残差连接 + concat融合
""")
def main():
print('=' * 60)
print('Converter (跨模态特征转换器) 结构与功能可视化')
print('=' * 60)
analyze_architecture()
test_cross_modal_convert()
print(f'\n所有可视化结果保存在: {OUTPUT_DIR}')
if __name__ == '__main__':
main()