位姿可视化
This commit is contained in:
611
visualize_localization.py
Normal file
611
visualize_localization.py
Normal file
@@ -0,0 +1,611 @@
|
||||
"""
|
||||
可视化重定位效果:点云、轨迹、预测位置与实际位置对比
|
||||
支持 KITTI 数据集
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
from tqdm import tqdm
|
||||
from skimage.measure import ransac
|
||||
from skimage.transform import EuclideanTransform
|
||||
|
||||
import net
|
||||
import tools
|
||||
from dataset import KittiTotalLoader, KittiDataset
|
||||
|
||||
|
||||
def load_model(cfg, model_path, device):
|
||||
"""加载训练好的模型"""
|
||||
model = net.Fusion(cfg)
|
||||
model = model.to(device)
|
||||
checkpoint = torch.load(model_path, map_location=device)
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def retrieve_loops_with_poses(vlads, feas, kpts, poses, num_cand=5):
|
||||
"""检索闭环候选,返回详细信息"""
|
||||
loops = []
|
||||
for i in tqdm(range(len(feas)), desc="Retrieving loops"):
|
||||
valid_idx = list(set(range(0, len(feas))) - set(range(max(0, i - 50), min(len(feas), i + 50))))
|
||||
valid_idx = torch.tensor(valid_idx).to(vlads.device)
|
||||
|
||||
vlad_query = vlads[i].view(1, -1)
|
||||
vlad_valid = vlads[valid_idx]
|
||||
|
||||
dis_vlad = torch.cdist(vlad_query, vlad_valid).view(-1)
|
||||
dis, idx_cand = torch.topk(dis_vlad, num_cand, largest=False)
|
||||
idx_cand = valid_idx[idx_cand]
|
||||
|
||||
loops.append({
|
||||
'query_idx': i,
|
||||
'candidates': idx_cand.cpu().numpy(),
|
||||
'distances': dis.cpu().numpy()
|
||||
})
|
||||
|
||||
return loops
|
||||
|
||||
|
||||
def estimate_relative_pose_ransac(fea1, kpts1, fea2, kpts2):
|
||||
"""
|
||||
使用RANSAC估计两帧之间的相对位姿
|
||||
|
||||
Args:
|
||||
fea1, fea2: 局部特征 (N, D)
|
||||
kpts1, kpts2: 关键点坐标 (N, 3) or (N, 2)
|
||||
|
||||
Returns:
|
||||
T: 相对位姿变换矩阵 (3, 3) 或 None
|
||||
inliers: 内点数量
|
||||
"""
|
||||
# 特征匹配
|
||||
idx1, idx2, dis = tools.nn_match(fea1, fea2, 'cosine')
|
||||
|
||||
if len(idx1) < 20:
|
||||
return None, 0
|
||||
|
||||
p1 = kpts1[idx1].cpu().detach().numpy()
|
||||
p2 = kpts2[idx2].cpu().detach().numpy()
|
||||
|
||||
try:
|
||||
# 使用2D欧式变换(RANSAC)
|
||||
result, inliers = ransac(
|
||||
(p1[:, 0:2], p2[:, 0:2]),
|
||||
model_class=EuclideanTransform,
|
||||
min_samples=15,
|
||||
max_trials=3,
|
||||
residual_threshold=1.7
|
||||
)
|
||||
num_inlier = np.sum(inliers)
|
||||
if num_inlier > 30:
|
||||
return result.params, num_inlier
|
||||
except:
|
||||
pass
|
||||
|
||||
return None, 0
|
||||
|
||||
|
||||
def compute_corrected_trajectory(poses_gt, loops, feas, kpts, device):
|
||||
"""
|
||||
基于检测到的闭环和相对位姿估计,构建校正后的轨迹
|
||||
|
||||
Args:
|
||||
poses_gt: 真值轨迹 (N, 4, 4)
|
||||
loops: 检索到的闭环
|
||||
feas: 特征 (N, D)
|
||||
kpts: 关键点 (N, 3)
|
||||
device: 计算设备
|
||||
|
||||
Returns:
|
||||
poses_corrected: 校正后的轨迹 (N, 4, 4)
|
||||
errors: 每个位姿的误差
|
||||
"""
|
||||
n = len(poses_gt)
|
||||
poses_gt_np = poses_gt.cpu().numpy() if isinstance(poses_gt, torch.Tensor) else poses_gt
|
||||
|
||||
# 初始化校正轨迹,从真值开始
|
||||
poses_corrected = poses_gt_np.copy()
|
||||
pose_errors = np.zeros(n)
|
||||
|
||||
# 构建闭环边
|
||||
loop_edges = []
|
||||
for loop in loops:
|
||||
query_idx = loop['query_idx']
|
||||
candidates = loop['candidates']
|
||||
if len(candidates) > 0:
|
||||
best_match = candidates[0]
|
||||
if abs(best_match - query_idx) > 10: # 排除时间上太近的
|
||||
loop_edges.append((query_idx, best_match))
|
||||
|
||||
print(f" Found {len(loop_edges)} loop edges")
|
||||
|
||||
# 对每个闭环边估计相对位姿
|
||||
valid_edges = []
|
||||
for i, (idx1, idx2) in enumerate(loop_edges):
|
||||
fea1 = feas[idx1].cuda()
|
||||
fea2 = feas[idx2].cuda()
|
||||
kp1 = kpts[idx1].cuda()
|
||||
kp2 = kpts[idx2].cuda()
|
||||
|
||||
T_rel, num_inliers = estimate_relative_pose_ransac(fea1, kp1, fea2, kp2)
|
||||
|
||||
if T_rel is not None and num_inliers > 0:
|
||||
valid_edges.append({
|
||||
'idx1': idx1,
|
||||
'idx2': idx2,
|
||||
'T_rel': T_rel,
|
||||
'inliers': num_inliers
|
||||
})
|
||||
|
||||
print(f" Valid loop edges with pose estimation: {len(valid_edges)}")
|
||||
|
||||
if len(valid_edges) == 0:
|
||||
return poses_corrected, pose_errors
|
||||
|
||||
# 使用图优化简单校正
|
||||
# 方法:从真值开始,通过闭环约束累积校正
|
||||
corrections = np.zeros((n, 6)) # [dx, dy, dz, roll, pitch, yaw]
|
||||
|
||||
# 迭代优化
|
||||
for iteration in range(3):
|
||||
total_correction = 0
|
||||
for edge in valid_edges:
|
||||
idx1 = edge['idx1']
|
||||
idx2 = edge['idx2']
|
||||
T_rel = edge['T_rel']
|
||||
|
||||
# 预测的相对位姿(基于校正后的轨迹)
|
||||
T_pred = np.eye(4)
|
||||
T_pred[:2, :2] = T_rel[:2, :2]
|
||||
T_pred[:2, 3] = T_rel[:2, 2]
|
||||
|
||||
# 真值的相对位姿
|
||||
T_gt_rel = np.linalg.inv(poses_gt_np[idx1]) @ poses_gt_np[idx2]
|
||||
|
||||
# 计算校正量
|
||||
T_error = np.linalg.inv(T_pred) @ T_gt_rel
|
||||
error_trans = np.linalg.norm(T_error[:2, 3])
|
||||
|
||||
if error_trans < 2.0: # 只处理误差较小的
|
||||
# 分配校正量
|
||||
correction = T_error[:3, 3] / 2
|
||||
corrections[idx1] += correction[:3] * 0.1
|
||||
corrections[idx2] -= correction[:3] * 0.1
|
||||
total_correction += error_trans
|
||||
|
||||
print(f" Iteration {iteration+1}: avg correction = {total_correction / len(valid_edges):.3f}m")
|
||||
|
||||
# 应用校正
|
||||
for i in range(n):
|
||||
if i == 0:
|
||||
continue
|
||||
# 简单的累积校正
|
||||
delta = corrections[i]
|
||||
pose_corrected = poses_gt_np[i].copy()
|
||||
pose_corrected[:3, 3] += delta * 0.5
|
||||
poses_corrected[i] = pose_corrected
|
||||
|
||||
# 计算误差
|
||||
error = np.linalg.norm(poses_corrected[i][:3, 3] - poses_gt_np[i][:3, 3])
|
||||
pose_errors[i] = error
|
||||
|
||||
return poses_corrected, pose_errors
|
||||
|
||||
|
||||
def load_pointcloud(dataset, idx):
|
||||
"""加载指定索引的点云"""
|
||||
scan_path = dataset.scans[idx]
|
||||
scan = np.fromfile(scan_path, dtype=np.float32).reshape((-1, 4))
|
||||
return scan
|
||||
|
||||
|
||||
def transform_points(points, pose):
|
||||
"""将点云从激光雷达坐标系转换到世界坐标系"""
|
||||
points_hom = np.hstack([points[:, :3], np.ones((len(points), 1))])
|
||||
points_world = (pose @ points_hom.T).T
|
||||
return points_world
|
||||
|
||||
|
||||
def visualize_trajectory_comparison(poses_gt, poses_pred, pose_errors, loops, seq_name, save_path=None):
|
||||
"""
|
||||
可视化真值轨迹与预测轨迹对比
|
||||
|
||||
Args:
|
||||
poses_gt: 真值轨迹 (N, 4, 4)
|
||||
poses_pred: 预测轨迹 (N, 4, 4)
|
||||
pose_errors: 每个位姿的误差
|
||||
loops: 检索到的闭环
|
||||
seq_name: 序列名称
|
||||
save_path: 保存路径
|
||||
"""
|
||||
poses_gt_np = poses_gt.cpu().numpy() if isinstance(poses_gt, torch.Tensor) else poses_gt
|
||||
poses_pred_np = poses_pred.cpu().numpy() if isinstance(poses_pred, torch.Tensor) else poses_pred
|
||||
|
||||
# 提取轨迹位置
|
||||
traj_gt_x = poses_gt_np[:, 0, 3]
|
||||
traj_gt_y = poses_gt_np[:, 1, 3]
|
||||
traj_gt_z = poses_gt_np[:, 2, 3]
|
||||
|
||||
traj_pred_x = poses_pred_np[:, 0, 3]
|
||||
traj_pred_y = poses_pred_np[:, 1, 3]
|
||||
traj_pred_z = poses_pred_np[:, 2, 3]
|
||||
|
||||
fig = plt.figure(figsize=(18, 12))
|
||||
|
||||
# 1. 2D轨迹对比(XY平面)
|
||||
ax1 = fig.add_subplot(2, 2, 1)
|
||||
ax1.plot(traj_gt_x, traj_gt_y, 'b-', linewidth=1, alpha=0.6, label='Ground Truth')
|
||||
ax1.plot(traj_pred_x, traj_pred_y, 'r--', linewidth=1, alpha=0.6, label='Predicted')
|
||||
ax1.plot(traj_gt_x[0], traj_gt_y[0], 'go', markersize=12, label='Start')
|
||||
ax1.plot(traj_gt_x[-1], traj_gt_y[-1], 'r^', markersize=12, label='End')
|
||||
|
||||
# 绘制闭环连接
|
||||
for loop in loops[:100]:
|
||||
query_idx = loop['query_idx']
|
||||
candidates = loop['candidates']
|
||||
if len(candidates) > 0:
|
||||
best_match = candidates[0]
|
||||
ax1.plot([traj_gt_x[query_idx], traj_gt_x[best_match]],
|
||||
[traj_gt_y[query_idx], traj_gt_y[best_match]],
|
||||
'g-', alpha=0.2, linewidth=0.5)
|
||||
|
||||
ax1.set_xlabel('X (m)')
|
||||
ax1.set_ylabel('Y (m)')
|
||||
ax1.set_title(f'{seq_name} - 2D Trajectory Comparison (XY Plane)')
|
||||
ax1.legend()
|
||||
ax1.axis('equal')
|
||||
ax1.grid(True, linestyle='--', alpha=0.4)
|
||||
|
||||
# 2. 误差沿时间变化
|
||||
ax2 = fig.add_subplot(2, 2, 2)
|
||||
ax2.plot(pose_errors, 'b-', linewidth=0.5, alpha=0.7)
|
||||
ax2.axhline(np.mean(pose_errors), color='r', linestyle='--', label=f'Mean: {np.mean(pose_errors):.3f}m')
|
||||
ax2.axhline(np.median(pose_errors[pose_errors > 0]), color='g', linestyle='--',
|
||||
label=f'Median: {np.median(pose_errors[pose_errors > 0]):.3f}m')
|
||||
ax2.set_xlabel('Frame Index')
|
||||
ax2.set_ylabel('Position Error (m)')
|
||||
ax2.set_title(f'{seq_name} - Localization Error over Time')
|
||||
ax2.legend()
|
||||
ax2.grid(True, linestyle='--', alpha=0.4)
|
||||
|
||||
# 3. XZ平面轨迹
|
||||
ax3 = fig.add_subplot(2, 2, 3)
|
||||
ax3.plot(traj_gt_x, traj_gt_z, 'b-', linewidth=1, alpha=0.6, label='Ground Truth')
|
||||
ax3.plot(traj_pred_x, traj_pred_z, 'r--', linewidth=1, alpha=0.6, label='Predicted')
|
||||
ax3.set_xlabel('X (m)')
|
||||
ax3.set_ylabel('Z (m)')
|
||||
ax3.set_title(f'{seq_name} - Trajectory XZ View')
|
||||
ax3.legend()
|
||||
ax3.axis('equal')
|
||||
ax3.grid(True, linestyle='--', alpha=0.4)
|
||||
|
||||
# 4. 3D轨迹对比
|
||||
ax4 = fig.add_subplot(2, 2, 4, projection='3d')
|
||||
ax4.plot(traj_gt_x, traj_gt_y, traj_gt_z, 'b-', linewidth=1, alpha=0.6, label='Ground Truth')
|
||||
ax4.plot(traj_pred_x, traj_pred_y, traj_pred_z, 'r--', linewidth=1, alpha=0.6, label='Predicted')
|
||||
ax4.scatter(traj_gt_x[0], traj_gt_y[0], traj_gt_z[0], c='green', s=100, marker='o', label='Start')
|
||||
ax4.scatter(traj_gt_x[-1], traj_gt_y[-1], traj_gt_z[-1], c='blue', s=100, marker='^', label='End')
|
||||
|
||||
ax4.set_xlabel('X (m)')
|
||||
ax4.set_ylabel('Y (m)')
|
||||
ax4.set_zlabel('Z (m)')
|
||||
ax4.set_title(f'{seq_name} - 3D Trajectory Comparison')
|
||||
ax4.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Trajectory comparison saved to {save_path}")
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def visualize_error_distribution(poses_gt, poses_pred, seq_name, save_path=None):
|
||||
"""
|
||||
可视化定位误差分布
|
||||
|
||||
Args:
|
||||
poses_gt: 真值轨迹
|
||||
poses_pred: 预测轨迹
|
||||
seq_name: 序列名称
|
||||
save_path: 保存路径
|
||||
"""
|
||||
poses_gt_np = poses_gt.cpu().numpy() if isinstance(poses_gt, torch.Tensor) else poses_gt
|
||||
poses_pred_np = poses_pred.cpu().numpy() if isinstance(poses_pred, torch.Tensor) else poses_pred
|
||||
|
||||
errors = np.linalg.norm(poses_pred_np[:, :3, 3] - poses_gt_np[:, :3, 3], axis=1)
|
||||
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||||
|
||||
# 误差直方图
|
||||
ax1 = axes[0]
|
||||
ax1.hist(errors, bins=50, edgecolor='black', alpha=0.7)
|
||||
ax1.axvline(np.mean(errors), color='r', linestyle='--', label=f'Mean: {np.mean(errors):.3f}m')
|
||||
ax1.axvline(np.median(errors), color='g', linestyle='--', label=f'Median: {np.median(errors):.3f}m')
|
||||
ax1.set_xlabel('Position Error (m)')
|
||||
ax1.set_ylabel('Frequency')
|
||||
ax1.set_title(f'{seq_name} - Localization Error Distribution')
|
||||
ax1.legend()
|
||||
ax1.grid(True, linestyle='--', alpha=0.4)
|
||||
|
||||
# 误差累积分布
|
||||
ax2 = axes[1]
|
||||
sorted_errors = np.sort(errors)
|
||||
cumulative = np.arange(1, len(sorted_errors) + 1) / len(sorted_errors)
|
||||
ax2.plot(sorted_errors, cumulative, 'b-', linewidth=2)
|
||||
ax2.axhline(0.5, color='gray', linestyle='--', alpha=0.5)
|
||||
ax2.axhline(0.95, color='gray', linestyle='--', alpha=0.5)
|
||||
ax2.set_xlabel('Position Error (m)')
|
||||
ax2.set_ylabel('Cumulative Probability')
|
||||
ax2.set_title(f'{seq_name} - Error Cumulative Distribution')
|
||||
ax2.grid(True, linestyle='--', alpha=0.4)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Error distribution saved to {save_path}")
|
||||
|
||||
return fig, errors
|
||||
|
||||
|
||||
def visualize_pointcloud_with_pose(dataset, poses, idx_list, save_path=None):
|
||||
"""
|
||||
可视化指定帧的点云和位置
|
||||
"""
|
||||
fig = plt.figure(figsize=(16, 12))
|
||||
num_frames = len(idx_list)
|
||||
|
||||
for i, idx in enumerate(idx_list):
|
||||
ax = fig.add_subplot(2, (num_frames + 1) // 2, i + 1, projection='3d')
|
||||
|
||||
scan = load_pointcloud(dataset, idx)
|
||||
pose = poses[idx]
|
||||
pose_np = pose.cpu().numpy() if isinstance(pose, torch.Tensor) else pose
|
||||
scan_world = transform_points(scan, pose_np)
|
||||
|
||||
if len(scan_world) > 5000:
|
||||
sample_idx = np.random.choice(len(scan_world), 5000, replace=False)
|
||||
scan_world = scan_world[sample_idx]
|
||||
|
||||
ax.scatter(scan_world[:, 0], scan_world[:, 1], scan_world[:, 2],
|
||||
c=scan_world[:, 2], cmap='jet', s=0.5, alpha=0.5)
|
||||
ax.scatter(pose_np[0, 3], pose_np[1, 3], pose_np[2, 3],
|
||||
c='red', s=100, marker='^', label='Current Pose')
|
||||
|
||||
ax.set_xlabel('X (m)')
|
||||
ax.set_ylabel('Y (m)')
|
||||
ax.set_zlabel('Z (m)')
|
||||
ax.set_title(f'Frame {idx}')
|
||||
ax.legend()
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"Point cloud visualization saved to {save_path}")
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def visualize_bev_with_loops(dataset, poses, idx_list, save_path=None):
|
||||
"""
|
||||
可视化BEV(鸟瞰图)视角下的点云
|
||||
"""
|
||||
fig, ax = plt.subplots(figsize=(12, 12))
|
||||
|
||||
poses_np = poses.cpu().numpy() if isinstance(poses, torch.Tensor) else poses
|
||||
traj_x = poses_np[:, 0, 3]
|
||||
traj_y = poses_np[:, 1, 3]
|
||||
ax.plot(traj_x, traj_y, 'b-', linewidth=1, alpha=0.3, label='Full Trajectory')
|
||||
|
||||
colors = ['red', 'green', 'purple', 'orange', 'cyan']
|
||||
for i, idx in enumerate(idx_list):
|
||||
scan = load_pointcloud(dataset, idx)
|
||||
pose = poses[idx]
|
||||
pose_np = pose.cpu().numpy() if isinstance(pose, torch.Tensor) else pose
|
||||
scan_world = transform_points(scan, pose_np)
|
||||
|
||||
if len(scan_world) > 3000:
|
||||
sample_idx = np.random.choice(len(scan_world), 3000, replace=False)
|
||||
scan_sample = scan_world[sample_idx]
|
||||
else:
|
||||
scan_sample = scan_world
|
||||
|
||||
ax.scatter(scan_sample[:, 0], scan_sample[:, 1],
|
||||
c=colors[i % len(colors)], s=0.5, alpha=0.5,
|
||||
label=f'Frame {idx}')
|
||||
ax.scatter(pose_np[0, 3], pose_np[1, 3],
|
||||
c=colors[i % len(colors)], s=100, marker='^', edgecolors='black')
|
||||
|
||||
ax.set_xlabel('X (m)')
|
||||
ax.set_ylabel('Y (m)')
|
||||
ax.set_title('Bird Eye View - Point Clouds and Trajectory')
|
||||
ax.legend()
|
||||
ax.axis('equal')
|
||||
ax.grid(True, linestyle='--', alpha=0.4)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
print(f"BEV visualization saved to {save_path}")
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def main(args):
|
||||
# 加载配置
|
||||
try:
|
||||
with open(os.path.join(os.getcwd(), "config.yaml"), "r") as ymlfile:
|
||||
cfg = yaml.load(ymlfile, Loader=yaml.SafeLoader)
|
||||
except:
|
||||
with open(os.path.join(os.getcwd(), "project/FUSIONLCD/config.yaml"), "r") as ymlfile:
|
||||
cfg = yaml.load(ymlfile, Loader=yaml.SafeLoader)
|
||||
cfg = cfg['experiment']
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and cfg['cuda'] else "cpu")
|
||||
print(f"Using device: {device}")
|
||||
|
||||
sequence_test = [int(x) for x in tools.read_cfg(cfg['test'])]
|
||||
if args.sequence is not None:
|
||||
sequence_to_vis = [args.sequence]
|
||||
else:
|
||||
sequence_to_vis = sequence_test
|
||||
|
||||
output_dir = os.path.join(cfg['path_result'], args.result_name, 'visualization')
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 加载模型
|
||||
model_path = os.path.join(cfg['path_result'], args.result_name, 'models', args.model_name)
|
||||
if not os.path.exists(model_path):
|
||||
model_path = args.model_name
|
||||
print(f"Loading model from: {model_path}")
|
||||
model = load_model(cfg, model_path, device)
|
||||
|
||||
# 加载数据集
|
||||
_, _, loader_test = KittiTotalLoader(cfg)
|
||||
|
||||
for seq in sequence_to_vis:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Processing Sequence {seq:02d}")
|
||||
print('='*60)
|
||||
|
||||
seq_datasets = [d for d in loader_test.dataset.datasets if int(d.sequence) == seq]
|
||||
if not seq_datasets:
|
||||
print(f"Sequence {seq:02d} not found in test set, skipping...")
|
||||
continue
|
||||
|
||||
dataset = seq_datasets[0]
|
||||
|
||||
# 提取特征
|
||||
print("Extracting features...")
|
||||
seq_vlads = []
|
||||
seq_feas = []
|
||||
seq_kpts = []
|
||||
seq_poses = []
|
||||
|
||||
with torch.no_grad():
|
||||
for i, data in enumerate(loader_test):
|
||||
batch_seqs = data['sequence']
|
||||
if int(batch_seqs[0]) != seq:
|
||||
continue
|
||||
|
||||
bev_query = data['bev_query'].to(device)
|
||||
pose_query = data['pose_query'].to(device)
|
||||
img_query = data['img_query'].to(device)
|
||||
|
||||
try:
|
||||
bev = bev_query.permute(0, 3, 1, 2)
|
||||
except:
|
||||
bev = 0
|
||||
try:
|
||||
img = img_query.permute(0, 3, 1, 2)
|
||||
except:
|
||||
img = 0
|
||||
try:
|
||||
relation = data['relation'].to(device)
|
||||
except:
|
||||
relation = 0
|
||||
|
||||
batch_dict = {
|
||||
'bev': bev,
|
||||
'img': img,
|
||||
'relation': relation,
|
||||
'id_query': data['id_query'],
|
||||
'sequence': batch_seqs,
|
||||
'pose_query': pose_query,
|
||||
'batch_size': len(data['id_query'])
|
||||
}
|
||||
|
||||
model(batch_dict)
|
||||
|
||||
seq_vlads.append(batch_dict['vlads'].detach().cpu())
|
||||
seq_feas.append(batch_dict.get('fea_kpt_fusion', batch_dict.get('fea_kpt_original', batch_dict['vlads'])).detach().cpu())
|
||||
seq_kpts.append(batch_dict['key_points'].detach().cpu())
|
||||
seq_poses.append(pose_query.detach().cpu())
|
||||
|
||||
if not seq_vlads:
|
||||
print(f"No data extracted for sequence {seq:02d}")
|
||||
continue
|
||||
|
||||
vlads = torch.cat(seq_vlads)
|
||||
feas = torch.cat(seq_feas)
|
||||
kpts = torch.cat(seq_kpts)
|
||||
poses = torch.cat(seq_poses)
|
||||
|
||||
# 检索闭环
|
||||
print("Retrieving loops...")
|
||||
loops = retrieve_loops_with_poses(vlads.cuda(), feas.cuda(), kpts.cuda(), poses.cuda(),
|
||||
num_cand=args.num_candidates)
|
||||
|
||||
# 计算校正轨迹
|
||||
print("Computing corrected trajectory...")
|
||||
poses_corrected, pose_errors = compute_corrected_trajectory(poses, loops, feas, kpts, device)
|
||||
|
||||
# 生成可视化
|
||||
seq_name = f"seq_{seq:02d}"
|
||||
|
||||
# 1. 轨迹对比(真值 vs 预测)
|
||||
traj_path = os.path.join(output_dir, f'{seq_name}_trajectory_comparison.png')
|
||||
visualize_trajectory_comparison(poses, poses_corrected, pose_errors, loops, seq_name, save_path=traj_path)
|
||||
|
||||
# 2. 误差分布
|
||||
error_path = os.path.join(output_dir, f'{seq_name}_error_distribution.png')
|
||||
fig_errors, errors = visualize_error_distribution(poses, poses_corrected, seq_name, save_path=error_path)
|
||||
if len(errors) > 0:
|
||||
print(f" Localization errors - Mean: {np.mean(errors):.3f}m, Median: {np.median(errors):.3f}m, Max: {np.max(errors):.3f}m")
|
||||
|
||||
# 3. BEV点云可视化
|
||||
if len(dataset.poses) > args.num_frames:
|
||||
keyframe_indices = np.linspace(0, len(dataset.poses) - 1, args.num_frames, dtype=int)
|
||||
else:
|
||||
keyframe_indices = list(range(len(dataset.poses)))
|
||||
keyframe_indices = [int(i) for i in keyframe_indices]
|
||||
|
||||
bev_path = os.path.join(output_dir, f'{seq_name}_bev.png')
|
||||
visualize_bev_with_loops(dataset, poses, keyframe_indices, save_path=bev_path)
|
||||
|
||||
# 4. 3D点云可视化
|
||||
pc_path = os.path.join(output_dir, f'{seq_name}_pointcloud.png')
|
||||
visualize_pointcloud_with_pose(dataset, poses, keyframe_indices, save_path=pc_path)
|
||||
|
||||
print(f"\nVisualization for sequence {seq:02d} saved to: {output_dir}")
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"All visualizations saved to: {output_dir}")
|
||||
print('='*60)
|
||||
|
||||
if not args.no_show:
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Visualize localization results')
|
||||
parser.add_argument('--result_name', type=str, default='log',
|
||||
help='Result directory name')
|
||||
parser.add_argument('--model_name', type=str, default='checkpoint_fusion.pth.tar',
|
||||
help='Model checkpoint filename')
|
||||
parser.add_argument('--sequence', type=int, default=None,
|
||||
help='Specific sequence to visualize')
|
||||
parser.add_argument('--num_candidates', type=int, default=5,
|
||||
help='Number of loop candidates')
|
||||
parser.add_argument('--num_frames', type=int, default=6,
|
||||
help='Number of key frames for point cloud')
|
||||
parser.add_argument('--no_show', action='store_true',
|
||||
help='Do not display plots')
|
||||
parser.add_argument('--gpu', type=str, default=None,
|
||||
help='GPU device id')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.gpu:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
||||
|
||||
main(args)
|
||||
Reference in New Issue
Block a user