Files
fusion_LCD/network_learning/04_generator_fusion_demo.py
2026-05-09 17:03:40 +08:00

305 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Generator & FusionHead 全景生成器与融合头 Demo
==============================================
Generator: 从变长图像特征生成固定数量的全景特征
Self-Attention → ConvTranspose1d(k3,s3) → AdaptiveMaxPool1d(150)
输入: (B, 128, N) N可变
输出: (B, 128, 150) 固定150个
FusionHead: 融合多来源特征
对 [original, gen, gen_gen, kpl_gen] 四个特征
→ pair-wise Self-Attention → max聚合 → Cross-Attention → 输出
输入: (B, 128, 150, 4)
输出: (B, 128, 150) 融合后特征
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from net import Generator, FusionHead, Attention
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output')
os.makedirs(OUTPUT_DIR, exist_ok=True)
def test_generator():
"""测试Generator: 变长→定长特征转换"""
print('\n--- Generator 全景特征生成器 ---')
generator = Generator(in_c=128, num=150)
generator.eval()
# 模拟变长输入 (B=2, C=128, N=可变的200)
torch.manual_seed(42)
x = torch.randn(2, 128, 200)
with torch.no_grad():
output = generator(x)
print(f'输入: {x.shape} (变长N=200)')
print(f'输出: {output.shape} (固定K=150)')
# 可视化输入输出特征
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
# 输入特征相似度矩阵 (前50个点)
x_norm = x[0] / (x[0].norm(dim=0, keepdim=True) + 1e-8)
sim_in = (x_norm.T[:50] @ x_norm[:, :50]).detach().numpy()
im0 = axes[0, 0].imshow(sim_in, cmap='RdYlBu_r', vmin=-1, vmax=1)
axes[0, 0].set_title('输入特征相似度 (前50点)')
plt.colorbar(im0, ax=axes[0, 0])
# 输出特征相似度矩阵
out_norm = output[0] / (output[0].norm(dim=0, keepdim=True) + 1e-8)
sim_out = (out_norm.T @ out_norm).detach().numpy()
im1 = axes[0, 1].imshow(sim_out, cmap='RdYlBu_r', vmin=-1, vmax=1)
axes[0, 1].set_title('输出特征相似度 (150点)')
plt.colorbar(im1, ax=axes[0, 1])
# 输入特征热图
im2 = axes[0, 2].imshow(x[0, :, :30].detach().numpy(), cmap='viridis', aspect='auto')
axes[0, 2].set_title('输入特征 (30点)')
axes[0, 2].set_xlabel('Point Index'); axes[0, 2].set_ylabel('Channel')
plt.colorbar(im2, ax=axes[0, 2])
# 输出特征热图
im3 = axes[1, 0].imshow(output[0, :, :30].detach().numpy(), cmap='viridis', aspect='auto')
axes[1, 0].set_title('输出特征 (30点)')
axes[1, 0].set_xlabel('Point Index'); axes[1, 0].set_ylabel('Channel')
plt.colorbar(im3, ax=axes[1, 0])
# ConvTranspose + AdaptiveMaxPool 原理
axes[1, 1].set_title('Generator 内部变换', fontsize=12)
axes[1, 1].text(0.5, 0.8, 'ConvTranspose1d(k3,s3)', transform=axes[1, 1].transAxes,
ha='center', fontsize=11, bbox=dict(boxstyle='round', facecolor='lightblue'))
axes[1, 1].text(0.5, 0.6, f'200 → 200*3 = 600', transform=axes[1, 1].transAxes,
ha='center', fontsize=10)
axes[1, 1].text(0.5, 0.4, 'AdaptiveMaxPool1d(150)', transform=axes[1, 1].transAxes,
ha='center', fontsize=11, bbox=dict(boxstyle='round', facecolor='lightgreen'))
axes[1, 1].text(0.5, 0.2, f'600 → 150', transform=axes[1, 1].transAxes,
ha='center', fontsize=10)
axes[1, 1].axis('off')
# 特征值分布对比
axes[1, 2].hist(x[0].detach().numpy().flatten(), bins=50, alpha=0.5,
label='Input', color='steelblue')
axes[1, 2].hist(output[0].detach().numpy().flatten(), bins=50, alpha=0.5,
label='Output', color='coral')
axes[1, 2].set_title('特征值分布对比')
axes[1, 2].legend()
plt.suptitle('Generator: 变长特征→固定大小特征', fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'generator_demo.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
# 测试不同输入长度
print('\nGenerator 对不同输入长度的适应:')
for n in [50, 100, 200, 500]:
x_test = torch.randn(1, 128, n)
with torch.no_grad():
out = generator(x_test)
print(f' N={n:4d} → 输出形状 {out.shape}')
def test_fusion_head():
"""测试FusionHead: 多来源特征融合"""
print('\n--- FusionHead 融合头 ---')
fusion_head = FusionHead(in_c=128)
fusion_head.eval()
# 模拟4种特征:
# [0]: fea_kpt_original - BEV原始关键点特征
# [1]: fea_kpt_original_gen - Generator生成的BEV特征
# [2]: fea_kpt_gen_gen - 双路径转换器输出
# [3]: fea_kpl_gen - BEV→图像空间特征
B, C, K = 2, 128, 150
torch.manual_seed(42)
# 让不同来源的特征有相关性但不完全相同
base = torch.randn(B, C, K)
fea_original = base
fea_gen = base + 0.3 * torch.randn(B, C, K)
fea_gen_gen = fea_gen + 0.2 * torch.randn(B, C, K)
fea_kpl_gen = base + 0.5 * torch.randn(B, C, K)
fea_kpts = torch.stack([fea_original, fea_gen, fea_gen_gen, fea_kpl_gen], dim=2)
print(f'输入: {fea_kpts.shape} [B, C, K, 4来源]')
with torch.no_grad():
fea_fused = fusion_head(fea_kpts)
print(f'输出: {fea_fused.shape} [B, C, K] 融合特征')
# 可视化
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
names = ['Original (BEV原始)', 'Generated (全景生成)',
'Gen_Gen (双路径)', 'KPL_Gen (图像空间)']
for idx in range(4):
ax = axes[idx // 2, idx % 2]
sim = torch.nn.functional.cosine_similarity(
fea_kpts[0, :, :, 0].T.unsqueeze(-1),
fea_kpts[0, :, :, idx].T.unsqueeze(0),
dim=1
)
im = ax.imshow(sim.detach().numpy(), cmap='RdYlBu_r', vmin=-1, vmax=1)
ax.set_title(f'{names[idx]}\nvs Original 相似度')
ax.set_xlabel('Point'); ax.set_ylabel('Point')
plt.colorbar(im, ax=ax)
# 融合特征 vs 原始特征
ax = axes[1, 2]
sim_fused = torch.nn.functional.cosine_similarity(
fea_original[0].T.unsqueeze(-1),
fea_fused[0].T.unsqueeze(0),
dim=1
)
im = ax.imshow(sim_fused.detach().numpy(), cmap='RdYlBu_r', vmin=-1, vmax=1)
ax.set_title('Fused vs Original 相似度')
ax.set_xlabel('Point'); ax.set_ylabel('Point')
plt.colorbar(im, ax=ax)
plt.suptitle('FusionHead: 多来源特征融合分析', fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'fusion_head_demo.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def visualize_attention_detail():
"""详细可视化FusionHead中的Attention机制"""
print('\n--- FusionHead Attention 详细分析 ---')
att = Attention(d_model=128)
att.eval()
# 模拟3对特征的Self-Attention
B, N_pair, C = 2, 3, 128
torch.manual_seed(42)
x = torch.randn(B * 2, N_pair, C) # 模拟batch*样本数的3对特征
with torch.no_grad():
output, weights = att(x, x, x)
print(f'Self-Attention 输入: {x.shape}')
print(f'输出: {output.shape}')
print(f'Attention权重: {weights.shape} (B, 3, 3)')
# 可视化attention权重
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
weights_np = weights[0].detach().numpy()
im0 = axes[0].imshow(weights_np, cmap='YlOrRd', vmin=0, vmax=1)
axes[0].set_title('Self-Attention 权重 (3对特征)')
axes[0].set_xticks(range(3))
axes[0].set_xticklabels(['Original', 'Generated', 'Gen_Gen'])
axes[0].set_yticks(range(3))
axes[0].set_yticklabels(['Original', 'Generated', 'Gen_Gen'])
for i in range(3):
for j in range(3):
axes[0].text(j, i, f'{weights_np[i, j]:.3f}', ha='center', va='center',
fontsize=12, color='white' if weights_np[i, j] > 0.5 else 'black')
plt.colorbar(im0, ax=axes[0])
# Cross-Attention 示意图
axes[1].set_title('FusionHead Attention 流程', fontsize=12)
steps = [
'1. 拼接4种特征 [original, gen, gen_gen, kpl_gen]',
'2. 取前3种 [original, gen, gen_gen]',
'3. 对每个样本的3对特征做Self-Attention',
'4. max聚合 → 每样本1个特征',
'5. Cross-Attention with kpl_gen (图像空间特征)',
'6. concat(original, cross_out) → Conv1d → 输出'
]
for i, step in enumerate(steps):
axes[1].text(0.1, 0.9 - i * 0.15, step, transform=axes[1].transAxes,
fontsize=10, family='monospace',
bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.7))
axes[1].axis('off')
plt.suptitle('FusionHead Attention 机制详解', fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'fusion_attention_detail.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def analyze_parameters():
"""参数量分析"""
print('\n--- 参数量分析 ---')
gen = Generator(in_c=128, num=150)
fusion = FusionHead(in_c=128)
for name, model in [('Generator', gen), ('FusionHead', fusion)]:
total = sum(p.numel() for p in model.parameters())
print(f'\n{name}: {total:,} params ({total / 1e3:.1f}K)')
for n, m in model.named_children():
p = sum(pmt.numel() for pmt in m.parameters())
print(f' {n:15s}: {p:>10,} params')
def main():
print('=' * 60)
print('Generator & FusionHead 结构与功能可视化')
print('=' * 60)
analyze_parameters()
test_generator()
test_fusion_head()
visualize_attention_detail()
print('\n' + '=' * 60)
print('结构总结:')
print('=' * 60)
print("""
Generator (全景特征生成器):
┌──────────────────────────────────────────────┐
│ 输入: (B, 128, N) N可变 │
│ ↓ Self-Attention (MHA) │
│ x2: (B, 128, N) 全局上下文特征 │
│ ↓ ConvTranspose1d(k3,s3) │
│ x3: (B, 128, N*3) 上采样扩展 │
│ ↓ AdaptiveMaxPool1d(150) │
│ 输出: (B, 128, 150) 固定K个全景特征 │
└──────────────────────────────────────────────┘
作用: 将BEV中可变数量的匹配点特征压缩为固定150个
与BEV关键点数量对齐
FusionHead (跨模态融合头):
┌──────────────────────────────────────────────┐
│ 输入: (B, 128, 150, 4) │
│ [original, gen, gen_gen, kpl_gen] │
│ ↓ │
│ 对前3对 (B*N, 3, C): │
│ Self-Attn → max(dim=1) → (B*N, C) │
│ ↓ reshape → (B, N, C) │
│ Cross-Attention with kpl_gen │
│ ↓ │
│ concat(original, cross_out) → Conv1d(256→128) │
│ 输出: (B, 128, 150) 融合特征 │
└──────────────────────────────────────────────┘
作用: 整合多来源特征,增强融合表示
""")
print(f'\n所有可视化结果保存在: {OUTPUT_DIR}')
if __name__ == '__main__':
main()