""" Generator & FusionHead 全景生成器与融合头 Demo ============================================== Generator: 从变长图像特征生成固定数量的全景特征 Self-Attention → ConvTranspose1d(k3,s3) → AdaptiveMaxPool1d(150) 输入: (B, 128, N) N可变 输出: (B, 128, 150) 固定150个 FusionHead: 融合多来源特征 对 [original, gen, gen_gen, kpl_gen] 四个特征 → pair-wise Self-Attention → max聚合 → Cross-Attention → 输出 输入: (B, 128, 150, 4) 输出: (B, 128, 150) 融合后特征 """ import torch import numpy as np import matplotlib.pyplot as plt import matplotlib matplotlib.use('Agg') import sys import os sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from net import Generator, FusionHead, Attention OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output') os.makedirs(OUTPUT_DIR, exist_ok=True) def test_generator(): """测试Generator: 变长→定长特征转换""" print('\n--- Generator 全景特征生成器 ---') generator = Generator(in_c=128, num=150) generator.eval() # 模拟变长输入 (B=2, C=128, N=可变的200) torch.manual_seed(42) x = torch.randn(2, 128, 200) with torch.no_grad(): output = generator(x) print(f'输入: {x.shape} (变长,N=200)') print(f'输出: {output.shape} (固定,K=150)') # 可视化输入输出特征 fig, axes = plt.subplots(2, 3, figsize=(18, 10)) # 输入特征相似度矩阵 (前50个点) x_norm = x[0] / (x[0].norm(dim=0, keepdim=True) + 1e-8) sim_in = (x_norm.T[:50] @ x_norm[:, :50]).detach().numpy() im0 = axes[0, 0].imshow(sim_in, cmap='RdYlBu_r', vmin=-1, vmax=1) axes[0, 0].set_title('输入特征相似度 (前50点)') plt.colorbar(im0, ax=axes[0, 0]) # 输出特征相似度矩阵 out_norm = output[0] / (output[0].norm(dim=0, keepdim=True) + 1e-8) sim_out = (out_norm.T @ out_norm).detach().numpy() im1 = axes[0, 1].imshow(sim_out, cmap='RdYlBu_r', vmin=-1, vmax=1) axes[0, 1].set_title('输出特征相似度 (150点)') plt.colorbar(im1, ax=axes[0, 1]) # 输入特征热图 im2 = axes[0, 2].imshow(x[0, :, :30].detach().numpy(), cmap='viridis', aspect='auto') axes[0, 2].set_title('输入特征 (30点)') axes[0, 2].set_xlabel('Point Index'); axes[0, 2].set_ylabel('Channel') plt.colorbar(im2, ax=axes[0, 2]) # 输出特征热图 im3 = axes[1, 0].imshow(output[0, :, :30].detach().numpy(), cmap='viridis', aspect='auto') axes[1, 0].set_title('输出特征 (30点)') axes[1, 0].set_xlabel('Point Index'); axes[1, 0].set_ylabel('Channel') plt.colorbar(im3, ax=axes[1, 0]) # ConvTranspose + AdaptiveMaxPool 原理 axes[1, 1].set_title('Generator 内部变换', fontsize=12) axes[1, 1].text(0.5, 0.8, 'ConvTranspose1d(k3,s3)', transform=axes[1, 1].transAxes, ha='center', fontsize=11, bbox=dict(boxstyle='round', facecolor='lightblue')) axes[1, 1].text(0.5, 0.6, f'200 → 200*3 = 600', transform=axes[1, 1].transAxes, ha='center', fontsize=10) axes[1, 1].text(0.5, 0.4, 'AdaptiveMaxPool1d(150)', transform=axes[1, 1].transAxes, ha='center', fontsize=11, bbox=dict(boxstyle='round', facecolor='lightgreen')) axes[1, 1].text(0.5, 0.2, f'600 → 150', transform=axes[1, 1].transAxes, ha='center', fontsize=10) axes[1, 1].axis('off') # 特征值分布对比 axes[1, 2].hist(x[0].detach().numpy().flatten(), bins=50, alpha=0.5, label='Input', color='steelblue') axes[1, 2].hist(output[0].detach().numpy().flatten(), bins=50, alpha=0.5, label='Output', color='coral') axes[1, 2].set_title('特征值分布对比') axes[1, 2].legend() plt.suptitle('Generator: 变长特征→固定大小特征', fontsize=14, fontweight='bold') plt.tight_layout() path = os.path.join(OUTPUT_DIR, 'generator_demo.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f' [保存] {path}') # 测试不同输入长度 print('\nGenerator 对不同输入长度的适应:') for n in [50, 100, 200, 500]: x_test = torch.randn(1, 128, n) with torch.no_grad(): out = generator(x_test) print(f' N={n:4d} → 输出形状 {out.shape}') def test_fusion_head(): """测试FusionHead: 多来源特征融合""" print('\n--- FusionHead 融合头 ---') fusion_head = FusionHead(in_c=128) fusion_head.eval() # 模拟4种特征: # [0]: fea_kpt_original - BEV原始关键点特征 # [1]: fea_kpt_original_gen - Generator生成的BEV特征 # [2]: fea_kpt_gen_gen - 双路径转换器输出 # [3]: fea_kpl_gen - BEV→图像空间特征 B, C, K = 2, 128, 150 torch.manual_seed(42) # 让不同来源的特征有相关性但不完全相同 base = torch.randn(B, C, K) fea_original = base fea_gen = base + 0.3 * torch.randn(B, C, K) fea_gen_gen = fea_gen + 0.2 * torch.randn(B, C, K) fea_kpl_gen = base + 0.5 * torch.randn(B, C, K) fea_kpts = torch.stack([fea_original, fea_gen, fea_gen_gen, fea_kpl_gen], dim=2) print(f'输入: {fea_kpts.shape} [B, C, K, 4来源]') with torch.no_grad(): fea_fused = fusion_head(fea_kpts) print(f'输出: {fea_fused.shape} [B, C, K] 融合特征') # 可视化 fig, axes = plt.subplots(2, 3, figsize=(18, 10)) names = ['Original (BEV原始)', 'Generated (全景生成)', 'Gen_Gen (双路径)', 'KPL_Gen (图像空间)'] for idx in range(4): ax = axes[idx // 2, idx % 2] sim = torch.nn.functional.cosine_similarity( fea_kpts[0, :, :, 0].T.unsqueeze(-1), fea_kpts[0, :, :, idx].T.unsqueeze(0), dim=1 ) im = ax.imshow(sim.detach().numpy(), cmap='RdYlBu_r', vmin=-1, vmax=1) ax.set_title(f'{names[idx]}\nvs Original 相似度') ax.set_xlabel('Point'); ax.set_ylabel('Point') plt.colorbar(im, ax=ax) # 融合特征 vs 原始特征 ax = axes[1, 2] sim_fused = torch.nn.functional.cosine_similarity( fea_original[0].T.unsqueeze(-1), fea_fused[0].T.unsqueeze(0), dim=1 ) im = ax.imshow(sim_fused.detach().numpy(), cmap='RdYlBu_r', vmin=-1, vmax=1) ax.set_title('Fused vs Original 相似度') ax.set_xlabel('Point'); ax.set_ylabel('Point') plt.colorbar(im, ax=ax) plt.suptitle('FusionHead: 多来源特征融合分析', fontsize=14, fontweight='bold') plt.tight_layout() path = os.path.join(OUTPUT_DIR, 'fusion_head_demo.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f' [保存] {path}') def visualize_attention_detail(): """详细可视化FusionHead中的Attention机制""" print('\n--- FusionHead Attention 详细分析 ---') att = Attention(d_model=128) att.eval() # 模拟3对特征的Self-Attention B, N_pair, C = 2, 3, 128 torch.manual_seed(42) x = torch.randn(B * 2, N_pair, C) # 模拟batch*样本数的3对特征 with torch.no_grad(): output, weights = att(x, x, x) print(f'Self-Attention 输入: {x.shape}') print(f'输出: {output.shape}') print(f'Attention权重: {weights.shape} (B, 3, 3)') # 可视化attention权重 fig, axes = plt.subplots(1, 2, figsize=(14, 5)) weights_np = weights[0].detach().numpy() im0 = axes[0].imshow(weights_np, cmap='YlOrRd', vmin=0, vmax=1) axes[0].set_title('Self-Attention 权重 (3对特征)') axes[0].set_xticks(range(3)) axes[0].set_xticklabels(['Original', 'Generated', 'Gen_Gen']) axes[0].set_yticks(range(3)) axes[0].set_yticklabels(['Original', 'Generated', 'Gen_Gen']) for i in range(3): for j in range(3): axes[0].text(j, i, f'{weights_np[i, j]:.3f}', ha='center', va='center', fontsize=12, color='white' if weights_np[i, j] > 0.5 else 'black') plt.colorbar(im0, ax=axes[0]) # Cross-Attention 示意图 axes[1].set_title('FusionHead Attention 流程', fontsize=12) steps = [ '1. 拼接4种特征 [original, gen, gen_gen, kpl_gen]', '2. 取前3种 [original, gen, gen_gen]', '3. 对每个样本的3对特征做Self-Attention', '4. max聚合 → 每样本1个特征', '5. Cross-Attention with kpl_gen (图像空间特征)', '6. concat(original, cross_out) → Conv1d → 输出' ] for i, step in enumerate(steps): axes[1].text(0.1, 0.9 - i * 0.15, step, transform=axes[1].transAxes, fontsize=10, family='monospace', bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.7)) axes[1].axis('off') plt.suptitle('FusionHead Attention 机制详解', fontsize=14, fontweight='bold') plt.tight_layout() path = os.path.join(OUTPUT_DIR, 'fusion_attention_detail.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f' [保存] {path}') def analyze_parameters(): """参数量分析""" print('\n--- 参数量分析 ---') gen = Generator(in_c=128, num=150) fusion = FusionHead(in_c=128) for name, model in [('Generator', gen), ('FusionHead', fusion)]: total = sum(p.numel() for p in model.parameters()) print(f'\n{name}: {total:,} params ({total / 1e3:.1f}K)') for n, m in model.named_children(): p = sum(pmt.numel() for pmt in m.parameters()) print(f' {n:15s}: {p:>10,} params') def main(): print('=' * 60) print('Generator & FusionHead 结构与功能可视化') print('=' * 60) analyze_parameters() test_generator() test_fusion_head() visualize_attention_detail() print('\n' + '=' * 60) print('结构总结:') print('=' * 60) print(""" Generator (全景特征生成器): ┌──────────────────────────────────────────────┐ │ 输入: (B, 128, N) N可变 │ │ ↓ Self-Attention (MHA) │ │ x2: (B, 128, N) 全局上下文特征 │ │ ↓ ConvTranspose1d(k3,s3) │ │ x3: (B, 128, N*3) 上采样扩展 │ │ ↓ AdaptiveMaxPool1d(150) │ │ 输出: (B, 128, 150) 固定K个全景特征 │ └──────────────────────────────────────────────┘ 作用: 将BEV中可变数量的匹配点特征压缩为固定150个, 与BEV关键点数量对齐 FusionHead (跨模态融合头): ┌──────────────────────────────────────────────┐ │ 输入: (B, 128, 150, 4) │ │ [original, gen, gen_gen, kpl_gen] │ │ ↓ │ │ 对前3对 (B*N, 3, C): │ │ Self-Attn → max(dim=1) → (B*N, C) │ │ ↓ reshape → (B, N, C) │ │ Cross-Attention with kpl_gen │ │ ↓ │ │ concat(original, cross_out) → Conv1d(256→128) │ │ 输出: (B, 128, 150) 融合特征 │ └──────────────────────────────────────────────┘ 作用: 整合多来源特征,增强融合表示 """) print(f'\n所有可视化结果保存在: {OUTPUT_DIR}') if __name__ == '__main__': main()