517 lines
18 KiB
Python
517 lines
18 KiB
Python
"""
|
||
完整流水线 Demo: 端到端网络结构可视化
|
||
=====================================
|
||
集成所有子网络,展示从输入到输出的完整数据流。
|
||
|
||
运行模式:
|
||
python 08_full_pipeline_demo.py --mode bev # 仅BEV分支
|
||
python 08_full_pipeline_demo.py --mode img # 仅图像分支
|
||
python 08_full_pipeline_demo.py --mode fusion # 完整融合模式
|
||
"""
|
||
|
||
import torch
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib
|
||
matplotlib.use('Agg')
|
||
|
||
import sys
|
||
import os
|
||
import argparse
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from net import Fusion, BEVHead, ImgHead, FusionHead
|
||
from BEVNet import RICNN
|
||
from ALIKE.alnet import ALNet
|
||
from netvlad import NetVLAD
|
||
from uot import UOTHead
|
||
|
||
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output')
|
||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||
|
||
|
||
def create_dummy_batch_dict(mode='fusion'):
|
||
"""创建模拟的batch_dict"""
|
||
B = 2 # batch中1对 (query + positive)
|
||
batch_dict = {
|
||
'batch_size': 2 * B,
|
||
}
|
||
|
||
if mode in ('fusion', 'bev'):
|
||
batch_dict['bev'] = torch.randn(2 * B, 7, 320, 320)
|
||
batch_dict['bev'][:, :3] = torch.sigmoid(batch_dict['bev'][:, :3]) # 可视通道
|
||
batch_dict['bev'][:, 2:3] = (batch_dict['bev'][:, 2:3] > 0.3).float() # guider mask
|
||
|
||
if mode in ('fusion', 'img'):
|
||
batch_dict['img'] = torch.randint(0, 256, (2 * B, 5, 192, 576)).float()
|
||
|
||
if mode == 'fusion':
|
||
# 模拟 relation: (B, max_len, K, 2)
|
||
max_len, K = 200, 11 # K=1+10: last dim is bev coord
|
||
batch_dict['relation'] = torch.zeros(2 * B, max_len, K, 2, dtype=torch.long)
|
||
for i in range(2 * B):
|
||
n_valid = 150
|
||
batch_dict['relation'][i, :n_valid, :K - 1, 0] = torch.randint(0, 576, (n_valid, K - 1))
|
||
batch_dict['relation'][i, :n_valid, :K - 1, 1] = torch.randint(0, 192, (n_valid, K - 1))
|
||
batch_dict['relation'][i, :n_valid, K - 1, 0] = torch.randint(0, 320, (n_valid,))
|
||
batch_dict['relation'][i, :n_valid, K - 1, 1] = torch.randint(0, 320, (n_valid,))
|
||
|
||
# pose_to_frame (训练时需要)
|
||
angle = 0.3
|
||
pose = torch.eye(4).unsqueeze(0).repeat(B, 1, 1)
|
||
pose[:, 0, 0] = torch.cos(torch.tensor(angle))
|
||
pose[:, 0, 1] = -torch.sin(torch.tensor(angle))
|
||
pose[:, 1, 0] = torch.sin(torch.tensor(angle))
|
||
pose[:, 1, 1] = torch.cos(torch.tensor(angle))
|
||
pose[:, 0, 3] = 2.0
|
||
pose[:, 1, 3] = -1.0
|
||
batch_dict['pose_to_frame'] = pose.clone()
|
||
|
||
batch_dict['pose_query'] = torch.eye(4).unsqueeze(0).repeat(B, 1, 1)
|
||
batch_dict['pose_positive'] = torch.eye(4).unsqueeze(0).repeat(B, 1, 1)
|
||
|
||
batch_dict['label_score'] = torch.zeros(B, 320, 320, 2)
|
||
batch_dict['id_query'] = torch.arange(B)
|
||
batch_dict['id_positive'] = torch.arange(B)
|
||
batch_dict['sequence'] = torch.zeros(B, dtype=torch.long)
|
||
|
||
return batch_dict
|
||
|
||
|
||
def run_bev_only():
|
||
"""仅BEV分支"""
|
||
print('\n' + '=' * 60)
|
||
print('模式: BEV Only (仅点云分支)')
|
||
print('=' * 60)
|
||
|
||
cfg = {
|
||
'flag': 'bev',
|
||
'kpts_number_bev': 150,
|
||
'kpts_number_img': 150,
|
||
'cluster_num_bev': 16,
|
||
'cluster_num_img': 16,
|
||
'cluster_num_fusion': 16,
|
||
'sinkhorn_iter': 5,
|
||
'vlad_size': 256,
|
||
}
|
||
|
||
model = Fusion(cfg)
|
||
model.eval()
|
||
total_params = sum(p.numel() for p in model.parameters())
|
||
print(f'模型参数量: {total_params:,} ({total_params / 1e6:.2f}M)')
|
||
|
||
batch_dict = create_dummy_batch_dict('bev')
|
||
|
||
with torch.no_grad():
|
||
output = model(batch_dict)
|
||
|
||
print('\n输出:')
|
||
for k, v in output.items():
|
||
if isinstance(v, torch.Tensor):
|
||
print(f' {k:30s}: {list(v.shape)}')
|
||
else:
|
||
print(f' {k:30s}: {v}')
|
||
|
||
# 可视化BEV分支数据流
|
||
fig, axes = plt.subplots(2, 4, figsize=(18, 9))
|
||
|
||
# BEV输入 (3个可视通道)
|
||
if 'bev' in output or 'bev' in batch_dict:
|
||
bev_in = batch_dict['bev'][0, :3].permute(1, 2, 0).numpy()
|
||
axes[0, 0].imshow(bev_in)
|
||
axes[0, 0].set_title('BEV输入 (3通道)')
|
||
axes[0, 0].axis('off')
|
||
|
||
# Score Map
|
||
if 'score_bev' in output:
|
||
axes[0, 1].imshow(output['score_bev'][0].numpy(), cmap='hot')
|
||
axes[0, 1].set_title('BEV Score Map')
|
||
axes[0, 1].axis('off')
|
||
|
||
# 关键点位置
|
||
if 'key_points' in output and 'pixels_kpt' in output:
|
||
bev_show = batch_dict['bev'][0, :3].permute(1, 2, 0).numpy()
|
||
axes[0, 2].imshow(bev_show)
|
||
kpt = output['pixels_kpt'][0].numpy()
|
||
axes[0, 2].scatter(kpt[:, 1], kpt[:, 0], c='red', s=5, alpha=0.8)
|
||
axes[0, 2].set_title(f'BEV Top-{len(kpt)} 关键点')
|
||
axes[0, 2].axis('off')
|
||
|
||
# Descriptor Map (第一通道)
|
||
if 'fea_bev' in output:
|
||
axes[0, 3].imshow(output['fea_bev'][0, 0].numpy(), cmap='viridis')
|
||
axes[0, 3].set_title('BEV Descriptor ch0')
|
||
axes[0, 3].axis('off')
|
||
|
||
# 关键点特征相似度
|
||
if 'fea_kpt_original' in output:
|
||
fea = output['fea_kpt_original']
|
||
# query vs positive 的相似度
|
||
B = fea.shape[0] // 2
|
||
sim = torch.nn.functional.cosine_similarity(
|
||
fea[:B].permute(0, 2, 1).unsqueeze(-1),
|
||
fea[B:].permute(0, 2, 1).unsqueeze(-2),
|
||
dim=1
|
||
)[0]
|
||
im = axes[1, 0].imshow(sim.numpy(), cmap='RdYlBu_r', vmin=-1, vmax=1)
|
||
axes[1, 0].set_title('Query-Positive 特征相似度')
|
||
axes[1, 0].set_xlabel('Positive'); axes[1, 0].set_ylabel('Query')
|
||
plt.colorbar(im, ax=axes[1, 0])
|
||
|
||
# VLAD
|
||
if 'vlads' in output:
|
||
vlad = output['vlads'][0].view(16, 128).numpy()
|
||
im = axes[1, 1].imshow(vlad, cmap='RdBu_r', aspect='auto')
|
||
axes[1, 1].set_title('VLAD描述子 (16×128)')
|
||
axes[1, 1].set_xlabel('Feature Dim'); axes[1, 1].set_ylabel('Cluster')
|
||
plt.colorbar(im, ax=axes[1, 1])
|
||
|
||
# 数据流图
|
||
axes[1, 2].set_title('BEV分支数据流')
|
||
flow = [
|
||
'bev (7,320,320)',
|
||
'→ x = bev[:3] (可视BEV)',
|
||
'→ points = bev[3:7] (坐标)',
|
||
'→ RICNN前向',
|
||
'→ score_bev (1,320,320)',
|
||
'→ fea_bev (128,320,320)',
|
||
'→ NMS + Top-K(150)',
|
||
'→ key_points (150,4)',
|
||
'→ fea_kpt (128,150)',
|
||
'→ EncodePosition',
|
||
'→ NetVLAD → vlad_bev (2048)',
|
||
]
|
||
for i, f in enumerate(flow):
|
||
axes[1, 2].text(0.1, 0.95 - i * 0.1, f, transform=axes[1, 2].transAxes,
|
||
fontsize=9, family='monospace')
|
||
axes[1, 2].axis('off')
|
||
|
||
# 参数量饼图
|
||
axes[1, 3].set_title('BEV分支参数分布')
|
||
modules = dict(model.bev.feature_extractor.named_children())
|
||
sizes = []
|
||
labels = []
|
||
for name, mod in modules.items():
|
||
p = sum(pm.numel() for pm in mod.parameters())
|
||
if p > 0:
|
||
sizes.append(p)
|
||
labels.append(f'{name}\n({p/1e3:.0f}K)')
|
||
axes[1, 3].pie(sizes, labels=labels, autopct='%1.1f%%', textprops={'fontsize': 8})
|
||
|
||
plt.suptitle('BEV Only 模式: 点云分支可视化', fontsize=14, fontweight='bold')
|
||
plt.tight_layout()
|
||
path = os.path.join(OUTPUT_DIR, 'full_pipeline_bev.png')
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f'[保存] {path}')
|
||
|
||
|
||
def run_img_only():
|
||
"""仅图像分支"""
|
||
print('\n' + '=' * 60)
|
||
print('模式: Image Only (仅图像分支)')
|
||
print('=' * 60)
|
||
|
||
cfg = {
|
||
'flag': 'img',
|
||
'kpts_number_bev': 150,
|
||
'kpts_number_img': 150,
|
||
'cluster_num_bev': 16,
|
||
'cluster_num_img': 16,
|
||
'cluster_num_fusion': 16,
|
||
'sinkhorn_iter': 5,
|
||
'vlad_size': 256,
|
||
}
|
||
|
||
model = Fusion(cfg)
|
||
model.eval()
|
||
total_params = sum(p.numel() for p in model.parameters())
|
||
print(f'模型参数量: {total_params:,} ({total_params / 1e6:.2f}M)')
|
||
|
||
batch_dict = create_dummy_batch_dict('img')
|
||
|
||
with torch.no_grad():
|
||
output = model(batch_dict)
|
||
|
||
print('\n输出:')
|
||
for k, v in output.items():
|
||
if isinstance(v, torch.Tensor):
|
||
print(f' {k:30s}: {list(v.shape)}')
|
||
else:
|
||
print(f' {k:30s}: {v}')
|
||
|
||
# 可视化
|
||
fig, axes = plt.subplots(2, 4, figsize=(18, 9))
|
||
|
||
# 输入图像
|
||
img_in = batch_dict['img'][0, :3].permute(1, 2, 0).numpy().astype(np.uint8)
|
||
axes[0, 0].imshow(img_in)
|
||
axes[0, 0].set_title('图像输入 (192×576)')
|
||
axes[0, 0].axis('off')
|
||
|
||
# Score Map
|
||
if 'score_img' in output:
|
||
axes[0, 1].imshow(output['score_img'][0, 0].numpy(), cmap='hot')
|
||
axes[0, 1].set_title('图像 Score Map')
|
||
axes[0, 1].axis('off')
|
||
|
||
# 关键点
|
||
if 'key_pixels' in output:
|
||
axes[0, 2].imshow(img_in)
|
||
kpt = output['key_pixels'][0].numpy()
|
||
axes[0, 2].scatter(kpt[:, 1], kpt[:, 0], c='red', s=5, alpha=0.8)
|
||
axes[0, 2].set_title(f'Top-{len(kpt)} 关键点')
|
||
axes[0, 2].axis('off')
|
||
|
||
# Descriptor Map
|
||
if 'fea_img' in output:
|
||
axes[0, 3].imshow(output['fea_img'][0, 0].numpy(), cmap='viridis')
|
||
axes[0, 3].set_title('图像 Descriptor ch0')
|
||
axes[0, 3].axis('off')
|
||
|
||
# 关键点特征相似度
|
||
if 'fea_kpl' in output:
|
||
fea = output['fea_kpl']
|
||
B = fea.shape[0] // 2
|
||
sim = torch.nn.functional.cosine_similarity(
|
||
fea[:B].permute(0, 2, 1).unsqueeze(-1),
|
||
fea[B:].permute(0, 2, 1).unsqueeze(-2),
|
||
dim=1
|
||
)[0]
|
||
im = axes[1, 0].imshow(sim.numpy(), cmap='RdYlBu_r', vmin=-1, vmax=1)
|
||
axes[1, 0].set_title('Query-Positive 特征相似度')
|
||
plt.colorbar(im, ax=axes[1, 0])
|
||
|
||
# 数据流图
|
||
axes[1, 1].set_title('图像分支数据流')
|
||
flow = [
|
||
'img (5,192,576)',
|
||
'→ x = img[:3]/255',
|
||
'→ ALNet前向',
|
||
'→ score_img (1,192,576)',
|
||
'→ fea_img (128,192,576)',
|
||
'→ NMS(2) + Top-K(150)',
|
||
'→ key_pixels (150,2)',
|
||
'→ fea_kpl (128,150)',
|
||
]
|
||
for i, f in enumerate(flow):
|
||
axes[1, 1].text(0.1, 0.95 - i * 0.11, f, transform=axes[1, 1].transAxes,
|
||
fontsize=9, family='monospace')
|
||
axes[1, 1].axis('off')
|
||
|
||
# 参数量饼图
|
||
axes[1, 2].set_title('图像分支参数分布')
|
||
modules = dict(model.img.feature_extractor.named_children())
|
||
sizes = []
|
||
labels = []
|
||
for name, mod in modules.items():
|
||
p = sum(pm.numel() for pm in mod.parameters())
|
||
if p > 0:
|
||
sizes.append(p)
|
||
labels.append(f'{name}\n({p/1e3:.0f}K)')
|
||
axes[1, 2].pie(sizes, labels=labels, autopct='%1.1f%%', textprops={'fontsize': 8})
|
||
|
||
axes[1, 3].axis('off')
|
||
|
||
plt.suptitle('Image Only 模式: 图像分支可视化', fontsize=14, fontweight='bold')
|
||
plt.tight_layout()
|
||
path = os.path.join(OUTPUT_DIR, 'full_pipeline_img.png')
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f'[保存] {path}')
|
||
|
||
|
||
def run_fusion():
|
||
"""完整融合模式"""
|
||
print('\n' + '=' * 60)
|
||
print('模式: Fusion (完整融合)')
|
||
print('=' * 60)
|
||
|
||
cfg = {
|
||
'flag': 'fusion',
|
||
'kpts_number_bev': 150,
|
||
'kpts_number_img': 150,
|
||
'cluster_num_bev': 16,
|
||
'cluster_num_img': 16,
|
||
'cluster_num_fusion': 16,
|
||
'sinkhorn_iter': 5,
|
||
'vlad_size': 256,
|
||
}
|
||
|
||
model = Fusion(cfg)
|
||
model.eval()
|
||
total_params = sum(p.numel() for p in model.parameters())
|
||
print(f'模型参数量: {total_params:,} ({total_params / 1e6:.2f}M)')
|
||
|
||
batch_dict = create_dummy_batch_dict('fusion')
|
||
|
||
with torch.no_grad():
|
||
output = model(batch_dict)
|
||
|
||
print('\n输出:')
|
||
for k, v in output.items():
|
||
if isinstance(v, torch.Tensor):
|
||
print(f' {k:30s}: {list(v.shape)}')
|
||
else:
|
||
print(f' {k:30s}: {v}')
|
||
|
||
# 可视化融合数据流
|
||
fig, axes = plt.subplots(3, 4, figsize=(22, 15))
|
||
|
||
# BEV输入
|
||
bev_in = batch_dict['bev'][0, :3].permute(1, 2, 0).numpy()
|
||
axes[0, 0].imshow(bev_in)
|
||
axes[0, 0].set_title('BEV 输入 (320×320)')
|
||
axes[0, 0].axis('off')
|
||
|
||
# 图像输入
|
||
img_in = batch_dict['img'][0, :3].permute(1, 2, 0).numpy().astype(np.uint8)
|
||
axes[0, 1].imshow(img_in)
|
||
axes[0, 1].set_title('图像输入 (192×576)')
|
||
axes[0, 1].axis('off')
|
||
|
||
# Score maps
|
||
if 'score_bev' in output:
|
||
axes[0, 2].imshow(output['score_bev'][0].numpy(), cmap='hot')
|
||
axes[0, 2].set_title('BEV Score')
|
||
axes[0, 2].axis('off')
|
||
if 'score_img' in output:
|
||
axes[0, 3].imshow(output['score_img'][0, 0].numpy(), cmap='hot')
|
||
axes[0, 3].set_title('Image Score')
|
||
axes[0, 3].axis('off')
|
||
|
||
# 融合特征空间中的相似度
|
||
if 'fea_kpt_original' in output and 'fea_kpt_fusion' in output:
|
||
fea_orig = output['fea_kpt_original']
|
||
fea_fusion = output['fea_kpt_fusion']
|
||
B = fea_orig.shape[0] // 2
|
||
|
||
sim_orig = torch.nn.functional.cosine_similarity(
|
||
fea_orig[:B].permute(0, 2, 1).unsqueeze(-1),
|
||
fea_orig[B:].permute(0, 2, 1).unsqueeze(-2),
|
||
dim=1
|
||
)[0].numpy()
|
||
|
||
sim_fusion = torch.nn.functional.cosine_similarity(
|
||
fea_fusion[:B].permute(0, 2, 1).unsqueeze(-1),
|
||
fea_fusion[B:].permute(0, 2, 1).unsqueeze(-2),
|
||
dim=1
|
||
)[0].numpy()
|
||
|
||
im1 = axes[1, 0].imshow(sim_orig, cmap='RdYlBu_r', vmin=-1, vmax=1)
|
||
axes[1, 0].set_title('原始特征 相似度 (150×150)')
|
||
plt.colorbar(im1, ax=axes[1, 0])
|
||
|
||
im2 = axes[1, 1].imshow(sim_fusion, cmap='RdYlBu_r', vmin=-1, vmax=1)
|
||
axes[1, 1].set_title('融合特征 相似度 (150×150)')
|
||
plt.colorbar(im2, ax=axes[1, 1])
|
||
|
||
axes[1, 2].imshow(np.abs(sim_orig - sim_fusion), cmap='YlOrRd')
|
||
axes[1, 2].set_title('相似度变化 |差异|')
|
||
plt.colorbar(im2, ax=axes[1, 2])
|
||
|
||
# VLAD
|
||
if 'vlads' in output:
|
||
vlad = output['vlads'][0].view(16, 128).numpy()
|
||
im = axes[1, 3].imshow(vlad, cmap='RdBu_r', aspect='auto')
|
||
axes[1, 3].set_title('VLAD 融合 (16×128)')
|
||
plt.colorbar(im, ax=axes[1, 3])
|
||
|
||
# 整体架构图
|
||
axes[2, 0].set_title('完整架构')
|
||
arch = [
|
||
'┌─ BEVHead ─────────────┐',
|
||
'│ RICNN + EncodePos │',
|
||
'│ → fea_kpt_original │',
|
||
'│ → vlad_bev │',
|
||
'└───────────────────────┘',
|
||
'┌─ ImgHead ─────────────┐',
|
||
'│ ALNet + NMS │',
|
||
'│ → fea_kpl │',
|
||
'│ → fea_img │',
|
||
'└───────────────────────┘',
|
||
'┌─ FusionHead ──────────┐',
|
||
'│ LocalPool + Converter │',
|
||
'│ Generator + FusionHead│',
|
||
'│ → fea_kpt_fusion │',
|
||
'└───────────────────────────────────────────────────────┘',
|
||
' VLAD = w·vlad_fusion + (1-w)·vlad_bev'
|
||
]
|
||
for i, a in enumerate(arch):
|
||
axes[2, 0].text(0.05, 0.98 - i * 0.075, a, transform=axes[2, 0].transAxes,
|
||
fontsize=7.5, family='monospace')
|
||
axes[2, 0].axis('off')
|
||
|
||
# 模块参数对比
|
||
axes[2, 1].set_title('各模块参数量')
|
||
module_names = []
|
||
module_params = []
|
||
for name, mod in model.named_children():
|
||
p = sum(pm.numel() for pm in mod.parameters())
|
||
if p > 0:
|
||
module_names.append(name)
|
||
module_params.append(p)
|
||
colors = plt.cm.Set3(np.linspace(0, 1, len(module_names)))
|
||
axes[2, 1].barh(range(len(module_names)), module_params, color=colors)
|
||
axes[2, 1].set_yticks(range(len(module_names)))
|
||
axes[2, 1].set_yticklabels(module_names, fontsize=8)
|
||
for i, p in enumerate(module_params):
|
||
axes[2, 1].text(p, i, f' {p/1e3:.0f}K', va='center', fontsize=8)
|
||
|
||
# 数据流汇总
|
||
axes[2, 2].set_title('融合模式数据流')
|
||
flow = [
|
||
'img, bev, relation 输入',
|
||
'├─ ImgHead → ALNet',
|
||
'│ ├─ score_img',
|
||
'│ ├─ fea_img (密集描述子)',
|
||
'│ └─ fea_kpl (关键点)',
|
||
'├─ BEVHead → RICNN',
|
||
'│ ├─ score_bev',
|
||
'│ ├─ fea_bev (密集描述子)',
|
||
'│ ├─ fea_kpt_original',
|
||
'│ └─ vlad_bev',
|
||
'└─ FusionHead',
|
||
' ├─ GridSample → fea_pl_dual, fea_pt_dual',
|
||
' ├─ Converters → 跨模态转换',
|
||
' ├─ Generator → 全景特征',
|
||
' ├─ FusionHead → 融合特征',
|
||
' └─ NetVLAD → vlad_fusion',
|
||
'最终: vlads = w·vlad_fusion + (1-w)·vlad_bev',
|
||
' UOT: → transformation (位姿)',
|
||
]
|
||
for i, f in enumerate(flow):
|
||
axes[2, 2].text(0.05, 0.98 - i * 0.06, f, transform=axes[2, 2].transAxes,
|
||
fontsize=7.5, family='monospace')
|
||
axes[2, 2].axis('off')
|
||
|
||
axes[2, 3].axis('off')
|
||
|
||
plt.suptitle('Fusion 模式: 完整跨模态融合可视化', fontsize=14, fontweight='bold')
|
||
plt.tight_layout()
|
||
path = os.path.join(OUTPUT_DIR, 'full_pipeline_fusion.png')
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f'[保存] {path}')
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='全流水线可视化')
|
||
parser.add_argument('--mode', type=str, default='all',
|
||
choices=['all', 'bev', 'img', 'fusion'],
|
||
help='运行模式')
|
||
args = parser.parse_args()
|
||
|
||
if args.mode in ('all', 'bev'):
|
||
run_bev_only()
|
||
if args.mode in ('all', 'img'):
|
||
run_img_only()
|
||
if args.mode in ('all', 'fusion'):
|
||
run_fusion()
|
||
|
||
print(f'\n所有可视化结果保存在: {OUTPUT_DIR}')
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|