305 lines
12 KiB
Python
305 lines
12 KiB
Python
"""
|
||
Generator & FusionHead 全景生成器与融合头 Demo
|
||
==============================================
|
||
Generator: 从变长图像特征生成固定数量的全景特征
|
||
Self-Attention → ConvTranspose1d(k3,s3) → AdaptiveMaxPool1d(150)
|
||
输入: (B, 128, N) N可变
|
||
输出: (B, 128, 150) 固定150个
|
||
|
||
FusionHead: 融合多来源特征
|
||
对 [original, gen, gen_gen, kpl_gen] 四个特征
|
||
→ pair-wise Self-Attention → max聚合 → Cross-Attention → 输出
|
||
输入: (B, 128, 150, 4)
|
||
输出: (B, 128, 150) 融合后特征
|
||
"""
|
||
|
||
import torch
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib
|
||
matplotlib.use('Agg')
|
||
|
||
import sys
|
||
import os
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from net import Generator, FusionHead, Attention
|
||
|
||
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output')
|
||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||
|
||
|
||
def test_generator():
|
||
"""测试Generator: 变长→定长特征转换"""
|
||
print('\n--- Generator 全景特征生成器 ---')
|
||
|
||
generator = Generator(in_c=128, num=150)
|
||
generator.eval()
|
||
|
||
# 模拟变长输入 (B=2, C=128, N=可变的200)
|
||
torch.manual_seed(42)
|
||
x = torch.randn(2, 128, 200)
|
||
|
||
with torch.no_grad():
|
||
output = generator(x)
|
||
|
||
print(f'输入: {x.shape} (变长,N=200)')
|
||
print(f'输出: {output.shape} (固定,K=150)')
|
||
|
||
# 可视化输入输出特征
|
||
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
|
||
|
||
# 输入特征相似度矩阵 (前50个点)
|
||
x_norm = x[0] / (x[0].norm(dim=0, keepdim=True) + 1e-8)
|
||
sim_in = (x_norm.T[:50] @ x_norm[:, :50]).detach().numpy()
|
||
im0 = axes[0, 0].imshow(sim_in, cmap='RdYlBu_r', vmin=-1, vmax=1)
|
||
axes[0, 0].set_title('输入特征相似度 (前50点)')
|
||
plt.colorbar(im0, ax=axes[0, 0])
|
||
|
||
# 输出特征相似度矩阵
|
||
out_norm = output[0] / (output[0].norm(dim=0, keepdim=True) + 1e-8)
|
||
sim_out = (out_norm.T @ out_norm).detach().numpy()
|
||
im1 = axes[0, 1].imshow(sim_out, cmap='RdYlBu_r', vmin=-1, vmax=1)
|
||
axes[0, 1].set_title('输出特征相似度 (150点)')
|
||
plt.colorbar(im1, ax=axes[0, 1])
|
||
|
||
# 输入特征热图
|
||
im2 = axes[0, 2].imshow(x[0, :, :30].detach().numpy(), cmap='viridis', aspect='auto')
|
||
axes[0, 2].set_title('输入特征 (30点)')
|
||
axes[0, 2].set_xlabel('Point Index'); axes[0, 2].set_ylabel('Channel')
|
||
plt.colorbar(im2, ax=axes[0, 2])
|
||
|
||
# 输出特征热图
|
||
im3 = axes[1, 0].imshow(output[0, :, :30].detach().numpy(), cmap='viridis', aspect='auto')
|
||
axes[1, 0].set_title('输出特征 (30点)')
|
||
axes[1, 0].set_xlabel('Point Index'); axes[1, 0].set_ylabel('Channel')
|
||
plt.colorbar(im3, ax=axes[1, 0])
|
||
|
||
# ConvTranspose + AdaptiveMaxPool 原理
|
||
axes[1, 1].set_title('Generator 内部变换', fontsize=12)
|
||
axes[1, 1].text(0.5, 0.8, 'ConvTranspose1d(k3,s3)', transform=axes[1, 1].transAxes,
|
||
ha='center', fontsize=11, bbox=dict(boxstyle='round', facecolor='lightblue'))
|
||
axes[1, 1].text(0.5, 0.6, f'200 → 200*3 = 600', transform=axes[1, 1].transAxes,
|
||
ha='center', fontsize=10)
|
||
axes[1, 1].text(0.5, 0.4, 'AdaptiveMaxPool1d(150)', transform=axes[1, 1].transAxes,
|
||
ha='center', fontsize=11, bbox=dict(boxstyle='round', facecolor='lightgreen'))
|
||
axes[1, 1].text(0.5, 0.2, f'600 → 150', transform=axes[1, 1].transAxes,
|
||
ha='center', fontsize=10)
|
||
axes[1, 1].axis('off')
|
||
|
||
# 特征值分布对比
|
||
axes[1, 2].hist(x[0].detach().numpy().flatten(), bins=50, alpha=0.5,
|
||
label='Input', color='steelblue')
|
||
axes[1, 2].hist(output[0].detach().numpy().flatten(), bins=50, alpha=0.5,
|
||
label='Output', color='coral')
|
||
axes[1, 2].set_title('特征值分布对比')
|
||
axes[1, 2].legend()
|
||
|
||
plt.suptitle('Generator: 变长特征→固定大小特征', fontsize=14, fontweight='bold')
|
||
plt.tight_layout()
|
||
path = os.path.join(OUTPUT_DIR, 'generator_demo.png')
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f' [保存] {path}')
|
||
|
||
# 测试不同输入长度
|
||
print('\nGenerator 对不同输入长度的适应:')
|
||
for n in [50, 100, 200, 500]:
|
||
x_test = torch.randn(1, 128, n)
|
||
with torch.no_grad():
|
||
out = generator(x_test)
|
||
print(f' N={n:4d} → 输出形状 {out.shape}')
|
||
|
||
|
||
def test_fusion_head():
|
||
"""测试FusionHead: 多来源特征融合"""
|
||
print('\n--- FusionHead 融合头 ---')
|
||
|
||
fusion_head = FusionHead(in_c=128)
|
||
fusion_head.eval()
|
||
|
||
# 模拟4种特征:
|
||
# [0]: fea_kpt_original - BEV原始关键点特征
|
||
# [1]: fea_kpt_original_gen - Generator生成的BEV特征
|
||
# [2]: fea_kpt_gen_gen - 双路径转换器输出
|
||
# [3]: fea_kpl_gen - BEV→图像空间特征
|
||
B, C, K = 2, 128, 150
|
||
torch.manual_seed(42)
|
||
|
||
# 让不同来源的特征有相关性但不完全相同
|
||
base = torch.randn(B, C, K)
|
||
fea_original = base
|
||
fea_gen = base + 0.3 * torch.randn(B, C, K)
|
||
fea_gen_gen = fea_gen + 0.2 * torch.randn(B, C, K)
|
||
fea_kpl_gen = base + 0.5 * torch.randn(B, C, K)
|
||
|
||
fea_kpts = torch.stack([fea_original, fea_gen, fea_gen_gen, fea_kpl_gen], dim=2)
|
||
print(f'输入: {fea_kpts.shape} [B, C, K, 4来源]')
|
||
|
||
with torch.no_grad():
|
||
fea_fused = fusion_head(fea_kpts)
|
||
|
||
print(f'输出: {fea_fused.shape} [B, C, K] 融合特征')
|
||
|
||
# 可视化
|
||
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
|
||
|
||
names = ['Original (BEV原始)', 'Generated (全景生成)',
|
||
'Gen_Gen (双路径)', 'KPL_Gen (图像空间)']
|
||
|
||
for idx in range(4):
|
||
ax = axes[idx // 2, idx % 2]
|
||
sim = torch.nn.functional.cosine_similarity(
|
||
fea_kpts[0, :, :, 0].T.unsqueeze(-1),
|
||
fea_kpts[0, :, :, idx].T.unsqueeze(0),
|
||
dim=1
|
||
)
|
||
im = ax.imshow(sim.detach().numpy(), cmap='RdYlBu_r', vmin=-1, vmax=1)
|
||
ax.set_title(f'{names[idx]}\nvs Original 相似度')
|
||
ax.set_xlabel('Point'); ax.set_ylabel('Point')
|
||
plt.colorbar(im, ax=ax)
|
||
|
||
# 融合特征 vs 原始特征
|
||
ax = axes[1, 2]
|
||
sim_fused = torch.nn.functional.cosine_similarity(
|
||
fea_original[0].T.unsqueeze(-1),
|
||
fea_fused[0].T.unsqueeze(0),
|
||
dim=1
|
||
)
|
||
im = ax.imshow(sim_fused.detach().numpy(), cmap='RdYlBu_r', vmin=-1, vmax=1)
|
||
ax.set_title('Fused vs Original 相似度')
|
||
ax.set_xlabel('Point'); ax.set_ylabel('Point')
|
||
plt.colorbar(im, ax=ax)
|
||
|
||
plt.suptitle('FusionHead: 多来源特征融合分析', fontsize=14, fontweight='bold')
|
||
plt.tight_layout()
|
||
path = os.path.join(OUTPUT_DIR, 'fusion_head_demo.png')
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f' [保存] {path}')
|
||
|
||
|
||
def visualize_attention_detail():
|
||
"""详细可视化FusionHead中的Attention机制"""
|
||
print('\n--- FusionHead Attention 详细分析 ---')
|
||
|
||
att = Attention(d_model=128)
|
||
att.eval()
|
||
|
||
# 模拟3对特征的Self-Attention
|
||
B, N_pair, C = 2, 3, 128
|
||
torch.manual_seed(42)
|
||
x = torch.randn(B * 2, N_pair, C) # 模拟batch*样本数的3对特征
|
||
|
||
with torch.no_grad():
|
||
output, weights = att(x, x, x)
|
||
|
||
print(f'Self-Attention 输入: {x.shape}')
|
||
print(f'输出: {output.shape}')
|
||
print(f'Attention权重: {weights.shape} (B, 3, 3)')
|
||
|
||
# 可视化attention权重
|
||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||
|
||
weights_np = weights[0].detach().numpy()
|
||
im0 = axes[0].imshow(weights_np, cmap='YlOrRd', vmin=0, vmax=1)
|
||
axes[0].set_title('Self-Attention 权重 (3对特征)')
|
||
axes[0].set_xticks(range(3))
|
||
axes[0].set_xticklabels(['Original', 'Generated', 'Gen_Gen'])
|
||
axes[0].set_yticks(range(3))
|
||
axes[0].set_yticklabels(['Original', 'Generated', 'Gen_Gen'])
|
||
|
||
for i in range(3):
|
||
for j in range(3):
|
||
axes[0].text(j, i, f'{weights_np[i, j]:.3f}', ha='center', va='center',
|
||
fontsize=12, color='white' if weights_np[i, j] > 0.5 else 'black')
|
||
plt.colorbar(im0, ax=axes[0])
|
||
|
||
# Cross-Attention 示意图
|
||
axes[1].set_title('FusionHead Attention 流程', fontsize=12)
|
||
steps = [
|
||
'1. 拼接4种特征 [original, gen, gen_gen, kpl_gen]',
|
||
'2. 取前3种 [original, gen, gen_gen]',
|
||
'3. 对每个样本的3对特征做Self-Attention',
|
||
'4. max聚合 → 每样本1个特征',
|
||
'5. Cross-Attention with kpl_gen (图像空间特征)',
|
||
'6. concat(original, cross_out) → Conv1d → 输出'
|
||
]
|
||
for i, step in enumerate(steps):
|
||
axes[1].text(0.1, 0.9 - i * 0.15, step, transform=axes[1].transAxes,
|
||
fontsize=10, family='monospace',
|
||
bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.7))
|
||
axes[1].axis('off')
|
||
|
||
plt.suptitle('FusionHead Attention 机制详解', fontsize=14, fontweight='bold')
|
||
plt.tight_layout()
|
||
path = os.path.join(OUTPUT_DIR, 'fusion_attention_detail.png')
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f' [保存] {path}')
|
||
|
||
|
||
def analyze_parameters():
|
||
"""参数量分析"""
|
||
print('\n--- 参数量分析 ---')
|
||
|
||
gen = Generator(in_c=128, num=150)
|
||
fusion = FusionHead(in_c=128)
|
||
|
||
for name, model in [('Generator', gen), ('FusionHead', fusion)]:
|
||
total = sum(p.numel() for p in model.parameters())
|
||
print(f'\n{name}: {total:,} params ({total / 1e3:.1f}K)')
|
||
for n, m in model.named_children():
|
||
p = sum(pmt.numel() for pmt in m.parameters())
|
||
print(f' {n:15s}: {p:>10,} params')
|
||
|
||
|
||
def main():
|
||
print('=' * 60)
|
||
print('Generator & FusionHead 结构与功能可视化')
|
||
print('=' * 60)
|
||
|
||
analyze_parameters()
|
||
test_generator()
|
||
test_fusion_head()
|
||
visualize_attention_detail()
|
||
|
||
print('\n' + '=' * 60)
|
||
print('结构总结:')
|
||
print('=' * 60)
|
||
print("""
|
||
Generator (全景特征生成器):
|
||
┌──────────────────────────────────────────────┐
|
||
│ 输入: (B, 128, N) N可变 │
|
||
│ ↓ Self-Attention (MHA) │
|
||
│ x2: (B, 128, N) 全局上下文特征 │
|
||
│ ↓ ConvTranspose1d(k3,s3) │
|
||
│ x3: (B, 128, N*3) 上采样扩展 │
|
||
│ ↓ AdaptiveMaxPool1d(150) │
|
||
│ 输出: (B, 128, 150) 固定K个全景特征 │
|
||
└──────────────────────────────────────────────┘
|
||
作用: 将BEV中可变数量的匹配点特征压缩为固定150个,
|
||
与BEV关键点数量对齐
|
||
|
||
FusionHead (跨模态融合头):
|
||
┌──────────────────────────────────────────────┐
|
||
│ 输入: (B, 128, 150, 4) │
|
||
│ [original, gen, gen_gen, kpl_gen] │
|
||
│ ↓ │
|
||
│ 对前3对 (B*N, 3, C): │
|
||
│ Self-Attn → max(dim=1) → (B*N, C) │
|
||
│ ↓ reshape → (B, N, C) │
|
||
│ Cross-Attention with kpl_gen │
|
||
│ ↓ │
|
||
│ concat(original, cross_out) → Conv1d(256→128) │
|
||
│ 输出: (B, 128, 150) 融合特征 │
|
||
└──────────────────────────────────────────────┘
|
||
作用: 整合多来源特征,增强融合表示
|
||
""")
|
||
|
||
print(f'\n所有可视化结果保存在: {OUTPUT_DIR}')
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|