Files
fusion_LCD/network_learning/02_ricnn_demo.py
2026-05-09 17:03:40 +08:00

426 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
RICNN 旋转不变CNN网络结构可视化 Demo
=======================================
RICNN 是点云BEV分支的特征提取网络核心创新是"旋转不变性"
与标准CNN不同RICNN的卷积核根据像素到中心的欧氏距离分组
使得旋转后的特征保持一致。
输入BEV图像 (B, 3, 320, 320)
输出score_map (B, 1, 320, 320) + descriptor_map (B, 128, 320, 320)
关键组件:
RIConv2d: 旋转不变卷积(按距离分组共享权重)
RIMaxpool2d: 旋转不变最大池化只对圆形邻域取max
RIAvgpool2d: 旋转不变平均池化只对圆形邻域取avg
RIResBlock: 旋转不变残差块
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from BEVNet import RICNN, RIConv2d, RIMaxpool2d, RIAvgpool2d, EncodePosition
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), 'output')
os.makedirs(OUTPUT_DIR, exist_ok=True)
def visualize_tensor(tensor, title, save_name, cmap='viridis', n_channels=8):
"""可视化特征图"""
if tensor.dim() == 4:
tensor = tensor[0]
C, H, W = tensor.shape
n_show = min(n_channels, C)
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
fig.suptitle(title, fontsize=14, fontweight='bold')
for i in range(n_show):
ax = axes[i // 4, i % 4]
im = ax.imshow(tensor[i].detach().cpu().numpy(), cmap=cmap)
ax.set_title(f'Channel {i}')
ax.axis('off')
plt.colorbar(im, ax=ax, fraction=0.046)
for i in range(n_show, 8):
axes[i // 4, i % 4].axis('off')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, save_name)
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def visualize_score_map(score_map, title, save_name):
"""可视化得分图"""
if score_map.dim() == 4:
score_map = score_map[0, 0]
elif score_map.dim() == 3:
score_map = score_map[0]
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
fig.suptitle(title, fontsize=14, fontweight='bold')
im0 = axes[0].imshow(score_map.detach().cpu().numpy(), cmap='hot')
axes[0].set_title('Score Map')
axes[0].axis('off')
plt.colorbar(im0, ax=axes[0])
axes[1].hist(score_map.detach().cpu().numpy().flatten(), bins=50, color='steelblue', edgecolor='white')
axes[1].set_title('Score 分布')
axes[1].set_xlabel('Score')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, save_name)
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def visualize_ri_conv_kernel():
"""可视化旋转不变卷积核的权重分组模式"""
print('\n--- RIConv2d 卷积核分组可视化 ---')
fig, axes = plt.subplots(1, 3, figsize=(16, 5))
for idx, kz in enumerate([3, 5, 7]):
# 计算距离掩码
coords = torch.arange(kz ** 2).view(-1, 1)
row = torch.div(coords, kz, rounding_mode='floor')
col = torch.fmod(coords, kz)
coords = torch.cat([row, col], dim=1)
dis = (coords - 0.5 * (kz - 1)).norm(dim=1) + 0.5 * (kz % 2 - 1)
dis = dis.view(kz, kz)
dis = torch.round(dis).long()
dis[dis > 0.5 * (kz - 1)] = -1
ax = axes[idx]
im = ax.imshow(dis.numpy(), cmap='tab10')
ax.set_title(f'Kernel {kz}x{kz}\nDistance Groups: {dis.max().item() + 1}')
# 标注每个位置的距离值
for i in range(kz):
for j in range(kz):
val = dis[i, j].item()
color = 'white' if val >= 0 else 'red'
ax.text(j, i, str(val), ha='center', va='center', fontsize=8, color=color)
ax.axis('off')
plt.suptitle('RIConv2d: 按到中心距离分组的卷积核权重', fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'ricnn_kernel_groups.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def visualize_ri_pooling():
"""可视化旋转不变池化的有效区域"""
print('\n--- 旋转不变池化区域可视化 ---')
fig, axes = plt.subplots(2, 2, figsize=(10, 10))
# RIMaxpool2d 有效区域 (kernel_size=5)
kz = 5
coords = torch.arange(kz ** 2).view(-1, 1)
row = torch.div(coords, kz, rounding_mode='floor')
col = torch.fmod(coords, kz)
coords = torch.cat([row, col], dim=1)
dis = (coords - 0.5 * (kz - 1)).norm(dim=1) + 0.5 * (kz % 2 - 1)
dis = dis.view(kz, kz)
dis = torch.round(dis)
dis[dis > 0.5 * (kz - 1)] = -1
mask_ri = (dis > -1).numpy().astype(float)
# 标准 MaxPool2d 有效区域(正方形)
mask_std = np.ones((kz, kz))
ax = axes[0, 0]
ax.imshow(mask_std, cmap='Blues')
ax.set_title(f'标准 MaxPool {kz}x{kz}\n有效区域: {mask_std.sum():.0f} 个像素', fontsize=12)
for i in range(kz):
for j in range(kz):
ax.text(j, i, '', ha='center', va='center', fontsize=10)
ax.axis('off')
ax = axes[0, 1]
ax.imshow(mask_ri, cmap='Oranges')
ax.set_title(f'RI MaxPool {kz}x{kz}\n有效区域: {mask_ri.sum():.0f} 个像素 (圆形)', fontsize=12)
for i in range(kz):
for j in range(kz):
text = '' if mask_ri[i, j] else ''
color = 'white' if mask_ri[i, j] else 'red'
ax.text(j, i, text, ha='center', va='center', fontsize=10, color=color)
ax.axis('off')
# 可视化旋转不变性:对比旋转前后的特征
ax = axes[1, 0]
ax.set_title('旋转不变性原理', fontsize=12)
ax.text(0.5, 0.7, '标准CNN:', transform=ax.transAxes, fontsize=11, ha='center',
bbox=dict(boxstyle='round', facecolor='lightblue'))
ax.text(0.5, 0.5, '旋转图像 → 特征也旋转 → 不匹配', transform=ax.transAxes, fontsize=10, ha='center')
ax.text(0.5, 0.3, 'RICNN:', transform=ax.transAxes, fontsize=11, ha='center',
bbox=dict(boxstyle='round', facecolor='lightgreen'))
ax.text(0.5, 0.1, '旋转图像 → 特征不变 → 可以匹配', transform=ax.transAxes, fontsize=10, ha='center')
ax.axis('off')
ax = axes[1, 1]
ax.set_title('RI vs 标准 CNN 对比', fontsize=12)
categories = ['旋转鲁棒性', '计算效率', '平移不变性', '尺度不变性']
ri_scores = [0.9, 0.7, 0.8, 0.5]
std_scores = [0.3, 1.0, 0.8, 0.5]
x = np.arange(len(categories))
width = 0.35
ax.bar(x - width / 2, ri_scores, width, label='RICNN', color='orange', alpha=0.8)
ax.bar(x + width / 2, std_scores, width, label='标准CNN', color='blue', alpha=0.8)
ax.set_xticks(x)
ax.set_xticklabels(categories, fontsize=9)
ax.set_ylim(0, 1.2)
ax.legend()
ax.set_ylabel('能力评分')
plt.suptitle('RICNN 旋转不变池化详解', fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'ricnn_pooling_visualization.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def test_rotation_invariance():
"""测试旋转不变性:对比旋转前后特征差异"""
print('\n--- 旋转不变性测试 ---')
model = RICNN()
model.eval()
# 创建测试BEV图像带明显特征
bev = torch.zeros(1, 3, 320, 320)
# 添加一些矩形特征
bev[0, 0, 100:120, 150:170] = 1.0
bev[0, 1, 140:160, 100:140] = 0.8
bev[0, 2, 150:170, 160:200] = 0.6
with torch.no_grad():
score_orig, desc_orig = model(bev)
# 旋转90度
bev_rot90 = torch.rot90(bev, k=1, dims=[2, 3])
score_rot90, desc_rot90 = model(bev_rot90)
# 旋转回去比较
desc_rot90_back = torch.rot90(desc_rot90, k=-1, dims=[2, 3])
# 旋转180度
bev_rot180 = torch.rot90(bev, k=2, dims=[2, 3])
score_rot180, desc_rot180 = model(bev_rot180)
desc_rot180_back = torch.rot90(desc_rot180, k=-2, dims=[2, 3])
# 计算相似度
cos_sim_90 = torch.nn.functional.cosine_similarity(
desc_orig.flatten(), desc_rot90_back.flatten(), dim=0)
cos_sim_180 = torch.nn.functional.cosine_similarity(
desc_orig.flatten(), desc_rot180_back.flatten(), dim=0)
print(f'原始 vs 旋转90°后特征 余弦相似度: {cos_sim_90.item():.4f}')
print(f'原始 vs 旋转180°后特征 余弦相似度: {cos_sim_180.item():.4f}')
print(f'(越接近1.0说明旋转不变性越好)')
# 可视化
fig, axes = plt.subplots(2, 4, figsize=(18, 8))
axes[0, 0].imshow(bev[0].permute(1, 2, 0).numpy())
axes[0, 0].set_title('原始BEV')
axes[0, 1].imshow(bev_rot90[0].permute(1, 2, 0).numpy())
axes[0, 1].set_title('旋转90°')
axes[0, 2].imshow(score_orig[0, 0].numpy(), cmap='hot')
axes[0, 2].set_title('原始Score')
axes[0, 3].imshow(score_rot90[0, 0].numpy(), cmap='hot')
axes[0, 3].set_title('旋转90° Score')
axes[1, 0].imshow(desc_orig[0, 0].numpy(), cmap='viridis')
axes[1, 0].set_title(f'原始Desc ch0')
axes[1, 1].imshow(desc_rot90_back[0, 0].numpy(), cmap='viridis')
axes[1, 1].set_title(f'旋回后Desc ch0\n相似度:{cos_sim_90.item():.3f}')
axes[1, 2].imshow((desc_orig[0, 0] - desc_rot90_back[0, 0]).abs().numpy(), cmap='Reds')
axes[1, 2].set_title('差异热图 ch0')
axes[1, 3].axis('off')
for ax in axes.flatten():
if ax.collections or ax.images:
continue
ax.axis('off')
plt.suptitle('RICNN 旋转不变性测试', fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'ricnn_rotation_invariance.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
return cos_sim_90.item(), cos_sim_180.item()
def visualize_ricnn_intermediate():
"""可视化RICNN中间层特征"""
print('\n--- RICNN 中间特征可视化 ---')
model = RICNN()
model.eval()
# 使用更有结构的输入
x = torch.linspace(-1, 1, 320)
y = torch.linspace(-1, 1, 320)
grid_y, grid_x = torch.meshgrid(y, x, indexing='ij')
r = torch.sqrt(grid_x ** 2 + grid_y ** 2)
bev = torch.zeros(1, 3, 320, 320)
bev[0, 0] = (torch.sin(grid_x * 10) * torch.cos(grid_y * 10) + 1) / 2
bev[0, 1] = (torch.cos(r * 5) + 1) / 2
bev[0, 2] = (r < 0.5).float()
# 逐层前向
with torch.no_grad():
x1 = model.block1(bev)
x2 = model.pool2(x1)
x2 = model.block2(x2)
x3 = model.pool4(x2)
x3 = model.block3(x3)
x4 = model.pool4(x3)
x4 = model.block4(x4)
print(f'输入BEV: {bev.shape}')
print(f'block1 (RIConvBlock 3→16): {x1.shape}')
print(f'pool2+block2 (RIResBlock 16→32): {x2.shape}')
print(f'pool4+block3 (RIResBlock 32→64): {x3.shape}')
print(f'pool4+block4 (RIResBlock 64→128): {x4.shape}')
visualize_tensor(x1, 'RICNN Block1 输出 (16通道)', 'ricnn_block1.png')
visualize_tensor(x2, 'RICNN Block2 输出 (32通道)', 'ricnn_block2.png')
visualize_tensor(x3, 'RICNN Block3 输出 (64通道)', 'ricnn_block3.png')
visualize_tensor(x4, 'RICNN Block4 输出 (128通道)', 'ricnn_block4.png')
def visualize_position_encoding():
"""可视化位置编码模块"""
print('\n--- EncodePosition 位置编码可视化 ---')
ep = EncodePosition(feature_size=128)
ep.eval()
# 模拟150个BEV关键点 (B, 150, 4) — [x,y,z,intensity]
kpts = torch.randn(2, 150, 4)
kpts[:, :, :2] = kpts[:, :, :2] * 30 # x,y 在 ±30m 范围
kpts[:, :, 2] = 0 # z=0 (BEV平面)
kpts[:, :, 3] = 1 # intensity=1
# 模拟特征 (B, 128, 150)
fea = torch.randn(2, 128, 150)
with torch.no_grad():
fea_encoded = ep(kpts, fea)
print(f'关键点输入: {kpts.shape}')
print(f'原始特征: {fea.shape}')
print(f'位置编码后特征: {fea_encoded.shape}')
# 可视化距离直方图
x1 = kpts[0].unsqueeze(1) # (150, 1, 4)
x2 = kpts[0].unsqueeze(0) # (1, 150, 4)
dx = x1 - x2
distance = dx.norm(p=2, dim=2) # (150, 150)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
im0 = axes[0].imshow(distance.numpy(), cmap='plasma')
axes[0].set_title('关键点间距离矩阵 (150x150)')
axes[0].set_xlabel('Keypoint j')
axes[0].set_ylabel('Keypoint i')
plt.colorbar(im0, ax=axes[0])
# 示例直方图 (第一个关键点)
hist = torch.histc(distance[0], bins=16, min=1, max=80)
axes[1].bar(range(16), hist.numpy(), color='steelblue')
axes[1].set_title('距离直方图 (16 bins, 1-80m)\n用于位置编码')
axes[1].set_xlabel('Distance Bin')
axes[1].set_ylabel('Count')
plt.suptitle('EncodePosition 位置编码模块', fontsize=14, fontweight='bold')
plt.tight_layout()
path = os.path.join(OUTPUT_DIR, 'ricnn_position_encoding.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f' [保存] {path}')
def analyze_parameters():
"""参数量分析"""
print('\n--- 参数量分析 ---')
model = RICNN()
total = sum(p.numel() for p in model.parameters())
print(f'总参数量: {total:,} ({total / 1e6:.2f}M)')
for name, module in model.named_children():
params = sum(p.numel() for p in module.parameters())
print(f' {name:20s}: {params:>10,} params ({params / 1e3:.1f}K)')
def main():
print('=' * 60)
print('RICNN (旋转不变CNN) 网络结构与特征可视化')
print('=' * 60)
analyze_parameters()
# 1. 卷积核分组可视化
visualize_ri_conv_kernel()
# 2. 池化区域可视化
visualize_ri_pooling()
# 3. 中间特征可视化
visualize_ricnn_intermediate()
# 4. 旋转不变性测试
test_rotation_invariance()
# 5. 位置编码可视化
visualize_position_encoding()
print('\n' + '=' * 60)
print('网络结构总结:')
print('=' * 60)
print("""
RICNN (Rotation-Invariant CNN):
┌──────────────────────────────────────────────────────┐
│ 输入: BEV图像 (B, 3, 320, 320) │
│ ↓ │
│ block1: RIConvBlock(3→16) → (B, 16, 320, 320) │
│ ↓ RIMaxpool2d(2) │
│ block2: RIResBlock(16→32) → (B, 32, 160, 160) │
│ ↓ RIMaxpool2d(5, s=4) │
│ block3: RIResBlock(32→64) → (B, 64, 40, 40) │
│ ↓ RIMaxpool2d(5, s=4) │
│ block4: RIResBlock(64→128) → (B, 128, 10, 10) │
│ ↓ │
│ 多尺度特征聚合 (1x1conv + 上采样 + concat) │
│ → (B, 128, 320, 320) │
│ ↓ Conv1x1(128→129) │
│ 输出: score(B,1,320,320) + desc(B,128,320,320) │
└──────────────────────────────────────────────────────┘
旋转不变性的实现:
- RIConv2d: 根据kernel位置到中心的欧氏距离分组
同距离的位置共享权重 → 旋转后权重不变
- RIMaxpool2d: 只在圆形邻域内取max忽略角点
- RIAvgpool2d: 只在圆形邻域内取mean
EncodePosition (位置编码):
- 输入: 150个关键点的3D坐标
- 计算150×150距离矩阵 → 直方图(16 bins) → MLP
- 残差加到特征上,增强空间感知能力
""")
print(f'\n所有可视化结果保存在: {OUTPUT_DIR}')
if __name__ == '__main__':
main()