261 lines
9.2 KiB
Python
261 lines
9.2 KiB
Python
"""
|
||
ALNet 网络结构可视化 Demo
|
||
===========================
|
||
ALNet 是图像分支的特征提取网络,基于 ALIKE 架构。
|
||
输入:图像 (B, 3, 192, 576)
|
||
输出:score_map (B, 1, 192, 576) + descriptor_map (B, 128, 192, 576)
|
||
|
||
网络由以下部分组成:
|
||
block1: ConvBlock(3→16) - 保持分辨率
|
||
pool2: MaxPool2d(2) - 下采样 2x
|
||
block2: ResBlock(16→32) - 残差块
|
||
pool4: MaxPool2d(4) - 下采样 4x
|
||
block3: ResBlock(32→64) - 残差块
|
||
pool4: MaxPool2d(4) - 下采样 4x
|
||
block4: ResBlock(64→128) - 残差块
|
||
特征聚合: 4层concat + 上采样 - 多尺度融合
|
||
输出头: Conv1x1(128→129) - score + descriptor
|
||
"""
|
||
|
||
import torch
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib
|
||
matplotlib.use('Agg') # 非交互后端,适合服务器
|
||
|
||
import sys
|
||
import os
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from ALIKE.alnet import ALNet, ConvBlock, ResBlock
|
||
|
||
# ============================================================
|
||
# 配置
|
||
# ============================================================
|
||
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output')
|
||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||
|
||
# 使用 alike-n 配置(论文中使用)
|
||
CFG = {'c1': 16, 'c2': 32, 'c3': 64, 'c4': 128, 'dim': 128, 'single_head': True}
|
||
|
||
|
||
def visualize_tensor(tensor, title, save_name, cmap='viridis', n_channels=8):
|
||
"""可视化特征图的多个通道"""
|
||
if tensor.dim() == 4:
|
||
tensor = tensor[0] # 取第一个batch
|
||
C, H, W = tensor.shape
|
||
n_show = min(n_channels, C)
|
||
|
||
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
|
||
fig.suptitle(title, fontsize=14, fontweight='bold')
|
||
|
||
for i in range(n_show):
|
||
ax = axes[i // 4, i % 4]
|
||
im = ax.imshow(tensor[i].detach().cpu().numpy(), cmap=cmap)
|
||
ax.set_title(f'Channel {i}')
|
||
ax.axis('off')
|
||
plt.colorbar(im, ax=ax, fraction=0.046)
|
||
|
||
for i in range(n_show, 8):
|
||
axes[i // 4, i % 4].axis('off')
|
||
|
||
plt.tight_layout()
|
||
path = os.path.join(OUTPUT_DIR, save_name)
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f' [保存] {path}')
|
||
|
||
|
||
def visualize_score_map(score_map, title, save_name):
|
||
"""可视化得分图"""
|
||
if score_map.dim() == 4:
|
||
score_map = score_map[0, 0]
|
||
elif score_map.dim() == 3:
|
||
score_map = score_map[0]
|
||
|
||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||
fig.suptitle(title, fontsize=14, fontweight='bold')
|
||
|
||
im0 = axes[0].imshow(score_map.detach().cpu().numpy(), cmap='hot')
|
||
axes[0].set_title('Score Map (热力图)')
|
||
axes[0].axis('off')
|
||
plt.colorbar(im0, ax=axes[0])
|
||
|
||
# 直方图
|
||
axes[1].hist(score_map.detach().cpu().numpy().flatten(), bins=50, color='steelblue', edgecolor='white')
|
||
axes[1].set_title('Score 分布直方图')
|
||
axes[1].set_xlabel('Score Value')
|
||
axes[1].set_ylabel('Frequency')
|
||
|
||
plt.tight_layout()
|
||
path = os.path.join(OUTPUT_DIR, save_name)
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f' [保存] {path}')
|
||
|
||
|
||
def visualize_intermediate_features(model, input_tensor):
|
||
"""逐层提取并可视化中间特征图"""
|
||
print('\n' + '=' * 60)
|
||
print('ALNet 中间特征逐层可视化')
|
||
print('=' * 60)
|
||
|
||
x = input_tensor
|
||
print(f'输入: {x.shape}')
|
||
|
||
# Block 1: ConvBlock
|
||
x1 = model.block1(x)
|
||
print(f'block1 (ConvBlock 3→16): {x1.shape}')
|
||
visualize_tensor(x1, 'Block1: ConvBlock 输出 (16通道)', 'alnet_block1.png')
|
||
|
||
# Pool2 + Block 2
|
||
x2 = model.pool2(x1)
|
||
x2 = model.block2(x2)
|
||
print(f'pool2 + block2 (ResBlock 16→32): {x2.shape}')
|
||
visualize_tensor(x2, 'Block2: ResBlock 输出 (32通道) [1/2分辨率]', 'alnet_block2.png')
|
||
|
||
# Pool4 + Block 3
|
||
x3 = model.pool4(x2)
|
||
x3 = model.block3(x3)
|
||
print(f'pool4 + block3 (ResBlock 32→64): {x3.shape}')
|
||
visualize_tensor(x3, 'Block3: ResBlock 输出 (64通道) [1/8分辨率]', 'alnet_block3.png')
|
||
|
||
# Pool4 + Block 4
|
||
x4 = model.pool4(x3)
|
||
x4 = model.block4(x4)
|
||
print(f'pool4 + block4 (ResBlock 64→128): {x4.shape}')
|
||
visualize_tensor(x4, 'Block4: ResBlock 输出 (128通道) [1/32分辨率]', 'alnet_block4.png')
|
||
|
||
# 特征聚合
|
||
f1 = model.gate(model.conv1(x1)) # dim//4 通道
|
||
f2 = model.gate(model.conv2(x2))
|
||
f3 = model.gate(model.conv3(x3))
|
||
f4 = model.gate(model.conv4(x4))
|
||
|
||
f2_up = model.upsample2(f2)
|
||
f3_up = model.upsample8(f3)
|
||
f4_up = model.upsample32(f4)
|
||
|
||
print(f'特征聚合: f1={f1.shape}, f2_up={f2_up.shape}, f3_up={f3_up.shape}, f4_up={f4_up.shape}')
|
||
|
||
fused = torch.cat([f1, f2_up, f3_up, f4_up], dim=1)
|
||
print(f'多尺度拼接后: {fused.shape}')
|
||
visualize_tensor(fused, '多尺度特征拼接 (128通道)', 'alnet_fused_features.png', n_channels=8)
|
||
|
||
# 输出头
|
||
output = model.convhead2(fused)
|
||
score_map = torch.sigmoid(output[:, -1:, :, :])
|
||
descriptor_map = output[:, :-1, :, :]
|
||
|
||
print(f'Score Map: {score_map.shape}')
|
||
print(f'Descriptor Map: {descriptor_map.shape}')
|
||
|
||
visualize_score_map(score_map, 'ALNet 最终输出 Score Map', 'alnet_final_score.png')
|
||
visualize_tensor(descriptor_map, 'ALNet 最终输出 Descriptor Map (128通道)', 'alnet_final_descriptor.png')
|
||
|
||
|
||
def visualize_receptive_field():
|
||
"""可视化有效感受野(通过梯度反传)"""
|
||
print('\n--- 感受野分析 ---')
|
||
model = ALNet(**CFG)
|
||
model.eval()
|
||
|
||
input_tensor = torch.randn(1, 3, 192, 576, requires_grad=True)
|
||
score_map, _ = model(input_tensor)
|
||
|
||
# 对score_map中心点的梯度反传
|
||
h, w = score_map.shape[2], score_map.shape[3]
|
||
score_map[0, 0, h // 2, w // 2].backward()
|
||
|
||
grad = input_tensor.grad.abs().sum(dim=1)[0]
|
||
fig, ax = plt.subplots(figsize=(12, 4))
|
||
im = ax.imshow(grad.detach().cpu().numpy(), cmap='hot')
|
||
ax.set_title('ALNet 有效感受野 (梯度幅度)', fontsize=14)
|
||
ax.axis('off')
|
||
plt.colorbar(im, ax=ax)
|
||
path = os.path.join(OUTPUT_DIR, 'alnet_receptive_field.png')
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f' [保存] {path}')
|
||
|
||
|
||
def analyze_parameters():
|
||
"""分析网络参数量"""
|
||
print('\n--- 参数量分析 ---')
|
||
model = ALNet(**CFG)
|
||
total = sum(p.numel() for p in model.parameters())
|
||
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||
|
||
print(f'总参数量: {total:,} ({total / 1e6:.2f}M)')
|
||
print(f'可训练参数: {trainable:,} ({trainable / 1e6:.2f}M)')
|
||
|
||
# 逐模块分析
|
||
for name, module in model.named_children():
|
||
params = sum(p.numel() for p in module.parameters())
|
||
print(f' {name:20s}: {params:>10,} params ({params / 1e3:.1f}K)')
|
||
|
||
|
||
def main():
|
||
print('=' * 60)
|
||
print('ALNet (图像特征提取网络) 结构与特征可视化')
|
||
print('=' * 60)
|
||
|
||
analyze_parameters()
|
||
|
||
# 构建模型
|
||
model = ALNet(**CFG)
|
||
model.eval()
|
||
|
||
# 模拟输入: 裁剪后的KITTI图像 (192, 576)
|
||
input_tensor = torch.randn(1, 3, 192, 576)
|
||
|
||
# 前向传播
|
||
with torch.no_grad():
|
||
score_map, descriptor_map = model(input_tensor)
|
||
|
||
print(f'\n输入尺寸: {input_tensor.shape}')
|
||
print(f'Score Map 输出: {score_map.shape} (范围: [{score_map.min():.3f}, {score_map.max():.3f}])')
|
||
print(f'Descriptor Map 输出: {descriptor_map.shape}')
|
||
|
||
# 逐层可视化中间特征
|
||
visualize_intermediate_features(model, input_tensor)
|
||
|
||
# 感受野分析
|
||
visualize_receptive_field()
|
||
|
||
# 网络结构文本总结
|
||
print('\n' + '=' * 60)
|
||
print('网络结构总结:')
|
||
print('=' * 60)
|
||
print("""
|
||
ALNet (alike-n config):
|
||
┌──────────────────────────────────────────────────────┐
|
||
│ 输入: (B, 3, 192, 576) │
|
||
│ ↓ │
|
||
│ block1: ConvBlock(3→16) → (B, 16, 192, 576) │
|
||
│ ↓ MaxPool2d(2) │
|
||
│ block2: ResBlock(16→32) → (B, 32, 96, 288) │
|
||
│ ↓ MaxPool2d(4) │
|
||
│ block3: ResBlock(32→64) → (B, 64, 24, 72) │
|
||
│ ↓ MaxPool2d(4) │
|
||
│ block4: ResBlock(64→128) → (B, 128, 6, 18) │
|
||
│ ↓ │
|
||
│ 特征聚合: 4尺度1×1conv + 上采样 + concat → (B,128,192,576) │
|
||
│ ↓ Conv1x1(128→129) │
|
||
│ 输出: score(B,1,192,576) + desc(B,128,192,576) │
|
||
└──────────────────────────────────────────────────────┘
|
||
|
||
block1/2/3/4 各阶段的作用:
|
||
- block1: 浅层特征(边缘、角点等低级特征)
|
||
- block2: 中层特征(纹理、局部形状)
|
||
- block3: 高层特征(语义信息、物体部件)
|
||
- block4: 最抽象特征(全局上下文)
|
||
- 多尺度融合: 结合各层信息,兼顾定位精度和语义鲁棒性
|
||
""")
|
||
|
||
print(f'\n所有可视化结果保存在: {OUTPUT_DIR}')
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|