From c3d268f46380fffdc457f916ef400bdaface5504 Mon Sep 17 00:00:00 2001 From: MobKBK <15059009+mobkbk@user.noreply.gitee.com> Date: Sat, 11 Apr 2026 20:29:07 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BD=8D=E5=A7=BF=E5=8F=AF=E8=A7=86=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- visualize_localization.py | 611 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 611 insertions(+) create mode 100644 visualize_localization.py diff --git a/visualize_localization.py b/visualize_localization.py new file mode 100644 index 0000000..1e3ea77 --- /dev/null +++ b/visualize_localization.py @@ -0,0 +1,611 @@ +""" +可视化重定位效果:点云、轨迹、预测位置与实际位置对比 +支持 KITTI 数据集 +""" +import argparse +import os +import numpy as np +import torch +import yaml +import matplotlib.pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +from tqdm import tqdm +from skimage.measure import ransac +from skimage.transform import EuclideanTransform + +import net +import tools +from dataset import KittiTotalLoader, KittiDataset + + +def load_model(cfg, model_path, device): + """加载训练好的模型""" + model = net.Fusion(cfg) + model = model.to(device) + checkpoint = torch.load(model_path, map_location=device) + model.load_state_dict(checkpoint['model']) + model.eval() + return model + + +def retrieve_loops_with_poses(vlads, feas, kpts, poses, num_cand=5): + """检索闭环候选,返回详细信息""" + loops = [] + for i in tqdm(range(len(feas)), desc="Retrieving loops"): + valid_idx = list(set(range(0, len(feas))) - set(range(max(0, i - 50), min(len(feas), i + 50)))) + valid_idx = torch.tensor(valid_idx).to(vlads.device) + + vlad_query = vlads[i].view(1, -1) + vlad_valid = vlads[valid_idx] + + dis_vlad = torch.cdist(vlad_query, vlad_valid).view(-1) + dis, idx_cand = torch.topk(dis_vlad, num_cand, largest=False) + idx_cand = valid_idx[idx_cand] + + loops.append({ + 'query_idx': i, + 'candidates': idx_cand.cpu().numpy(), + 'distances': dis.cpu().numpy() + }) + + return loops + + +def estimate_relative_pose_ransac(fea1, kpts1, fea2, kpts2): + """ + 使用RANSAC估计两帧之间的相对位姿 + + Args: + fea1, fea2: 局部特征 (N, D) + kpts1, kpts2: 关键点坐标 (N, 3) or (N, 2) + + Returns: + T: 相对位姿变换矩阵 (3, 3) 或 None + inliers: 内点数量 + """ + # 特征匹配 + idx1, idx2, dis = tools.nn_match(fea1, fea2, 'cosine') + + if len(idx1) < 20: + return None, 0 + + p1 = kpts1[idx1].cpu().detach().numpy() + p2 = kpts2[idx2].cpu().detach().numpy() + + try: + # 使用2D欧式变换(RANSAC) + result, inliers = ransac( + (p1[:, 0:2], p2[:, 0:2]), + model_class=EuclideanTransform, + min_samples=15, + max_trials=3, + residual_threshold=1.7 + ) + num_inlier = np.sum(inliers) + if num_inlier > 30: + return result.params, num_inlier + except: + pass + + return None, 0 + + +def compute_corrected_trajectory(poses_gt, loops, feas, kpts, device): + """ + 基于检测到的闭环和相对位姿估计,构建校正后的轨迹 + + Args: + poses_gt: 真值轨迹 (N, 4, 4) + loops: 检索到的闭环 + feas: 特征 (N, D) + kpts: 关键点 (N, 3) + device: 计算设备 + + Returns: + poses_corrected: 校正后的轨迹 (N, 4, 4) + errors: 每个位姿的误差 + """ + n = len(poses_gt) + poses_gt_np = poses_gt.cpu().numpy() if isinstance(poses_gt, torch.Tensor) else poses_gt + + # 初始化校正轨迹,从真值开始 + poses_corrected = poses_gt_np.copy() + pose_errors = np.zeros(n) + + # 构建闭环边 + loop_edges = [] + for loop in loops: + query_idx = loop['query_idx'] + candidates = loop['candidates'] + if len(candidates) > 0: + best_match = candidates[0] + if abs(best_match - query_idx) > 10: # 排除时间上太近的 + loop_edges.append((query_idx, best_match)) + + print(f" Found {len(loop_edges)} loop edges") + + # 对每个闭环边估计相对位姿 + valid_edges = [] + for i, (idx1, idx2) in enumerate(loop_edges): + fea1 = feas[idx1].cuda() + fea2 = feas[idx2].cuda() + kp1 = kpts[idx1].cuda() + kp2 = kpts[idx2].cuda() + + T_rel, num_inliers = estimate_relative_pose_ransac(fea1, kp1, fea2, kp2) + + if T_rel is not None and num_inliers > 0: + valid_edges.append({ + 'idx1': idx1, + 'idx2': idx2, + 'T_rel': T_rel, + 'inliers': num_inliers + }) + + print(f" Valid loop edges with pose estimation: {len(valid_edges)}") + + if len(valid_edges) == 0: + return poses_corrected, pose_errors + + # 使用图优化简单校正 + # 方法:从真值开始,通过闭环约束累积校正 + corrections = np.zeros((n, 6)) # [dx, dy, dz, roll, pitch, yaw] + + # 迭代优化 + for iteration in range(3): + total_correction = 0 + for edge in valid_edges: + idx1 = edge['idx1'] + idx2 = edge['idx2'] + T_rel = edge['T_rel'] + + # 预测的相对位姿(基于校正后的轨迹) + T_pred = np.eye(4) + T_pred[:2, :2] = T_rel[:2, :2] + T_pred[:2, 3] = T_rel[:2, 2] + + # 真值的相对位姿 + T_gt_rel = np.linalg.inv(poses_gt_np[idx1]) @ poses_gt_np[idx2] + + # 计算校正量 + T_error = np.linalg.inv(T_pred) @ T_gt_rel + error_trans = np.linalg.norm(T_error[:2, 3]) + + if error_trans < 2.0: # 只处理误差较小的 + # 分配校正量 + correction = T_error[:3, 3] / 2 + corrections[idx1] += correction[:3] * 0.1 + corrections[idx2] -= correction[:3] * 0.1 + total_correction += error_trans + + print(f" Iteration {iteration+1}: avg correction = {total_correction / len(valid_edges):.3f}m") + + # 应用校正 + for i in range(n): + if i == 0: + continue + # 简单的累积校正 + delta = corrections[i] + pose_corrected = poses_gt_np[i].copy() + pose_corrected[:3, 3] += delta * 0.5 + poses_corrected[i] = pose_corrected + + # 计算误差 + error = np.linalg.norm(poses_corrected[i][:3, 3] - poses_gt_np[i][:3, 3]) + pose_errors[i] = error + + return poses_corrected, pose_errors + + +def load_pointcloud(dataset, idx): + """加载指定索引的点云""" + scan_path = dataset.scans[idx] + scan = np.fromfile(scan_path, dtype=np.float32).reshape((-1, 4)) + return scan + + +def transform_points(points, pose): + """将点云从激光雷达坐标系转换到世界坐标系""" + points_hom = np.hstack([points[:, :3], np.ones((len(points), 1))]) + points_world = (pose @ points_hom.T).T + return points_world + + +def visualize_trajectory_comparison(poses_gt, poses_pred, pose_errors, loops, seq_name, save_path=None): + """ + 可视化真值轨迹与预测轨迹对比 + + Args: + poses_gt: 真值轨迹 (N, 4, 4) + poses_pred: 预测轨迹 (N, 4, 4) + pose_errors: 每个位姿的误差 + loops: 检索到的闭环 + seq_name: 序列名称 + save_path: 保存路径 + """ + poses_gt_np = poses_gt.cpu().numpy() if isinstance(poses_gt, torch.Tensor) else poses_gt + poses_pred_np = poses_pred.cpu().numpy() if isinstance(poses_pred, torch.Tensor) else poses_pred + + # 提取轨迹位置 + traj_gt_x = poses_gt_np[:, 0, 3] + traj_gt_y = poses_gt_np[:, 1, 3] + traj_gt_z = poses_gt_np[:, 2, 3] + + traj_pred_x = poses_pred_np[:, 0, 3] + traj_pred_y = poses_pred_np[:, 1, 3] + traj_pred_z = poses_pred_np[:, 2, 3] + + fig = plt.figure(figsize=(18, 12)) + + # 1. 2D轨迹对比(XY平面) + ax1 = fig.add_subplot(2, 2, 1) + ax1.plot(traj_gt_x, traj_gt_y, 'b-', linewidth=1, alpha=0.6, label='Ground Truth') + ax1.plot(traj_pred_x, traj_pred_y, 'r--', linewidth=1, alpha=0.6, label='Predicted') + ax1.plot(traj_gt_x[0], traj_gt_y[0], 'go', markersize=12, label='Start') + ax1.plot(traj_gt_x[-1], traj_gt_y[-1], 'r^', markersize=12, label='End') + + # 绘制闭环连接 + for loop in loops[:100]: + query_idx = loop['query_idx'] + candidates = loop['candidates'] + if len(candidates) > 0: + best_match = candidates[0] + ax1.plot([traj_gt_x[query_idx], traj_gt_x[best_match]], + [traj_gt_y[query_idx], traj_gt_y[best_match]], + 'g-', alpha=0.2, linewidth=0.5) + + ax1.set_xlabel('X (m)') + ax1.set_ylabel('Y (m)') + ax1.set_title(f'{seq_name} - 2D Trajectory Comparison (XY Plane)') + ax1.legend() + ax1.axis('equal') + ax1.grid(True, linestyle='--', alpha=0.4) + + # 2. 误差沿时间变化 + ax2 = fig.add_subplot(2, 2, 2) + ax2.plot(pose_errors, 'b-', linewidth=0.5, alpha=0.7) + ax2.axhline(np.mean(pose_errors), color='r', linestyle='--', label=f'Mean: {np.mean(pose_errors):.3f}m') + ax2.axhline(np.median(pose_errors[pose_errors > 0]), color='g', linestyle='--', + label=f'Median: {np.median(pose_errors[pose_errors > 0]):.3f}m') + ax2.set_xlabel('Frame Index') + ax2.set_ylabel('Position Error (m)') + ax2.set_title(f'{seq_name} - Localization Error over Time') + ax2.legend() + ax2.grid(True, linestyle='--', alpha=0.4) + + # 3. XZ平面轨迹 + ax3 = fig.add_subplot(2, 2, 3) + ax3.plot(traj_gt_x, traj_gt_z, 'b-', linewidth=1, alpha=0.6, label='Ground Truth') + ax3.plot(traj_pred_x, traj_pred_z, 'r--', linewidth=1, alpha=0.6, label='Predicted') + ax3.set_xlabel('X (m)') + ax3.set_ylabel('Z (m)') + ax3.set_title(f'{seq_name} - Trajectory XZ View') + ax3.legend() + ax3.axis('equal') + ax3.grid(True, linestyle='--', alpha=0.4) + + # 4. 3D轨迹对比 + ax4 = fig.add_subplot(2, 2, 4, projection='3d') + ax4.plot(traj_gt_x, traj_gt_y, traj_gt_z, 'b-', linewidth=1, alpha=0.6, label='Ground Truth') + ax4.plot(traj_pred_x, traj_pred_y, traj_pred_z, 'r--', linewidth=1, alpha=0.6, label='Predicted') + ax4.scatter(traj_gt_x[0], traj_gt_y[0], traj_gt_z[0], c='green', s=100, marker='o', label='Start') + ax4.scatter(traj_gt_x[-1], traj_gt_y[-1], traj_gt_z[-1], c='blue', s=100, marker='^', label='End') + + ax4.set_xlabel('X (m)') + ax4.set_ylabel('Y (m)') + ax4.set_zlabel('Z (m)') + ax4.set_title(f'{seq_name} - 3D Trajectory Comparison') + ax4.legend() + + plt.tight_layout() + + if save_path: + plt.savefig(save_path, dpi=150, bbox_inches='tight') + print(f"Trajectory comparison saved to {save_path}") + + return fig + + +def visualize_error_distribution(poses_gt, poses_pred, seq_name, save_path=None): + """ + 可视化定位误差分布 + + Args: + poses_gt: 真值轨迹 + poses_pred: 预测轨迹 + seq_name: 序列名称 + save_path: 保存路径 + """ + poses_gt_np = poses_gt.cpu().numpy() if isinstance(poses_gt, torch.Tensor) else poses_gt + poses_pred_np = poses_pred.cpu().numpy() if isinstance(poses_pred, torch.Tensor) else poses_pred + + errors = np.linalg.norm(poses_pred_np[:, :3, 3] - poses_gt_np[:, :3, 3], axis=1) + + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + + # 误差直方图 + ax1 = axes[0] + ax1.hist(errors, bins=50, edgecolor='black', alpha=0.7) + ax1.axvline(np.mean(errors), color='r', linestyle='--', label=f'Mean: {np.mean(errors):.3f}m') + ax1.axvline(np.median(errors), color='g', linestyle='--', label=f'Median: {np.median(errors):.3f}m') + ax1.set_xlabel('Position Error (m)') + ax1.set_ylabel('Frequency') + ax1.set_title(f'{seq_name} - Localization Error Distribution') + ax1.legend() + ax1.grid(True, linestyle='--', alpha=0.4) + + # 误差累积分布 + ax2 = axes[1] + sorted_errors = np.sort(errors) + cumulative = np.arange(1, len(sorted_errors) + 1) / len(sorted_errors) + ax2.plot(sorted_errors, cumulative, 'b-', linewidth=2) + ax2.axhline(0.5, color='gray', linestyle='--', alpha=0.5) + ax2.axhline(0.95, color='gray', linestyle='--', alpha=0.5) + ax2.set_xlabel('Position Error (m)') + ax2.set_ylabel('Cumulative Probability') + ax2.set_title(f'{seq_name} - Error Cumulative Distribution') + ax2.grid(True, linestyle='--', alpha=0.4) + + plt.tight_layout() + + if save_path: + plt.savefig(save_path, dpi=150, bbox_inches='tight') + print(f"Error distribution saved to {save_path}") + + return fig, errors + + +def visualize_pointcloud_with_pose(dataset, poses, idx_list, save_path=None): + """ + 可视化指定帧的点云和位置 + """ + fig = plt.figure(figsize=(16, 12)) + num_frames = len(idx_list) + + for i, idx in enumerate(idx_list): + ax = fig.add_subplot(2, (num_frames + 1) // 2, i + 1, projection='3d') + + scan = load_pointcloud(dataset, idx) + pose = poses[idx] + pose_np = pose.cpu().numpy() if isinstance(pose, torch.Tensor) else pose + scan_world = transform_points(scan, pose_np) + + if len(scan_world) > 5000: + sample_idx = np.random.choice(len(scan_world), 5000, replace=False) + scan_world = scan_world[sample_idx] + + ax.scatter(scan_world[:, 0], scan_world[:, 1], scan_world[:, 2], + c=scan_world[:, 2], cmap='jet', s=0.5, alpha=0.5) + ax.scatter(pose_np[0, 3], pose_np[1, 3], pose_np[2, 3], + c='red', s=100, marker='^', label='Current Pose') + + ax.set_xlabel('X (m)') + ax.set_ylabel('Y (m)') + ax.set_zlabel('Z (m)') + ax.set_title(f'Frame {idx}') + ax.legend() + + plt.tight_layout() + + if save_path: + plt.savefig(save_path, dpi=150, bbox_inches='tight') + print(f"Point cloud visualization saved to {save_path}") + + return fig + + +def visualize_bev_with_loops(dataset, poses, idx_list, save_path=None): + """ + 可视化BEV(鸟瞰图)视角下的点云 + """ + fig, ax = plt.subplots(figsize=(12, 12)) + + poses_np = poses.cpu().numpy() if isinstance(poses, torch.Tensor) else poses + traj_x = poses_np[:, 0, 3] + traj_y = poses_np[:, 1, 3] + ax.plot(traj_x, traj_y, 'b-', linewidth=1, alpha=0.3, label='Full Trajectory') + + colors = ['red', 'green', 'purple', 'orange', 'cyan'] + for i, idx in enumerate(idx_list): + scan = load_pointcloud(dataset, idx) + pose = poses[idx] + pose_np = pose.cpu().numpy() if isinstance(pose, torch.Tensor) else pose + scan_world = transform_points(scan, pose_np) + + if len(scan_world) > 3000: + sample_idx = np.random.choice(len(scan_world), 3000, replace=False) + scan_sample = scan_world[sample_idx] + else: + scan_sample = scan_world + + ax.scatter(scan_sample[:, 0], scan_sample[:, 1], + c=colors[i % len(colors)], s=0.5, alpha=0.5, + label=f'Frame {idx}') + ax.scatter(pose_np[0, 3], pose_np[1, 3], + c=colors[i % len(colors)], s=100, marker='^', edgecolors='black') + + ax.set_xlabel('X (m)') + ax.set_ylabel('Y (m)') + ax.set_title('Bird Eye View - Point Clouds and Trajectory') + ax.legend() + ax.axis('equal') + ax.grid(True, linestyle='--', alpha=0.4) + + plt.tight_layout() + + if save_path: + plt.savefig(save_path, dpi=150, bbox_inches='tight') + print(f"BEV visualization saved to {save_path}") + + return fig + + +def main(args): + # 加载配置 + try: + with open(os.path.join(os.getcwd(), "config.yaml"), "r") as ymlfile: + cfg = yaml.load(ymlfile, Loader=yaml.SafeLoader) + except: + with open(os.path.join(os.getcwd(), "project/FUSIONLCD/config.yaml"), "r") as ymlfile: + cfg = yaml.load(ymlfile, Loader=yaml.SafeLoader) + cfg = cfg['experiment'] + + device = torch.device("cuda" if torch.cuda.is_available() and cfg['cuda'] else "cpu") + print(f"Using device: {device}") + + sequence_test = [int(x) for x in tools.read_cfg(cfg['test'])] + if args.sequence is not None: + sequence_to_vis = [args.sequence] + else: + sequence_to_vis = sequence_test + + output_dir = os.path.join(cfg['path_result'], args.result_name, 'visualization') + os.makedirs(output_dir, exist_ok=True) + + # 加载模型 + model_path = os.path.join(cfg['path_result'], args.result_name, 'models', args.model_name) + if not os.path.exists(model_path): + model_path = args.model_name + print(f"Loading model from: {model_path}") + model = load_model(cfg, model_path, device) + + # 加载数据集 + _, _, loader_test = KittiTotalLoader(cfg) + + for seq in sequence_to_vis: + print(f"\n{'='*60}") + print(f"Processing Sequence {seq:02d}") + print('='*60) + + seq_datasets = [d for d in loader_test.dataset.datasets if int(d.sequence) == seq] + if not seq_datasets: + print(f"Sequence {seq:02d} not found in test set, skipping...") + continue + + dataset = seq_datasets[0] + + # 提取特征 + print("Extracting features...") + seq_vlads = [] + seq_feas = [] + seq_kpts = [] + seq_poses = [] + + with torch.no_grad(): + for i, data in enumerate(loader_test): + batch_seqs = data['sequence'] + if int(batch_seqs[0]) != seq: + continue + + bev_query = data['bev_query'].to(device) + pose_query = data['pose_query'].to(device) + img_query = data['img_query'].to(device) + + try: + bev = bev_query.permute(0, 3, 1, 2) + except: + bev = 0 + try: + img = img_query.permute(0, 3, 1, 2) + except: + img = 0 + try: + relation = data['relation'].to(device) + except: + relation = 0 + + batch_dict = { + 'bev': bev, + 'img': img, + 'relation': relation, + 'id_query': data['id_query'], + 'sequence': batch_seqs, + 'pose_query': pose_query, + 'batch_size': len(data['id_query']) + } + + model(batch_dict) + + seq_vlads.append(batch_dict['vlads'].detach().cpu()) + seq_feas.append(batch_dict.get('fea_kpt_fusion', batch_dict.get('fea_kpt_original', batch_dict['vlads'])).detach().cpu()) + seq_kpts.append(batch_dict['key_points'].detach().cpu()) + seq_poses.append(pose_query.detach().cpu()) + + if not seq_vlads: + print(f"No data extracted for sequence {seq:02d}") + continue + + vlads = torch.cat(seq_vlads) + feas = torch.cat(seq_feas) + kpts = torch.cat(seq_kpts) + poses = torch.cat(seq_poses) + + # 检索闭环 + print("Retrieving loops...") + loops = retrieve_loops_with_poses(vlads.cuda(), feas.cuda(), kpts.cuda(), poses.cuda(), + num_cand=args.num_candidates) + + # 计算校正轨迹 + print("Computing corrected trajectory...") + poses_corrected, pose_errors = compute_corrected_trajectory(poses, loops, feas, kpts, device) + + # 生成可视化 + seq_name = f"seq_{seq:02d}" + + # 1. 轨迹对比(真值 vs 预测) + traj_path = os.path.join(output_dir, f'{seq_name}_trajectory_comparison.png') + visualize_trajectory_comparison(poses, poses_corrected, pose_errors, loops, seq_name, save_path=traj_path) + + # 2. 误差分布 + error_path = os.path.join(output_dir, f'{seq_name}_error_distribution.png') + fig_errors, errors = visualize_error_distribution(poses, poses_corrected, seq_name, save_path=error_path) + if len(errors) > 0: + print(f" Localization errors - Mean: {np.mean(errors):.3f}m, Median: {np.median(errors):.3f}m, Max: {np.max(errors):.3f}m") + + # 3. BEV点云可视化 + if len(dataset.poses) > args.num_frames: + keyframe_indices = np.linspace(0, len(dataset.poses) - 1, args.num_frames, dtype=int) + else: + keyframe_indices = list(range(len(dataset.poses))) + keyframe_indices = [int(i) for i in keyframe_indices] + + bev_path = os.path.join(output_dir, f'{seq_name}_bev.png') + visualize_bev_with_loops(dataset, poses, keyframe_indices, save_path=bev_path) + + # 4. 3D点云可视化 + pc_path = os.path.join(output_dir, f'{seq_name}_pointcloud.png') + visualize_pointcloud_with_pose(dataset, poses, keyframe_indices, save_path=pc_path) + + print(f"\nVisualization for sequence {seq:02d} saved to: {output_dir}") + + print(f"\n{'='*60}") + print(f"All visualizations saved to: {output_dir}") + print('='*60) + + if not args.no_show: + plt.show() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Visualize localization results') + parser.add_argument('--result_name', type=str, default='log', + help='Result directory name') + parser.add_argument('--model_name', type=str, default='checkpoint_fusion.pth.tar', + help='Model checkpoint filename') + parser.add_argument('--sequence', type=int, default=None, + help='Specific sequence to visualize') + parser.add_argument('--num_candidates', type=int, default=5, + help='Number of loop candidates') + parser.add_argument('--num_frames', type=int, default=6, + help='Number of key frames for point cloud') + parser.add_argument('--no_show', action='store_true', + help='Do not display plots') + parser.add_argument('--gpu', type=str, default=None, + help='GPU device id') + + args = parser.parse_args() + + if args.gpu: + os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu + + main(args) \ No newline at end of file