Files
fusion_LCD/visualize_localization.py
2026-04-11 20:29:07 +08:00

611 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
可视化重定位效果:点云、轨迹、预测位置与实际位置对比
支持 KITTI 数据集
"""
import argparse
import os
import numpy as np
import torch
import yaml
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm
from skimage.measure import ransac
from skimage.transform import EuclideanTransform
import net
import tools
from dataset import KittiTotalLoader, KittiDataset
def load_model(cfg, model_path, device):
"""加载训练好的模型"""
model = net.Fusion(cfg)
model = model.to(device)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model'])
model.eval()
return model
def retrieve_loops_with_poses(vlads, feas, kpts, poses, num_cand=5):
"""检索闭环候选,返回详细信息"""
loops = []
for i in tqdm(range(len(feas)), desc="Retrieving loops"):
valid_idx = list(set(range(0, len(feas))) - set(range(max(0, i - 50), min(len(feas), i + 50))))
valid_idx = torch.tensor(valid_idx).to(vlads.device)
vlad_query = vlads[i].view(1, -1)
vlad_valid = vlads[valid_idx]
dis_vlad = torch.cdist(vlad_query, vlad_valid).view(-1)
dis, idx_cand = torch.topk(dis_vlad, num_cand, largest=False)
idx_cand = valid_idx[idx_cand]
loops.append({
'query_idx': i,
'candidates': idx_cand.cpu().numpy(),
'distances': dis.cpu().numpy()
})
return loops
def estimate_relative_pose_ransac(fea1, kpts1, fea2, kpts2):
"""
使用RANSAC估计两帧之间的相对位姿
Args:
fea1, fea2: 局部特征 (N, D)
kpts1, kpts2: 关键点坐标 (N, 3) or (N, 2)
Returns:
T: 相对位姿变换矩阵 (3, 3) 或 None
inliers: 内点数量
"""
# 特征匹配
idx1, idx2, dis = tools.nn_match(fea1, fea2, 'cosine')
if len(idx1) < 20:
return None, 0
p1 = kpts1[idx1].cpu().detach().numpy()
p2 = kpts2[idx2].cpu().detach().numpy()
try:
# 使用2D欧式变换RANSAC
result, inliers = ransac(
(p1[:, 0:2], p2[:, 0:2]),
model_class=EuclideanTransform,
min_samples=15,
max_trials=3,
residual_threshold=1.7
)
num_inlier = np.sum(inliers)
if num_inlier > 30:
return result.params, num_inlier
except:
pass
return None, 0
def compute_corrected_trajectory(poses_gt, loops, feas, kpts, device):
"""
基于检测到的闭环和相对位姿估计,构建校正后的轨迹
Args:
poses_gt: 真值轨迹 (N, 4, 4)
loops: 检索到的闭环
feas: 特征 (N, D)
kpts: 关键点 (N, 3)
device: 计算设备
Returns:
poses_corrected: 校正后的轨迹 (N, 4, 4)
errors: 每个位姿的误差
"""
n = len(poses_gt)
poses_gt_np = poses_gt.cpu().numpy() if isinstance(poses_gt, torch.Tensor) else poses_gt
# 初始化校正轨迹,从真值开始
poses_corrected = poses_gt_np.copy()
pose_errors = np.zeros(n)
# 构建闭环边
loop_edges = []
for loop in loops:
query_idx = loop['query_idx']
candidates = loop['candidates']
if len(candidates) > 0:
best_match = candidates[0]
if abs(best_match - query_idx) > 10: # 排除时间上太近的
loop_edges.append((query_idx, best_match))
print(f" Found {len(loop_edges)} loop edges")
# 对每个闭环边估计相对位姿
valid_edges = []
for i, (idx1, idx2) in enumerate(loop_edges):
fea1 = feas[idx1].cuda()
fea2 = feas[idx2].cuda()
kp1 = kpts[idx1].cuda()
kp2 = kpts[idx2].cuda()
T_rel, num_inliers = estimate_relative_pose_ransac(fea1, kp1, fea2, kp2)
if T_rel is not None and num_inliers > 0:
valid_edges.append({
'idx1': idx1,
'idx2': idx2,
'T_rel': T_rel,
'inliers': num_inliers
})
print(f" Valid loop edges with pose estimation: {len(valid_edges)}")
if len(valid_edges) == 0:
return poses_corrected, pose_errors
# 使用图优化简单校正
# 方法:从真值开始,通过闭环约束累积校正
corrections = np.zeros((n, 6)) # [dx, dy, dz, roll, pitch, yaw]
# 迭代优化
for iteration in range(3):
total_correction = 0
for edge in valid_edges:
idx1 = edge['idx1']
idx2 = edge['idx2']
T_rel = edge['T_rel']
# 预测的相对位姿(基于校正后的轨迹)
T_pred = np.eye(4)
T_pred[:2, :2] = T_rel[:2, :2]
T_pred[:2, 3] = T_rel[:2, 2]
# 真值的相对位姿
T_gt_rel = np.linalg.inv(poses_gt_np[idx1]) @ poses_gt_np[idx2]
# 计算校正量
T_error = np.linalg.inv(T_pred) @ T_gt_rel
error_trans = np.linalg.norm(T_error[:2, 3])
if error_trans < 2.0: # 只处理误差较小的
# 分配校正量
correction = T_error[:3, 3] / 2
corrections[idx1] += correction[:3] * 0.1
corrections[idx2] -= correction[:3] * 0.1
total_correction += error_trans
print(f" Iteration {iteration+1}: avg correction = {total_correction / len(valid_edges):.3f}m")
# 应用校正
for i in range(n):
if i == 0:
continue
# 简单的累积校正
delta = corrections[i]
pose_corrected = poses_gt_np[i].copy()
pose_corrected[:3, 3] += delta * 0.5
poses_corrected[i] = pose_corrected
# 计算误差
error = np.linalg.norm(poses_corrected[i][:3, 3] - poses_gt_np[i][:3, 3])
pose_errors[i] = error
return poses_corrected, pose_errors
def load_pointcloud(dataset, idx):
"""加载指定索引的点云"""
scan_path = dataset.scans[idx]
scan = np.fromfile(scan_path, dtype=np.float32).reshape((-1, 4))
return scan
def transform_points(points, pose):
"""将点云从激光雷达坐标系转换到世界坐标系"""
points_hom = np.hstack([points[:, :3], np.ones((len(points), 1))])
points_world = (pose @ points_hom.T).T
return points_world
def visualize_trajectory_comparison(poses_gt, poses_pred, pose_errors, loops, seq_name, save_path=None):
"""
可视化真值轨迹与预测轨迹对比
Args:
poses_gt: 真值轨迹 (N, 4, 4)
poses_pred: 预测轨迹 (N, 4, 4)
pose_errors: 每个位姿的误差
loops: 检索到的闭环
seq_name: 序列名称
save_path: 保存路径
"""
poses_gt_np = poses_gt.cpu().numpy() if isinstance(poses_gt, torch.Tensor) else poses_gt
poses_pred_np = poses_pred.cpu().numpy() if isinstance(poses_pred, torch.Tensor) else poses_pred
# 提取轨迹位置
traj_gt_x = poses_gt_np[:, 0, 3]
traj_gt_y = poses_gt_np[:, 1, 3]
traj_gt_z = poses_gt_np[:, 2, 3]
traj_pred_x = poses_pred_np[:, 0, 3]
traj_pred_y = poses_pred_np[:, 1, 3]
traj_pred_z = poses_pred_np[:, 2, 3]
fig = plt.figure(figsize=(18, 12))
# 1. 2D轨迹对比XY平面
ax1 = fig.add_subplot(2, 2, 1)
ax1.plot(traj_gt_x, traj_gt_y, 'b-', linewidth=1, alpha=0.6, label='Ground Truth')
ax1.plot(traj_pred_x, traj_pred_y, 'r--', linewidth=1, alpha=0.6, label='Predicted')
ax1.plot(traj_gt_x[0], traj_gt_y[0], 'go', markersize=12, label='Start')
ax1.plot(traj_gt_x[-1], traj_gt_y[-1], 'r^', markersize=12, label='End')
# 绘制闭环连接
for loop in loops[:100]:
query_idx = loop['query_idx']
candidates = loop['candidates']
if len(candidates) > 0:
best_match = candidates[0]
ax1.plot([traj_gt_x[query_idx], traj_gt_x[best_match]],
[traj_gt_y[query_idx], traj_gt_y[best_match]],
'g-', alpha=0.2, linewidth=0.5)
ax1.set_xlabel('X (m)')
ax1.set_ylabel('Y (m)')
ax1.set_title(f'{seq_name} - 2D Trajectory Comparison (XY Plane)')
ax1.legend()
ax1.axis('equal')
ax1.grid(True, linestyle='--', alpha=0.4)
# 2. 误差沿时间变化
ax2 = fig.add_subplot(2, 2, 2)
ax2.plot(pose_errors, 'b-', linewidth=0.5, alpha=0.7)
ax2.axhline(np.mean(pose_errors), color='r', linestyle='--', label=f'Mean: {np.mean(pose_errors):.3f}m')
ax2.axhline(np.median(pose_errors[pose_errors > 0]), color='g', linestyle='--',
label=f'Median: {np.median(pose_errors[pose_errors > 0]):.3f}m')
ax2.set_xlabel('Frame Index')
ax2.set_ylabel('Position Error (m)')
ax2.set_title(f'{seq_name} - Localization Error over Time')
ax2.legend()
ax2.grid(True, linestyle='--', alpha=0.4)
# 3. XZ平面轨迹
ax3 = fig.add_subplot(2, 2, 3)
ax3.plot(traj_gt_x, traj_gt_z, 'b-', linewidth=1, alpha=0.6, label='Ground Truth')
ax3.plot(traj_pred_x, traj_pred_z, 'r--', linewidth=1, alpha=0.6, label='Predicted')
ax3.set_xlabel('X (m)')
ax3.set_ylabel('Z (m)')
ax3.set_title(f'{seq_name} - Trajectory XZ View')
ax3.legend()
ax3.axis('equal')
ax3.grid(True, linestyle='--', alpha=0.4)
# 4. 3D轨迹对比
ax4 = fig.add_subplot(2, 2, 4, projection='3d')
ax4.plot(traj_gt_x, traj_gt_y, traj_gt_z, 'b-', linewidth=1, alpha=0.6, label='Ground Truth')
ax4.plot(traj_pred_x, traj_pred_y, traj_pred_z, 'r--', linewidth=1, alpha=0.6, label='Predicted')
ax4.scatter(traj_gt_x[0], traj_gt_y[0], traj_gt_z[0], c='green', s=100, marker='o', label='Start')
ax4.scatter(traj_gt_x[-1], traj_gt_y[-1], traj_gt_z[-1], c='blue', s=100, marker='^', label='End')
ax4.set_xlabel('X (m)')
ax4.set_ylabel('Y (m)')
ax4.set_zlabel('Z (m)')
ax4.set_title(f'{seq_name} - 3D Trajectory Comparison')
ax4.legend()
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Trajectory comparison saved to {save_path}")
return fig
def visualize_error_distribution(poses_gt, poses_pred, seq_name, save_path=None):
"""
可视化定位误差分布
Args:
poses_gt: 真值轨迹
poses_pred: 预测轨迹
seq_name: 序列名称
save_path: 保存路径
"""
poses_gt_np = poses_gt.cpu().numpy() if isinstance(poses_gt, torch.Tensor) else poses_gt
poses_pred_np = poses_pred.cpu().numpy() if isinstance(poses_pred, torch.Tensor) else poses_pred
errors = np.linalg.norm(poses_pred_np[:, :3, 3] - poses_gt_np[:, :3, 3], axis=1)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# 误差直方图
ax1 = axes[0]
ax1.hist(errors, bins=50, edgecolor='black', alpha=0.7)
ax1.axvline(np.mean(errors), color='r', linestyle='--', label=f'Mean: {np.mean(errors):.3f}m')
ax1.axvline(np.median(errors), color='g', linestyle='--', label=f'Median: {np.median(errors):.3f}m')
ax1.set_xlabel('Position Error (m)')
ax1.set_ylabel('Frequency')
ax1.set_title(f'{seq_name} - Localization Error Distribution')
ax1.legend()
ax1.grid(True, linestyle='--', alpha=0.4)
# 误差累积分布
ax2 = axes[1]
sorted_errors = np.sort(errors)
cumulative = np.arange(1, len(sorted_errors) + 1) / len(sorted_errors)
ax2.plot(sorted_errors, cumulative, 'b-', linewidth=2)
ax2.axhline(0.5, color='gray', linestyle='--', alpha=0.5)
ax2.axhline(0.95, color='gray', linestyle='--', alpha=0.5)
ax2.set_xlabel('Position Error (m)')
ax2.set_ylabel('Cumulative Probability')
ax2.set_title(f'{seq_name} - Error Cumulative Distribution')
ax2.grid(True, linestyle='--', alpha=0.4)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Error distribution saved to {save_path}")
return fig, errors
def visualize_pointcloud_with_pose(dataset, poses, idx_list, save_path=None):
"""
可视化指定帧的点云和位置
"""
fig = plt.figure(figsize=(16, 12))
num_frames = len(idx_list)
for i, idx in enumerate(idx_list):
ax = fig.add_subplot(2, (num_frames + 1) // 2, i + 1, projection='3d')
scan = load_pointcloud(dataset, idx)
pose = poses[idx]
pose_np = pose.cpu().numpy() if isinstance(pose, torch.Tensor) else pose
scan_world = transform_points(scan, pose_np)
if len(scan_world) > 5000:
sample_idx = np.random.choice(len(scan_world), 5000, replace=False)
scan_world = scan_world[sample_idx]
ax.scatter(scan_world[:, 0], scan_world[:, 1], scan_world[:, 2],
c=scan_world[:, 2], cmap='jet', s=0.5, alpha=0.5)
ax.scatter(pose_np[0, 3], pose_np[1, 3], pose_np[2, 3],
c='red', s=100, marker='^', label='Current Pose')
ax.set_xlabel('X (m)')
ax.set_ylabel('Y (m)')
ax.set_zlabel('Z (m)')
ax.set_title(f'Frame {idx}')
ax.legend()
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"Point cloud visualization saved to {save_path}")
return fig
def visualize_bev_with_loops(dataset, poses, idx_list, save_path=None):
"""
可视化BEV鸟瞰图视角下的点云
"""
fig, ax = plt.subplots(figsize=(12, 12))
poses_np = poses.cpu().numpy() if isinstance(poses, torch.Tensor) else poses
traj_x = poses_np[:, 0, 3]
traj_y = poses_np[:, 1, 3]
ax.plot(traj_x, traj_y, 'b-', linewidth=1, alpha=0.3, label='Full Trajectory')
colors = ['red', 'green', 'purple', 'orange', 'cyan']
for i, idx in enumerate(idx_list):
scan = load_pointcloud(dataset, idx)
pose = poses[idx]
pose_np = pose.cpu().numpy() if isinstance(pose, torch.Tensor) else pose
scan_world = transform_points(scan, pose_np)
if len(scan_world) > 3000:
sample_idx = np.random.choice(len(scan_world), 3000, replace=False)
scan_sample = scan_world[sample_idx]
else:
scan_sample = scan_world
ax.scatter(scan_sample[:, 0], scan_sample[:, 1],
c=colors[i % len(colors)], s=0.5, alpha=0.5,
label=f'Frame {idx}')
ax.scatter(pose_np[0, 3], pose_np[1, 3],
c=colors[i % len(colors)], s=100, marker='^', edgecolors='black')
ax.set_xlabel('X (m)')
ax.set_ylabel('Y (m)')
ax.set_title('Bird Eye View - Point Clouds and Trajectory')
ax.legend()
ax.axis('equal')
ax.grid(True, linestyle='--', alpha=0.4)
plt.tight_layout()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
print(f"BEV visualization saved to {save_path}")
return fig
def main(args):
# 加载配置
try:
with open(os.path.join(os.getcwd(), "config.yaml"), "r") as ymlfile:
cfg = yaml.load(ymlfile, Loader=yaml.SafeLoader)
except:
with open(os.path.join(os.getcwd(), "project/FUSIONLCD/config.yaml"), "r") as ymlfile:
cfg = yaml.load(ymlfile, Loader=yaml.SafeLoader)
cfg = cfg['experiment']
device = torch.device("cuda" if torch.cuda.is_available() and cfg['cuda'] else "cpu")
print(f"Using device: {device}")
sequence_test = [int(x) for x in tools.read_cfg(cfg['test'])]
if args.sequence is not None:
sequence_to_vis = [args.sequence]
else:
sequence_to_vis = sequence_test
output_dir = os.path.join(cfg['path_result'], args.result_name, 'visualization')
os.makedirs(output_dir, exist_ok=True)
# 加载模型
model_path = os.path.join(cfg['path_result'], args.result_name, 'models', args.model_name)
if not os.path.exists(model_path):
model_path = args.model_name
print(f"Loading model from: {model_path}")
model = load_model(cfg, model_path, device)
# 加载数据集
_, _, loader_test = KittiTotalLoader(cfg)
for seq in sequence_to_vis:
print(f"\n{'='*60}")
print(f"Processing Sequence {seq:02d}")
print('='*60)
seq_datasets = [d for d in loader_test.dataset.datasets if int(d.sequence) == seq]
if not seq_datasets:
print(f"Sequence {seq:02d} not found in test set, skipping...")
continue
dataset = seq_datasets[0]
# 提取特征
print("Extracting features...")
seq_vlads = []
seq_feas = []
seq_kpts = []
seq_poses = []
with torch.no_grad():
for i, data in enumerate(loader_test):
batch_seqs = data['sequence']
if int(batch_seqs[0]) != seq:
continue
bev_query = data['bev_query'].to(device)
pose_query = data['pose_query'].to(device)
img_query = data['img_query'].to(device)
try:
bev = bev_query.permute(0, 3, 1, 2)
except:
bev = 0
try:
img = img_query.permute(0, 3, 1, 2)
except:
img = 0
try:
relation = data['relation'].to(device)
except:
relation = 0
batch_dict = {
'bev': bev,
'img': img,
'relation': relation,
'id_query': data['id_query'],
'sequence': batch_seqs,
'pose_query': pose_query,
'batch_size': len(data['id_query'])
}
model(batch_dict)
seq_vlads.append(batch_dict['vlads'].detach().cpu())
seq_feas.append(batch_dict.get('fea_kpt_fusion', batch_dict.get('fea_kpt_original', batch_dict['vlads'])).detach().cpu())
seq_kpts.append(batch_dict['key_points'].detach().cpu())
seq_poses.append(pose_query.detach().cpu())
if not seq_vlads:
print(f"No data extracted for sequence {seq:02d}")
continue
vlads = torch.cat(seq_vlads)
feas = torch.cat(seq_feas)
kpts = torch.cat(seq_kpts)
poses = torch.cat(seq_poses)
# 检索闭环
print("Retrieving loops...")
loops = retrieve_loops_with_poses(vlads.cuda(), feas.cuda(), kpts.cuda(), poses.cuda(),
num_cand=args.num_candidates)
# 计算校正轨迹
print("Computing corrected trajectory...")
poses_corrected, pose_errors = compute_corrected_trajectory(poses, loops, feas, kpts, device)
# 生成可视化
seq_name = f"seq_{seq:02d}"
# 1. 轨迹对比(真值 vs 预测)
traj_path = os.path.join(output_dir, f'{seq_name}_trajectory_comparison.png')
visualize_trajectory_comparison(poses, poses_corrected, pose_errors, loops, seq_name, save_path=traj_path)
# 2. 误差分布
error_path = os.path.join(output_dir, f'{seq_name}_error_distribution.png')
fig_errors, errors = visualize_error_distribution(poses, poses_corrected, seq_name, save_path=error_path)
if len(errors) > 0:
print(f" Localization errors - Mean: {np.mean(errors):.3f}m, Median: {np.median(errors):.3f}m, Max: {np.max(errors):.3f}m")
# 3. BEV点云可视化
if len(dataset.poses) > args.num_frames:
keyframe_indices = np.linspace(0, len(dataset.poses) - 1, args.num_frames, dtype=int)
else:
keyframe_indices = list(range(len(dataset.poses)))
keyframe_indices = [int(i) for i in keyframe_indices]
bev_path = os.path.join(output_dir, f'{seq_name}_bev.png')
visualize_bev_with_loops(dataset, poses, keyframe_indices, save_path=bev_path)
# 4. 3D点云可视化
pc_path = os.path.join(output_dir, f'{seq_name}_pointcloud.png')
visualize_pointcloud_with_pose(dataset, poses, keyframe_indices, save_path=pc_path)
print(f"\nVisualization for sequence {seq:02d} saved to: {output_dir}")
print(f"\n{'='*60}")
print(f"All visualizations saved to: {output_dir}")
print('='*60)
if not args.no_show:
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Visualize localization results')
parser.add_argument('--result_name', type=str, default='log',
help='Result directory name')
parser.add_argument('--model_name', type=str, default='checkpoint_fusion.pth.tar',
help='Model checkpoint filename')
parser.add_argument('--sequence', type=int, default=None,
help='Specific sequence to visualize')
parser.add_argument('--num_candidates', type=int, default=5,
help='Number of loop candidates')
parser.add_argument('--num_frames', type=int, default=6,
help='Number of key frames for point cloud')
parser.add_argument('--no_show', action='store_true',
help='Do not display plots')
parser.add_argument('--gpu', type=str, default=None,
help='GPU device id')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
main(args)