""" ALNet 网络结构可视化 Demo =========================== ALNet 是图像分支的特征提取网络,基于 ALIKE 架构。 输入:图像 (B, 3, 192, 576) 输出:score_map (B, 1, 192, 576) + descriptor_map (B, 128, 192, 576) 网络由以下部分组成: block1: ConvBlock(3→16) - 保持分辨率 pool2: MaxPool2d(2) - 下采样 2x block2: ResBlock(16→32) - 残差块 pool4: MaxPool2d(4) - 下采样 4x block3: ResBlock(32→64) - 残差块 pool4: MaxPool2d(4) - 下采样 4x block4: ResBlock(64→128) - 残差块 特征聚合: 4层concat + 上采样 - 多尺度融合 输出头: Conv1x1(128→129) - score + descriptor """ import torch import numpy as np import matplotlib.pyplot as plt import matplotlib matplotlib.use('Agg') # 非交互后端,适合服务器 import sys import os sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from ALIKE.alnet import ALNet, ConvBlock, ResBlock # ============================================================ # 配置 # ============================================================ OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output') os.makedirs(OUTPUT_DIR, exist_ok=True) # 使用 alike-n 配置(论文中使用) CFG = {'c1': 16, 'c2': 32, 'c3': 64, 'c4': 128, 'dim': 128, 'single_head': True} def visualize_tensor(tensor, title, save_name, cmap='viridis', n_channels=8): """可视化特征图的多个通道""" if tensor.dim() == 4: tensor = tensor[0] # 取第一个batch C, H, W = tensor.shape n_show = min(n_channels, C) fig, axes = plt.subplots(2, 4, figsize=(16, 8)) fig.suptitle(title, fontsize=14, fontweight='bold') for i in range(n_show): ax = axes[i // 4, i % 4] im = ax.imshow(tensor[i].detach().cpu().numpy(), cmap=cmap) ax.set_title(f'Channel {i}') ax.axis('off') plt.colorbar(im, ax=ax, fraction=0.046) for i in range(n_show, 8): axes[i // 4, i % 4].axis('off') plt.tight_layout() path = os.path.join(OUTPUT_DIR, save_name) plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f' [保存] {path}') def visualize_score_map(score_map, title, save_name): """可视化得分图""" if score_map.dim() == 4: score_map = score_map[0, 0] elif score_map.dim() == 3: score_map = score_map[0] fig, axes = plt.subplots(1, 2, figsize=(14, 5)) fig.suptitle(title, fontsize=14, fontweight='bold') im0 = axes[0].imshow(score_map.detach().cpu().numpy(), cmap='hot') axes[0].set_title('Score Map (热力图)') axes[0].axis('off') plt.colorbar(im0, ax=axes[0]) # 直方图 axes[1].hist(score_map.detach().cpu().numpy().flatten(), bins=50, color='steelblue', edgecolor='white') axes[1].set_title('Score 分布直方图') axes[1].set_xlabel('Score Value') axes[1].set_ylabel('Frequency') plt.tight_layout() path = os.path.join(OUTPUT_DIR, save_name) plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f' [保存] {path}') def visualize_intermediate_features(model, input_tensor): """逐层提取并可视化中间特征图""" print('\n' + '=' * 60) print('ALNet 中间特征逐层可视化') print('=' * 60) x = input_tensor print(f'输入: {x.shape}') # Block 1: ConvBlock x1 = model.block1(x) print(f'block1 (ConvBlock 3→16): {x1.shape}') visualize_tensor(x1, 'Block1: ConvBlock 输出 (16通道)', 'alnet_block1.png') # Pool2 + Block 2 x2 = model.pool2(x1) x2 = model.block2(x2) print(f'pool2 + block2 (ResBlock 16→32): {x2.shape}') visualize_tensor(x2, 'Block2: ResBlock 输出 (32通道) [1/2分辨率]', 'alnet_block2.png') # Pool4 + Block 3 x3 = model.pool4(x2) x3 = model.block3(x3) print(f'pool4 + block3 (ResBlock 32→64): {x3.shape}') visualize_tensor(x3, 'Block3: ResBlock 输出 (64通道) [1/8分辨率]', 'alnet_block3.png') # Pool4 + Block 4 x4 = model.pool4(x3) x4 = model.block4(x4) print(f'pool4 + block4 (ResBlock 64→128): {x4.shape}') visualize_tensor(x4, 'Block4: ResBlock 输出 (128通道) [1/32分辨率]', 'alnet_block4.png') # 特征聚合 f1 = model.gate(model.conv1(x1)) # dim//4 通道 f2 = model.gate(model.conv2(x2)) f3 = model.gate(model.conv3(x3)) f4 = model.gate(model.conv4(x4)) f2_up = model.upsample2(f2) f3_up = model.upsample8(f3) f4_up = model.upsample32(f4) print(f'特征聚合: f1={f1.shape}, f2_up={f2_up.shape}, f3_up={f3_up.shape}, f4_up={f4_up.shape}') fused = torch.cat([f1, f2_up, f3_up, f4_up], dim=1) print(f'多尺度拼接后: {fused.shape}') visualize_tensor(fused, '多尺度特征拼接 (128通道)', 'alnet_fused_features.png', n_channels=8) # 输出头 output = model.convhead2(fused) score_map = torch.sigmoid(output[:, -1:, :, :]) descriptor_map = output[:, :-1, :, :] print(f'Score Map: {score_map.shape}') print(f'Descriptor Map: {descriptor_map.shape}') visualize_score_map(score_map, 'ALNet 最终输出 Score Map', 'alnet_final_score.png') visualize_tensor(descriptor_map, 'ALNet 最终输出 Descriptor Map (128通道)', 'alnet_final_descriptor.png') def visualize_receptive_field(): """可视化有效感受野(通过梯度反传)""" print('\n--- 感受野分析 ---') model = ALNet(**CFG) model.eval() input_tensor = torch.randn(1, 3, 192, 576, requires_grad=True) score_map, _ = model(input_tensor) # 对score_map中心点的梯度反传 h, w = score_map.shape[2], score_map.shape[3] score_map[0, 0, h // 2, w // 2].backward() grad = input_tensor.grad.abs().sum(dim=1)[0] fig, ax = plt.subplots(figsize=(12, 4)) im = ax.imshow(grad.detach().cpu().numpy(), cmap='hot') ax.set_title('ALNet 有效感受野 (梯度幅度)', fontsize=14) ax.axis('off') plt.colorbar(im, ax=ax) path = os.path.join(OUTPUT_DIR, 'alnet_receptive_field.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f' [保存] {path}') def analyze_parameters(): """分析网络参数量""" print('\n--- 参数量分析 ---') model = ALNet(**CFG) total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f'总参数量: {total:,} ({total / 1e6:.2f}M)') print(f'可训练参数: {trainable:,} ({trainable / 1e6:.2f}M)') # 逐模块分析 for name, module in model.named_children(): params = sum(p.numel() for p in module.parameters()) print(f' {name:20s}: {params:>10,} params ({params / 1e3:.1f}K)') def main(): print('=' * 60) print('ALNet (图像特征提取网络) 结构与特征可视化') print('=' * 60) analyze_parameters() # 构建模型 model = ALNet(**CFG) model.eval() # 模拟输入: 裁剪后的KITTI图像 (192, 576) input_tensor = torch.randn(1, 3, 192, 576) # 前向传播 with torch.no_grad(): score_map, descriptor_map = model(input_tensor) print(f'\n输入尺寸: {input_tensor.shape}') print(f'Score Map 输出: {score_map.shape} (范围: [{score_map.min():.3f}, {score_map.max():.3f}])') print(f'Descriptor Map 输出: {descriptor_map.shape}') # 逐层可视化中间特征 visualize_intermediate_features(model, input_tensor) # 感受野分析 visualize_receptive_field() # 网络结构文本总结 print('\n' + '=' * 60) print('网络结构总结:') print('=' * 60) print(""" ALNet (alike-n config): ┌──────────────────────────────────────────────────────┐ │ 输入: (B, 3, 192, 576) │ │ ↓ │ │ block1: ConvBlock(3→16) → (B, 16, 192, 576) │ │ ↓ MaxPool2d(2) │ │ block2: ResBlock(16→32) → (B, 32, 96, 288) │ │ ↓ MaxPool2d(4) │ │ block3: ResBlock(32→64) → (B, 64, 24, 72) │ │ ↓ MaxPool2d(4) │ │ block4: ResBlock(64→128) → (B, 128, 6, 18) │ │ ↓ │ │ 特征聚合: 4尺度1×1conv + 上采样 + concat → (B,128,192,576) │ │ ↓ Conv1x1(128→129) │ │ 输出: score(B,1,192,576) + desc(B,128,192,576) │ └──────────────────────────────────────────────────────┘ block1/2/3/4 各阶段的作用: - block1: 浅层特征(边缘、角点等低级特征) - block2: 中层特征(纹理、局部形状) - block3: 高层特征(语义信息、物体部件) - block4: 最抽象特征(全局上下文) - 多尺度融合: 结合各层信息,兼顾定位精度和语义鲁棒性 """) print(f'\n所有可视化结果保存在: {OUTPUT_DIR}') if __name__ == '__main__': main()