Files
fusion_LCD/network_learning/01_alnet_demo.py
2026-05-09 17:03:40 +08:00

261 lines
9.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
ALNet 网络结构可视化 Demo
===========================
ALNet 是图像分支的特征提取网络,基于 ALIKE 架构。
输入:图像 (B, 3, 192, 576)
输出score_map (B, 1, 192, 576) + descriptor_map (B, 128, 192, 576)
网络由以下部分组成:
block1: ConvBlock(3→16) - 保持分辨率
pool2: MaxPool2d(2) - 下采样 2x
block2: ResBlock(16→32) - 残差块
pool4: MaxPool2d(4) - 下采样 4x
block3: ResBlock(32→64) - 残差块
pool4: MaxPool2d(4) - 下采样 4x
block4: ResBlock(64→128) - 残差块
特征聚合: 4层concat + 上采样 - 多尺度融合
输出头: Conv1x1(128→129) - score + descriptor
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') # 非交互后端,适合服务器
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from ALIKE.alnet import ALNet, ConvBlock, ResBlock
# ============================================================
# 配置
# ============================================================
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output')
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 使用 alike-n 配置(论文中使用)
CFG = {'c1': 16, 'c2': 32, 'c3': 64, 'c4': 128, 'dim': 128, 'single_head': True}
def visualize_tensor(tensor, title, save_name, cmap='viridis', n_channels=8):
"""可视化特征图的多个通道"""
if tensor.dim() == 4:
tensor = tensor[0] # 取第一个batch
C, H, W = tensor.shape
n_show = min(n_channels, C)
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
fig.suptitle(title, fontsize=14, fontweight='bold')
for i in range(n_show):
ax = axes[i // 4, i % 4]
im = ax.imshow(tensor[i].detach().cpu().numpy(), cmap=cmap)
ax.set_title(f'Channel {i}')
ax.axis('off')
plt.colorbar(im, ax=ax, fraction=0.046)
for i in range(n_show, 8):
axes[i // 4, i % 4].axis('off')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, save_name)
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def visualize_score_map(score_map, title, save_name):
"""可视化得分图"""
if score_map.dim() == 4:
score_map = score_map[0, 0]
elif score_map.dim() == 3:
score_map = score_map[0]
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
fig.suptitle(title, fontsize=14, fontweight='bold')
im0 = axes[0].imshow(score_map.detach().cpu().numpy(), cmap='hot')
axes[0].set_title('Score Map (热力图)')
axes[0].axis('off')
plt.colorbar(im0, ax=axes[0])
# 直方图
axes[1].hist(score_map.detach().cpu().numpy().flatten(), bins=50, color='steelblue', edgecolor='white')
axes[1].set_title('Score 分布直方图')
axes[1].set_xlabel('Score Value')
axes[1].set_ylabel('Frequency')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, save_name)
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def visualize_intermediate_features(model, input_tensor):
"""逐层提取并可视化中间特征图"""
print('\n' + '=' * 60)
print('ALNet 中间特征逐层可视化')
print('=' * 60)
x = input_tensor
print(f'输入: {x.shape}')
# Block 1: ConvBlock
x1 = model.block1(x)
print(f'block1 (ConvBlock 3→16): {x1.shape}')
visualize_tensor(x1, 'Block1: ConvBlock 输出 (16通道)', 'alnet_block1.png')
# Pool2 + Block 2
x2 = model.pool2(x1)
x2 = model.block2(x2)
print(f'pool2 + block2 (ResBlock 16→32): {x2.shape}')
visualize_tensor(x2, 'Block2: ResBlock 输出 (32通道) [1/2分辨率]', 'alnet_block2.png')
# Pool4 + Block 3
x3 = model.pool4(x2)
x3 = model.block3(x3)
print(f'pool4 + block3 (ResBlock 32→64): {x3.shape}')
visualize_tensor(x3, 'Block3: ResBlock 输出 (64通道) [1/8分辨率]', 'alnet_block3.png')
# Pool4 + Block 4
x4 = model.pool4(x3)
x4 = model.block4(x4)
print(f'pool4 + block4 (ResBlock 64→128): {x4.shape}')
visualize_tensor(x4, 'Block4: ResBlock 输出 (128通道) [1/32分辨率]', 'alnet_block4.png')
# 特征聚合
f1 = model.gate(model.conv1(x1)) # dim//4 通道
f2 = model.gate(model.conv2(x2))
f3 = model.gate(model.conv3(x3))
f4 = model.gate(model.conv4(x4))
f2_up = model.upsample2(f2)
f3_up = model.upsample8(f3)
f4_up = model.upsample32(f4)
print(f'特征聚合: f1={f1.shape}, f2_up={f2_up.shape}, f3_up={f3_up.shape}, f4_up={f4_up.shape}')
fused = torch.cat([f1, f2_up, f3_up, f4_up], dim=1)
print(f'多尺度拼接后: {fused.shape}')
visualize_tensor(fused, '多尺度特征拼接 (128通道)', 'alnet_fused_features.png', n_channels=8)
# 输出头
output = model.convhead2(fused)
score_map = torch.sigmoid(output[:, -1:, :, :])
descriptor_map = output[:, :-1, :, :]
print(f'Score Map: {score_map.shape}')
print(f'Descriptor Map: {descriptor_map.shape}')
visualize_score_map(score_map, 'ALNet 最终输出 Score Map', 'alnet_final_score.png')
visualize_tensor(descriptor_map, 'ALNet 最终输出 Descriptor Map (128通道)', 'alnet_final_descriptor.png')
def visualize_receptive_field():
"""可视化有效感受野(通过梯度反传)"""
print('\n--- 感受野分析 ---')
model = ALNet(**CFG)
model.eval()
input_tensor = torch.randn(1, 3, 192, 576, requires_grad=True)
score_map, _ = model(input_tensor)
# 对score_map中心点的梯度反传
h, w = score_map.shape[2], score_map.shape[3]
score_map[0, 0, h // 2, w // 2].backward()
grad = input_tensor.grad.abs().sum(dim=1)[0]
fig, ax = plt.subplots(figsize=(12, 4))
im = ax.imshow(grad.detach().cpu().numpy(), cmap='hot')
ax.set_title('ALNet 有效感受野 (梯度幅度)', fontsize=14)
ax.axis('off')
plt.colorbar(im, ax=ax)
path = os.path.join(OUTPUT_DIR, 'alnet_receptive_field.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def analyze_parameters():
"""分析网络参数量"""
print('\n--- 参数量分析 ---')
model = ALNet(**CFG)
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'总参数量: {total:,} ({total / 1e6:.2f}M)')
print(f'可训练参数: {trainable:,} ({trainable / 1e6:.2f}M)')
# 逐模块分析
for name, module in model.named_children():
params = sum(p.numel() for p in module.parameters())
print(f' {name:20s}: {params:>10,} params ({params / 1e3:.1f}K)')
def main():
print('=' * 60)
print('ALNet (图像特征提取网络) 结构与特征可视化')
print('=' * 60)
analyze_parameters()
# 构建模型
model = ALNet(**CFG)
model.eval()
# 模拟输入: 裁剪后的KITTI图像 (192, 576)
input_tensor = torch.randn(1, 3, 192, 576)
# 前向传播
with torch.no_grad():
score_map, descriptor_map = model(input_tensor)
print(f'\n输入尺寸: {input_tensor.shape}')
print(f'Score Map 输出: {score_map.shape} (范围: [{score_map.min():.3f}, {score_map.max():.3f}])')
print(f'Descriptor Map 输出: {descriptor_map.shape}')
# 逐层可视化中间特征
visualize_intermediate_features(model, input_tensor)
# 感受野分析
visualize_receptive_field()
# 网络结构文本总结
print('\n' + '=' * 60)
print('网络结构总结:')
print('=' * 60)
print("""
ALNet (alike-n config):
┌──────────────────────────────────────────────────────┐
│ 输入: (B, 3, 192, 576) │
│ ↓ │
│ block1: ConvBlock(3→16) → (B, 16, 192, 576) │
│ ↓ MaxPool2d(2) │
│ block2: ResBlock(16→32) → (B, 32, 96, 288) │
│ ↓ MaxPool2d(4) │
│ block3: ResBlock(32→64) → (B, 64, 24, 72) │
│ ↓ MaxPool2d(4) │
│ block4: ResBlock(64→128) → (B, 128, 6, 18) │
│ ↓ │
│ 特征聚合: 4尺度1×1conv + 上采样 + concat → (B,128,192,576) │
│ ↓ Conv1x1(128→129) │
│ 输出: score(B,1,192,576) + desc(B,128,192,576) │
└──────────────────────────────────────────────────────┘
block1/2/3/4 各阶段的作用:
- block1: 浅层特征(边缘、角点等低级特征)
- block2: 中层特征(纹理、局部形状)
- block3: 高层特征(语义信息、物体部件)
- block4: 最抽象特征(全局上下文)
- 多尺度融合: 结合各层信息,兼顾定位精度和语义鲁棒性
""")
print(f'\n所有可视化结果保存在: {OUTPUT_DIR}')
if __name__ == '__main__':
main()