426 lines
16 KiB
Python
426 lines
16 KiB
Python
"""
|
||
RICNN 旋转不变CNN网络结构可视化 Demo
|
||
=======================================
|
||
RICNN 是点云BEV分支的特征提取网络,核心创新是"旋转不变性"。
|
||
与标准CNN不同,RICNN的卷积核根据像素到中心的欧氏距离分组,
|
||
使得旋转后的特征保持一致。
|
||
|
||
输入:BEV图像 (B, 3, 320, 320)
|
||
输出:score_map (B, 1, 320, 320) + descriptor_map (B, 128, 320, 320)
|
||
|
||
关键组件:
|
||
RIConv2d: 旋转不变卷积(按距离分组共享权重)
|
||
RIMaxpool2d: 旋转不变最大池化(只对圆形邻域取max)
|
||
RIAvgpool2d: 旋转不变平均池化(只对圆形邻域取avg)
|
||
RIResBlock: 旋转不变残差块
|
||
"""
|
||
|
||
import torch
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib
|
||
matplotlib.use('Agg')
|
||
|
||
import sys
|
||
import os
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from BEVNet import RICNN, RIConv2d, RIMaxpool2d, RIAvgpool2d, EncodePosition
|
||
|
||
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output')
|
||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||
|
||
|
||
def visualize_tensor(tensor, title, save_name, cmap='viridis', n_channels=8):
|
||
"""可视化特征图"""
|
||
if tensor.dim() == 4:
|
||
tensor = tensor[0]
|
||
C, H, W = tensor.shape
|
||
n_show = min(n_channels, C)
|
||
|
||
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
|
||
fig.suptitle(title, fontsize=14, fontweight='bold')
|
||
for i in range(n_show):
|
||
ax = axes[i // 4, i % 4]
|
||
im = ax.imshow(tensor[i].detach().cpu().numpy(), cmap=cmap)
|
||
ax.set_title(f'Channel {i}')
|
||
ax.axis('off')
|
||
plt.colorbar(im, ax=ax, fraction=0.046)
|
||
for i in range(n_show, 8):
|
||
axes[i // 4, i % 4].axis('off')
|
||
plt.tight_layout()
|
||
path = os.path.join(OUTPUT_DIR, save_name)
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f' [保存] {path}')
|
||
|
||
|
||
def visualize_score_map(score_map, title, save_name):
|
||
"""可视化得分图"""
|
||
if score_map.dim() == 4:
|
||
score_map = score_map[0, 0]
|
||
elif score_map.dim() == 3:
|
||
score_map = score_map[0]
|
||
|
||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||
fig.suptitle(title, fontsize=14, fontweight='bold')
|
||
im0 = axes[0].imshow(score_map.detach().cpu().numpy(), cmap='hot')
|
||
axes[0].set_title('Score Map')
|
||
axes[0].axis('off')
|
||
plt.colorbar(im0, ax=axes[0])
|
||
axes[1].hist(score_map.detach().cpu().numpy().flatten(), bins=50, color='steelblue', edgecolor='white')
|
||
axes[1].set_title('Score 分布')
|
||
axes[1].set_xlabel('Score')
|
||
plt.tight_layout()
|
||
path = os.path.join(OUTPUT_DIR, save_name)
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f' [保存] {path}')
|
||
|
||
|
||
def visualize_ri_conv_kernel():
|
||
"""可视化旋转不变卷积核的权重分组模式"""
|
||
print('\n--- RIConv2d 卷积核分组可视化 ---')
|
||
|
||
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
|
||
|
||
for idx, kz in enumerate([3, 5, 7]):
|
||
# 计算距离掩码
|
||
coords = torch.arange(kz ** 2).view(-1, 1)
|
||
row = torch.div(coords, kz, rounding_mode='floor')
|
||
col = torch.fmod(coords, kz)
|
||
coords = torch.cat([row, col], dim=1)
|
||
dis = (coords - 0.5 * (kz - 1)).norm(dim=1) + 0.5 * (kz % 2 - 1)
|
||
dis = dis.view(kz, kz)
|
||
dis = torch.round(dis).long()
|
||
dis[dis > 0.5 * (kz - 1)] = -1
|
||
|
||
ax = axes[idx]
|
||
im = ax.imshow(dis.numpy(), cmap='tab10')
|
||
ax.set_title(f'Kernel {kz}x{kz}\nDistance Groups: {dis.max().item() + 1}')
|
||
# 标注每个位置的距离值
|
||
for i in range(kz):
|
||
for j in range(kz):
|
||
val = dis[i, j].item()
|
||
color = 'white' if val >= 0 else 'red'
|
||
ax.text(j, i, str(val), ha='center', va='center', fontsize=8, color=color)
|
||
ax.axis('off')
|
||
|
||
plt.suptitle('RIConv2d: 按到中心距离分组的卷积核权重', fontsize=14, fontweight='bold')
|
||
plt.tight_layout()
|
||
path = os.path.join(OUTPUT_DIR, 'ricnn_kernel_groups.png')
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f' [保存] {path}')
|
||
|
||
|
||
def visualize_ri_pooling():
|
||
"""可视化旋转不变池化的有效区域"""
|
||
print('\n--- 旋转不变池化区域可视化 ---')
|
||
|
||
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
|
||
|
||
# RIMaxpool2d 有效区域 (kernel_size=5)
|
||
kz = 5
|
||
coords = torch.arange(kz ** 2).view(-1, 1)
|
||
row = torch.div(coords, kz, rounding_mode='floor')
|
||
col = torch.fmod(coords, kz)
|
||
coords = torch.cat([row, col], dim=1)
|
||
dis = (coords - 0.5 * (kz - 1)).norm(dim=1) + 0.5 * (kz % 2 - 1)
|
||
dis = dis.view(kz, kz)
|
||
dis = torch.round(dis)
|
||
dis[dis > 0.5 * (kz - 1)] = -1
|
||
mask_ri = (dis > -1).numpy().astype(float)
|
||
|
||
# 标准 MaxPool2d 有效区域(正方形)
|
||
mask_std = np.ones((kz, kz))
|
||
|
||
ax = axes[0, 0]
|
||
ax.imshow(mask_std, cmap='Blues')
|
||
ax.set_title(f'标准 MaxPool {kz}x{kz}\n有效区域: {mask_std.sum():.0f} 个像素', fontsize=12)
|
||
for i in range(kz):
|
||
for j in range(kz):
|
||
ax.text(j, i, '✓', ha='center', va='center', fontsize=10)
|
||
ax.axis('off')
|
||
|
||
ax = axes[0, 1]
|
||
ax.imshow(mask_ri, cmap='Oranges')
|
||
ax.set_title(f'RI MaxPool {kz}x{kz}\n有效区域: {mask_ri.sum():.0f} 个像素 (圆形)', fontsize=12)
|
||
for i in range(kz):
|
||
for j in range(kz):
|
||
text = '✓' if mask_ri[i, j] else '✗'
|
||
color = 'white' if mask_ri[i, j] else 'red'
|
||
ax.text(j, i, text, ha='center', va='center', fontsize=10, color=color)
|
||
ax.axis('off')
|
||
|
||
# 可视化旋转不变性:对比旋转前后的特征
|
||
ax = axes[1, 0]
|
||
ax.set_title('旋转不变性原理', fontsize=12)
|
||
ax.text(0.5, 0.7, '标准CNN:', transform=ax.transAxes, fontsize=11, ha='center',
|
||
bbox=dict(boxstyle='round', facecolor='lightblue'))
|
||
ax.text(0.5, 0.5, '旋转图像 → 特征也旋转 → 不匹配', transform=ax.transAxes, fontsize=10, ha='center')
|
||
ax.text(0.5, 0.3, 'RICNN:', transform=ax.transAxes, fontsize=11, ha='center',
|
||
bbox=dict(boxstyle='round', facecolor='lightgreen'))
|
||
ax.text(0.5, 0.1, '旋转图像 → 特征不变 → 可以匹配', transform=ax.transAxes, fontsize=10, ha='center')
|
||
ax.axis('off')
|
||
|
||
ax = axes[1, 1]
|
||
ax.set_title('RI vs 标准 CNN 对比', fontsize=12)
|
||
categories = ['旋转鲁棒性', '计算效率', '平移不变性', '尺度不变性']
|
||
ri_scores = [0.9, 0.7, 0.8, 0.5]
|
||
std_scores = [0.3, 1.0, 0.8, 0.5]
|
||
x = np.arange(len(categories))
|
||
width = 0.35
|
||
ax.bar(x - width / 2, ri_scores, width, label='RICNN', color='orange', alpha=0.8)
|
||
ax.bar(x + width / 2, std_scores, width, label='标准CNN', color='blue', alpha=0.8)
|
||
ax.set_xticks(x)
|
||
ax.set_xticklabels(categories, fontsize=9)
|
||
ax.set_ylim(0, 1.2)
|
||
ax.legend()
|
||
ax.set_ylabel('能力评分')
|
||
|
||
plt.suptitle('RICNN 旋转不变池化详解', fontsize=14, fontweight='bold')
|
||
plt.tight_layout()
|
||
path = os.path.join(OUTPUT_DIR, 'ricnn_pooling_visualization.png')
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f' [保存] {path}')
|
||
|
||
|
||
def test_rotation_invariance():
|
||
"""测试旋转不变性:对比旋转前后特征差异"""
|
||
print('\n--- 旋转不变性测试 ---')
|
||
|
||
model = RICNN()
|
||
model.eval()
|
||
|
||
# 创建测试BEV图像(带明显特征)
|
||
bev = torch.zeros(1, 3, 320, 320)
|
||
# 添加一些矩形特征
|
||
bev[0, 0, 100:120, 150:170] = 1.0
|
||
bev[0, 1, 140:160, 100:140] = 0.8
|
||
bev[0, 2, 150:170, 160:200] = 0.6
|
||
|
||
with torch.no_grad():
|
||
score_orig, desc_orig = model(bev)
|
||
|
||
# 旋转90度
|
||
bev_rot90 = torch.rot90(bev, k=1, dims=[2, 3])
|
||
score_rot90, desc_rot90 = model(bev_rot90)
|
||
# 旋转回去比较
|
||
desc_rot90_back = torch.rot90(desc_rot90, k=-1, dims=[2, 3])
|
||
|
||
# 旋转180度
|
||
bev_rot180 = torch.rot90(bev, k=2, dims=[2, 3])
|
||
score_rot180, desc_rot180 = model(bev_rot180)
|
||
desc_rot180_back = torch.rot90(desc_rot180, k=-2, dims=[2, 3])
|
||
|
||
# 计算相似度
|
||
cos_sim_90 = torch.nn.functional.cosine_similarity(
|
||
desc_orig.flatten(), desc_rot90_back.flatten(), dim=0)
|
||
cos_sim_180 = torch.nn.functional.cosine_similarity(
|
||
desc_orig.flatten(), desc_rot180_back.flatten(), dim=0)
|
||
|
||
print(f'原始 vs 旋转90°后特征 余弦相似度: {cos_sim_90.item():.4f}')
|
||
print(f'原始 vs 旋转180°后特征 余弦相似度: {cos_sim_180.item():.4f}')
|
||
print(f'(越接近1.0说明旋转不变性越好)')
|
||
|
||
# 可视化
|
||
fig, axes = plt.subplots(2, 4, figsize=(18, 8))
|
||
|
||
axes[0, 0].imshow(bev[0].permute(1, 2, 0).numpy())
|
||
axes[0, 0].set_title('原始BEV')
|
||
axes[0, 1].imshow(bev_rot90[0].permute(1, 2, 0).numpy())
|
||
axes[0, 1].set_title('旋转90°')
|
||
axes[0, 2].imshow(score_orig[0, 0].numpy(), cmap='hot')
|
||
axes[0, 2].set_title('原始Score')
|
||
axes[0, 3].imshow(score_rot90[0, 0].numpy(), cmap='hot')
|
||
axes[0, 3].set_title('旋转90° Score')
|
||
|
||
axes[1, 0].imshow(desc_orig[0, 0].numpy(), cmap='viridis')
|
||
axes[1, 0].set_title(f'原始Desc ch0')
|
||
axes[1, 1].imshow(desc_rot90_back[0, 0].numpy(), cmap='viridis')
|
||
axes[1, 1].set_title(f'旋回后Desc ch0\n相似度:{cos_sim_90.item():.3f}')
|
||
axes[1, 2].imshow((desc_orig[0, 0] - desc_rot90_back[0, 0]).abs().numpy(), cmap='Reds')
|
||
axes[1, 2].set_title('差异热图 ch0')
|
||
axes[1, 3].axis('off')
|
||
|
||
for ax in axes.flatten():
|
||
if ax.collections or ax.images:
|
||
continue
|
||
ax.axis('off')
|
||
|
||
plt.suptitle('RICNN 旋转不变性测试', fontsize=14, fontweight='bold')
|
||
plt.tight_layout()
|
||
path = os.path.join(OUTPUT_DIR, 'ricnn_rotation_invariance.png')
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f' [保存] {path}')
|
||
|
||
return cos_sim_90.item(), cos_sim_180.item()
|
||
|
||
|
||
def visualize_ricnn_intermediate():
|
||
"""可视化RICNN中间层特征"""
|
||
print('\n--- RICNN 中间特征可视化 ---')
|
||
|
||
model = RICNN()
|
||
model.eval()
|
||
|
||
# 使用更有结构的输入
|
||
x = torch.linspace(-1, 1, 320)
|
||
y = torch.linspace(-1, 1, 320)
|
||
grid_y, grid_x = torch.meshgrid(y, x, indexing='ij')
|
||
r = torch.sqrt(grid_x ** 2 + grid_y ** 2)
|
||
|
||
bev = torch.zeros(1, 3, 320, 320)
|
||
bev[0, 0] = (torch.sin(grid_x * 10) * torch.cos(grid_y * 10) + 1) / 2
|
||
bev[0, 1] = (torch.cos(r * 5) + 1) / 2
|
||
bev[0, 2] = (r < 0.5).float()
|
||
|
||
# 逐层前向
|
||
with torch.no_grad():
|
||
x1 = model.block1(bev)
|
||
x2 = model.pool2(x1)
|
||
x2 = model.block2(x2)
|
||
x3 = model.pool4(x2)
|
||
x3 = model.block3(x3)
|
||
x4 = model.pool4(x3)
|
||
x4 = model.block4(x4)
|
||
|
||
print(f'输入BEV: {bev.shape}')
|
||
print(f'block1 (RIConvBlock 3→16): {x1.shape}')
|
||
print(f'pool2+block2 (RIResBlock 16→32): {x2.shape}')
|
||
print(f'pool4+block3 (RIResBlock 32→64): {x3.shape}')
|
||
print(f'pool4+block4 (RIResBlock 64→128): {x4.shape}')
|
||
|
||
visualize_tensor(x1, 'RICNN Block1 输出 (16通道)', 'ricnn_block1.png')
|
||
visualize_tensor(x2, 'RICNN Block2 输出 (32通道)', 'ricnn_block2.png')
|
||
visualize_tensor(x3, 'RICNN Block3 输出 (64通道)', 'ricnn_block3.png')
|
||
visualize_tensor(x4, 'RICNN Block4 输出 (128通道)', 'ricnn_block4.png')
|
||
|
||
|
||
def visualize_position_encoding():
|
||
"""可视化位置编码模块"""
|
||
print('\n--- EncodePosition 位置编码可视化 ---')
|
||
|
||
ep = EncodePosition(feature_size=128)
|
||
ep.eval()
|
||
|
||
# 模拟150个BEV关键点 (B, 150, 4) — [x,y,z,intensity]
|
||
kpts = torch.randn(2, 150, 4)
|
||
kpts[:, :, :2] = kpts[:, :, :2] * 30 # x,y 在 ±30m 范围
|
||
kpts[:, :, 2] = 0 # z=0 (BEV平面)
|
||
kpts[:, :, 3] = 1 # intensity=1
|
||
|
||
# 模拟特征 (B, 128, 150)
|
||
fea = torch.randn(2, 128, 150)
|
||
|
||
with torch.no_grad():
|
||
fea_encoded = ep(kpts, fea)
|
||
|
||
print(f'关键点输入: {kpts.shape}')
|
||
print(f'原始特征: {fea.shape}')
|
||
print(f'位置编码后特征: {fea_encoded.shape}')
|
||
|
||
# 可视化距离直方图
|
||
x1 = kpts[0].unsqueeze(1) # (150, 1, 4)
|
||
x2 = kpts[0].unsqueeze(0) # (1, 150, 4)
|
||
dx = x1 - x2
|
||
distance = dx.norm(p=2, dim=2) # (150, 150)
|
||
|
||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||
im0 = axes[0].imshow(distance.numpy(), cmap='plasma')
|
||
axes[0].set_title('关键点间距离矩阵 (150x150)')
|
||
axes[0].set_xlabel('Keypoint j')
|
||
axes[0].set_ylabel('Keypoint i')
|
||
plt.colorbar(im0, ax=axes[0])
|
||
|
||
# 示例直方图 (第一个关键点)
|
||
hist = torch.histc(distance[0], bins=16, min=1, max=80)
|
||
axes[1].bar(range(16), hist.numpy(), color='steelblue')
|
||
axes[1].set_title('距离直方图 (16 bins, 1-80m)\n用于位置编码')
|
||
axes[1].set_xlabel('Distance Bin')
|
||
axes[1].set_ylabel('Count')
|
||
|
||
plt.suptitle('EncodePosition 位置编码模块', fontsize=14, fontweight='bold')
|
||
plt.tight_layout()
|
||
path = os.path.join(OUTPUT_DIR, 'ricnn_position_encoding.png')
|
||
plt.savefig(path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
print(f' [保存] {path}')
|
||
|
||
|
||
def analyze_parameters():
|
||
"""参数量分析"""
|
||
print('\n--- 参数量分析 ---')
|
||
model = RICNN()
|
||
total = sum(p.numel() for p in model.parameters())
|
||
print(f'总参数量: {total:,} ({total / 1e6:.2f}M)')
|
||
for name, module in model.named_children():
|
||
params = sum(p.numel() for p in module.parameters())
|
||
print(f' {name:20s}: {params:>10,} params ({params / 1e3:.1f}K)')
|
||
|
||
|
||
def main():
|
||
print('=' * 60)
|
||
print('RICNN (旋转不变CNN) 网络结构与特征可视化')
|
||
print('=' * 60)
|
||
|
||
analyze_parameters()
|
||
|
||
# 1. 卷积核分组可视化
|
||
visualize_ri_conv_kernel()
|
||
|
||
# 2. 池化区域可视化
|
||
visualize_ri_pooling()
|
||
|
||
# 3. 中间特征可视化
|
||
visualize_ricnn_intermediate()
|
||
|
||
# 4. 旋转不变性测试
|
||
test_rotation_invariance()
|
||
|
||
# 5. 位置编码可视化
|
||
visualize_position_encoding()
|
||
|
||
print('\n' + '=' * 60)
|
||
print('网络结构总结:')
|
||
print('=' * 60)
|
||
print("""
|
||
RICNN (Rotation-Invariant CNN):
|
||
┌──────────────────────────────────────────────────────┐
|
||
│ 输入: BEV图像 (B, 3, 320, 320) │
|
||
│ ↓ │
|
||
│ block1: RIConvBlock(3→16) → (B, 16, 320, 320) │
|
||
│ ↓ RIMaxpool2d(2) │
|
||
│ block2: RIResBlock(16→32) → (B, 32, 160, 160) │
|
||
│ ↓ RIMaxpool2d(5, s=4) │
|
||
│ block3: RIResBlock(32→64) → (B, 64, 40, 40) │
|
||
│ ↓ RIMaxpool2d(5, s=4) │
|
||
│ block4: RIResBlock(64→128) → (B, 128, 10, 10) │
|
||
│ ↓ │
|
||
│ 多尺度特征聚合 (1x1conv + 上采样 + concat) │
|
||
│ → (B, 128, 320, 320) │
|
||
│ ↓ Conv1x1(128→129) │
|
||
│ 输出: score(B,1,320,320) + desc(B,128,320,320) │
|
||
└──────────────────────────────────────────────────────┘
|
||
|
||
旋转不变性的实现:
|
||
- RIConv2d: 根据kernel位置到中心的欧氏距离分组
|
||
同距离的位置共享权重 → 旋转后权重不变
|
||
- RIMaxpool2d: 只在圆形邻域内取max(忽略角点)
|
||
- RIAvgpool2d: 只在圆形邻域内取mean
|
||
|
||
EncodePosition (位置编码):
|
||
- 输入: 150个关键点的3D坐标
|
||
- 计算150×150距离矩阵 → 直方图(16 bins) → MLP
|
||
- 残差加到特征上,增强空间感知能力
|
||
""")
|
||
|
||
print(f'\n所有可视化结果保存在: {OUTPUT_DIR}')
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|