网络测试和学习demo
This commit is contained in:
230
network_learning/03_converter_demo.py
Normal file
230
network_learning/03_converter_demo.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
Converter 跨模态特征转换器 Demo
|
||||
================================
|
||||
Converter 是跨模态融合的核心组件,负责在不同模态之间转换特征:
|
||||
- cvt_bev: 图像特征 → BEV空间特征
|
||||
- cvt_img: BEV特征 → 图像空间特征
|
||||
|
||||
结构:
|
||||
Self-Attention (MHA) + Conv1d瓶颈残差块
|
||||
输入: (B, 128, N) N个特征点
|
||||
输出: (B, 128, N) 转换后的特征
|
||||
|
||||
作用: 使两个模态的特征在同一个空间中对齐,便于后续匹配和融合
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from net import Converter
|
||||
|
||||
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output')
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
|
||||
def visualize_feature_similarity(fea_before, fea_after, title, save_name):
|
||||
"""可视化特征转换前后的相似度矩阵"""
|
||||
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
|
||||
|
||||
# 转换前特征相似度
|
||||
fea_before_norm = fea_before / (fea_before.norm(dim=1, keepdim=True) + 1e-8)
|
||||
sim_before = (fea_before_norm[0].T @ fea_before_norm[0]).detach().numpy()
|
||||
|
||||
im0 = axes[0, 0].imshow(sim_before, cmap='RdYlBu_r', vmin=-1, vmax=1)
|
||||
axes[0, 0].set_title('转换前 特征相似度矩阵')
|
||||
axes[0, 0].set_xlabel('Point j'); axes[0, 0].set_ylabel('Point i')
|
||||
plt.colorbar(im0, ax=axes[0, 0])
|
||||
|
||||
# 转换后特征相似度
|
||||
fea_after_norm = fea_after / (fea_after.norm(dim=1, keepdim=True) + 1e-8)
|
||||
sim_after = (fea_after_norm[0].T @ fea_after_norm[0]).detach().numpy()
|
||||
|
||||
im1 = axes[0, 1].imshow(sim_after, cmap='RdYlBu_r', vmin=-1, vmax=1)
|
||||
axes[0, 1].set_title('转换后 特征相似度矩阵')
|
||||
axes[0, 1].set_xlabel('Point j'); axes[0, 1].set_ylabel('Point i')
|
||||
plt.colorbar(im1, ax=axes[0, 1])
|
||||
|
||||
# 差异
|
||||
im2 = axes[0, 2].imshow(np.abs(sim_after - sim_before), cmap='YlOrRd')
|
||||
axes[0, 2].set_title('相似度变化 |差值|')
|
||||
axes[0, 2].set_xlabel('Point j'); axes[0, 2].set_ylabel('Point i')
|
||||
plt.colorbar(im2, ax=axes[0, 2])
|
||||
|
||||
# 特征值分布 before
|
||||
vals_before = fea_before[0].detach().numpy().flatten()
|
||||
axes[1, 0].hist(vals_before, bins=50, color='steelblue', edgecolor='white', alpha=0.7)
|
||||
axes[1, 0].set_title('转换前 特征值分布')
|
||||
axes[1, 0].set_xlabel('Feature Value')
|
||||
|
||||
# 特征值分布 after
|
||||
vals_after = fea_after[0].detach().numpy().flatten()
|
||||
axes[1, 1].hist(vals_after, bins=50, color='coral', edgecolor='white', alpha=0.7)
|
||||
axes[1, 1].set_title('转换后 特征值分布')
|
||||
axes[1, 1].set_xlabel('Feature Value')
|
||||
|
||||
# 重叠对比
|
||||
axes[1, 2].hist(vals_before, bins=50, color='steelblue', edgecolor='white',
|
||||
alpha=0.5, label='Before')
|
||||
axes[1, 2].hist(vals_after, bins=50, color='coral', edgecolor='white',
|
||||
alpha=0.5, label='After')
|
||||
axes[1, 2].set_title('分布对比')
|
||||
axes[1, 2].legend()
|
||||
|
||||
plt.suptitle(title, fontsize=14, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
path = os.path.join(OUTPUT_DIR, save_name)
|
||||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f' [保存] {path}')
|
||||
|
||||
|
||||
def visualize_attention(converter, fea_input):
|
||||
"""提取并可视化Self-Attention权重"""
|
||||
b, c, n = fea_input.shape
|
||||
x1 = fea_input.permute(0, 2, 1) # B, N, C
|
||||
|
||||
# 手动计算attention权重
|
||||
with torch.no_grad():
|
||||
q = converter.mha.w_q(x1)
|
||||
k = converter.mha.w_k(x1)
|
||||
weights = torch.nn.functional.softmax(
|
||||
torch.matmul(q, k.transpose(-2, -1)) / (converter.mha.d_model ** 0.5),
|
||||
dim=-1
|
||||
)
|
||||
|
||||
# 可视化前几个点的attention
|
||||
n_show = 6
|
||||
n = min(n, weights.shape[1])
|
||||
|
||||
fig, axes = plt.subplots(2, 3, figsize=(16, 10))
|
||||
for idx in range(min(n_show, n)):
|
||||
ax = axes[idx // 3, idx % 3]
|
||||
ax.bar(range(min(n, 50)), weights[0, idx, :min(n, 50)].detach().numpy(),
|
||||
color='steelblue', width=1.0)
|
||||
ax.set_title(f'Query Point {idx} 的 Attention')
|
||||
ax.set_xlabel('Key Point')
|
||||
ax.set_ylabel('Weight')
|
||||
ax.axhline(y=1.0 / n, color='red', linestyle='--', alpha=0.5, label=f'平均={1/n:.3f}')
|
||||
ax.legend(fontsize=8)
|
||||
|
||||
for idx in range(n_show, 6):
|
||||
axes[idx // 3, idx % 3].axis('off')
|
||||
|
||||
plt.suptitle('Converter Self-Attention 权重分析', fontsize=14, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
path = os.path.join(OUTPUT_DIR, 'converter_attention.png')
|
||||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f' [保存] {path}')
|
||||
|
||||
|
||||
def test_cross_modal_convert():
|
||||
"""测试跨模态转换:模拟图像特征→BEV特征转换"""
|
||||
print('\n--- 跨模态转换测试 ---')
|
||||
|
||||
converter_bev = Converter(in_c=128)
|
||||
converter_img = Converter(in_c=128)
|
||||
|
||||
# 模拟两个模态的特征
|
||||
# 图像空间特征 (从图像特征图采样的N个点)
|
||||
torch.manual_seed(42)
|
||||
fea_img_space = torch.randn(2, 128, 100) # B=2, C=128, N=100
|
||||
|
||||
# BEV空间特征
|
||||
fea_bev_space = torch.randn(2, 128, 100)
|
||||
|
||||
with torch.no_grad():
|
||||
# 图像→BEV: 将图像空间特征转换到BEV空间
|
||||
fea_to_bev = converter_bev(fea_img_space)
|
||||
|
||||
# BEV→图像: 将BEV空间特征转换到图像空间
|
||||
fea_to_img = converter_img(fea_bev_space)
|
||||
|
||||
print(f'图像空间特征输入: {fea_img_space.shape}')
|
||||
print(f'→ cvt_bev 转换后: {fea_to_bev.shape}')
|
||||
print(f'BEV空间特征输入: {fea_bev_space.shape}')
|
||||
print(f'→ cvt_img 转换后: {fea_to_img.shape}')
|
||||
|
||||
# 可视化转换前后
|
||||
visualize_feature_similarity(
|
||||
fea_img_space, fea_to_bev,
|
||||
'cvt_bev: 图像特征 → BEV空间',
|
||||
'converter_img_to_bev.png'
|
||||
)
|
||||
|
||||
visualize_feature_similarity(
|
||||
fea_bev_space, fea_to_img,
|
||||
'cvt_img: BEV特征 → 图像空间',
|
||||
'converter_bev_to_img.png'
|
||||
)
|
||||
|
||||
# 可视化attention
|
||||
visualize_attention(converter_bev, fea_img_space)
|
||||
|
||||
|
||||
def analyze_architecture():
|
||||
"""分析Converter结构"""
|
||||
print('\n--- Converter 架构分析 ---')
|
||||
|
||||
converter = Converter(in_c=128)
|
||||
total = sum(p.numel() for p in converter.parameters())
|
||||
print(f'总参数量: {total:,} ({total / 1e3:.1f}K)')
|
||||
|
||||
for name, module in converter.named_children():
|
||||
params = sum(p.numel() for p in module.parameters())
|
||||
print(f' {name:15s}: {params:>10,} params')
|
||||
|
||||
# 详细结构
|
||||
print("""
|
||||
Converter 内部结构:
|
||||
|
||||
┌──────────────────────────────────────────┐
|
||||
│ 输入 x: (B, 128, N) │
|
||||
│ │ │
|
||||
│ ┌─────┴─────┐ │
|
||||
│ │ 路径1: MHA │ 路径2: Conv1d瓶颈块 │
|
||||
│ │ Self-Attn │ Conv1d(128→32→128) │
|
||||
│ │ x → x2 │ x → x3 │
|
||||
│ └─────┬─────┘ │
|
||||
│ │ │
|
||||
│ concat([x2, x3]) → Conv1d(256→128) │
|
||||
│ │ │
|
||||
│ 输出: (B, 128, N) │
|
||||
└──────────────────────────────────────────┘
|
||||
|
||||
MHA (多头自注意力):
|
||||
- d_model=128, num_heads=4
|
||||
- Q,K,V → 点积attention → FFN
|
||||
- 捕捉特征点之间的全局关系
|
||||
|
||||
Conv1d瓶颈块:
|
||||
- 128→32→16→32→128→128 (bottleneck)
|
||||
- 逐点卷积,提取通道间的非线性关系
|
||||
|
||||
两条路径互补:
|
||||
- MHA: 全局上下文建模
|
||||
- Conv1d: 局部特征变换
|
||||
- 残差连接 + concat融合
|
||||
""")
|
||||
|
||||
|
||||
def main():
|
||||
print('=' * 60)
|
||||
print('Converter (跨模态特征转换器) 结构与功能可视化')
|
||||
print('=' * 60)
|
||||
|
||||
analyze_architecture()
|
||||
test_cross_modal_convert()
|
||||
|
||||
print(f'\n所有可视化结果保存在: {OUTPUT_DIR}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user