""" 可视化重定位效果:点云、轨迹、预测位置与实际位置对比 支持 KITTI 数据集 """ import argparse import os import numpy as np import torch import yaml import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from tqdm import tqdm from skimage.measure import ransac from skimage.transform import EuclideanTransform import net import tools from dataset import KittiTotalLoader, KittiDataset def load_model(cfg, model_path, device): """加载训练好的模型""" model = net.Fusion(cfg) model = model.to(device) checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint['model']) model.eval() return model def retrieve_loops_with_poses(vlads, feas, kpts, poses, num_cand=5): """检索闭环候选,返回详细信息""" loops = [] for i in tqdm(range(len(feas)), desc="Retrieving loops"): valid_idx = list(set(range(0, len(feas))) - set(range(max(0, i - 50), min(len(feas), i + 50)))) valid_idx = torch.tensor(valid_idx).to(vlads.device) vlad_query = vlads[i].view(1, -1) vlad_valid = vlads[valid_idx] dis_vlad = torch.cdist(vlad_query, vlad_valid).view(-1) dis, idx_cand = torch.topk(dis_vlad, num_cand, largest=False) idx_cand = valid_idx[idx_cand] loops.append({ 'query_idx': i, 'candidates': idx_cand.cpu().numpy(), 'distances': dis.cpu().numpy() }) return loops def estimate_relative_pose_ransac(fea1, kpts1, fea2, kpts2): """ 使用RANSAC估计两帧之间的相对位姿 Args: fea1, fea2: 局部特征 (N, D) kpts1, kpts2: 关键点坐标 (N, 3) or (N, 2) Returns: T: 相对位姿变换矩阵 (3, 3) 或 None inliers: 内点数量 """ # 特征匹配 idx1, idx2, dis = tools.nn_match(fea1, fea2, 'cosine') if len(idx1) < 20: return None, 0 p1 = kpts1[idx1].cpu().detach().numpy() p2 = kpts2[idx2].cpu().detach().numpy() try: # 使用2D欧式变换(RANSAC) result, inliers = ransac( (p1[:, 0:2], p2[:, 0:2]), model_class=EuclideanTransform, min_samples=15, max_trials=3, residual_threshold=1.7 ) num_inlier = np.sum(inliers) if num_inlier > 30: return result.params, num_inlier except: pass return None, 0 def compute_corrected_trajectory(poses_gt, loops, feas, kpts, device): """ 基于检测到的闭环和相对位姿估计,构建校正后的轨迹 Args: poses_gt: 真值轨迹 (N, 4, 4) loops: 检索到的闭环 feas: 特征 (N, D) kpts: 关键点 (N, 3) device: 计算设备 Returns: poses_corrected: 校正后的轨迹 (N, 4, 4) errors: 每个位姿的误差 """ n = len(poses_gt) poses_gt_np = poses_gt.cpu().numpy() if isinstance(poses_gt, torch.Tensor) else poses_gt # 初始化校正轨迹,从真值开始 poses_corrected = poses_gt_np.copy() pose_errors = np.zeros(n) # 构建闭环边 loop_edges = [] for loop in loops: query_idx = loop['query_idx'] candidates = loop['candidates'] if len(candidates) > 0: best_match = candidates[0] if abs(best_match - query_idx) > 10: # 排除时间上太近的 loop_edges.append((query_idx, best_match)) print(f" Found {len(loop_edges)} loop edges") # 对每个闭环边估计相对位姿 valid_edges = [] for i, (idx1, idx2) in enumerate(loop_edges): fea1 = feas[idx1].cuda() fea2 = feas[idx2].cuda() kp1 = kpts[idx1].cuda() kp2 = kpts[idx2].cuda() T_rel, num_inliers = estimate_relative_pose_ransac(fea1, kp1, fea2, kp2) if T_rel is not None and num_inliers > 0: valid_edges.append({ 'idx1': idx1, 'idx2': idx2, 'T_rel': T_rel, 'inliers': num_inliers }) print(f" Valid loop edges with pose estimation: {len(valid_edges)}") if len(valid_edges) == 0: return poses_corrected, pose_errors # 使用图优化简单校正 # 方法:从真值开始,通过闭环约束累积校正 corrections = np.zeros((n, 6)) # [dx, dy, dz, roll, pitch, yaw] # 迭代优化 for iteration in range(3): total_correction = 0 for edge in valid_edges: idx1 = edge['idx1'] idx2 = edge['idx2'] T_rel = edge['T_rel'] # 预测的相对位姿(基于校正后的轨迹) T_pred = np.eye(4) T_pred[:2, :2] = T_rel[:2, :2] T_pred[:2, 3] = T_rel[:2, 2] # 真值的相对位姿 T_gt_rel = np.linalg.inv(poses_gt_np[idx1]) @ poses_gt_np[idx2] # 计算校正量 T_error = np.linalg.inv(T_pred) @ T_gt_rel error_trans = np.linalg.norm(T_error[:2, 3]) if error_trans < 2.0: # 只处理误差较小的 # 分配校正量 correction = T_error[:3, 3] / 2 corrections[idx1] += correction[:3] * 0.1 corrections[idx2] -= correction[:3] * 0.1 total_correction += error_trans print(f" Iteration {iteration+1}: avg correction = {total_correction / len(valid_edges):.3f}m") # 应用校正 for i in range(n): if i == 0: continue # 简单的累积校正 delta = corrections[i] pose_corrected = poses_gt_np[i].copy() pose_corrected[:3, 3] += delta * 0.5 poses_corrected[i] = pose_corrected # 计算误差 error = np.linalg.norm(poses_corrected[i][:3, 3] - poses_gt_np[i][:3, 3]) pose_errors[i] = error return poses_corrected, pose_errors def load_pointcloud(dataset, idx): """加载指定索引的点云""" scan_path = dataset.scans[idx] scan = np.fromfile(scan_path, dtype=np.float32).reshape((-1, 4)) return scan def transform_points(points, pose): """将点云从激光雷达坐标系转换到世界坐标系""" points_hom = np.hstack([points[:, :3], np.ones((len(points), 1))]) points_world = (pose @ points_hom.T).T return points_world def visualize_trajectory_comparison(poses_gt, poses_pred, pose_errors, loops, seq_name, save_path=None): """ 可视化真值轨迹与预测轨迹对比 Args: poses_gt: 真值轨迹 (N, 4, 4) poses_pred: 预测轨迹 (N, 4, 4) pose_errors: 每个位姿的误差 loops: 检索到的闭环 seq_name: 序列名称 save_path: 保存路径 """ poses_gt_np = poses_gt.cpu().numpy() if isinstance(poses_gt, torch.Tensor) else poses_gt poses_pred_np = poses_pred.cpu().numpy() if isinstance(poses_pred, torch.Tensor) else poses_pred # 提取轨迹位置 traj_gt_x = poses_gt_np[:, 0, 3] traj_gt_y = poses_gt_np[:, 1, 3] traj_gt_z = poses_gt_np[:, 2, 3] traj_pred_x = poses_pred_np[:, 0, 3] traj_pred_y = poses_pred_np[:, 1, 3] traj_pred_z = poses_pred_np[:, 2, 3] fig = plt.figure(figsize=(18, 12)) # 1. 2D轨迹对比(XY平面) ax1 = fig.add_subplot(2, 2, 1) ax1.plot(traj_gt_x, traj_gt_y, 'b-', linewidth=1, alpha=0.6, label='Ground Truth') ax1.plot(traj_pred_x, traj_pred_y, 'r--', linewidth=1, alpha=0.6, label='Predicted') ax1.plot(traj_gt_x[0], traj_gt_y[0], 'go', markersize=12, label='Start') ax1.plot(traj_gt_x[-1], traj_gt_y[-1], 'r^', markersize=12, label='End') # 绘制闭环连接 for loop in loops[:100]: query_idx = loop['query_idx'] candidates = loop['candidates'] if len(candidates) > 0: best_match = candidates[0] ax1.plot([traj_gt_x[query_idx], traj_gt_x[best_match]], [traj_gt_y[query_idx], traj_gt_y[best_match]], 'g-', alpha=0.2, linewidth=0.5) ax1.set_xlabel('X (m)') ax1.set_ylabel('Y (m)') ax1.set_title(f'{seq_name} - 2D Trajectory Comparison (XY Plane)') ax1.legend() ax1.axis('equal') ax1.grid(True, linestyle='--', alpha=0.4) # 2. 误差沿时间变化 ax2 = fig.add_subplot(2, 2, 2) ax2.plot(pose_errors, 'b-', linewidth=0.5, alpha=0.7) ax2.axhline(np.mean(pose_errors), color='r', linestyle='--', label=f'Mean: {np.mean(pose_errors):.3f}m') ax2.axhline(np.median(pose_errors[pose_errors > 0]), color='g', linestyle='--', label=f'Median: {np.median(pose_errors[pose_errors > 0]):.3f}m') ax2.set_xlabel('Frame Index') ax2.set_ylabel('Position Error (m)') ax2.set_title(f'{seq_name} - Localization Error over Time') ax2.legend() ax2.grid(True, linestyle='--', alpha=0.4) # 3. XZ平面轨迹 ax3 = fig.add_subplot(2, 2, 3) ax3.plot(traj_gt_x, traj_gt_z, 'b-', linewidth=1, alpha=0.6, label='Ground Truth') ax3.plot(traj_pred_x, traj_pred_z, 'r--', linewidth=1, alpha=0.6, label='Predicted') ax3.set_xlabel('X (m)') ax3.set_ylabel('Z (m)') ax3.set_title(f'{seq_name} - Trajectory XZ View') ax3.legend() ax3.axis('equal') ax3.grid(True, linestyle='--', alpha=0.4) # 4. 3D轨迹对比 ax4 = fig.add_subplot(2, 2, 4, projection='3d') ax4.plot(traj_gt_x, traj_gt_y, traj_gt_z, 'b-', linewidth=1, alpha=0.6, label='Ground Truth') ax4.plot(traj_pred_x, traj_pred_y, traj_pred_z, 'r--', linewidth=1, alpha=0.6, label='Predicted') ax4.scatter(traj_gt_x[0], traj_gt_y[0], traj_gt_z[0], c='green', s=100, marker='o', label='Start') ax4.scatter(traj_gt_x[-1], traj_gt_y[-1], traj_gt_z[-1], c='blue', s=100, marker='^', label='End') ax4.set_xlabel('X (m)') ax4.set_ylabel('Y (m)') ax4.set_zlabel('Z (m)') ax4.set_title(f'{seq_name} - 3D Trajectory Comparison') ax4.legend() plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"Trajectory comparison saved to {save_path}") return fig def visualize_error_distribution(poses_gt, poses_pred, seq_name, save_path=None): """ 可视化定位误差分布 Args: poses_gt: 真值轨迹 poses_pred: 预测轨迹 seq_name: 序列名称 save_path: 保存路径 """ poses_gt_np = poses_gt.cpu().numpy() if isinstance(poses_gt, torch.Tensor) else poses_gt poses_pred_np = poses_pred.cpu().numpy() if isinstance(poses_pred, torch.Tensor) else poses_pred errors = np.linalg.norm(poses_pred_np[:, :3, 3] - poses_gt_np[:, :3, 3], axis=1) fig, axes = plt.subplots(1, 2, figsize=(14, 5)) # 误差直方图 ax1 = axes[0] ax1.hist(errors, bins=50, edgecolor='black', alpha=0.7) ax1.axvline(np.mean(errors), color='r', linestyle='--', label=f'Mean: {np.mean(errors):.3f}m') ax1.axvline(np.median(errors), color='g', linestyle='--', label=f'Median: {np.median(errors):.3f}m') ax1.set_xlabel('Position Error (m)') ax1.set_ylabel('Frequency') ax1.set_title(f'{seq_name} - Localization Error Distribution') ax1.legend() ax1.grid(True, linestyle='--', alpha=0.4) # 误差累积分布 ax2 = axes[1] sorted_errors = np.sort(errors) cumulative = np.arange(1, len(sorted_errors) + 1) / len(sorted_errors) ax2.plot(sorted_errors, cumulative, 'b-', linewidth=2) ax2.axhline(0.5, color='gray', linestyle='--', alpha=0.5) ax2.axhline(0.95, color='gray', linestyle='--', alpha=0.5) ax2.set_xlabel('Position Error (m)') ax2.set_ylabel('Cumulative Probability') ax2.set_title(f'{seq_name} - Error Cumulative Distribution') ax2.grid(True, linestyle='--', alpha=0.4) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"Error distribution saved to {save_path}") return fig, errors def visualize_pointcloud_with_pose(dataset, poses, idx_list, save_path=None): """ 可视化指定帧的点云和位置 """ fig = plt.figure(figsize=(16, 12)) num_frames = len(idx_list) for i, idx in enumerate(idx_list): ax = fig.add_subplot(2, (num_frames + 1) // 2, i + 1, projection='3d') scan = load_pointcloud(dataset, idx) pose = poses[idx] pose_np = pose.cpu().numpy() if isinstance(pose, torch.Tensor) else pose scan_world = transform_points(scan, pose_np) if len(scan_world) > 5000: sample_idx = np.random.choice(len(scan_world), 5000, replace=False) scan_world = scan_world[sample_idx] ax.scatter(scan_world[:, 0], scan_world[:, 1], scan_world[:, 2], c=scan_world[:, 2], cmap='jet', s=0.5, alpha=0.5) ax.scatter(pose_np[0, 3], pose_np[1, 3], pose_np[2, 3], c='red', s=100, marker='^', label='Current Pose') ax.set_xlabel('X (m)') ax.set_ylabel('Y (m)') ax.set_zlabel('Z (m)') ax.set_title(f'Frame {idx}') ax.legend() plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"Point cloud visualization saved to {save_path}") return fig def visualize_bev_with_loops(dataset, poses, idx_list, save_path=None): """ 可视化BEV(鸟瞰图)视角下的点云 """ fig, ax = plt.subplots(figsize=(12, 12)) poses_np = poses.cpu().numpy() if isinstance(poses, torch.Tensor) else poses traj_x = poses_np[:, 0, 3] traj_y = poses_np[:, 1, 3] ax.plot(traj_x, traj_y, 'b-', linewidth=1, alpha=0.3, label='Full Trajectory') colors = ['red', 'green', 'purple', 'orange', 'cyan'] for i, idx in enumerate(idx_list): scan = load_pointcloud(dataset, idx) pose = poses[idx] pose_np = pose.cpu().numpy() if isinstance(pose, torch.Tensor) else pose scan_world = transform_points(scan, pose_np) if len(scan_world) > 3000: sample_idx = np.random.choice(len(scan_world), 3000, replace=False) scan_sample = scan_world[sample_idx] else: scan_sample = scan_world ax.scatter(scan_sample[:, 0], scan_sample[:, 1], c=colors[i % len(colors)], s=0.5, alpha=0.5, label=f'Frame {idx}') ax.scatter(pose_np[0, 3], pose_np[1, 3], c=colors[i % len(colors)], s=100, marker='^', edgecolors='black') ax.set_xlabel('X (m)') ax.set_ylabel('Y (m)') ax.set_title('Bird Eye View - Point Clouds and Trajectory') ax.legend() ax.axis('equal') ax.grid(True, linestyle='--', alpha=0.4) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=150, bbox_inches='tight') print(f"BEV visualization saved to {save_path}") return fig def main(args): # 加载配置 try: with open(os.path.join(os.getcwd(), "config.yaml"), "r") as ymlfile: cfg = yaml.load(ymlfile, Loader=yaml.SafeLoader) except: with open(os.path.join(os.getcwd(), "project/FUSIONLCD/config.yaml"), "r") as ymlfile: cfg = yaml.load(ymlfile, Loader=yaml.SafeLoader) cfg = cfg['experiment'] device = torch.device("cuda" if torch.cuda.is_available() and cfg['cuda'] else "cpu") print(f"Using device: {device}") sequence_test = [int(x) for x in tools.read_cfg(cfg['test'])] if args.sequence is not None: sequence_to_vis = [args.sequence] else: sequence_to_vis = sequence_test output_dir = os.path.join(cfg['path_result'], args.result_name, 'visualization') os.makedirs(output_dir, exist_ok=True) # 加载模型 model_path = os.path.join(cfg['path_result'], args.result_name, 'models', args.model_name) if not os.path.exists(model_path): model_path = args.model_name print(f"Loading model from: {model_path}") model = load_model(cfg, model_path, device) # 加载数据集 _, _, loader_test = KittiTotalLoader(cfg) for seq in sequence_to_vis: print(f"\n{'='*60}") print(f"Processing Sequence {seq:02d}") print('='*60) seq_datasets = [d for d in loader_test.dataset.datasets if int(d.sequence) == seq] if not seq_datasets: print(f"Sequence {seq:02d} not found in test set, skipping...") continue dataset = seq_datasets[0] # 提取特征 print("Extracting features...") seq_vlads = [] seq_feas = [] seq_kpts = [] seq_poses = [] with torch.no_grad(): for i, data in enumerate(loader_test): batch_seqs = data['sequence'] if int(batch_seqs[0]) != seq: continue bev_query = data['bev_query'].to(device) pose_query = data['pose_query'].to(device) img_query = data['img_query'].to(device) try: bev = bev_query.permute(0, 3, 1, 2) except: bev = 0 try: img = img_query.permute(0, 3, 1, 2) except: img = 0 try: relation = data['relation'].to(device) except: relation = 0 batch_dict = { 'bev': bev, 'img': img, 'relation': relation, 'id_query': data['id_query'], 'sequence': batch_seqs, 'pose_query': pose_query, 'batch_size': len(data['id_query']) } model(batch_dict) seq_vlads.append(batch_dict['vlads'].detach().cpu()) seq_feas.append(batch_dict.get('fea_kpt_fusion', batch_dict.get('fea_kpt_original', batch_dict['vlads'])).detach().cpu()) seq_kpts.append(batch_dict['key_points'].detach().cpu()) seq_poses.append(pose_query.detach().cpu()) if not seq_vlads: print(f"No data extracted for sequence {seq:02d}") continue vlads = torch.cat(seq_vlads) feas = torch.cat(seq_feas) kpts = torch.cat(seq_kpts) poses = torch.cat(seq_poses) # 检索闭环 print("Retrieving loops...") loops = retrieve_loops_with_poses(vlads.cuda(), feas.cuda(), kpts.cuda(), poses.cuda(), num_cand=args.num_candidates) # 计算校正轨迹 print("Computing corrected trajectory...") poses_corrected, pose_errors = compute_corrected_trajectory(poses, loops, feas, kpts, device) # 生成可视化 seq_name = f"seq_{seq:02d}" # 1. 轨迹对比(真值 vs 预测) traj_path = os.path.join(output_dir, f'{seq_name}_trajectory_comparison.png') visualize_trajectory_comparison(poses, poses_corrected, pose_errors, loops, seq_name, save_path=traj_path) # 2. 误差分布 error_path = os.path.join(output_dir, f'{seq_name}_error_distribution.png') fig_errors, errors = visualize_error_distribution(poses, poses_corrected, seq_name, save_path=error_path) if len(errors) > 0: print(f" Localization errors - Mean: {np.mean(errors):.3f}m, Median: {np.median(errors):.3f}m, Max: {np.max(errors):.3f}m") # 3. BEV点云可视化 if len(dataset.poses) > args.num_frames: keyframe_indices = np.linspace(0, len(dataset.poses) - 1, args.num_frames, dtype=int) else: keyframe_indices = list(range(len(dataset.poses))) keyframe_indices = [int(i) for i in keyframe_indices] bev_path = os.path.join(output_dir, f'{seq_name}_bev.png') visualize_bev_with_loops(dataset, poses, keyframe_indices, save_path=bev_path) # 4. 3D点云可视化 pc_path = os.path.join(output_dir, f'{seq_name}_pointcloud.png') visualize_pointcloud_with_pose(dataset, poses, keyframe_indices, save_path=pc_path) print(f"\nVisualization for sequence {seq:02d} saved to: {output_dir}") print(f"\n{'='*60}") print(f"All visualizations saved to: {output_dir}") print('='*60) if not args.no_show: plt.show() if __name__ == '__main__': parser = argparse.ArgumentParser(description='Visualize localization results') parser.add_argument('--result_name', type=str, default='log', help='Result directory name') parser.add_argument('--model_name', type=str, default='checkpoint_fusion.pth.tar', help='Model checkpoint filename') parser.add_argument('--sequence', type=int, default=None, help='Specific sequence to visualize') parser.add_argument('--num_candidates', type=int, default=5, help='Number of loop candidates') parser.add_argument('--num_frames', type=int, default=6, help='Number of key frames for point cloud') parser.add_argument('--no_show', action='store_true', help='Do not display plots') parser.add_argument('--gpu', type=str, default=None, help='GPU device id') args = parser.parse_args() if args.gpu: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu main(args)