340 lines
15 KiB
Python
340 lines
15 KiB
Python
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from pytorch_metric_learning import distances
|
||
|
||
|
||
def tr_loss(batch_dict,key):
|
||
loss1 = (batch_dict[key][:, 0:3, 3] -
|
||
batch_dict['pose_to_frame'][:, 0:3, 3]).norm(dim=1).mean()
|
||
loss2 = (torch.acos(torch.clip(batch_dict[key][:, 0, 0].view(-1, 1), -1, 1)) -
|
||
torch.acos(torch.clip(batch_dict['pose_to_frame'][:, 0, 0].view(-1, 1), -1, 1))).norm(dim=1).mean() / 3.1415 * 180
|
||
return loss1, loss2
|
||
|
||
|
||
def gen_points_loss(batch_dict):
|
||
key_points_gen = batch_dict['key_points_gen']
|
||
key_points = batch_dict['key_points']
|
||
key_points_gen1 = torch.cat((key_points_gen, key_points_gen * 0), dim=2)
|
||
key_points_gen1[:, :, 3] = 1
|
||
# pose_query=batch_dict['pose_query']
|
||
# pose_positive=batch_dict['pose_positive']
|
||
# poses=torch.cat((pose_query,pose_positive),dim=0)
|
||
# key_points_gen2=torch.bmm(poses,key_points_gen1.permute(0,2,1)).permute(0,2,1)
|
||
# key_points2=torch.bmm(poses,key_points.permute(0,2,1)).permute(0,2,1)
|
||
# loss_gpo=(key_points_gen2[:,:,:2]-key_points2[:,:,:2]).norm(p=1,dim=2).mean()
|
||
pose_to_frame = batch_dict['pose_to_frame']
|
||
B = pose_to_frame.shape[0]
|
||
src_pts = key_points[:B]
|
||
tgt_pts = key_points[B:]
|
||
src_pts_gen = key_points_gen1[:B]
|
||
tgt_pts_gen = key_points_gen1[B:]
|
||
srcs = torch.cat((src_pts, src_pts_gen), dim=0)
|
||
tgts = torch.cat((tgt_pts_gen, tgt_pts), dim=0)
|
||
pose_to_frame1 = torch.cat((pose_to_frame, pose_to_frame), dim=0)
|
||
srcs1 = torch.bmm(pose_to_frame1, srcs.permute(0, 2, 1)).permute(0, 2, 1)
|
||
loss = torch.mean(torch.abs(srcs1[:, :, :2] - tgts[:, :, :2]))
|
||
return loss
|
||
|
||
|
||
def rand_dis(x, y):
|
||
assert len(x.shape)==2 and len(y.shape)==2,'x and y must be 2 dim'
|
||
N, N = x.size()
|
||
ids=torch.arange(N).to(x.device)
|
||
idx = ids.view(1, N).repeat(N, 1)
|
||
mask = ~(idx == idx.transpose(0, 1))
|
||
idx1 = idx[mask].view(N, N - 1)
|
||
random_indices = torch.randint(N - 1, size=(N,)).to(x.device)
|
||
rand_idx = torch.gather(idx1, 1, random_indices.view(-1, 1))
|
||
rand_idx1 = torch.cat([ids.view(N, 1), rand_idx], dim=1)
|
||
diag = ids.view(N, 1).repeat(1, 2)
|
||
x1 = x[rand_idx1[:, 0], rand_idx1[:, 1]]
|
||
x2 = x1*0
|
||
x3 = torch.cat([x1.view(N, 1), x2.view(N, 1)], dim=1)
|
||
y1 = y[rand_idx1[:, 0], rand_idx1[:, 1]]
|
||
y2 = y[diag[:, 0], diag[:, 1]]
|
||
y3 = torch.cat([y1.view(N, 1), y2.view(N, 1)], dim=1)
|
||
dis=torch.abs(x3-y3).mean()+F.relu(0.2-torch.abs(y1)).mean()
|
||
return dis
|
||
|
||
|
||
def gen_feature_loss(batch_dict):
|
||
#BCN
|
||
fea_pt_dual_gen = batch_dict['fea_pt_dual_gen']
|
||
fea_pl_dual_gen = batch_dict['fea_pl_dual_gen']
|
||
fea_kpt_original_gen = batch_dict['fea_kpt_original_gen']
|
||
# fea_kpt_gen_gen=batch_dict['fea_kpt_gen_gen']
|
||
fea_pt_dual = batch_dict['fea_pt_dual']
|
||
fea_pl_dual = batch_dict['fea_pl_dual']
|
||
fea_kpt_original = batch_dict['fea_kpt_original']
|
||
# fea_pt_dual = batch_dict['fea_pt_dual'].detach()
|
||
# fea_pl_dual = batch_dict['fea_pl_dual'].detach()
|
||
# fea_kpt_original = batch_dict['fea_kpt_original'].detach()
|
||
|
||
b = fea_pl_dual.shape[0]
|
||
loss0 = 0
|
||
loss1 = 0
|
||
loss2 = 0
|
||
loss3 = 0
|
||
relation = batch_dict['relation']
|
||
nums=0
|
||
for i in range(b):
|
||
cnt = torch.sum((relation[i, :, -1, 0] > 0) & (relation[i, :, -1, 1] > 0))
|
||
nums+=cnt
|
||
fea_pt_dual1 = fea_pt_dual[i, :, :cnt] # 匹配点云特征,CN
|
||
fea_pt_dual_gen1 = fea_pt_dual_gen[i, :, :cnt] # 匹配点云特征,生成于图像
|
||
fea_pl_dual1 = fea_pl_dual[i, :, :cnt] # 匹配图像特征
|
||
fea_pl_dual_gen1 = fea_pl_dual_gen[i, :, :cnt] # 匹配图像特征,生成于点云
|
||
|
||
# loss0 = loss0 + torch.abs(fea_pt_dual1 - fea_pt_dual_gen1).mean()
|
||
loss0 = loss0 + (1 - F.cosine_similarity(fea_pt_dual1,fea_pt_dual_gen1,dim=0)).mean()
|
||
# loss0 = loss0 + F.mse_loss(fea_pt_dual1, fea_pt_dual_gen1)
|
||
# loss0 = loss0 + ((fea_pt_dual1 - fea_pt_dual_gen1).norm(p=2, dim=0)).mean()
|
||
# sims00=tools.batch_distance(fea_pt_dual1.unsqueeze(0).permute(0,2,1),fea_pt_dual1.unsqueeze(0).permute(0,2,1),'cosine')
|
||
# sims01=tools.batch_distance(fea_pt_dual_gen1.unsqueeze(0).permute(0,2,1),fea_pt_dual_gen1.unsqueeze(0).permute(0,2,1),'cosine')
|
||
# loss0 = loss0 + torch.abs(sims00-sims01).mean()
|
||
|
||
# loss1 = loss1 + torch.abs(fea_pl_dual1 - fea_pl_dual_gen1).mean()
|
||
loss1 = loss1 + (1 - F.cosine_similarity(fea_pl_dual1,fea_pl_dual_gen1,dim=0)).mean()
|
||
# loss1 = loss1 + F.mse_loss(fea_pl_dual1, fea_pl_dual_gen1)
|
||
# loss1 = loss1 + ((fea_pl_dual1 - fea_pl_dual_gen1).norm(p=2, dim=0)).mean()
|
||
# sims10=tools.batch_distance(fea_pl_dual1.unsqueeze(0).permute(0,2,1),fea_pl_dual1.unsqueeze(0).permute(0,2,1),'cosine')
|
||
# sims11=tools.batch_distance(fea_pl_dual_gen1.unsqueeze(0).permute(0,2,1),fea_pl_dual_gen1.unsqueeze(0).permute(0,2,1),'cosine')
|
||
# loss1 = loss1 + torch.abs(sims10-sims11).mean()
|
||
|
||
#全景特征生成模块损失计算
|
||
# loss2 = loss2 + torch.abs(fea_kpt_original[i] - fea_kpt_original_gen[i]).mean()
|
||
loss2= loss2 + (1-F.cosine_similarity(fea_kpt_original[i], fea_kpt_original_gen[i],dim=0)).mean()
|
||
# loss2 = loss2 + F.mse_loss(fea_kpt_original[i], fea_kpt_original_gen[i])
|
||
# loss2 = loss2 + ((fea_kpt_original[i] - fea_kpt_original_gen[i]).norm(p=2, dim=0)).mean()
|
||
# sims20=tools.batch_distance(fea_kpt_original[i:i+1].permute(0,2,1),fea_kpt_original[i:i+1].permute(0,2,1),'cosine')
|
||
# sims21=tools.batch_distance(fea_kpt_original_gen[i:i+1].permute(0,2,1),fea_kpt_original_gen[i:i+1].permute(0,2,1),'cosine')
|
||
# loss2 = loss2 + torch.abs(sims20-sims21).mean()
|
||
loss0 = loss0 / b
|
||
loss1 = loss1 / b
|
||
loss2 = loss2 / b
|
||
return loss0, loss1, loss2, loss3
|
||
|
||
|
||
def sinkhorn_matches_loss(batch_dict,key):
|
||
project_kpts = batch_dict[key] # calculated from corrspondence of kpts
|
||
src_coords = batch_dict['key_points']
|
||
pose_to_frame = batch_dict['pose_to_frame']
|
||
src_coords = src_coords.clone().view(batch_dict['batch_size'], -1, 4)
|
||
B, N_POINT, _ = src_coords.shape
|
||
B = B // 2
|
||
src_coords = src_coords[:B, :, [0, 1, 2, 3]]
|
||
src_coords[:, :, -1] = 1.
|
||
gt_dst_coords = torch.bmm(pose_to_frame, src_coords.permute(0, 2, 1)) # True project kpts
|
||
gt_dst_coords = gt_dst_coords.permute(0, 2, 1)[:, :, :3]
|
||
loss = (gt_dst_coords - project_kpts).norm(dim=2).mean()
|
||
return loss
|
||
|
||
|
||
|
||
def score_loss(batch_dict):
|
||
score = batch_dict['score_bev']
|
||
label_score = batch_dict['label_score']
|
||
label_score = torch.cat([label_score[:, :, :, 0], label_score[:, :, :, 1]], dim=0)
|
||
mask1 = score > 1e-8
|
||
# mask2 = label_score > 1e-8
|
||
# mask = mask1 | mask2
|
||
score = score[mask1]
|
||
label_score = label_score[mask1]
|
||
loss = nn.functional.mse_loss(score, label_score)
|
||
|
||
return loss
|
||
|
||
|
||
def pose_loss(batch_dict,key):
|
||
src_coords = batch_dict['key_points']
|
||
src_coords = src_coords.clone().view(batch_dict['batch_size'], -1, 4)
|
||
delta_pose = batch_dict['pose_to_frame']
|
||
B, N_POINT, _ = src_coords.shape
|
||
B = B // 2
|
||
src_coords = src_coords[:B]
|
||
gt_dst_coords = torch.bmm(delta_pose, src_coords.permute(0, 2, 1)).float()
|
||
gt_dst_coords = gt_dst_coords.permute(0, 2, 1)[:, :, :3]
|
||
|
||
transformation = batch_dict[key]
|
||
pred_dst_coords = torch.bmm(transformation, src_coords.permute(0, 2, 1))
|
||
pred_dst_coords = pred_dst_coords.permute(0, 2, 1)[:, :, :3]
|
||
loss = torch.mean(torch.abs(pred_dst_coords - gt_dst_coords))
|
||
return loss
|
||
|
||
|
||
def get_all_triplets(dist_mat, pos_mask, neg_mask, is_inverted=False, margin=0.5, different_embedding=False):
|
||
if not different_embedding:
|
||
pos_mask = torch.triu(pos_mask, 1)
|
||
triplets = pos_mask.unsqueeze(2) * neg_mask.unsqueeze(1)
|
||
return torch.where(triplets)
|
||
|
||
|
||
def hardest_negative_selector(dist_mat, pos_mask, neg_mask, is_inverted, margin=0.5, different_embedding=False):
|
||
if not different_embedding:
|
||
pos_mask = torch.triu(pos_mask, 1)
|
||
a, p = torch.where(pos_mask)
|
||
if neg_mask.sum() == 0:
|
||
return a, p, None
|
||
if is_inverted:
|
||
dist_neg = dist_mat * neg_mask
|
||
n = torch.max(dist_neg, dim=1)
|
||
else:
|
||
dist_neg = dist_mat.clone()
|
||
dist_neg[~neg_mask] = dist_neg.max() + 1.
|
||
_, n = torch.min(dist_neg, dim=1)
|
||
n = n[a]
|
||
return a, p, n
|
||
|
||
|
||
def random_negative_selector(dist_mat, pos_mask, neg_mask, is_inverted, margin=0.5, different_embedding=False):
|
||
if not different_embedding:
|
||
pos_mask = torch.triu(pos_mask, 1)
|
||
a, p = torch.where(pos_mask)
|
||
selected_negs = []
|
||
for i in range(a.shape[0]):
|
||
possible_negs = torch.where(neg_mask[a[i]])[0]
|
||
if len(possible_negs) == 0:
|
||
return a, p, None
|
||
|
||
dist_neg = dist_mat[a[i], possible_negs]
|
||
if is_inverted:
|
||
curr_loss = -dist_mat[a[i], p[i]] + dist_neg + margin
|
||
else:
|
||
curr_loss = dist_mat[a[i], p[i]] - dist_neg + margin
|
||
|
||
if len(possible_negs[curr_loss > 0]) > 0:
|
||
possible_negs = possible_negs[curr_loss > 0]
|
||
random_neg = np.random.choice(possible_negs.cpu().numpy())
|
||
selected_negs.append(random_neg)
|
||
n = torch.tensor(selected_negs, dtype=a.dtype, device=a.device)
|
||
return a, p, n
|
||
|
||
|
||
def semihard_negative_selector(dist_mat, pos_mask, neg_mask, is_inverted, margin=0.5, different_embedding=False):
|
||
if not different_embedding:
|
||
pos_mask = torch.triu(pos_mask, 1)
|
||
a, p = torch.where(pos_mask)
|
||
selected_negs = []
|
||
for i in range(a.shape[0]):
|
||
possible_negs = torch.where(neg_mask[a[i]])[0]
|
||
if len(possible_negs) == 0:
|
||
return a, p, None
|
||
|
||
dist_neg = dist_mat[a[i], possible_negs]
|
||
if is_inverted:
|
||
curr_loss = -dist_mat[a[i], p[i]] + dist_neg + margin
|
||
else:
|
||
curr_loss = dist_mat[a[i], p[i]] - dist_neg + margin
|
||
|
||
semihard_idxs = (curr_loss > 0) & (curr_loss < margin)
|
||
if len(possible_negs[semihard_idxs]) > 0:
|
||
possible_negs = possible_negs[semihard_idxs]
|
||
random_neg = np.random.choice(possible_negs.cpu().numpy())
|
||
selected_negs.append(random_neg)
|
||
n = torch.tensor(selected_negs, dtype=a.dtype, device=a.device)
|
||
return a, p, n
|
||
|
||
|
||
class TripletLoss(nn.Module):
|
||
def __init__(self, margin: float, triplet_selector, distance: distances.BaseDistance):
|
||
super(TripletLoss, self).__init__()
|
||
self.margin = margin
|
||
self.triplet_selector = triplet_selector
|
||
self.distance = distance
|
||
|
||
def forward(self, embeddings, pos_mask, neg_mask, other_embeddings=None):
|
||
if other_embeddings is None:
|
||
other_embeddings = embeddings
|
||
dist_mat = self.distance(embeddings, other_embeddings)
|
||
triplets = self.triplet_selector(
|
||
dist_mat, pos_mask, neg_mask, self.distance.is_inverted)
|
||
distance_positive = dist_mat[triplets[0], triplets[1]]
|
||
if triplets[-1] is None:
|
||
if self.distance.is_inverted:
|
||
return F.relu(1 - distance_positive).mean()
|
||
else:
|
||
return F.relu(distance_positive).mean()
|
||
distance_negative = dist_mat[triplets[0], triplets[2]]
|
||
curr_margin = self.distance.margin(
|
||
distance_positive, distance_negative)
|
||
loss = F.relu(curr_margin + self.margin)
|
||
return loss.mean()
|
||
|
||
|
||
def _pairwise_distance(x, squared=False, eps=1e-16):
|
||
# Compute the 2D matrix of distances between all the embeddings.
|
||
|
||
cor_mat = torch.matmul(x, x.t())
|
||
norm_mat = cor_mat.diag()
|
||
distances = norm_mat.unsqueeze(1) - 2 * cor_mat + norm_mat.unsqueeze(0)
|
||
distances = F.relu(distances)
|
||
|
||
if not squared:
|
||
mask = torch.eq(distances, 0.0).float()
|
||
distances = distances + mask * eps
|
||
distances = torch.sqrt(distances)
|
||
distances = distances * (1.0 - mask)
|
||
|
||
return distances
|
||
|
||
|
||
class TotalLoss(nn.Module):
|
||
def __init__(self, cfg):
|
||
super(TotalLoss, self).__init__()
|
||
if 'hardest' == cfg['negetative_selsector']:
|
||
neg_selector = hardest_negative_selector
|
||
elif 'semihard' == cfg['negetative_selsector']:
|
||
neg_selector = semihard_negative_selector
|
||
else:
|
||
neg_selector = random_negative_selector
|
||
self.trip_fun = TripletLoss(margin=cfg['trip_margin'], triplet_selector=neg_selector, distance=distances.LpDistance())
|
||
self.negetative_distcance = 50
|
||
|
||
def forward(self, batch_dict):
|
||
l_pose=l_score=l_match=l_tra=l_rot=l_gb=l_gi=l_gpa=l_gpo=l_kpl = 0
|
||
if 'key_points' in batch_dict.keys():
|
||
l_score = score_loss(batch_dict)
|
||
l_match1,l_pose1,l_match2,l_pose2,l_tra1,l_rot1,l_tra2,l_rot2=0,0,0,0,0,0,0,0
|
||
|
||
if 'transformation_original' in batch_dict.keys():
|
||
l_match1 = sinkhorn_matches_loss(batch_dict,'project_kpts_original')
|
||
l_tra1, l_rot1 = tr_loss(batch_dict,'transformation_original')
|
||
l_pose1 = pose_loss(batch_dict,'transformation_original')
|
||
if 'transformation_fusion' in batch_dict.keys():
|
||
l_match2 = sinkhorn_matches_loss(batch_dict,'project_kpts_fusion')
|
||
l_tra2, l_rot2 = tr_loss(batch_dict,'transformation_fusion')
|
||
l_pose2 = pose_loss(batch_dict,'transformation_fusion')
|
||
cnt=1
|
||
if min(l_rot1,l_rot2)>0:
|
||
cnt=2
|
||
l_match=(l_match1+l_match2)/cnt
|
||
l_pose=(l_pose1+l_pose2)/cnt
|
||
l_tra=(l_tra1+l_tra2)/cnt
|
||
l_rot=(l_rot1+l_rot2)/cnt
|
||
if ('fea_pt_dual_gen' in batch_dict.keys()) or ('fea_pl_dual_gen' in batch_dict.keys()):
|
||
l_gb, l_gi, l_gpa,l_kpl = gen_feature_loss(batch_dict)
|
||
if 'key_points_gen' in batch_dict.keys():
|
||
l_gpo = gen_points_loss(batch_dict)
|
||
if 'sequence' in batch_dict:
|
||
neg_mask = batch_dict['sequence'].view(1, -1) != batch_dict['sequence'].view(-1, 1)
|
||
else:
|
||
neg_mask = torch.zeros((batch_dict['pose_query'].shape[0] * 2, batch_dict['pose_query'].shape[0] * 2), dtype=torch.bool)
|
||
pair_dist = _pairwise_distance(batch_dict['pose_query'][:, 0:3, 3])
|
||
neg_mask = ((pair_dist > self.negetative_distcance) | neg_mask.to(pair_dist.device))
|
||
neg_mask = neg_mask.repeat(2, 2)
|
||
batch_size = batch_dict['batch_size']
|
||
pos_mask = torch.zeros((batch_size, batch_size), device=neg_mask.device)
|
||
|
||
for i in range(batch_size // 2):
|
||
pos_mask[i, i + batch_size // 2] = 1
|
||
pos_mask[i + batch_size // 2, i] = 1
|
||
l_triplet = self.trip_fun(batch_dict['vlads'], pos_mask, neg_mask)
|
||
l_total = l_score + l_pose + 0.05 * l_match + l_triplet + (l_gb + l_gi + l_gpa + l_kpl)
|
||
loss = [l_total, l_pose, l_score, l_match, l_triplet, l_tra, l_rot, l_gb, l_gi, l_gpa, l_gpo,l_kpl]
|
||
for i in range(len(loss)):
|
||
if loss[i]==0:
|
||
loss[i]=loss[0]*0
|
||
batch_dict['loss']=loss
|
||
return batch_dict |