150 lines
5.9 KiB
Python
150 lines
5.9 KiB
Python
import math
|
|
import torch
|
|
import torch._utils
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torchvision.models import resnet
|
|
from typing import Optional, Callable
|
|
|
|
|
|
def compute_rigid_transform(points1, points2, weights):
|
|
"""Compute rigid transforms between two point clouds via weighted SVD.
|
|
Adapted from https://github.com/yewzijian/RPMNet/
|
|
Args:
|
|
points1 (torch.Tensor): (B, M, 3) coordinates of the first point cloud
|
|
points2 (torch.Tensor): (B, N, 3) coordinates of the second point cloud
|
|
weights (torch.Tensor): (B, M)
|
|
Returns:
|
|
Transform T (B, 3, 4) to get from points1 to points2, i.e. T*points1 = points2
|
|
"""
|
|
|
|
weights_normalized = weights[..., None] / (torch.sum(weights[..., None], dim=1, keepdim=True) + 1e-5)
|
|
centroid_a = torch.sum(points1 * weights_normalized, dim=1)
|
|
centroid_b = torch.sum(points2 * weights_normalized, dim=1)
|
|
a_centered = points1 - centroid_a[:, None, :]
|
|
b_centered = points2 - centroid_b[:, None, :]
|
|
cov = a_centered.transpose(-2, -1) @ (b_centered * weights_normalized)
|
|
|
|
# Compute rotation using Kabsch algorithm. Will compute two copies with +/-V[:,:3]
|
|
# and choose based on determinant to avoid flips
|
|
u, s, v = torch.svd(cov, some=False, compute_uv=True)
|
|
rot_mat_pos = v @ u.transpose(-1, -2)
|
|
v_neg = v.clone()
|
|
v_neg[:, :, 2] *= -1
|
|
rot_mat_neg = v_neg @ u.transpose(-1, -2)
|
|
rot_mat = torch.where(torch.det(rot_mat_pos)[:, None, None] > 0, rot_mat_pos, rot_mat_neg)
|
|
assert torch.all(torch.det(rot_mat) > 0)
|
|
|
|
# Compute translation (uncenter centroid)
|
|
translation = -rot_mat @ centroid_a[:, :, None] + centroid_b[:, :, None]
|
|
|
|
transform = torch.cat((rot_mat, translation), dim=2)
|
|
return transform
|
|
|
|
|
|
def sinkhorn_unbalanced(feature1, feature2, epsilon, gamma, max_iter, matrix='cosine'):
|
|
"""
|
|
Sinkhorn algorithm for Unbalanced Optimal Transport.
|
|
Modified from https://github.com/valeoai/FLOT/
|
|
Args:
|
|
feature1 (torch.Tensor):
|
|
(B, N, C) Point-wise features for points cloud 1.
|
|
feature2 (torch.Tensor):
|
|
(B, M, C) Point-wise features for points cloud 2.
|
|
epsilon (torch.Tensor):
|
|
Entropic regularization.
|
|
gamma (torch.Tensor):
|
|
Mass regularization.
|
|
max_iter (int):
|
|
Number of iteration of the Sinkhorn algorithm.
|
|
Returns:
|
|
T (torch.Tensor):
|
|
(B, N, M) Transport plan between point cloud 1 and 2.
|
|
"""
|
|
if matrix == 'cosine':
|
|
# Transport cost matrix
|
|
feature1 = feature1 / torch.sqrt(torch.sum(feature1 ** 2, -1, keepdim=True) + 1e-8)
|
|
feature2 = feature2 / torch.sqrt(torch.sum(feature2 ** 2, -1, keepdim=True) + 1e-8)
|
|
C = 1.0 - torch.bmm(feature1, feature2.transpose(1, 2))
|
|
elif matrix == 'euclidean':
|
|
distance_matrix = torch.sum(feature1 ** 2, -1, keepdim=True)
|
|
distance_matrix = distance_matrix + torch.sum(feature2 ** 2, -1, keepdim=True).transpose(1, 2)
|
|
distance_matrix = distance_matrix - 2 * torch.bmm(feature1, feature2.transpose(1, 2)) # c^2=a^2+b^2-2abcos
|
|
distance_matrix = distance_matrix ** 0.5
|
|
# d_max, _ = torch.max(distance_matrix, dim=2, keepdim=True)
|
|
C = distance_matrix
|
|
|
|
# Entropic regularisation
|
|
K = torch.exp(-C / epsilon) # * support
|
|
|
|
# Early return if no iteration
|
|
if max_iter == 0:
|
|
return K
|
|
|
|
# Init. of Sinkhorn algorithm
|
|
power = gamma / (gamma + epsilon + 1e-8)
|
|
a = (torch.ones((K.shape[0], K.shape[1], 1), device=feature1.device, dtype=feature1.dtype) / K.shape[1])
|
|
prob1 = (torch.ones((K.shape[0], K.shape[1], 1), device=feature1.device, dtype=feature1.dtype) / K.shape[1])
|
|
prob2 = (torch.ones((K.shape[0], K.shape[2], 1), device=feature2.device, dtype=feature2.dtype) / K.shape[2])
|
|
|
|
# Sinkhorn algorithm
|
|
for _ in range(max_iter):
|
|
# Update b
|
|
KTa = torch.bmm(K.transpose(1, 2), a)
|
|
b = torch.pow(prob2 / (KTa + 1e-8), power)
|
|
# Update a
|
|
Kb = torch.bmm(K, b)
|
|
a = torch.pow(prob1 / (Kb + 1e-8), power)
|
|
|
|
# Transportation map
|
|
T = torch.mul(torch.mul(a, K), b.transpose(1, 2))
|
|
return T
|
|
|
|
|
|
class UOTHead(nn.Module):
|
|
|
|
def __init__(self, nb_iter=5,name='original'):
|
|
super().__init__()
|
|
self.epsilon = torch.nn.Parameter(torch.zeros(1)) # Entropic regularisation
|
|
self.gamma = torch.nn.Parameter(torch.zeros(1)) # Mass regularisation
|
|
self.nb_iter = nb_iter
|
|
self.name=name
|
|
|
|
def forward(self, batch_dict, src_coords=None, mode='pairs'):
|
|
|
|
feats = batch_dict['fea_kpt_'+self.name].squeeze(-1)
|
|
|
|
B, C, NUM = feats.shape
|
|
|
|
assert B % 2 == 0, "Batch size must be multiple of 2: B anchor + B positive samples"
|
|
B = B // 2
|
|
feat1 = feats[:B]
|
|
feat2 = feats[B:]
|
|
|
|
coords = batch_dict['key_points']
|
|
coords1 = coords[:B, :, 0:3]
|
|
coords2 = coords[B:, :, 0:3]
|
|
|
|
correspondences_feature = sinkhorn_unbalanced(
|
|
feat1.permute(0, 2, 1),
|
|
feat2.permute(0, 2, 1),
|
|
epsilon=torch.exp(self.epsilon) + 0.03,
|
|
gamma=torch.exp(self.gamma),
|
|
max_iter=self.nb_iter,
|
|
matrix='cosine',
|
|
)
|
|
|
|
feature_corr_sum = correspondences_feature.sum(-1, keepdim=True)
|
|
project_kpts = (correspondences_feature @ coords2) / (feature_corr_sum + 1e-8)
|
|
project_feas = (correspondences_feature @ feat2.permute(0, 2, 1)) / (feature_corr_sum + 1e-8)
|
|
|
|
batch_dict['project_kpts_'+self.name] = project_kpts
|
|
batch_dict['project_feas_'+self.name] = project_feas.permute(0, 2, 1)
|
|
# batch_dict['project_coord_kpts'] = project_coord_kpts
|
|
batch_dict['correspondences_feature_'+self.name] = correspondences_feature
|
|
# batch_dict['correspondences_coord'] = correspondences_coord
|
|
|
|
transformation = compute_rigid_transform(coords1, project_kpts, feature_corr_sum.squeeze(-1))
|
|
batch_dict['transformation_'+self.name] = transformation
|
|
|
|
return batch_dict |