Files
fusion_LCD/loss.py
2026-03-04 20:07:57 +08:00

340 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_metric_learning import distances
def tr_loss(batch_dict,key):
loss1 = (batch_dict[key][:, 0:3, 3] -
batch_dict['pose_to_frame'][:, 0:3, 3]).norm(dim=1).mean()
loss2 = (torch.acos(torch.clip(batch_dict[key][:, 0, 0].view(-1, 1), -1, 1)) -
torch.acos(torch.clip(batch_dict['pose_to_frame'][:, 0, 0].view(-1, 1), -1, 1))).norm(dim=1).mean() / 3.1415 * 180
return loss1, loss2
def gen_points_loss(batch_dict):
key_points_gen = batch_dict['key_points_gen']
key_points = batch_dict['key_points']
key_points_gen1 = torch.cat((key_points_gen, key_points_gen * 0), dim=2)
key_points_gen1[:, :, 3] = 1
# pose_query=batch_dict['pose_query']
# pose_positive=batch_dict['pose_positive']
# poses=torch.cat((pose_query,pose_positive),dim=0)
# key_points_gen2=torch.bmm(poses,key_points_gen1.permute(0,2,1)).permute(0,2,1)
# key_points2=torch.bmm(poses,key_points.permute(0,2,1)).permute(0,2,1)
# loss_gpo=(key_points_gen2[:,:,:2]-key_points2[:,:,:2]).norm(p=1,dim=2).mean()
pose_to_frame = batch_dict['pose_to_frame']
B = pose_to_frame.shape[0]
src_pts = key_points[:B]
tgt_pts = key_points[B:]
src_pts_gen = key_points_gen1[:B]
tgt_pts_gen = key_points_gen1[B:]
srcs = torch.cat((src_pts, src_pts_gen), dim=0)
tgts = torch.cat((tgt_pts_gen, tgt_pts), dim=0)
pose_to_frame1 = torch.cat((pose_to_frame, pose_to_frame), dim=0)
srcs1 = torch.bmm(pose_to_frame1, srcs.permute(0, 2, 1)).permute(0, 2, 1)
loss = torch.mean(torch.abs(srcs1[:, :, :2] - tgts[:, :, :2]))
return loss
def rand_dis(x, y):
assert len(x.shape)==2 and len(y.shape)==2,'x and y must be 2 dim'
N, N = x.size()
ids=torch.arange(N).to(x.device)
idx = ids.view(1, N).repeat(N, 1)
mask = ~(idx == idx.transpose(0, 1))
idx1 = idx[mask].view(N, N - 1)
random_indices = torch.randint(N - 1, size=(N,)).to(x.device)
rand_idx = torch.gather(idx1, 1, random_indices.view(-1, 1))
rand_idx1 = torch.cat([ids.view(N, 1), rand_idx], dim=1)
diag = ids.view(N, 1).repeat(1, 2)
x1 = x[rand_idx1[:, 0], rand_idx1[:, 1]]
x2 = x1*0
x3 = torch.cat([x1.view(N, 1), x2.view(N, 1)], dim=1)
y1 = y[rand_idx1[:, 0], rand_idx1[:, 1]]
y2 = y[diag[:, 0], diag[:, 1]]
y3 = torch.cat([y1.view(N, 1), y2.view(N, 1)], dim=1)
dis=torch.abs(x3-y3).mean()+F.relu(0.2-torch.abs(y1)).mean()
return dis
def gen_feature_loss(batch_dict):
#BCN
fea_pt_dual_gen = batch_dict['fea_pt_dual_gen']
fea_pl_dual_gen = batch_dict['fea_pl_dual_gen']
fea_kpt_original_gen = batch_dict['fea_kpt_original_gen']
# fea_kpt_gen_gen=batch_dict['fea_kpt_gen_gen']
fea_pt_dual = batch_dict['fea_pt_dual']
fea_pl_dual = batch_dict['fea_pl_dual']
fea_kpt_original = batch_dict['fea_kpt_original']
# fea_pt_dual = batch_dict['fea_pt_dual'].detach()
# fea_pl_dual = batch_dict['fea_pl_dual'].detach()
# fea_kpt_original = batch_dict['fea_kpt_original'].detach()
b = fea_pl_dual.shape[0]
loss0 = 0
loss1 = 0
loss2 = 0
loss3 = 0
relation = batch_dict['relation']
nums=0
for i in range(b):
cnt = torch.sum((relation[i, :, -1, 0] > 0) & (relation[i, :, -1, 1] > 0))
nums+=cnt
fea_pt_dual1 = fea_pt_dual[i, :, :cnt] # 匹配点云特征CN
fea_pt_dual_gen1 = fea_pt_dual_gen[i, :, :cnt] # 匹配点云特征,生成于图像
fea_pl_dual1 = fea_pl_dual[i, :, :cnt] # 匹配图像特征
fea_pl_dual_gen1 = fea_pl_dual_gen[i, :, :cnt] # 匹配图像特征,生成于点云
# loss0 = loss0 + torch.abs(fea_pt_dual1 - fea_pt_dual_gen1).mean()
loss0 = loss0 + (1 - F.cosine_similarity(fea_pt_dual1,fea_pt_dual_gen1,dim=0)).mean()
# loss0 = loss0 + F.mse_loss(fea_pt_dual1, fea_pt_dual_gen1)
# loss0 = loss0 + ((fea_pt_dual1 - fea_pt_dual_gen1).norm(p=2, dim=0)).mean()
# sims00=tools.batch_distance(fea_pt_dual1.unsqueeze(0).permute(0,2,1),fea_pt_dual1.unsqueeze(0).permute(0,2,1),'cosine')
# sims01=tools.batch_distance(fea_pt_dual_gen1.unsqueeze(0).permute(0,2,1),fea_pt_dual_gen1.unsqueeze(0).permute(0,2,1),'cosine')
# loss0 = loss0 + torch.abs(sims00-sims01).mean()
# loss1 = loss1 + torch.abs(fea_pl_dual1 - fea_pl_dual_gen1).mean()
loss1 = loss1 + (1 - F.cosine_similarity(fea_pl_dual1,fea_pl_dual_gen1,dim=0)).mean()
# loss1 = loss1 + F.mse_loss(fea_pl_dual1, fea_pl_dual_gen1)
# loss1 = loss1 + ((fea_pl_dual1 - fea_pl_dual_gen1).norm(p=2, dim=0)).mean()
# sims10=tools.batch_distance(fea_pl_dual1.unsqueeze(0).permute(0,2,1),fea_pl_dual1.unsqueeze(0).permute(0,2,1),'cosine')
# sims11=tools.batch_distance(fea_pl_dual_gen1.unsqueeze(0).permute(0,2,1),fea_pl_dual_gen1.unsqueeze(0).permute(0,2,1),'cosine')
# loss1 = loss1 + torch.abs(sims10-sims11).mean()
#全景特征生成模块损失计算
# loss2 = loss2 + torch.abs(fea_kpt_original[i] - fea_kpt_original_gen[i]).mean()
loss2= loss2 + (1-F.cosine_similarity(fea_kpt_original[i], fea_kpt_original_gen[i],dim=0)).mean()
# loss2 = loss2 + F.mse_loss(fea_kpt_original[i], fea_kpt_original_gen[i])
# loss2 = loss2 + ((fea_kpt_original[i] - fea_kpt_original_gen[i]).norm(p=2, dim=0)).mean()
# sims20=tools.batch_distance(fea_kpt_original[i:i+1].permute(0,2,1),fea_kpt_original[i:i+1].permute(0,2,1),'cosine')
# sims21=tools.batch_distance(fea_kpt_original_gen[i:i+1].permute(0,2,1),fea_kpt_original_gen[i:i+1].permute(0,2,1),'cosine')
# loss2 = loss2 + torch.abs(sims20-sims21).mean()
loss0 = loss0 / b
loss1 = loss1 / b
loss2 = loss2 / b
return loss0, loss1, loss2, loss3
def sinkhorn_matches_loss(batch_dict,key):
project_kpts = batch_dict[key] # calculated from corrspondence of kpts
src_coords = batch_dict['key_points']
pose_to_frame = batch_dict['pose_to_frame']
src_coords = src_coords.clone().view(batch_dict['batch_size'], -1, 4)
B, N_POINT, _ = src_coords.shape
B = B // 2
src_coords = src_coords[:B, :, [0, 1, 2, 3]]
src_coords[:, :, -1] = 1.
gt_dst_coords = torch.bmm(pose_to_frame, src_coords.permute(0, 2, 1)) # True project kpts
gt_dst_coords = gt_dst_coords.permute(0, 2, 1)[:, :, :3]
loss = (gt_dst_coords - project_kpts).norm(dim=2).mean()
return loss
def score_loss(batch_dict):
score = batch_dict['score_bev']
label_score = batch_dict['label_score']
label_score = torch.cat([label_score[:, :, :, 0], label_score[:, :, :, 1]], dim=0)
mask1 = score > 1e-8
# mask2 = label_score > 1e-8
# mask = mask1 | mask2
score = score[mask1]
label_score = label_score[mask1]
loss = nn.functional.mse_loss(score, label_score)
return loss
def pose_loss(batch_dict,key):
src_coords = batch_dict['key_points']
src_coords = src_coords.clone().view(batch_dict['batch_size'], -1, 4)
delta_pose = batch_dict['pose_to_frame']
B, N_POINT, _ = src_coords.shape
B = B // 2
src_coords = src_coords[:B]
gt_dst_coords = torch.bmm(delta_pose, src_coords.permute(0, 2, 1)).float()
gt_dst_coords = gt_dst_coords.permute(0, 2, 1)[:, :, :3]
transformation = batch_dict[key]
pred_dst_coords = torch.bmm(transformation, src_coords.permute(0, 2, 1))
pred_dst_coords = pred_dst_coords.permute(0, 2, 1)[:, :, :3]
loss = torch.mean(torch.abs(pred_dst_coords - gt_dst_coords))
return loss
def get_all_triplets(dist_mat, pos_mask, neg_mask, is_inverted=False, margin=0.5, different_embedding=False):
if not different_embedding:
pos_mask = torch.triu(pos_mask, 1)
triplets = pos_mask.unsqueeze(2) * neg_mask.unsqueeze(1)
return torch.where(triplets)
def hardest_negative_selector(dist_mat, pos_mask, neg_mask, is_inverted, margin=0.5, different_embedding=False):
if not different_embedding:
pos_mask = torch.triu(pos_mask, 1)
a, p = torch.where(pos_mask)
if neg_mask.sum() == 0:
return a, p, None
if is_inverted:
dist_neg = dist_mat * neg_mask
n = torch.max(dist_neg, dim=1)
else:
dist_neg = dist_mat.clone()
dist_neg[~neg_mask] = dist_neg.max() + 1.
_, n = torch.min(dist_neg, dim=1)
n = n[a]
return a, p, n
def random_negative_selector(dist_mat, pos_mask, neg_mask, is_inverted, margin=0.5, different_embedding=False):
if not different_embedding:
pos_mask = torch.triu(pos_mask, 1)
a, p = torch.where(pos_mask)
selected_negs = []
for i in range(a.shape[0]):
possible_negs = torch.where(neg_mask[a[i]])[0]
if len(possible_negs) == 0:
return a, p, None
dist_neg = dist_mat[a[i], possible_negs]
if is_inverted:
curr_loss = -dist_mat[a[i], p[i]] + dist_neg + margin
else:
curr_loss = dist_mat[a[i], p[i]] - dist_neg + margin
if len(possible_negs[curr_loss > 0]) > 0:
possible_negs = possible_negs[curr_loss > 0]
random_neg = np.random.choice(possible_negs.cpu().numpy())
selected_negs.append(random_neg)
n = torch.tensor(selected_negs, dtype=a.dtype, device=a.device)
return a, p, n
def semihard_negative_selector(dist_mat, pos_mask, neg_mask, is_inverted, margin=0.5, different_embedding=False):
if not different_embedding:
pos_mask = torch.triu(pos_mask, 1)
a, p = torch.where(pos_mask)
selected_negs = []
for i in range(a.shape[0]):
possible_negs = torch.where(neg_mask[a[i]])[0]
if len(possible_negs) == 0:
return a, p, None
dist_neg = dist_mat[a[i], possible_negs]
if is_inverted:
curr_loss = -dist_mat[a[i], p[i]] + dist_neg + margin
else:
curr_loss = dist_mat[a[i], p[i]] - dist_neg + margin
semihard_idxs = (curr_loss > 0) & (curr_loss < margin)
if len(possible_negs[semihard_idxs]) > 0:
possible_negs = possible_negs[semihard_idxs]
random_neg = np.random.choice(possible_negs.cpu().numpy())
selected_negs.append(random_neg)
n = torch.tensor(selected_negs, dtype=a.dtype, device=a.device)
return a, p, n
class TripletLoss(nn.Module):
def __init__(self, margin: float, triplet_selector, distance: distances.BaseDistance):
super(TripletLoss, self).__init__()
self.margin = margin
self.triplet_selector = triplet_selector
self.distance = distance
def forward(self, embeddings, pos_mask, neg_mask, other_embeddings=None):
if other_embeddings is None:
other_embeddings = embeddings
dist_mat = self.distance(embeddings, other_embeddings)
triplets = self.triplet_selector(
dist_mat, pos_mask, neg_mask, self.distance.is_inverted)
distance_positive = dist_mat[triplets[0], triplets[1]]
if triplets[-1] is None:
if self.distance.is_inverted:
return F.relu(1 - distance_positive).mean()
else:
return F.relu(distance_positive).mean()
distance_negative = dist_mat[triplets[0], triplets[2]]
curr_margin = self.distance.margin(
distance_positive, distance_negative)
loss = F.relu(curr_margin + self.margin)
return loss.mean()
def _pairwise_distance(x, squared=False, eps=1e-16):
# Compute the 2D matrix of distances between all the embeddings.
cor_mat = torch.matmul(x, x.t())
norm_mat = cor_mat.diag()
distances = norm_mat.unsqueeze(1) - 2 * cor_mat + norm_mat.unsqueeze(0)
distances = F.relu(distances)
if not squared:
mask = torch.eq(distances, 0.0).float()
distances = distances + mask * eps
distances = torch.sqrt(distances)
distances = distances * (1.0 - mask)
return distances
class TotalLoss(nn.Module):
def __init__(self, cfg):
super(TotalLoss, self).__init__()
if 'hardest' == cfg['negetative_selsector']:
neg_selector = hardest_negative_selector
elif 'semihard' == cfg['negetative_selsector']:
neg_selector = semihard_negative_selector
else:
neg_selector = random_negative_selector
self.trip_fun = TripletLoss(margin=cfg['trip_margin'], triplet_selector=neg_selector, distance=distances.LpDistance())
self.negetative_distcance = 50
def forward(self, batch_dict):
l_pose=l_score=l_match=l_tra=l_rot=l_gb=l_gi=l_gpa=l_gpo=l_kpl = 0
if 'key_points' in batch_dict.keys():
l_score = score_loss(batch_dict)
l_match1,l_pose1,l_match2,l_pose2,l_tra1,l_rot1,l_tra2,l_rot2=0,0,0,0,0,0,0,0
if 'transformation_original' in batch_dict.keys():
l_match1 = sinkhorn_matches_loss(batch_dict,'project_kpts_original')
l_tra1, l_rot1 = tr_loss(batch_dict,'transformation_original')
l_pose1 = pose_loss(batch_dict,'transformation_original')
if 'transformation_fusion' in batch_dict.keys():
l_match2 = sinkhorn_matches_loss(batch_dict,'project_kpts_fusion')
l_tra2, l_rot2 = tr_loss(batch_dict,'transformation_fusion')
l_pose2 = pose_loss(batch_dict,'transformation_fusion')
cnt=1
if min(l_rot1,l_rot2)>0:
cnt=2
l_match=(l_match1+l_match2)/cnt
l_pose=(l_pose1+l_pose2)/cnt
l_tra=(l_tra1+l_tra2)/cnt
l_rot=(l_rot1+l_rot2)/cnt
if ('fea_pt_dual_gen' in batch_dict.keys()) or ('fea_pl_dual_gen' in batch_dict.keys()):
l_gb, l_gi, l_gpa,l_kpl = gen_feature_loss(batch_dict)
if 'key_points_gen' in batch_dict.keys():
l_gpo = gen_points_loss(batch_dict)
if 'sequence' in batch_dict:
neg_mask = batch_dict['sequence'].view(1, -1) != batch_dict['sequence'].view(-1, 1)
else:
neg_mask = torch.zeros((batch_dict['pose_query'].shape[0] * 2, batch_dict['pose_query'].shape[0] * 2), dtype=torch.bool)
pair_dist = _pairwise_distance(batch_dict['pose_query'][:, 0:3, 3])
neg_mask = ((pair_dist > self.negetative_distcance) | neg_mask.to(pair_dist.device))
neg_mask = neg_mask.repeat(2, 2)
batch_size = batch_dict['batch_size']
pos_mask = torch.zeros((batch_size, batch_size), device=neg_mask.device)
for i in range(batch_size // 2):
pos_mask[i, i + batch_size // 2] = 1
pos_mask[i + batch_size // 2, i] = 1
l_triplet = self.trip_fun(batch_dict['vlads'], pos_mask, neg_mask)
l_total = l_score + l_pose + 0.05 * l_match + l_triplet + (l_gb + l_gi + l_gpa + l_kpl)
loss = [l_total, l_pose, l_score, l_match, l_triplet, l_tra, l_rot, l_gb, l_gi, l_gpa, l_gpo,l_kpl]
for i in range(len(loss)):
if loss[i]==0:
loss[i]=loss[0]*0
batch_dict['loss']=loss
return batch_dict