231 lines
8.0 KiB
Python
231 lines
8.0 KiB
Python
"""
|
||
Converter 跨模态特征转换器 Demo
|
||
================================
|
||
Converter 是跨模态融合的核心组件,负责在不同模态之间转换特征:
|
||
- cvt_bev: 图像特征 → BEV空间特征
|
||
- cvt_img: BEV特征 → 图像空间特征
|
||
|
||
结构:
|
||
Self-Attention (MHA) + Conv1d瓶颈残差块
|
||
输入: (B, 128, N) N个特征点
|
||
输出: (B, 128, N) 转换后的特征
|
||
|
||
作用: 使两个模态的特征在同一个空间中对齐,便于后续匹配和融合
|
||
"""
|
||
|
||
import torch
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib
|
||
matplotlib.use('Agg')
|
||
|
||
import sys
|
||
import os
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from net import Converter
|
||
|
||
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output')
|
||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||
|
||
|
||
def visualize_feature_similarity(fea_before, fea_after, title, save_name):
|
||
"""可视化特征转换前后的相似度矩阵"""
|
||
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
|
||
|
||
# 转换前特征相似度
|
||
fea_before_norm = fea_before / (fea_before.norm(dim=1, keepdim=True) + 1e-8)
|
||
sim_before = (fea_before_norm[0].T @ fea_before_norm[0]).detach().numpy()
|
||
|
||
im0 = axes[0, 0].imshow(sim_before, cmap='RdYlBu_r', vmin=-1, vmax=1)
|
||
axes[0, 0].set_title('转换前 特征相似度矩阵')
|
||
axes[0, 0].set_xlabel('Point j'); axes[0, 0].set_ylabel('Point i')
|
||
plt.colorbar(im0, ax=axes[0, 0])
|
||
|
||
# 转换后特征相似度
|
||
fea_after_norm = fea_after / (fea_after.norm(dim=1, keepdim=True) + 1e-8)
|
||
sim_after = (fea_after_norm[0].T @ fea_after_norm[0]).detach().numpy()
|
||
|
||
im1 = axes[0, 1].imshow(sim_after, cmap='RdYlBu_r', vmin=-1, vmax=1)
|
||
axes[0, 1].set_title('转换后 特征相似度矩阵')
|
||
axes[0, 1].set_xlabel('Point j'); axes[0, 1].set_ylabel('Point i')
|
||
plt.colorbar(im1, ax=axes[0, 1])
|
||
|
||
# 差异
|
||
im2 = axes[0, 2].imshow(np.abs(sim_after - sim_before), cmap='YlOrRd')
|
||
axes[0, 2].set_title('相似度变化 |差值|')
|
||
axes[0, 2].set_xlabel('Point j'); axes[0, 2].set_ylabel('Point i')
|
||
plt.colorbar(im2, ax=axes[0, 2])
|
||
|
||
# 特征值分布 before
|
||
vals_before = fea_before[0].detach().numpy().flatten()
|
||
axes[1, 0].hist(vals_before, bins=50, color='steelblue', edgecolor='white', alpha=0.7)
|
||
axes[1, 0].set_title('转换前 特征值分布')
|
||
axes[1, 0].set_xlabel('Feature Value')
|
||
|
||
# 特征值分布 after
|
||
vals_after = fea_after[0].detach().numpy().flatten()
|
||
axes[1, 1].hist(vals_after, bins=50, color='coral', edgecolor='white', alpha=0.7)
|
||
axes[1, 1].set_title('转换后 特征值分布')
|
||
axes[1, 1].set_xlabel('Feature Value')
|
||
|
||
# 重叠对比
|
||
axes[1, 2].hist(vals_before, bins=50, color='steelblue', edgecolor='white',
|
||
alpha=0.5, label='Before')
|
||
axes[1, 2].hist(vals_after, bins=50, color='coral', edgecolor='white',
|
||
alpha=0.5, label='After')
|
||
axes[1, 2].set_title('分布对比')
|
||
axes[1, 2].legend()
|
||
|
||
plt.suptitle(title, fontsize=14, fontweight='bold')
|
||
plt.tight_layout()
|
||
path = os.path.join(OUTPUT_DIR, save_name)
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f' [保存] {path}')
|
||
|
||
|
||
def visualize_attention(converter, fea_input):
|
||
"""提取并可视化Self-Attention权重"""
|
||
b, c, n = fea_input.shape
|
||
x1 = fea_input.permute(0, 2, 1) # B, N, C
|
||
|
||
# 手动计算attention权重
|
||
with torch.no_grad():
|
||
q = converter.mha.w_q(x1)
|
||
k = converter.mha.w_k(x1)
|
||
weights = torch.nn.functional.softmax(
|
||
torch.matmul(q, k.transpose(-2, -1)) / (converter.mha.d_model ** 0.5),
|
||
dim=-1
|
||
)
|
||
|
||
# 可视化前几个点的attention
|
||
n_show = 6
|
||
n = min(n, weights.shape[1])
|
||
|
||
fig, axes = plt.subplots(2, 3, figsize=(16, 10))
|
||
for idx in range(min(n_show, n)):
|
||
ax = axes[idx // 3, idx % 3]
|
||
ax.bar(range(min(n, 50)), weights[0, idx, :min(n, 50)].detach().numpy(),
|
||
color='steelblue', width=1.0)
|
||
ax.set_title(f'Query Point {idx} 的 Attention')
|
||
ax.set_xlabel('Key Point')
|
||
ax.set_ylabel('Weight')
|
||
ax.axhline(y=1.0 / n, color='red', linestyle='--', alpha=0.5, label=f'平均={1/n:.3f}')
|
||
ax.legend(fontsize=8)
|
||
|
||
for idx in range(n_show, 6):
|
||
axes[idx // 3, idx % 3].axis('off')
|
||
|
||
plt.suptitle('Converter Self-Attention 权重分析', fontsize=14, fontweight='bold')
|
||
plt.tight_layout()
|
||
path = os.path.join(OUTPUT_DIR, 'converter_attention.png')
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f' [保存] {path}')
|
||
|
||
|
||
def test_cross_modal_convert():
|
||
"""测试跨模态转换:模拟图像特征→BEV特征转换"""
|
||
print('\n--- 跨模态转换测试 ---')
|
||
|
||
converter_bev = Converter(in_c=128)
|
||
converter_img = Converter(in_c=128)
|
||
|
||
# 模拟两个模态的特征
|
||
# 图像空间特征 (从图像特征图采样的N个点)
|
||
torch.manual_seed(42)
|
||
fea_img_space = torch.randn(2, 128, 100) # B=2, C=128, N=100
|
||
|
||
# BEV空间特征
|
||
fea_bev_space = torch.randn(2, 128, 100)
|
||
|
||
with torch.no_grad():
|
||
# 图像→BEV: 将图像空间特征转换到BEV空间
|
||
fea_to_bev = converter_bev(fea_img_space)
|
||
|
||
# BEV→图像: 将BEV空间特征转换到图像空间
|
||
fea_to_img = converter_img(fea_bev_space)
|
||
|
||
print(f'图像空间特征输入: {fea_img_space.shape}')
|
||
print(f'→ cvt_bev 转换后: {fea_to_bev.shape}')
|
||
print(f'BEV空间特征输入: {fea_bev_space.shape}')
|
||
print(f'→ cvt_img 转换后: {fea_to_img.shape}')
|
||
|
||
# 可视化转换前后
|
||
visualize_feature_similarity(
|
||
fea_img_space, fea_to_bev,
|
||
'cvt_bev: 图像特征 → BEV空间',
|
||
'converter_img_to_bev.png'
|
||
)
|
||
|
||
visualize_feature_similarity(
|
||
fea_bev_space, fea_to_img,
|
||
'cvt_img: BEV特征 → 图像空间',
|
||
'converter_bev_to_img.png'
|
||
)
|
||
|
||
# 可视化attention
|
||
visualize_attention(converter_bev, fea_img_space)
|
||
|
||
|
||
def analyze_architecture():
|
||
"""分析Converter结构"""
|
||
print('\n--- Converter 架构分析 ---')
|
||
|
||
converter = Converter(in_c=128)
|
||
total = sum(p.numel() for p in converter.parameters())
|
||
print(f'总参数量: {total:,} ({total / 1e3:.1f}K)')
|
||
|
||
for name, module in converter.named_children():
|
||
params = sum(p.numel() for p in module.parameters())
|
||
print(f' {name:15s}: {params:>10,} params')
|
||
|
||
# 详细结构
|
||
print("""
|
||
Converter 内部结构:
|
||
|
||
┌──────────────────────────────────────────┐
|
||
│ 输入 x: (B, 128, N) │
|
||
│ │ │
|
||
│ ┌─────┴─────┐ │
|
||
│ │ 路径1: MHA │ 路径2: Conv1d瓶颈块 │
|
||
│ │ Self-Attn │ Conv1d(128→32→128) │
|
||
│ │ x → x2 │ x → x3 │
|
||
│ └─────┬─────┘ │
|
||
│ │ │
|
||
│ concat([x2, x3]) → Conv1d(256→128) │
|
||
│ │ │
|
||
│ 输出: (B, 128, N) │
|
||
└──────────────────────────────────────────┘
|
||
|
||
MHA (多头自注意力):
|
||
- d_model=128, num_heads=4
|
||
- Q,K,V → 点积attention → FFN
|
||
- 捕捉特征点之间的全局关系
|
||
|
||
Conv1d瓶颈块:
|
||
- 128→32→16→32→128→128 (bottleneck)
|
||
- 逐点卷积,提取通道间的非线性关系
|
||
|
||
两条路径互补:
|
||
- MHA: 全局上下文建模
|
||
- Conv1d: 局部特征变换
|
||
- 残差连接 + concat融合
|
||
""")
|
||
|
||
|
||
def main():
|
||
print('=' * 60)
|
||
print('Converter (跨模态特征转换器) 结构与功能可视化')
|
||
print('=' * 60)
|
||
|
||
analyze_architecture()
|
||
test_cross_modal_convert()
|
||
|
||
print(f'\n所有可视化结果保存在: {OUTPUT_DIR}')
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|