Files
fusion_LCD/network_learning/08_full_pipeline_demo.py
2026-05-09 17:03:40 +08:00

517 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
完整流水线 Demo: 端到端网络结构可视化
=====================================
集成所有子网络,展示从输入到输出的完整数据流。
运行模式:
python 08_full_pipeline_demo.py --mode bev # 仅BEV分支
python 08_full_pipeline_demo.py --mode img # 仅图像分支
python 08_full_pipeline_demo.py --mode fusion # 完整融合模式
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
import sys
import os
import argparse
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from net import Fusion, BEVHead, ImgHead, FusionHead
from BEVNet import RICNN
from ALIKE.alnet import ALNet
from netvlad import NetVLAD
from uot import UOTHead
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output')
os.makedirs(OUTPUT_DIR, exist_ok=True)
def create_dummy_batch_dict(mode='fusion'):
"""创建模拟的batch_dict"""
B = 2 # batch中1对 (query + positive)
batch_dict = {
'batch_size': 2 * B,
}
if mode in ('fusion', 'bev'):
batch_dict['bev'] = torch.randn(2 * B, 7, 320, 320)
batch_dict['bev'][:, :3] = torch.sigmoid(batch_dict['bev'][:, :3]) # 可视通道
batch_dict['bev'][:, 2:3] = (batch_dict['bev'][:, 2:3] > 0.3).float() # guider mask
if mode in ('fusion', 'img'):
batch_dict['img'] = torch.randint(0, 256, (2 * B, 5, 192, 576)).float()
if mode == 'fusion':
# 模拟 relation: (B, max_len, K, 2)
max_len, K = 200, 11 # K=1+10: last dim is bev coord
batch_dict['relation'] = torch.zeros(2 * B, max_len, K, 2, dtype=torch.long)
for i in range(2 * B):
n_valid = 150
batch_dict['relation'][i, :n_valid, :K - 1, 0] = torch.randint(0, 576, (n_valid, K - 1))
batch_dict['relation'][i, :n_valid, :K - 1, 1] = torch.randint(0, 192, (n_valid, K - 1))
batch_dict['relation'][i, :n_valid, K - 1, 0] = torch.randint(0, 320, (n_valid,))
batch_dict['relation'][i, :n_valid, K - 1, 1] = torch.randint(0, 320, (n_valid,))
# pose_to_frame (训练时需要)
angle = 0.3
pose = torch.eye(4).unsqueeze(0).repeat(B, 1, 1)
pose[:, 0, 0] = torch.cos(torch.tensor(angle))
pose[:, 0, 1] = -torch.sin(torch.tensor(angle))
pose[:, 1, 0] = torch.sin(torch.tensor(angle))
pose[:, 1, 1] = torch.cos(torch.tensor(angle))
pose[:, 0, 3] = 2.0
pose[:, 1, 3] = -1.0
batch_dict['pose_to_frame'] = pose.clone()
batch_dict['pose_query'] = torch.eye(4).unsqueeze(0).repeat(B, 1, 1)
batch_dict['pose_positive'] = torch.eye(4).unsqueeze(0).repeat(B, 1, 1)
batch_dict['label_score'] = torch.zeros(B, 320, 320, 2)
batch_dict['id_query'] = torch.arange(B)
batch_dict['id_positive'] = torch.arange(B)
batch_dict['sequence'] = torch.zeros(B, dtype=torch.long)
return batch_dict
def run_bev_only():
"""仅BEV分支"""
print('\n' + '=' * 60)
print('模式: BEV Only (仅点云分支)')
print('=' * 60)
cfg = {
'flag': 'bev',
'kpts_number_bev': 150,
'kpts_number_img': 150,
'cluster_num_bev': 16,
'cluster_num_img': 16,
'cluster_num_fusion': 16,
'sinkhorn_iter': 5,
'vlad_size': 256,
}
model = Fusion(cfg)
model.eval()
total_params = sum(p.numel() for p in model.parameters())
print(f'模型参数量: {total_params:,} ({total_params / 1e6:.2f}M)')
batch_dict = create_dummy_batch_dict('bev')
with torch.no_grad():
output = model(batch_dict)
print('\n输出:')
for k, v in output.items():
if isinstance(v, torch.Tensor):
print(f' {k:30s}: {list(v.shape)}')
else:
print(f' {k:30s}: {v}')
# 可视化BEV分支数据流
fig, axes = plt.subplots(2, 4, figsize=(18, 9))
# BEV输入 (3个可视通道)
if 'bev' in output or 'bev' in batch_dict:
bev_in = batch_dict['bev'][0, :3].permute(1, 2, 0).numpy()
axes[0, 0].imshow(bev_in)
axes[0, 0].set_title('BEV输入 (3通道)')
axes[0, 0].axis('off')
# Score Map
if 'score_bev' in output:
axes[0, 1].imshow(output['score_bev'][0].numpy(), cmap='hot')
axes[0, 1].set_title('BEV Score Map')
axes[0, 1].axis('off')
# 关键点位置
if 'key_points' in output and 'pixels_kpt' in output:
bev_show = batch_dict['bev'][0, :3].permute(1, 2, 0).numpy()
axes[0, 2].imshow(bev_show)
kpt = output['pixels_kpt'][0].numpy()
axes[0, 2].scatter(kpt[:, 1], kpt[:, 0], c='red', s=5, alpha=0.8)
axes[0, 2].set_title(f'BEV Top-{len(kpt)} 关键点')
axes[0, 2].axis('off')
# Descriptor Map (第一通道)
if 'fea_bev' in output:
axes[0, 3].imshow(output['fea_bev'][0, 0].numpy(), cmap='viridis')
axes[0, 3].set_title('BEV Descriptor ch0')
axes[0, 3].axis('off')
# 关键点特征相似度
if 'fea_kpt_original' in output:
fea = output['fea_kpt_original']
# query vs positive 的相似度
B = fea.shape[0] // 2
sim = torch.nn.functional.cosine_similarity(
fea[:B].permute(0, 2, 1).unsqueeze(-1),
fea[B:].permute(0, 2, 1).unsqueeze(-2),
dim=1
)[0]
im = axes[1, 0].imshow(sim.numpy(), cmap='RdYlBu_r', vmin=-1, vmax=1)
axes[1, 0].set_title('Query-Positive 特征相似度')
axes[1, 0].set_xlabel('Positive'); axes[1, 0].set_ylabel('Query')
plt.colorbar(im, ax=axes[1, 0])
# VLAD
if 'vlads' in output:
vlad = output['vlads'][0].view(16, 128).numpy()
im = axes[1, 1].imshow(vlad, cmap='RdBu_r', aspect='auto')
axes[1, 1].set_title('VLAD描述子 (16×128)')
axes[1, 1].set_xlabel('Feature Dim'); axes[1, 1].set_ylabel('Cluster')
plt.colorbar(im, ax=axes[1, 1])
# 数据流图
axes[1, 2].set_title('BEV分支数据流')
flow = [
'bev (7,320,320)',
'→ x = bev[:3] (可视BEV)',
'→ points = bev[3:7] (坐标)',
'→ RICNN前向',
'→ score_bev (1,320,320)',
'→ fea_bev (128,320,320)',
'→ NMS + Top-K(150)',
'→ key_points (150,4)',
'→ fea_kpt (128,150)',
'→ EncodePosition',
'→ NetVLAD → vlad_bev (2048)',
]
for i, f in enumerate(flow):
axes[1, 2].text(0.1, 0.95 - i * 0.1, f, transform=axes[1, 2].transAxes,
fontsize=9, family='monospace')
axes[1, 2].axis('off')
# 参数量饼图
axes[1, 3].set_title('BEV分支参数分布')
modules = dict(model.bev.feature_extractor.named_children())
sizes = []
labels = []
for name, mod in modules.items():
p = sum(pm.numel() for pm in mod.parameters())
if p > 0:
sizes.append(p)
labels.append(f'{name}\n({p/1e3:.0f}K)')
axes[1, 3].pie(sizes, labels=labels, autopct='%1.1f%%', textprops={'fontsize': 8})
plt.suptitle('BEV Only 模式: 点云分支可视化', fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'full_pipeline_bev.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f'[保存] {path}')
def run_img_only():
"""仅图像分支"""
print('\n' + '=' * 60)
print('模式: Image Only (仅图像分支)')
print('=' * 60)
cfg = {
'flag': 'img',
'kpts_number_bev': 150,
'kpts_number_img': 150,
'cluster_num_bev': 16,
'cluster_num_img': 16,
'cluster_num_fusion': 16,
'sinkhorn_iter': 5,
'vlad_size': 256,
}
model = Fusion(cfg)
model.eval()
total_params = sum(p.numel() for p in model.parameters())
print(f'模型参数量: {total_params:,} ({total_params / 1e6:.2f}M)')
batch_dict = create_dummy_batch_dict('img')
with torch.no_grad():
output = model(batch_dict)
print('\n输出:')
for k, v in output.items():
if isinstance(v, torch.Tensor):
print(f' {k:30s}: {list(v.shape)}')
else:
print(f' {k:30s}: {v}')
# 可视化
fig, axes = plt.subplots(2, 4, figsize=(18, 9))
# 输入图像
img_in = batch_dict['img'][0, :3].permute(1, 2, 0).numpy().astype(np.uint8)
axes[0, 0].imshow(img_in)
axes[0, 0].set_title('图像输入 (192×576)')
axes[0, 0].axis('off')
# Score Map
if 'score_img' in output:
axes[0, 1].imshow(output['score_img'][0, 0].numpy(), cmap='hot')
axes[0, 1].set_title('图像 Score Map')
axes[0, 1].axis('off')
# 关键点
if 'key_pixels' in output:
axes[0, 2].imshow(img_in)
kpt = output['key_pixels'][0].numpy()
axes[0, 2].scatter(kpt[:, 1], kpt[:, 0], c='red', s=5, alpha=0.8)
axes[0, 2].set_title(f'Top-{len(kpt)} 关键点')
axes[0, 2].axis('off')
# Descriptor Map
if 'fea_img' in output:
axes[0, 3].imshow(output['fea_img'][0, 0].numpy(), cmap='viridis')
axes[0, 3].set_title('图像 Descriptor ch0')
axes[0, 3].axis('off')
# 关键点特征相似度
if 'fea_kpl' in output:
fea = output['fea_kpl']
B = fea.shape[0] // 2
sim = torch.nn.functional.cosine_similarity(
fea[:B].permute(0, 2, 1).unsqueeze(-1),
fea[B:].permute(0, 2, 1).unsqueeze(-2),
dim=1
)[0]
im = axes[1, 0].imshow(sim.numpy(), cmap='RdYlBu_r', vmin=-1, vmax=1)
axes[1, 0].set_title('Query-Positive 特征相似度')
plt.colorbar(im, ax=axes[1, 0])
# 数据流图
axes[1, 1].set_title('图像分支数据流')
flow = [
'img (5,192,576)',
'→ x = img[:3]/255',
'→ ALNet前向',
'→ score_img (1,192,576)',
'→ fea_img (128,192,576)',
'→ NMS(2) + Top-K(150)',
'→ key_pixels (150,2)',
'→ fea_kpl (128,150)',
]
for i, f in enumerate(flow):
axes[1, 1].text(0.1, 0.95 - i * 0.11, f, transform=axes[1, 1].transAxes,
fontsize=9, family='monospace')
axes[1, 1].axis('off')
# 参数量饼图
axes[1, 2].set_title('图像分支参数分布')
modules = dict(model.img.feature_extractor.named_children())
sizes = []
labels = []
for name, mod in modules.items():
p = sum(pm.numel() for pm in mod.parameters())
if p > 0:
sizes.append(p)
labels.append(f'{name}\n({p/1e3:.0f}K)')
axes[1, 2].pie(sizes, labels=labels, autopct='%1.1f%%', textprops={'fontsize': 8})
axes[1, 3].axis('off')
plt.suptitle('Image Only 模式: 图像分支可视化', fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'full_pipeline_img.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f'[保存] {path}')
def run_fusion():
"""完整融合模式"""
print('\n' + '=' * 60)
print('模式: Fusion (完整融合)')
print('=' * 60)
cfg = {
'flag': 'fusion',
'kpts_number_bev': 150,
'kpts_number_img': 150,
'cluster_num_bev': 16,
'cluster_num_img': 16,
'cluster_num_fusion': 16,
'sinkhorn_iter': 5,
'vlad_size': 256,
}
model = Fusion(cfg)
model.eval()
total_params = sum(p.numel() for p in model.parameters())
print(f'模型参数量: {total_params:,} ({total_params / 1e6:.2f}M)')
batch_dict = create_dummy_batch_dict('fusion')
with torch.no_grad():
output = model(batch_dict)
print('\n输出:')
for k, v in output.items():
if isinstance(v, torch.Tensor):
print(f' {k:30s}: {list(v.shape)}')
else:
print(f' {k:30s}: {v}')
# 可视化融合数据流
fig, axes = plt.subplots(3, 4, figsize=(22, 15))
# BEV输入
bev_in = batch_dict['bev'][0, :3].permute(1, 2, 0).numpy()
axes[0, 0].imshow(bev_in)
axes[0, 0].set_title('BEV 输入 (320×320)')
axes[0, 0].axis('off')
# 图像输入
img_in = batch_dict['img'][0, :3].permute(1, 2, 0).numpy().astype(np.uint8)
axes[0, 1].imshow(img_in)
axes[0, 1].set_title('图像输入 (192×576)')
axes[0, 1].axis('off')
# Score maps
if 'score_bev' in output:
axes[0, 2].imshow(output['score_bev'][0].numpy(), cmap='hot')
axes[0, 2].set_title('BEV Score')
axes[0, 2].axis('off')
if 'score_img' in output:
axes[0, 3].imshow(output['score_img'][0, 0].numpy(), cmap='hot')
axes[0, 3].set_title('Image Score')
axes[0, 3].axis('off')
# 融合特征空间中的相似度
if 'fea_kpt_original' in output and 'fea_kpt_fusion' in output:
fea_orig = output['fea_kpt_original']
fea_fusion = output['fea_kpt_fusion']
B = fea_orig.shape[0] // 2
sim_orig = torch.nn.functional.cosine_similarity(
fea_orig[:B].permute(0, 2, 1).unsqueeze(-1),
fea_orig[B:].permute(0, 2, 1).unsqueeze(-2),
dim=1
)[0].numpy()
sim_fusion = torch.nn.functional.cosine_similarity(
fea_fusion[:B].permute(0, 2, 1).unsqueeze(-1),
fea_fusion[B:].permute(0, 2, 1).unsqueeze(-2),
dim=1
)[0].numpy()
im1 = axes[1, 0].imshow(sim_orig, cmap='RdYlBu_r', vmin=-1, vmax=1)
axes[1, 0].set_title('原始特征 相似度 (150×150)')
plt.colorbar(im1, ax=axes[1, 0])
im2 = axes[1, 1].imshow(sim_fusion, cmap='RdYlBu_r', vmin=-1, vmax=1)
axes[1, 1].set_title('融合特征 相似度 (150×150)')
plt.colorbar(im2, ax=axes[1, 1])
axes[1, 2].imshow(np.abs(sim_orig - sim_fusion), cmap='YlOrRd')
axes[1, 2].set_title('相似度变化 |差异|')
plt.colorbar(im2, ax=axes[1, 2])
# VLAD
if 'vlads' in output:
vlad = output['vlads'][0].view(16, 128).numpy()
im = axes[1, 3].imshow(vlad, cmap='RdBu_r', aspect='auto')
axes[1, 3].set_title('VLAD 融合 (16×128)')
plt.colorbar(im, ax=axes[1, 3])
# 整体架构图
axes[2, 0].set_title('完整架构')
arch = [
'┌─ BEVHead ─────────────┐',
'│ RICNN + EncodePos │',
'│ → fea_kpt_original │',
'│ → vlad_bev │',
'└───────────────────────┘',
'┌─ ImgHead ─────────────┐',
'│ ALNet + NMS │',
'│ → fea_kpl │',
'│ → fea_img │',
'└───────────────────────┘',
'┌─ FusionHead ──────────┐',
'│ LocalPool + Converter │',
'│ Generator + FusionHead│',
'│ → fea_kpt_fusion │',
'└───────────────────────────────────────────────────────┘',
' VLAD = w·vlad_fusion + (1-w)·vlad_bev'
]
for i, a in enumerate(arch):
axes[2, 0].text(0.05, 0.98 - i * 0.075, a, transform=axes[2, 0].transAxes,
fontsize=7.5, family='monospace')
axes[2, 0].axis('off')
# 模块参数对比
axes[2, 1].set_title('各模块参数量')
module_names = []
module_params = []
for name, mod in model.named_children():
p = sum(pm.numel() for pm in mod.parameters())
if p > 0:
module_names.append(name)
module_params.append(p)
colors = plt.cm.Set3(np.linspace(0, 1, len(module_names)))
axes[2, 1].barh(range(len(module_names)), module_params, color=colors)
axes[2, 1].set_yticks(range(len(module_names)))
axes[2, 1].set_yticklabels(module_names, fontsize=8)
for i, p in enumerate(module_params):
axes[2, 1].text(p, i, f' {p/1e3:.0f}K', va='center', fontsize=8)
# 数据流汇总
axes[2, 2].set_title('融合模式数据流')
flow = [
'img, bev, relation 输入',
'├─ ImgHead → ALNet',
'│ ├─ score_img',
'│ ├─ fea_img (密集描述子)',
'│ └─ fea_kpl (关键点)',
'├─ BEVHead → RICNN',
'│ ├─ score_bev',
'│ ├─ fea_bev (密集描述子)',
'│ ├─ fea_kpt_original',
'│ └─ vlad_bev',
'└─ FusionHead',
' ├─ GridSample → fea_pl_dual, fea_pt_dual',
' ├─ Converters → 跨模态转换',
' ├─ Generator → 全景特征',
' ├─ FusionHead → 融合特征',
' └─ NetVLAD → vlad_fusion',
'最终: vlads = w·vlad_fusion + (1-w)·vlad_bev',
' UOT: → transformation (位姿)',
]
for i, f in enumerate(flow):
axes[2, 2].text(0.05, 0.98 - i * 0.06, f, transform=axes[2, 2].transAxes,
fontsize=7.5, family='monospace')
axes[2, 2].axis('off')
axes[2, 3].axis('off')
plt.suptitle('Fusion 模式: 完整跨模态融合可视化', fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'full_pipeline_fusion.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f'[保存] {path}')
def main():
parser = argparse.ArgumentParser(description='全流水线可视化')
parser.add_argument('--mode', type=str, default='all',
choices=['all', 'bev', 'img', 'fusion'],
help='运行模式')
args = parser.parse_args()
if args.mode in ('all', 'bev'):
run_bev_only()
if args.mode in ('all', 'img'):
run_img_only()
if args.mode in ('all', 'fusion'):
run_fusion()
print(f'\n所有可视化结果保存在: {OUTPUT_DIR}')
if __name__ == '__main__':
main()