网络测试和学习demo
This commit is contained in:
260
network_learning/01_alnet_demo.py
Normal file
260
network_learning/01_alnet_demo.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""
|
||||
ALNet 网络结构可视化 Demo
|
||||
===========================
|
||||
ALNet 是图像分支的特征提取网络,基于 ALIKE 架构。
|
||||
输入:图像 (B, 3, 192, 576)
|
||||
输出:score_map (B, 1, 192, 576) + descriptor_map (B, 128, 192, 576)
|
||||
|
||||
网络由以下部分组成:
|
||||
block1: ConvBlock(3→16) - 保持分辨率
|
||||
pool2: MaxPool2d(2) - 下采样 2x
|
||||
block2: ResBlock(16→32) - 残差块
|
||||
pool4: MaxPool2d(4) - 下采样 4x
|
||||
block3: ResBlock(32→64) - 残差块
|
||||
pool4: MaxPool2d(4) - 下采样 4x
|
||||
block4: ResBlock(64→128) - 残差块
|
||||
特征聚合: 4层concat + 上采样 - 多尺度融合
|
||||
输出头: Conv1x1(128→129) - score + descriptor
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
matplotlib.use('Agg') # 非交互后端,适合服务器
|
||||
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from ALIKE.alnet import ALNet, ConvBlock, ResBlock
|
||||
|
||||
# ============================================================
|
||||
# 配置
|
||||
# ============================================================
|
||||
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output')
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
# 使用 alike-n 配置(论文中使用)
|
||||
CFG = {'c1': 16, 'c2': 32, 'c3': 64, 'c4': 128, 'dim': 128, 'single_head': True}
|
||||
|
||||
|
||||
def visualize_tensor(tensor, title, save_name, cmap='viridis', n_channels=8):
|
||||
"""可视化特征图的多个通道"""
|
||||
if tensor.dim() == 4:
|
||||
tensor = tensor[0] # 取第一个batch
|
||||
C, H, W = tensor.shape
|
||||
n_show = min(n_channels, C)
|
||||
|
||||
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
|
||||
fig.suptitle(title, fontsize=14, fontweight='bold')
|
||||
|
||||
for i in range(n_show):
|
||||
ax = axes[i // 4, i % 4]
|
||||
im = ax.imshow(tensor[i].detach().cpu().numpy(), cmap=cmap)
|
||||
ax.set_title(f'Channel {i}')
|
||||
ax.axis('off')
|
||||
plt.colorbar(im, ax=ax, fraction=0.046)
|
||||
|
||||
for i in range(n_show, 8):
|
||||
axes[i // 4, i % 4].axis('off')
|
||||
|
||||
plt.tight_layout()
|
||||
path = os.path.join(OUTPUT_DIR, save_name)
|
||||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f' [保存] {path}')
|
||||
|
||||
|
||||
def visualize_score_map(score_map, title, save_name):
|
||||
"""可视化得分图"""
|
||||
if score_map.dim() == 4:
|
||||
score_map = score_map[0, 0]
|
||||
elif score_map.dim() == 3:
|
||||
score_map = score_map[0]
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||||
fig.suptitle(title, fontsize=14, fontweight='bold')
|
||||
|
||||
im0 = axes[0].imshow(score_map.detach().cpu().numpy(), cmap='hot')
|
||||
axes[0].set_title('Score Map (热力图)')
|
||||
axes[0].axis('off')
|
||||
plt.colorbar(im0, ax=axes[0])
|
||||
|
||||
# 直方图
|
||||
axes[1].hist(score_map.detach().cpu().numpy().flatten(), bins=50, color='steelblue', edgecolor='white')
|
||||
axes[1].set_title('Score 分布直方图')
|
||||
axes[1].set_xlabel('Score Value')
|
||||
axes[1].set_ylabel('Frequency')
|
||||
|
||||
plt.tight_layout()
|
||||
path = os.path.join(OUTPUT_DIR, save_name)
|
||||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f' [保存] {path}')
|
||||
|
||||
|
||||
def visualize_intermediate_features(model, input_tensor):
|
||||
"""逐层提取并可视化中间特征图"""
|
||||
print('\n' + '=' * 60)
|
||||
print('ALNet 中间特征逐层可视化')
|
||||
print('=' * 60)
|
||||
|
||||
x = input_tensor
|
||||
print(f'输入: {x.shape}')
|
||||
|
||||
# Block 1: ConvBlock
|
||||
x1 = model.block1(x)
|
||||
print(f'block1 (ConvBlock 3→16): {x1.shape}')
|
||||
visualize_tensor(x1, 'Block1: ConvBlock 输出 (16通道)', 'alnet_block1.png')
|
||||
|
||||
# Pool2 + Block 2
|
||||
x2 = model.pool2(x1)
|
||||
x2 = model.block2(x2)
|
||||
print(f'pool2 + block2 (ResBlock 16→32): {x2.shape}')
|
||||
visualize_tensor(x2, 'Block2: ResBlock 输出 (32通道) [1/2分辨率]', 'alnet_block2.png')
|
||||
|
||||
# Pool4 + Block 3
|
||||
x3 = model.pool4(x2)
|
||||
x3 = model.block3(x3)
|
||||
print(f'pool4 + block3 (ResBlock 32→64): {x3.shape}')
|
||||
visualize_tensor(x3, 'Block3: ResBlock 输出 (64通道) [1/8分辨率]', 'alnet_block3.png')
|
||||
|
||||
# Pool4 + Block 4
|
||||
x4 = model.pool4(x3)
|
||||
x4 = model.block4(x4)
|
||||
print(f'pool4 + block4 (ResBlock 64→128): {x4.shape}')
|
||||
visualize_tensor(x4, 'Block4: ResBlock 输出 (128通道) [1/32分辨率]', 'alnet_block4.png')
|
||||
|
||||
# 特征聚合
|
||||
f1 = model.gate(model.conv1(x1)) # dim//4 通道
|
||||
f2 = model.gate(model.conv2(x2))
|
||||
f3 = model.gate(model.conv3(x3))
|
||||
f4 = model.gate(model.conv4(x4))
|
||||
|
||||
f2_up = model.upsample2(f2)
|
||||
f3_up = model.upsample8(f3)
|
||||
f4_up = model.upsample32(f4)
|
||||
|
||||
print(f'特征聚合: f1={f1.shape}, f2_up={f2_up.shape}, f3_up={f3_up.shape}, f4_up={f4_up.shape}')
|
||||
|
||||
fused = torch.cat([f1, f2_up, f3_up, f4_up], dim=1)
|
||||
print(f'多尺度拼接后: {fused.shape}')
|
||||
visualize_tensor(fused, '多尺度特征拼接 (128通道)', 'alnet_fused_features.png', n_channels=8)
|
||||
|
||||
# 输出头
|
||||
output = model.convhead2(fused)
|
||||
score_map = torch.sigmoid(output[:, -1:, :, :])
|
||||
descriptor_map = output[:, :-1, :, :]
|
||||
|
||||
print(f'Score Map: {score_map.shape}')
|
||||
print(f'Descriptor Map: {descriptor_map.shape}')
|
||||
|
||||
visualize_score_map(score_map, 'ALNet 最终输出 Score Map', 'alnet_final_score.png')
|
||||
visualize_tensor(descriptor_map, 'ALNet 最终输出 Descriptor Map (128通道)', 'alnet_final_descriptor.png')
|
||||
|
||||
|
||||
def visualize_receptive_field():
|
||||
"""可视化有效感受野(通过梯度反传)"""
|
||||
print('\n--- 感受野分析 ---')
|
||||
model = ALNet(**CFG)
|
||||
model.eval()
|
||||
|
||||
input_tensor = torch.randn(1, 3, 192, 576, requires_grad=True)
|
||||
score_map, _ = model(input_tensor)
|
||||
|
||||
# 对score_map中心点的梯度反传
|
||||
h, w = score_map.shape[2], score_map.shape[3]
|
||||
score_map[0, 0, h // 2, w // 2].backward()
|
||||
|
||||
grad = input_tensor.grad.abs().sum(dim=1)[0]
|
||||
fig, ax = plt.subplots(figsize=(12, 4))
|
||||
im = ax.imshow(grad.detach().cpu().numpy(), cmap='hot')
|
||||
ax.set_title('ALNet 有效感受野 (梯度幅度)', fontsize=14)
|
||||
ax.axis('off')
|
||||
plt.colorbar(im, ax=ax)
|
||||
path = os.path.join(OUTPUT_DIR, 'alnet_receptive_field.png')
|
||||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
print(f' [保存] {path}')
|
||||
|
||||
|
||||
def analyze_parameters():
|
||||
"""分析网络参数量"""
|
||||
print('\n--- 参数量分析 ---')
|
||||
model = ALNet(**CFG)
|
||||
total = sum(p.numel() for p in model.parameters())
|
||||
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
print(f'总参数量: {total:,} ({total / 1e6:.2f}M)')
|
||||
print(f'可训练参数: {trainable:,} ({trainable / 1e6:.2f}M)')
|
||||
|
||||
# 逐模块分析
|
||||
for name, module in model.named_children():
|
||||
params = sum(p.numel() for p in module.parameters())
|
||||
print(f' {name:20s}: {params:>10,} params ({params / 1e3:.1f}K)')
|
||||
|
||||
|
||||
def main():
|
||||
print('=' * 60)
|
||||
print('ALNet (图像特征提取网络) 结构与特征可视化')
|
||||
print('=' * 60)
|
||||
|
||||
analyze_parameters()
|
||||
|
||||
# 构建模型
|
||||
model = ALNet(**CFG)
|
||||
model.eval()
|
||||
|
||||
# 模拟输入: 裁剪后的KITTI图像 (192, 576)
|
||||
input_tensor = torch.randn(1, 3, 192, 576)
|
||||
|
||||
# 前向传播
|
||||
with torch.no_grad():
|
||||
score_map, descriptor_map = model(input_tensor)
|
||||
|
||||
print(f'\n输入尺寸: {input_tensor.shape}')
|
||||
print(f'Score Map 输出: {score_map.shape} (范围: [{score_map.min():.3f}, {score_map.max():.3f}])')
|
||||
print(f'Descriptor Map 输出: {descriptor_map.shape}')
|
||||
|
||||
# 逐层可视化中间特征
|
||||
visualize_intermediate_features(model, input_tensor)
|
||||
|
||||
# 感受野分析
|
||||
visualize_receptive_field()
|
||||
|
||||
# 网络结构文本总结
|
||||
print('\n' + '=' * 60)
|
||||
print('网络结构总结:')
|
||||
print('=' * 60)
|
||||
print("""
|
||||
ALNet (alike-n config):
|
||||
┌──────────────────────────────────────────────────────┐
|
||||
│ 输入: (B, 3, 192, 576) │
|
||||
│ ↓ │
|
||||
│ block1: ConvBlock(3→16) → (B, 16, 192, 576) │
|
||||
│ ↓ MaxPool2d(2) │
|
||||
│ block2: ResBlock(16→32) → (B, 32, 96, 288) │
|
||||
│ ↓ MaxPool2d(4) │
|
||||
│ block3: ResBlock(32→64) → (B, 64, 24, 72) │
|
||||
│ ↓ MaxPool2d(4) │
|
||||
│ block4: ResBlock(64→128) → (B, 128, 6, 18) │
|
||||
│ ↓ │
|
||||
│ 特征聚合: 4尺度1×1conv + 上采样 + concat → (B,128,192,576) │
|
||||
│ ↓ Conv1x1(128→129) │
|
||||
│ 输出: score(B,1,192,576) + desc(B,128,192,576) │
|
||||
└──────────────────────────────────────────────────────┘
|
||||
|
||||
block1/2/3/4 各阶段的作用:
|
||||
- block1: 浅层特征(边缘、角点等低级特征)
|
||||
- block2: 中层特征(纹理、局部形状)
|
||||
- block3: 高层特征(语义信息、物体部件)
|
||||
- block4: 最抽象特征(全局上下文)
|
||||
- 多尺度融合: 结合各层信息,兼顾定位精度和语义鲁棒性
|
||||
""")
|
||||
|
||||
print(f'\n所有可视化结果保存在: {OUTPUT_DIR}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user