""" RICNN 旋转不变CNN网络结构可视化 Demo ======================================= RICNN 是点云BEV分支的特征提取网络,核心创新是"旋转不变性"。 与标准CNN不同,RICNN的卷积核根据像素到中心的欧氏距离分组, 使得旋转后的特征保持一致。 输入:BEV图像 (B, 3, 320, 320) 输出:score_map (B, 1, 320, 320) + descriptor_map (B, 128, 320, 320) 关键组件: RIConv2d: 旋转不变卷积(按距离分组共享权重) RIMaxpool2d: 旋转不变最大池化(只对圆形邻域取max) RIAvgpool2d: 旋转不变平均池化(只对圆形邻域取avg) RIResBlock: 旋转不变残差块 """ import torch import numpy as np import matplotlib.pyplot as plt import matplotlib matplotlib.use('Agg') import sys import os sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from BEVNet import RICNN, RIConv2d, RIMaxpool2d, RIAvgpool2d, EncodePosition OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output') os.makedirs(OUTPUT_DIR, exist_ok=True) def visualize_tensor(tensor, title, save_name, cmap='viridis', n_channels=8): """可视化特征图""" if tensor.dim() == 4: tensor = tensor[0] C, H, W = tensor.shape n_show = min(n_channels, C) fig, axes = plt.subplots(2, 4, figsize=(16, 8)) fig.suptitle(title, fontsize=14, fontweight='bold') for i in range(n_show): ax = axes[i // 4, i % 4] im = ax.imshow(tensor[i].detach().cpu().numpy(), cmap=cmap) ax.set_title(f'Channel {i}') ax.axis('off') plt.colorbar(im, ax=ax, fraction=0.046) for i in range(n_show, 8): axes[i // 4, i % 4].axis('off') plt.tight_layout() path = os.path.join(OUTPUT_DIR, save_name) plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f' [保存] {path}') def visualize_score_map(score_map, title, save_name): """可视化得分图""" if score_map.dim() == 4: score_map = score_map[0, 0] elif score_map.dim() == 3: score_map = score_map[0] fig, axes = plt.subplots(1, 2, figsize=(14, 5)) fig.suptitle(title, fontsize=14, fontweight='bold') im0 = axes[0].imshow(score_map.detach().cpu().numpy(), cmap='hot') axes[0].set_title('Score Map') axes[0].axis('off') plt.colorbar(im0, ax=axes[0]) axes[1].hist(score_map.detach().cpu().numpy().flatten(), bins=50, color='steelblue', edgecolor='white') axes[1].set_title('Score 分布') axes[1].set_xlabel('Score') plt.tight_layout() path = os.path.join(OUTPUT_DIR, save_name) plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f' [保存] {path}') def visualize_ri_conv_kernel(): """可视化旋转不变卷积核的权重分组模式""" print('\n--- RIConv2d 卷积核分组可视化 ---') fig, axes = plt.subplots(1, 3, figsize=(16, 5)) for idx, kz in enumerate([3, 5, 7]): # 计算距离掩码 coords = torch.arange(kz ** 2).view(-1, 1) row = torch.div(coords, kz, rounding_mode='floor') col = torch.fmod(coords, kz) coords = torch.cat([row, col], dim=1) dis = (coords - 0.5 * (kz - 1)).norm(dim=1) + 0.5 * (kz % 2 - 1) dis = dis.view(kz, kz) dis = torch.round(dis).long() dis[dis > 0.5 * (kz - 1)] = -1 ax = axes[idx] im = ax.imshow(dis.numpy(), cmap='tab10') ax.set_title(f'Kernel {kz}x{kz}\nDistance Groups: {dis.max().item() + 1}') # 标注每个位置的距离值 for i in range(kz): for j in range(kz): val = dis[i, j].item() color = 'white' if val >= 0 else 'red' ax.text(j, i, str(val), ha='center', va='center', fontsize=8, color=color) ax.axis('off') plt.suptitle('RIConv2d: 按到中心距离分组的卷积核权重', fontsize=14, fontweight='bold') plt.tight_layout() path = os.path.join(OUTPUT_DIR, 'ricnn_kernel_groups.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f' [保存] {path}') def visualize_ri_pooling(): """可视化旋转不变池化的有效区域""" print('\n--- 旋转不变池化区域可视化 ---') fig, axes = plt.subplots(2, 2, figsize=(10, 10)) # RIMaxpool2d 有效区域 (kernel_size=5) kz = 5 coords = torch.arange(kz ** 2).view(-1, 1) row = torch.div(coords, kz, rounding_mode='floor') col = torch.fmod(coords, kz) coords = torch.cat([row, col], dim=1) dis = (coords - 0.5 * (kz - 1)).norm(dim=1) + 0.5 * (kz % 2 - 1) dis = dis.view(kz, kz) dis = torch.round(dis) dis[dis > 0.5 * (kz - 1)] = -1 mask_ri = (dis > -1).numpy().astype(float) # 标准 MaxPool2d 有效区域(正方形) mask_std = np.ones((kz, kz)) ax = axes[0, 0] ax.imshow(mask_std, cmap='Blues') ax.set_title(f'标准 MaxPool {kz}x{kz}\n有效区域: {mask_std.sum():.0f} 个像素', fontsize=12) for i in range(kz): for j in range(kz): ax.text(j, i, '✓', ha='center', va='center', fontsize=10) ax.axis('off') ax = axes[0, 1] ax.imshow(mask_ri, cmap='Oranges') ax.set_title(f'RI MaxPool {kz}x{kz}\n有效区域: {mask_ri.sum():.0f} 个像素 (圆形)', fontsize=12) for i in range(kz): for j in range(kz): text = '✓' if mask_ri[i, j] else '✗' color = 'white' if mask_ri[i, j] else 'red' ax.text(j, i, text, ha='center', va='center', fontsize=10, color=color) ax.axis('off') # 可视化旋转不变性:对比旋转前后的特征 ax = axes[1, 0] ax.set_title('旋转不变性原理', fontsize=12) ax.text(0.5, 0.7, '标准CNN:', transform=ax.transAxes, fontsize=11, ha='center', bbox=dict(boxstyle='round', facecolor='lightblue')) ax.text(0.5, 0.5, '旋转图像 → 特征也旋转 → 不匹配', transform=ax.transAxes, fontsize=10, ha='center') ax.text(0.5, 0.3, 'RICNN:', transform=ax.transAxes, fontsize=11, ha='center', bbox=dict(boxstyle='round', facecolor='lightgreen')) ax.text(0.5, 0.1, '旋转图像 → 特征不变 → 可以匹配', transform=ax.transAxes, fontsize=10, ha='center') ax.axis('off') ax = axes[1, 1] ax.set_title('RI vs 标准 CNN 对比', fontsize=12) categories = ['旋转鲁棒性', '计算效率', '平移不变性', '尺度不变性'] ri_scores = [0.9, 0.7, 0.8, 0.5] std_scores = [0.3, 1.0, 0.8, 0.5] x = np.arange(len(categories)) width = 0.35 ax.bar(x - width / 2, ri_scores, width, label='RICNN', color='orange', alpha=0.8) ax.bar(x + width / 2, std_scores, width, label='标准CNN', color='blue', alpha=0.8) ax.set_xticks(x) ax.set_xticklabels(categories, fontsize=9) ax.set_ylim(0, 1.2) ax.legend() ax.set_ylabel('能力评分') plt.suptitle('RICNN 旋转不变池化详解', fontsize=14, fontweight='bold') plt.tight_layout() path = os.path.join(OUTPUT_DIR, 'ricnn_pooling_visualization.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f' [保存] {path}') def test_rotation_invariance(): """测试旋转不变性:对比旋转前后特征差异""" print('\n--- 旋转不变性测试 ---') model = RICNN() model.eval() # 创建测试BEV图像(带明显特征) bev = torch.zeros(1, 3, 320, 320) # 添加一些矩形特征 bev[0, 0, 100:120, 150:170] = 1.0 bev[0, 1, 140:160, 100:140] = 0.8 bev[0, 2, 150:170, 160:200] = 0.6 with torch.no_grad(): score_orig, desc_orig = model(bev) # 旋转90度 bev_rot90 = torch.rot90(bev, k=1, dims=[2, 3]) score_rot90, desc_rot90 = model(bev_rot90) # 旋转回去比较 desc_rot90_back = torch.rot90(desc_rot90, k=-1, dims=[2, 3]) # 旋转180度 bev_rot180 = torch.rot90(bev, k=2, dims=[2, 3]) score_rot180, desc_rot180 = model(bev_rot180) desc_rot180_back = torch.rot90(desc_rot180, k=-2, dims=[2, 3]) # 计算相似度 cos_sim_90 = torch.nn.functional.cosine_similarity( desc_orig.flatten(), desc_rot90_back.flatten(), dim=0) cos_sim_180 = torch.nn.functional.cosine_similarity( desc_orig.flatten(), desc_rot180_back.flatten(), dim=0) print(f'原始 vs 旋转90°后特征 余弦相似度: {cos_sim_90.item():.4f}') print(f'原始 vs 旋转180°后特征 余弦相似度: {cos_sim_180.item():.4f}') print(f'(越接近1.0说明旋转不变性越好)') # 可视化 fig, axes = plt.subplots(2, 4, figsize=(18, 8)) axes[0, 0].imshow(bev[0].permute(1, 2, 0).numpy()) axes[0, 0].set_title('原始BEV') axes[0, 1].imshow(bev_rot90[0].permute(1, 2, 0).numpy()) axes[0, 1].set_title('旋转90°') axes[0, 2].imshow(score_orig[0, 0].numpy(), cmap='hot') axes[0, 2].set_title('原始Score') axes[0, 3].imshow(score_rot90[0, 0].numpy(), cmap='hot') axes[0, 3].set_title('旋转90° Score') axes[1, 0].imshow(desc_orig[0, 0].numpy(), cmap='viridis') axes[1, 0].set_title(f'原始Desc ch0') axes[1, 1].imshow(desc_rot90_back[0, 0].numpy(), cmap='viridis') axes[1, 1].set_title(f'旋回后Desc ch0\n相似度:{cos_sim_90.item():.3f}') axes[1, 2].imshow((desc_orig[0, 0] - desc_rot90_back[0, 0]).abs().numpy(), cmap='Reds') axes[1, 2].set_title('差异热图 ch0') axes[1, 3].axis('off') for ax in axes.flatten(): if ax.collections or ax.images: continue ax.axis('off') plt.suptitle('RICNN 旋转不变性测试', fontsize=14, fontweight='bold') plt.tight_layout() path = os.path.join(OUTPUT_DIR, 'ricnn_rotation_invariance.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f' [保存] {path}') return cos_sim_90.item(), cos_sim_180.item() def visualize_ricnn_intermediate(): """可视化RICNN中间层特征""" print('\n--- RICNN 中间特征可视化 ---') model = RICNN() model.eval() # 使用更有结构的输入 x = torch.linspace(-1, 1, 320) y = torch.linspace(-1, 1, 320) grid_y, grid_x = torch.meshgrid(y, x, indexing='ij') r = torch.sqrt(grid_x ** 2 + grid_y ** 2) bev = torch.zeros(1, 3, 320, 320) bev[0, 0] = (torch.sin(grid_x * 10) * torch.cos(grid_y * 10) + 1) / 2 bev[0, 1] = (torch.cos(r * 5) + 1) / 2 bev[0, 2] = (r < 0.5).float() # 逐层前向 with torch.no_grad(): x1 = model.block1(bev) x2 = model.pool2(x1) x2 = model.block2(x2) x3 = model.pool4(x2) x3 = model.block3(x3) x4 = model.pool4(x3) x4 = model.block4(x4) print(f'输入BEV: {bev.shape}') print(f'block1 (RIConvBlock 3→16): {x1.shape}') print(f'pool2+block2 (RIResBlock 16→32): {x2.shape}') print(f'pool4+block3 (RIResBlock 32→64): {x3.shape}') print(f'pool4+block4 (RIResBlock 64→128): {x4.shape}') visualize_tensor(x1, 'RICNN Block1 输出 (16通道)', 'ricnn_block1.png') visualize_tensor(x2, 'RICNN Block2 输出 (32通道)', 'ricnn_block2.png') visualize_tensor(x3, 'RICNN Block3 输出 (64通道)', 'ricnn_block3.png') visualize_tensor(x4, 'RICNN Block4 输出 (128通道)', 'ricnn_block4.png') def visualize_position_encoding(): """可视化位置编码模块""" print('\n--- EncodePosition 位置编码可视化 ---') ep = EncodePosition(feature_size=128) ep.eval() # 模拟150个BEV关键点 (B, 150, 4) — [x,y,z,intensity] kpts = torch.randn(2, 150, 4) kpts[:, :, :2] = kpts[:, :, :2] * 30 # x,y 在 ±30m 范围 kpts[:, :, 2] = 0 # z=0 (BEV平面) kpts[:, :, 3] = 1 # intensity=1 # 模拟特征 (B, 128, 150) fea = torch.randn(2, 128, 150) with torch.no_grad(): fea_encoded = ep(kpts, fea) print(f'关键点输入: {kpts.shape}') print(f'原始特征: {fea.shape}') print(f'位置编码后特征: {fea_encoded.shape}') # 可视化距离直方图 x1 = kpts[0].unsqueeze(1) # (150, 1, 4) x2 = kpts[0].unsqueeze(0) # (1, 150, 4) dx = x1 - x2 distance = dx.norm(p=2, dim=2) # (150, 150) fig, axes = plt.subplots(1, 2, figsize=(14, 5)) im0 = axes[0].imshow(distance.numpy(), cmap='plasma') axes[0].set_title('关键点间距离矩阵 (150x150)') axes[0].set_xlabel('Keypoint j') axes[0].set_ylabel('Keypoint i') plt.colorbar(im0, ax=axes[0]) # 示例直方图 (第一个关键点) hist = torch.histc(distance[0], bins=16, min=1, max=80) axes[1].bar(range(16), hist.numpy(), color='steelblue') axes[1].set_title('距离直方图 (16 bins, 1-80m)\n用于位置编码') axes[1].set_xlabel('Distance Bin') axes[1].set_ylabel('Count') plt.suptitle('EncodePosition 位置编码模块', fontsize=14, fontweight='bold') plt.tight_layout() path = os.path.join(OUTPUT_DIR, 'ricnn_position_encoding.png') plt.savefig(path, dpi=150, bbox_inches='tight') plt.close() print(f' [保存] {path}') def analyze_parameters(): """参数量分析""" print('\n--- 参数量分析 ---') model = RICNN() total = sum(p.numel() for p in model.parameters()) print(f'总参数量: {total:,} ({total / 1e6:.2f}M)') for name, module in model.named_children(): params = sum(p.numel() for p in module.parameters()) print(f' {name:20s}: {params:>10,} params ({params / 1e3:.1f}K)') def main(): print('=' * 60) print('RICNN (旋转不变CNN) 网络结构与特征可视化') print('=' * 60) analyze_parameters() # 1. 卷积核分组可视化 visualize_ri_conv_kernel() # 2. 池化区域可视化 visualize_ri_pooling() # 3. 中间特征可视化 visualize_ricnn_intermediate() # 4. 旋转不变性测试 test_rotation_invariance() # 5. 位置编码可视化 visualize_position_encoding() print('\n' + '=' * 60) print('网络结构总结:') print('=' * 60) print(""" RICNN (Rotation-Invariant CNN): ┌──────────────────────────────────────────────────────┐ │ 输入: BEV图像 (B, 3, 320, 320) │ │ ↓ │ │ block1: RIConvBlock(3→16) → (B, 16, 320, 320) │ │ ↓ RIMaxpool2d(2) │ │ block2: RIResBlock(16→32) → (B, 32, 160, 160) │ │ ↓ RIMaxpool2d(5, s=4) │ │ block3: RIResBlock(32→64) → (B, 64, 40, 40) │ │ ↓ RIMaxpool2d(5, s=4) │ │ block4: RIResBlock(64→128) → (B, 128, 10, 10) │ │ ↓ │ │ 多尺度特征聚合 (1x1conv + 上采样 + concat) │ │ → (B, 128, 320, 320) │ │ ↓ Conv1x1(128→129) │ │ 输出: score(B,1,320,320) + desc(B,128,320,320) │ └──────────────────────────────────────────────────────┘ 旋转不变性的实现: - RIConv2d: 根据kernel位置到中心的欧氏距离分组 同距离的位置共享权重 → 旋转后权重不变 - RIMaxpool2d: 只在圆形邻域内取max(忽略角点) - RIAvgpool2d: 只在圆形邻域内取mean EncodePosition (位置编码): - 输入: 150个关键点的3D坐标 - 计算150×150距离矩阵 → 直方图(16 bins) → MLP - 残差加到特征上,增强空间感知能力 """) print(f'\n所有可视化结果保存在: {OUTPUT_DIR}') if __name__ == '__main__': main()