网络测试和学习demo

This commit is contained in:
cyy_mac
2026-05-09 17:03:40 +08:00
parent edbe8fdbf9
commit 78298e56f1
9 changed files with 2868 additions and 0 deletions

View File

@@ -0,0 +1,260 @@
"""
ALNet 网络结构可视化 Demo
===========================
ALNet 是图像分支的特征提取网络,基于 ALIKE 架构。
输入:图像 (B, 3, 192, 576)
输出score_map (B, 1, 192, 576) + descriptor_map (B, 128, 192, 576)
网络由以下部分组成:
block1: ConvBlock(3→16) - 保持分辨率
pool2: MaxPool2d(2) - 下采样 2x
block2: ResBlock(16→32) - 残差块
pool4: MaxPool2d(4) - 下采样 4x
block3: ResBlock(32→64) - 残差块
pool4: MaxPool2d(4) - 下采样 4x
block4: ResBlock(64→128) - 残差块
特征聚合: 4层concat + 上采样 - 多尺度融合
输出头: Conv1x1(128→129) - score + descriptor
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') # 非交互后端,适合服务器
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from ALIKE.alnet import ALNet, ConvBlock, ResBlock
# ============================================================
# 配置
# ============================================================
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output')
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 使用 alike-n 配置(论文中使用)
CFG = {'c1': 16, 'c2': 32, 'c3': 64, 'c4': 128, 'dim': 128, 'single_head': True}
def visualize_tensor(tensor, title, save_name, cmap='viridis', n_channels=8):
"""可视化特征图的多个通道"""
if tensor.dim() == 4:
tensor = tensor[0] # 取第一个batch
C, H, W = tensor.shape
n_show = min(n_channels, C)
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
fig.suptitle(title, fontsize=14, fontweight='bold')
for i in range(n_show):
ax = axes[i // 4, i % 4]
im = ax.imshow(tensor[i].detach().cpu().numpy(), cmap=cmap)
ax.set_title(f'Channel {i}')
ax.axis('off')
plt.colorbar(im, ax=ax, fraction=0.046)
for i in range(n_show, 8):
axes[i // 4, i % 4].axis('off')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, save_name)
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def visualize_score_map(score_map, title, save_name):
"""可视化得分图"""
if score_map.dim() == 4:
score_map = score_map[0, 0]
elif score_map.dim() == 3:
score_map = score_map[0]
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
fig.suptitle(title, fontsize=14, fontweight='bold')
im0 = axes[0].imshow(score_map.detach().cpu().numpy(), cmap='hot')
axes[0].set_title('Score Map (热力图)')
axes[0].axis('off')
plt.colorbar(im0, ax=axes[0])
# 直方图
axes[1].hist(score_map.detach().cpu().numpy().flatten(), bins=50, color='steelblue', edgecolor='white')
axes[1].set_title('Score 分布直方图')
axes[1].set_xlabel('Score Value')
axes[1].set_ylabel('Frequency')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, save_name)
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def visualize_intermediate_features(model, input_tensor):
"""逐层提取并可视化中间特征图"""
print('\n' + '=' * 60)
print('ALNet 中间特征逐层可视化')
print('=' * 60)
x = input_tensor
print(f'输入: {x.shape}')
# Block 1: ConvBlock
x1 = model.block1(x)
print(f'block1 (ConvBlock 3→16): {x1.shape}')
visualize_tensor(x1, 'Block1: ConvBlock 输出 (16通道)', 'alnet_block1.png')
# Pool2 + Block 2
x2 = model.pool2(x1)
x2 = model.block2(x2)
print(f'pool2 + block2 (ResBlock 16→32): {x2.shape}')
visualize_tensor(x2, 'Block2: ResBlock 输出 (32通道) [1/2分辨率]', 'alnet_block2.png')
# Pool4 + Block 3
x3 = model.pool4(x2)
x3 = model.block3(x3)
print(f'pool4 + block3 (ResBlock 32→64): {x3.shape}')
visualize_tensor(x3, 'Block3: ResBlock 输出 (64通道) [1/8分辨率]', 'alnet_block3.png')
# Pool4 + Block 4
x4 = model.pool4(x3)
x4 = model.block4(x4)
print(f'pool4 + block4 (ResBlock 64→128): {x4.shape}')
visualize_tensor(x4, 'Block4: ResBlock 输出 (128通道) [1/32分辨率]', 'alnet_block4.png')
# 特征聚合
f1 = model.gate(model.conv1(x1)) # dim//4 通道
f2 = model.gate(model.conv2(x2))
f3 = model.gate(model.conv3(x3))
f4 = model.gate(model.conv4(x4))
f2_up = model.upsample2(f2)
f3_up = model.upsample8(f3)
f4_up = model.upsample32(f4)
print(f'特征聚合: f1={f1.shape}, f2_up={f2_up.shape}, f3_up={f3_up.shape}, f4_up={f4_up.shape}')
fused = torch.cat([f1, f2_up, f3_up, f4_up], dim=1)
print(f'多尺度拼接后: {fused.shape}')
visualize_tensor(fused, '多尺度特征拼接 (128通道)', 'alnet_fused_features.png', n_channels=8)
# 输出头
output = model.convhead2(fused)
score_map = torch.sigmoid(output[:, -1:, :, :])
descriptor_map = output[:, :-1, :, :]
print(f'Score Map: {score_map.shape}')
print(f'Descriptor Map: {descriptor_map.shape}')
visualize_score_map(score_map, 'ALNet 最终输出 Score Map', 'alnet_final_score.png')
visualize_tensor(descriptor_map, 'ALNet 最终输出 Descriptor Map (128通道)', 'alnet_final_descriptor.png')
def visualize_receptive_field():
"""可视化有效感受野(通过梯度反传)"""
print('\n--- 感受野分析 ---')
model = ALNet(**CFG)
model.eval()
input_tensor = torch.randn(1, 3, 192, 576, requires_grad=True)
score_map, _ = model(input_tensor)
# 对score_map中心点的梯度反传
h, w = score_map.shape[2], score_map.shape[3]
score_map[0, 0, h // 2, w // 2].backward()
grad = input_tensor.grad.abs().sum(dim=1)[0]
fig, ax = plt.subplots(figsize=(12, 4))
im = ax.imshow(grad.detach().cpu().numpy(), cmap='hot')
ax.set_title('ALNet 有效感受野 (梯度幅度)', fontsize=14)
ax.axis('off')
plt.colorbar(im, ax=ax)
path = os.path.join(OUTPUT_DIR, 'alnet_receptive_field.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def analyze_parameters():
"""分析网络参数量"""
print('\n--- 参数量分析 ---')
model = ALNet(**CFG)
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'总参数量: {total:,} ({total / 1e6:.2f}M)')
print(f'可训练参数: {trainable:,} ({trainable / 1e6:.2f}M)')
# 逐模块分析
for name, module in model.named_children():
params = sum(p.numel() for p in module.parameters())
print(f' {name:20s}: {params:>10,} params ({params / 1e3:.1f}K)')
def main():
print('=' * 60)
print('ALNet (图像特征提取网络) 结构与特征可视化')
print('=' * 60)
analyze_parameters()
# 构建模型
model = ALNet(**CFG)
model.eval()
# 模拟输入: 裁剪后的KITTI图像 (192, 576)
input_tensor = torch.randn(1, 3, 192, 576)
# 前向传播
with torch.no_grad():
score_map, descriptor_map = model(input_tensor)
print(f'\n输入尺寸: {input_tensor.shape}')
print(f'Score Map 输出: {score_map.shape} (范围: [{score_map.min():.3f}, {score_map.max():.3f}])')
print(f'Descriptor Map 输出: {descriptor_map.shape}')
# 逐层可视化中间特征
visualize_intermediate_features(model, input_tensor)
# 感受野分析
visualize_receptive_field()
# 网络结构文本总结
print('\n' + '=' * 60)
print('网络结构总结:')
print('=' * 60)
print("""
ALNet (alike-n config):
┌──────────────────────────────────────────────────────┐
│ 输入: (B, 3, 192, 576) │
│ ↓ │
│ block1: ConvBlock(3→16) → (B, 16, 192, 576) │
│ ↓ MaxPool2d(2) │
│ block2: ResBlock(16→32) → (B, 32, 96, 288) │
│ ↓ MaxPool2d(4) │
│ block3: ResBlock(32→64) → (B, 64, 24, 72) │
│ ↓ MaxPool2d(4) │
│ block4: ResBlock(64→128) → (B, 128, 6, 18) │
│ ↓ │
│ 特征聚合: 4尺度1×1conv + 上采样 + concat → (B,128,192,576) │
│ ↓ Conv1x1(128→129) │
│ 输出: score(B,1,192,576) + desc(B,128,192,576) │
└──────────────────────────────────────────────────────┘
block1/2/3/4 各阶段的作用:
- block1: 浅层特征(边缘、角点等低级特征)
- block2: 中层特征(纹理、局部形状)
- block3: 高层特征(语义信息、物体部件)
- block4: 最抽象特征(全局上下文)
- 多尺度融合: 结合各层信息,兼顾定位精度和语义鲁棒性
""")
print(f'\n所有可视化结果保存在: {OUTPUT_DIR}')
if __name__ == '__main__':
main()