网络测试和学习demo

This commit is contained in:
cyy_mac
2026-05-09 17:03:40 +08:00
parent edbe8fdbf9
commit 78298e56f1
9 changed files with 2868 additions and 0 deletions

View File

@@ -0,0 +1,516 @@
"""
完整流水线 Demo: 端到端网络结构可视化
=====================================
集成所有子网络,展示从输入到输出的完整数据流。
运行模式:
python 08_full_pipeline_demo.py --mode bev # 仅BEV分支
python 08_full_pipeline_demo.py --mode img # 仅图像分支
python 08_full_pipeline_demo.py --mode fusion # 完整融合模式
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
import sys
import os
import argparse
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from net import Fusion, BEVHead, ImgHead, FusionHead
from BEVNet import RICNN
from ALIKE.alnet import ALNet
from netvlad import NetVLAD
from uot import UOTHead
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output')
os.makedirs(OUTPUT_DIR, exist_ok=True)
def create_dummy_batch_dict(mode='fusion'):
"""创建模拟的batch_dict"""
B = 2 # batch中1对 (query + positive)
batch_dict = {
'batch_size': 2 * B,
}
if mode in ('fusion', 'bev'):
batch_dict['bev'] = torch.randn(2 * B, 7, 320, 320)
batch_dict['bev'][:, :3] = torch.sigmoid(batch_dict['bev'][:, :3]) # 可视通道
batch_dict['bev'][:, 2:3] = (batch_dict['bev'][:, 2:3] > 0.3).float() # guider mask
if mode in ('fusion', 'img'):
batch_dict['img'] = torch.randint(0, 256, (2 * B, 5, 192, 576)).float()
if mode == 'fusion':
# 模拟 relation: (B, max_len, K, 2)
max_len, K = 200, 11 # K=1+10: last dim is bev coord
batch_dict['relation'] = torch.zeros(2 * B, max_len, K, 2, dtype=torch.long)
for i in range(2 * B):
n_valid = 150
batch_dict['relation'][i, :n_valid, :K - 1, 0] = torch.randint(0, 576, (n_valid, K - 1))
batch_dict['relation'][i, :n_valid, :K - 1, 1] = torch.randint(0, 192, (n_valid, K - 1))
batch_dict['relation'][i, :n_valid, K - 1, 0] = torch.randint(0, 320, (n_valid,))
batch_dict['relation'][i, :n_valid, K - 1, 1] = torch.randint(0, 320, (n_valid,))
# pose_to_frame (训练时需要)
angle = 0.3
pose = torch.eye(4).unsqueeze(0).repeat(B, 1, 1)
pose[:, 0, 0] = torch.cos(torch.tensor(angle))
pose[:, 0, 1] = -torch.sin(torch.tensor(angle))
pose[:, 1, 0] = torch.sin(torch.tensor(angle))
pose[:, 1, 1] = torch.cos(torch.tensor(angle))
pose[:, 0, 3] = 2.0
pose[:, 1, 3] = -1.0
batch_dict['pose_to_frame'] = pose.clone()
batch_dict['pose_query'] = torch.eye(4).unsqueeze(0).repeat(B, 1, 1)
batch_dict['pose_positive'] = torch.eye(4).unsqueeze(0).repeat(B, 1, 1)
batch_dict['label_score'] = torch.zeros(B, 320, 320, 2)
batch_dict['id_query'] = torch.arange(B)
batch_dict['id_positive'] = torch.arange(B)
batch_dict['sequence'] = torch.zeros(B, dtype=torch.long)
return batch_dict
def run_bev_only():
"""仅BEV分支"""
print('\n' + '=' * 60)
print('模式: BEV Only (仅点云分支)')
print('=' * 60)
cfg = {
'flag': 'bev',
'kpts_number_bev': 150,
'kpts_number_img': 150,
'cluster_num_bev': 16,
'cluster_num_img': 16,
'cluster_num_fusion': 16,
'sinkhorn_iter': 5,
'vlad_size': 256,
}
model = Fusion(cfg)
model.eval()
total_params = sum(p.numel() for p in model.parameters())
print(f'模型参数量: {total_params:,} ({total_params / 1e6:.2f}M)')
batch_dict = create_dummy_batch_dict('bev')
with torch.no_grad():
output = model(batch_dict)
print('\n输出:')
for k, v in output.items():
if isinstance(v, torch.Tensor):
print(f' {k:30s}: {list(v.shape)}')
else:
print(f' {k:30s}: {v}')
# 可视化BEV分支数据流
fig, axes = plt.subplots(2, 4, figsize=(18, 9))
# BEV输入 (3个可视通道)
if 'bev' in output or 'bev' in batch_dict:
bev_in = batch_dict['bev'][0, :3].permute(1, 2, 0).numpy()
axes[0, 0].imshow(bev_in)
axes[0, 0].set_title('BEV输入 (3通道)')
axes[0, 0].axis('off')
# Score Map
if 'score_bev' in output:
axes[0, 1].imshow(output['score_bev'][0].numpy(), cmap='hot')
axes[0, 1].set_title('BEV Score Map')
axes[0, 1].axis('off')
# 关键点位置
if 'key_points' in output and 'pixels_kpt' in output:
bev_show = batch_dict['bev'][0, :3].permute(1, 2, 0).numpy()
axes[0, 2].imshow(bev_show)
kpt = output['pixels_kpt'][0].numpy()
axes[0, 2].scatter(kpt[:, 1], kpt[:, 0], c='red', s=5, alpha=0.8)
axes[0, 2].set_title(f'BEV Top-{len(kpt)} 关键点')
axes[0, 2].axis('off')
# Descriptor Map (第一通道)
if 'fea_bev' in output:
axes[0, 3].imshow(output['fea_bev'][0, 0].numpy(), cmap='viridis')
axes[0, 3].set_title('BEV Descriptor ch0')
axes[0, 3].axis('off')
# 关键点特征相似度
if 'fea_kpt_original' in output:
fea = output['fea_kpt_original']
# query vs positive 的相似度
B = fea.shape[0] // 2
sim = torch.nn.functional.cosine_similarity(
fea[:B].permute(0, 2, 1).unsqueeze(-1),
fea[B:].permute(0, 2, 1).unsqueeze(-2),
dim=1
)[0]
im = axes[1, 0].imshow(sim.numpy(), cmap='RdYlBu_r', vmin=-1, vmax=1)
axes[1, 0].set_title('Query-Positive 特征相似度')
axes[1, 0].set_xlabel('Positive'); axes[1, 0].set_ylabel('Query')
plt.colorbar(im, ax=axes[1, 0])
# VLAD
if 'vlads' in output:
vlad = output['vlads'][0].view(16, 128).numpy()
im = axes[1, 1].imshow(vlad, cmap='RdBu_r', aspect='auto')
axes[1, 1].set_title('VLAD描述子 (16×128)')
axes[1, 1].set_xlabel('Feature Dim'); axes[1, 1].set_ylabel('Cluster')
plt.colorbar(im, ax=axes[1, 1])
# 数据流图
axes[1, 2].set_title('BEV分支数据流')
flow = [
'bev (7,320,320)',
'→ x = bev[:3] (可视BEV)',
'→ points = bev[3:7] (坐标)',
'→ RICNN前向',
'→ score_bev (1,320,320)',
'→ fea_bev (128,320,320)',
'→ NMS + Top-K(150)',
'→ key_points (150,4)',
'→ fea_kpt (128,150)',
'→ EncodePosition',
'→ NetVLAD → vlad_bev (2048)',
]
for i, f in enumerate(flow):
axes[1, 2].text(0.1, 0.95 - i * 0.1, f, transform=axes[1, 2].transAxes,
fontsize=9, family='monospace')
axes[1, 2].axis('off')
# 参数量饼图
axes[1, 3].set_title('BEV分支参数分布')
modules = dict(model.bev.feature_extractor.named_children())
sizes = []
labels = []
for name, mod in modules.items():
p = sum(pm.numel() for pm in mod.parameters())
if p > 0:
sizes.append(p)
labels.append(f'{name}\n({p/1e3:.0f}K)')
axes[1, 3].pie(sizes, labels=labels, autopct='%1.1f%%', textprops={'fontsize': 8})
plt.suptitle('BEV Only 模式: 点云分支可视化', fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'full_pipeline_bev.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f'[保存] {path}')
def run_img_only():
"""仅图像分支"""
print('\n' + '=' * 60)
print('模式: Image Only (仅图像分支)')
print('=' * 60)
cfg = {
'flag': 'img',
'kpts_number_bev': 150,
'kpts_number_img': 150,
'cluster_num_bev': 16,
'cluster_num_img': 16,
'cluster_num_fusion': 16,
'sinkhorn_iter': 5,
'vlad_size': 256,
}
model = Fusion(cfg)
model.eval()
total_params = sum(p.numel() for p in model.parameters())
print(f'模型参数量: {total_params:,} ({total_params / 1e6:.2f}M)')
batch_dict = create_dummy_batch_dict('img')
with torch.no_grad():
output = model(batch_dict)
print('\n输出:')
for k, v in output.items():
if isinstance(v, torch.Tensor):
print(f' {k:30s}: {list(v.shape)}')
else:
print(f' {k:30s}: {v}')
# 可视化
fig, axes = plt.subplots(2, 4, figsize=(18, 9))
# 输入图像
img_in = batch_dict['img'][0, :3].permute(1, 2, 0).numpy().astype(np.uint8)
axes[0, 0].imshow(img_in)
axes[0, 0].set_title('图像输入 (192×576)')
axes[0, 0].axis('off')
# Score Map
if 'score_img' in output:
axes[0, 1].imshow(output['score_img'][0, 0].numpy(), cmap='hot')
axes[0, 1].set_title('图像 Score Map')
axes[0, 1].axis('off')
# 关键点
if 'key_pixels' in output:
axes[0, 2].imshow(img_in)
kpt = output['key_pixels'][0].numpy()
axes[0, 2].scatter(kpt[:, 1], kpt[:, 0], c='red', s=5, alpha=0.8)
axes[0, 2].set_title(f'Top-{len(kpt)} 关键点')
axes[0, 2].axis('off')
# Descriptor Map
if 'fea_img' in output:
axes[0, 3].imshow(output['fea_img'][0, 0].numpy(), cmap='viridis')
axes[0, 3].set_title('图像 Descriptor ch0')
axes[0, 3].axis('off')
# 关键点特征相似度
if 'fea_kpl' in output:
fea = output['fea_kpl']
B = fea.shape[0] // 2
sim = torch.nn.functional.cosine_similarity(
fea[:B].permute(0, 2, 1).unsqueeze(-1),
fea[B:].permute(0, 2, 1).unsqueeze(-2),
dim=1
)[0]
im = axes[1, 0].imshow(sim.numpy(), cmap='RdYlBu_r', vmin=-1, vmax=1)
axes[1, 0].set_title('Query-Positive 特征相似度')
plt.colorbar(im, ax=axes[1, 0])
# 数据流图
axes[1, 1].set_title('图像分支数据流')
flow = [
'img (5,192,576)',
'→ x = img[:3]/255',
'→ ALNet前向',
'→ score_img (1,192,576)',
'→ fea_img (128,192,576)',
'→ NMS(2) + Top-K(150)',
'→ key_pixels (150,2)',
'→ fea_kpl (128,150)',
]
for i, f in enumerate(flow):
axes[1, 1].text(0.1, 0.95 - i * 0.11, f, transform=axes[1, 1].transAxes,
fontsize=9, family='monospace')
axes[1, 1].axis('off')
# 参数量饼图
axes[1, 2].set_title('图像分支参数分布')
modules = dict(model.img.feature_extractor.named_children())
sizes = []
labels = []
for name, mod in modules.items():
p = sum(pm.numel() for pm in mod.parameters())
if p > 0:
sizes.append(p)
labels.append(f'{name}\n({p/1e3:.0f}K)')
axes[1, 2].pie(sizes, labels=labels, autopct='%1.1f%%', textprops={'fontsize': 8})
axes[1, 3].axis('off')
plt.suptitle('Image Only 模式: 图像分支可视化', fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'full_pipeline_img.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f'[保存] {path}')
def run_fusion():
"""完整融合模式"""
print('\n' + '=' * 60)
print('模式: Fusion (完整融合)')
print('=' * 60)
cfg = {
'flag': 'fusion',
'kpts_number_bev': 150,
'kpts_number_img': 150,
'cluster_num_bev': 16,
'cluster_num_img': 16,
'cluster_num_fusion': 16,
'sinkhorn_iter': 5,
'vlad_size': 256,
}
model = Fusion(cfg)
model.eval()
total_params = sum(p.numel() for p in model.parameters())
print(f'模型参数量: {total_params:,} ({total_params / 1e6:.2f}M)')
batch_dict = create_dummy_batch_dict('fusion')
with torch.no_grad():
output = model(batch_dict)
print('\n输出:')
for k, v in output.items():
if isinstance(v, torch.Tensor):
print(f' {k:30s}: {list(v.shape)}')
else:
print(f' {k:30s}: {v}')
# 可视化融合数据流
fig, axes = plt.subplots(3, 4, figsize=(22, 15))
# BEV输入
bev_in = batch_dict['bev'][0, :3].permute(1, 2, 0).numpy()
axes[0, 0].imshow(bev_in)
axes[0, 0].set_title('BEV 输入 (320×320)')
axes[0, 0].axis('off')
# 图像输入
img_in = batch_dict['img'][0, :3].permute(1, 2, 0).numpy().astype(np.uint8)
axes[0, 1].imshow(img_in)
axes[0, 1].set_title('图像输入 (192×576)')
axes[0, 1].axis('off')
# Score maps
if 'score_bev' in output:
axes[0, 2].imshow(output['score_bev'][0].numpy(), cmap='hot')
axes[0, 2].set_title('BEV Score')
axes[0, 2].axis('off')
if 'score_img' in output:
axes[0, 3].imshow(output['score_img'][0, 0].numpy(), cmap='hot')
axes[0, 3].set_title('Image Score')
axes[0, 3].axis('off')
# 融合特征空间中的相似度
if 'fea_kpt_original' in output and 'fea_kpt_fusion' in output:
fea_orig = output['fea_kpt_original']
fea_fusion = output['fea_kpt_fusion']
B = fea_orig.shape[0] // 2
sim_orig = torch.nn.functional.cosine_similarity(
fea_orig[:B].permute(0, 2, 1).unsqueeze(-1),
fea_orig[B:].permute(0, 2, 1).unsqueeze(-2),
dim=1
)[0].numpy()
sim_fusion = torch.nn.functional.cosine_similarity(
fea_fusion[:B].permute(0, 2, 1).unsqueeze(-1),
fea_fusion[B:].permute(0, 2, 1).unsqueeze(-2),
dim=1
)[0].numpy()
im1 = axes[1, 0].imshow(sim_orig, cmap='RdYlBu_r', vmin=-1, vmax=1)
axes[1, 0].set_title('原始特征 相似度 (150×150)')
plt.colorbar(im1, ax=axes[1, 0])
im2 = axes[1, 1].imshow(sim_fusion, cmap='RdYlBu_r', vmin=-1, vmax=1)
axes[1, 1].set_title('融合特征 相似度 (150×150)')
plt.colorbar(im2, ax=axes[1, 1])
axes[1, 2].imshow(np.abs(sim_orig - sim_fusion), cmap='YlOrRd')
axes[1, 2].set_title('相似度变化 |差异|')
plt.colorbar(im2, ax=axes[1, 2])
# VLAD
if 'vlads' in output:
vlad = output['vlads'][0].view(16, 128).numpy()
im = axes[1, 3].imshow(vlad, cmap='RdBu_r', aspect='auto')
axes[1, 3].set_title('VLAD 融合 (16×128)')
plt.colorbar(im, ax=axes[1, 3])
# 整体架构图
axes[2, 0].set_title('完整架构')
arch = [
'┌─ BEVHead ─────────────┐',
'│ RICNN + EncodePos │',
'│ → fea_kpt_original │',
'│ → vlad_bev │',
'└───────────────────────┘',
'┌─ ImgHead ─────────────┐',
'│ ALNet + NMS │',
'│ → fea_kpl │',
'│ → fea_img │',
'└───────────────────────┘',
'┌─ FusionHead ──────────┐',
'│ LocalPool + Converter │',
'│ Generator + FusionHead│',
'│ → fea_kpt_fusion │',
'└───────────────────────────────────────────────────────┘',
' VLAD = w·vlad_fusion + (1-w)·vlad_bev'
]
for i, a in enumerate(arch):
axes[2, 0].text(0.05, 0.98 - i * 0.075, a, transform=axes[2, 0].transAxes,
fontsize=7.5, family='monospace')
axes[2, 0].axis('off')
# 模块参数对比
axes[2, 1].set_title('各模块参数量')
module_names = []
module_params = []
for name, mod in model.named_children():
p = sum(pm.numel() for pm in mod.parameters())
if p > 0:
module_names.append(name)
module_params.append(p)
colors = plt.cm.Set3(np.linspace(0, 1, len(module_names)))
axes[2, 1].barh(range(len(module_names)), module_params, color=colors)
axes[2, 1].set_yticks(range(len(module_names)))
axes[2, 1].set_yticklabels(module_names, fontsize=8)
for i, p in enumerate(module_params):
axes[2, 1].text(p, i, f' {p/1e3:.0f}K', va='center', fontsize=8)
# 数据流汇总
axes[2, 2].set_title('融合模式数据流')
flow = [
'img, bev, relation 输入',
'├─ ImgHead → ALNet',
'│ ├─ score_img',
'│ ├─ fea_img (密集描述子)',
'│ └─ fea_kpl (关键点)',
'├─ BEVHead → RICNN',
'│ ├─ score_bev',
'│ ├─ fea_bev (密集描述子)',
'│ ├─ fea_kpt_original',
'│ └─ vlad_bev',
'└─ FusionHead',
' ├─ GridSample → fea_pl_dual, fea_pt_dual',
' ├─ Converters → 跨模态转换',
' ├─ Generator → 全景特征',
' ├─ FusionHead → 融合特征',
' └─ NetVLAD → vlad_fusion',
'最终: vlads = w·vlad_fusion + (1-w)·vlad_bev',
' UOT: → transformation (位姿)',
]
for i, f in enumerate(flow):
axes[2, 2].text(0.05, 0.98 - i * 0.06, f, transform=axes[2, 2].transAxes,
fontsize=7.5, family='monospace')
axes[2, 2].axis('off')
axes[2, 3].axis('off')
plt.suptitle('Fusion 模式: 完整跨模态融合可视化', fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'full_pipeline_fusion.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f'[保存] {path}')
def main():
parser = argparse.ArgumentParser(description='全流水线可视化')
parser.add_argument('--mode', type=str, default='all',
choices=['all', 'bev', 'img', 'fusion'],
help='运行模式')
args = parser.parse_args()
if args.mode in ('all', 'bev'):
run_bev_only()
if args.mode in ('all', 'img'):
run_img_only()
if args.mode in ('all', 'fusion'):
run_fusion()
print(f'\n所有可视化结果保存在: {OUTPUT_DIR}')
if __name__ == '__main__':
main()