Files
fusion_LCD/evaluate_lcd.py
2026-03-04 20:07:57 +08:00

266 lines
10 KiB
Python

import os
import time
import matplotlib
# set non-interactive backend for server (must be set before pyplot import)
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import torch
from skimage.measure import ransac
from skimage.transform import EuclideanTransform
import tools
from tqdm import tqdm
from sklearn.metrics import auc
from sklearn.neighbors import KDTree
import warnings
warnings.filterwarnings("ignore")
def recall_with_candidates(vlads, poses, sequence, recall_num=25, positive_distance=4):
recall_at_k = [0] * recall_num
num_with_loop = 0
if __name__ == '__main__':
flag = False
else:
flag = True
for i in tqdm(range(0, len(vlads)), disable=flag, ncols=60, desc='Recall@k'):
valid_idx = list(set(range(0, len(vlads))) - set(range(max(0, i - 50), min(len(vlads), i + 50))))
valid_idx = torch.tensor(valid_idx).to(vlads.device)
vlad_query = vlads[i].view(1, -1)
vlad_valid = vlads[valid_idx]
dis_valid = torch.linalg.norm((poses[i:i + 1, 0:3, 3] - poses[valid_idx, 0:3, 3]), dim=1)
min_dis = torch.min(dis_valid)
if min_dis > positive_distance:
continue
num_with_loop = num_with_loop + 1
# global feature to query quickly
dis_vlad = torch.cdist(vlad_query, vlad_valid).view(-1, )
dis, idx_cand = torch.topk(dis_vlad, recall_num, largest=False)
idx_cand = valid_idx[idx_cand]
for j in range(recall_num):
idx_cand1 = idx_cand[j]
dis = torch.linalg.norm((poses[i:i + 1, 0:3, 3] - poses[idx_cand1, 0:3, 3]), dim=1)
if dis <= positive_distance:
recall_at_k[j] = recall_at_k[j] + 1
break
time.sleep(1)
recall_at_k = np.cumsum(recall_at_k) / float(num_with_loop)
print('Sequence %02d, Recall@' % sequence, end='')
for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 14, 19, 24]:
if i == (len(recall_at_k) - 1):
print('%d[%.3f]' % (i + 1, recall_at_k[i]))
else:
print('%d[%.3f]' % (i + 1, recall_at_k[i]), end=', ')
return recall_at_k
def retrieve(vlads, feas, kpts, poses, num_cand=1, verify='ransac'):
loops = []
if __name__ == '__main__':
flag = False
else:
flag = True
ts = []
for i in tqdm(range(0, len(feas)), disable=flag, ncols=60, desc='retrieve loop'):
t0 = time.time()
valid_idx = list(set(range(0, len(feas))) - set(range(max(0, i - 50), min(len(feas), i + 50))))
valid_idx = torch.tensor(valid_idx).to(vlads.device)
vlad_query = vlads[i].view(1, -1)
vlad_valid = vlads[valid_idx]
# global feature to query quickly
dis_vlad = torch.cdist(vlad_query, vlad_valid).view(-1, )
dis, idx_cand = torch.topk(dis_vlad, num_cand, largest=False)
t_retrieve = time.time() - t0
idx_cand = valid_idx[idx_cand]
# local feature to qverify
fea_query = feas[i]
if verify == 'ransac':
p1 = kpts[i]
p1 = p1.cpu().detach().numpy()
min_dis = torch.tensor([9999])
idx_detect = idx_cand[0]
dis_truth = torch.tensor([9999])
for idx_cand1 in idx_cand:
fea_cand1 = feas[idx_cand1]
p2 = kpts[idx_cand1].cpu().detach().numpy()
idx1, idx2, dis = tools.nn_match(fea_query, fea_cand1, 'cosine')
if len(idx1) < 31:
continue
idx1 = idx1.cpu().detach().numpy()
idx2 = idx2.cpu().detach().numpy()
try:
model, inliers = ransac((p1[idx1, 0:2], p2[idx2, 0:2]), model_class=EuclideanTransform, min_samples=15, max_trials=3, residual_threshold=1)
num_inlier = np.sum(inliers)
# r = model.params[0:2, 0:2]
dis_estimate = np.linalg.norm(model.params[0:2, 2])
# rot = model.rotation
if num_inlier > 30: # ransac存在足够内点
if min_dis > dis_estimate:
min_dis = dis_estimate
idx_detect = idx_cand1
dis_truth = torch.linalg.norm((poses[i, 0:3, 3] - poses[idx_detect, 0:3, 3]))
except:
pass
loops.append([i, idx_detect.item(), min_dis.item(), dis_truth.item()])
else:
idx_detect = idx_cand[0]
dis_truth = torch.linalg.norm(poses[i, 0:3, 3] - poses[idx_detect, 0:3, 3])
loops.append([i, idx_detect.item(), dis[0].item(), dis_truth.item()])
t_verify = time.time() - t0 - t_retrieve
ts.append([t_retrieve, t_verify])
# if loops[-1][2] < 4 and loops[-1][1] < i:
# loop1.append(loops[-1][1])
# x = poses[:, 0, 3]
# y = poses[:, 1, 3]
# x1 = x[loop1]
# y1 = y[loop1]
# plt.plot(x, y, 'b.', markersize=1)
# plt.plot(x1, y1, 'ro', markersize=2, markerfacecolor='none')
# plt.axis('equal')
# plt.show()
ts = np.array(ts) * 1000
# np.savetxt('times.txt', ts)
# x=np.arange(len(ts))
# plt.plot(x,ts[:,0],'b.')
# plt.plot(x,ts[:,1],'r.')
# plt.show()
loops = np.array(loops)
return loops
def pr_curve(poses, loops, sequence, positive_distance=4):
map_tree_poses = KDTree(poses[:, 0:3, 3])
reverse_loops = []
real_loop = []
for i in range(0,len(poses)):
min_range = max(0, i - 50)
max_range = min(i + 50, poses.shape[0])
current_pose = poses[i]
indices = map_tree_poses.query_radius(np.expand_dims(current_pose[0:3, 3], 0), positive_distance)
valid_idxs = list(set(indices[0]) - set(range(min_range, max_range)))
valid_idxs = np.array(valid_idxs)
if len(valid_idxs) > 0:
# dis = np.linalg.norm(current_pose[0:3, 3]-poses[valid_idxs,0:3,3],axis=1)
real_loop.append(1)
r0 = poses[i, :3, :3]
rs = poses[valid_idxs, :3, :3]
dr = np.linalg.inv(r0) @ rs.swapaxes(0, 2)
angle = np.arccos(np.clip((np.trace(dr) - 1) / 2, -1, 1))
angle = angle * 180 / np.pi
if np.min(angle) > 90:
reverse_loops.append(1)
else:
reverse_loops.append(0)
else:
real_loop.append(0)
reverse_loops.append(0)
reverse_loops = np.array(reverse_loops)
real_loop = np.array(real_loop)
# loops=np.hstack((loops,real_loop.reshape(-1,1)))
# np.savetxt('loops_bev%02d.txt'%sequence,loops,fmt='%.6f')
# print('sequence %d, %d frames, %d loops, %d reverse loops' % (sequence,len(real_loop), np.sum(real_loop), np.sum(reverse_loops)))
# # return 0
distances = loops[:, 3]
detected_loop = loops[:, 2]
precision2 = [1]
recall2 = [0]
for thr in np.unique(detected_loop):
tp = detected_loop <= thr
tp = tp & real_loop
tp = tp & (distances <= positive_distance)
tp = tp.sum()
fp = (detected_loop <= thr).sum() - tp
fn = (real_loop.sum()) - tp
if (tp + fp) > 0.:
precision2.append(tp / (tp + fp))
else:
precision2.append(1.)
recall2.append(tp / (tp + fn))
f1s = []
for i in range(len(recall2)):
f1s.append((2 * precision2[i] * recall2[i]) / (precision2[i] + recall2[i]))
f1 = max(f1s)
recall_p1 = np.max(np.array(recall2)[np.array(precision2) == 1])
# plt.plot(recall2, precision2, 'b-')
# plt.show()
pr = np.array(precision2 + recall2).reshape(2, -1).T
# np.save('fusion_pr_%02d.npy' % sequence, pr)
ap = auc(recall2, precision2)
idx=loops[:,2]<9999
loops1=loops[idx]
rp=np.sum(np.abs(loops1[:,2]-loops1[:,3])<2)/len(loops1)
print('Sequence %02d, AP %.3f, Recall@100 %.3f, F1 %.3f, RP %.3f/%d' % (sequence, ap, recall_p1, f1, rp, len(loops1)))
# if ap<0.1:
# exit()
# --- save PR curve to file (for server usage) ---
try:
out_dir = os.path.join('/home/adlab36/chenyouyuan/FUSIONLCD', 'result', 'plots')
os.makedirs(out_dir, exist_ok=True)
plt.figure(figsize=(6, 5))
plt.plot(recall2, precision2, 'b-', marker='o', linewidth=2)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title(f'Sequence {int(sequence):02d} PR (AP={ap:.3f})')
plt.grid(True, linestyle='--', alpha=0.4)
fname = os.path.join(out_dir, f'pr_sequence_{int(sequence):02d}.png')
plt.tight_layout()
plt.savefig(fname, dpi=150)
plt.close()
except Exception as e:
# 如果保存失败也不要阻塞主流程
print(f'Warning: failed to save PR plot for sequence {sequence}: {e}')
return ap, recall_p1, f1
def lcd(data):
vlads = data['vlads'].cuda()
kpts = data['key_points']
sequences = data['sequences']
poses = data['pose_query'].cuda()
feas = data['fea_kpt'].cuda()
# feas = feas / torch.sqrt(torch.sum(feas ** 2, -1, keepdim=True) + 1e-8)
result = []
recall_at_ks = []
recall_at_k=[]
for s in torch.unique(sequences):
# if s==54:
# continue
mask = sequences == s
vlads1 = vlads[mask]
feas1 = feas[mask]
kpts1 = kpts[mask]
poses1 = poses[mask]
poses2 = poses1.cpu().detach().numpy()
# recall_at_k = recall_with_candidates(vlads1, poses1, s)
# idx=np.arange(len(vlads1)//2)
# idx=np.tile(idx, 2)
# vlads1, feas1, kpts1, poses1 =vlads1[idx], feas1[idx], kpts1[idx], poses1[idx]
loops = retrieve(vlads1, feas1, kpts1, poses1, 1, 'ransac')
ap, recall_p1, f1 = pr_curve(poses2, loops, s, 4)
recall_at_ks.append(recall_at_k)
result.append([ap, recall_p1, f1])
return result, recall_at_ks
if __name__ == '__main__':
np.random.seed(123)
# data = torch.load('/data4/caodanyang/results/FUSIONLCD/07030/database/database_bev.pth.tar')
# lcd(data)
print('----------------------------------------------------------------------')
data= torch.load('/home/adlab36/chenyouyuan/FUSIONLCD/result/log/database/database_bevp.pth.tar')
lcd(data)
print('----------------------------------------------------------------------')
# data=torch.load('/data4/caodanyang/results/FUSIONLCD/07030/database/database_fusion.pth.tar')
# lcd(data)