Files
fusion_LCD/uot.py
2026-03-04 20:07:57 +08:00

150 lines
5.9 KiB
Python

import math
import torch
import torch._utils
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet
from typing import Optional, Callable
def compute_rigid_transform(points1, points2, weights):
"""Compute rigid transforms between two point clouds via weighted SVD.
Adapted from https://github.com/yewzijian/RPMNet/
Args:
points1 (torch.Tensor): (B, M, 3) coordinates of the first point cloud
points2 (torch.Tensor): (B, N, 3) coordinates of the second point cloud
weights (torch.Tensor): (B, M)
Returns:
Transform T (B, 3, 4) to get from points1 to points2, i.e. T*points1 = points2
"""
weights_normalized = weights[..., None] / (torch.sum(weights[..., None], dim=1, keepdim=True) + 1e-5)
centroid_a = torch.sum(points1 * weights_normalized, dim=1)
centroid_b = torch.sum(points2 * weights_normalized, dim=1)
a_centered = points1 - centroid_a[:, None, :]
b_centered = points2 - centroid_b[:, None, :]
cov = a_centered.transpose(-2, -1) @ (b_centered * weights_normalized)
# Compute rotation using Kabsch algorithm. Will compute two copies with +/-V[:,:3]
# and choose based on determinant to avoid flips
u, s, v = torch.svd(cov, some=False, compute_uv=True)
rot_mat_pos = v @ u.transpose(-1, -2)
v_neg = v.clone()
v_neg[:, :, 2] *= -1
rot_mat_neg = v_neg @ u.transpose(-1, -2)
rot_mat = torch.where(torch.det(rot_mat_pos)[:, None, None] > 0, rot_mat_pos, rot_mat_neg)
assert torch.all(torch.det(rot_mat) > 0)
# Compute translation (uncenter centroid)
translation = -rot_mat @ centroid_a[:, :, None] + centroid_b[:, :, None]
transform = torch.cat((rot_mat, translation), dim=2)
return transform
def sinkhorn_unbalanced(feature1, feature2, epsilon, gamma, max_iter, matrix='cosine'):
"""
Sinkhorn algorithm for Unbalanced Optimal Transport.
Modified from https://github.com/valeoai/FLOT/
Args:
feature1 (torch.Tensor):
(B, N, C) Point-wise features for points cloud 1.
feature2 (torch.Tensor):
(B, M, C) Point-wise features for points cloud 2.
epsilon (torch.Tensor):
Entropic regularization.
gamma (torch.Tensor):
Mass regularization.
max_iter (int):
Number of iteration of the Sinkhorn algorithm.
Returns:
T (torch.Tensor):
(B, N, M) Transport plan between point cloud 1 and 2.
"""
if matrix == 'cosine':
# Transport cost matrix
feature1 = feature1 / torch.sqrt(torch.sum(feature1 ** 2, -1, keepdim=True) + 1e-8)
feature2 = feature2 / torch.sqrt(torch.sum(feature2 ** 2, -1, keepdim=True) + 1e-8)
C = 1.0 - torch.bmm(feature1, feature2.transpose(1, 2))
elif matrix == 'euclidean':
distance_matrix = torch.sum(feature1 ** 2, -1, keepdim=True)
distance_matrix = distance_matrix + torch.sum(feature2 ** 2, -1, keepdim=True).transpose(1, 2)
distance_matrix = distance_matrix - 2 * torch.bmm(feature1, feature2.transpose(1, 2)) # c^2=a^2+b^2-2abcos
distance_matrix = distance_matrix ** 0.5
# d_max, _ = torch.max(distance_matrix, dim=2, keepdim=True)
C = distance_matrix
# Entropic regularisation
K = torch.exp(-C / epsilon) # * support
# Early return if no iteration
if max_iter == 0:
return K
# Init. of Sinkhorn algorithm
power = gamma / (gamma + epsilon + 1e-8)
a = (torch.ones((K.shape[0], K.shape[1], 1), device=feature1.device, dtype=feature1.dtype) / K.shape[1])
prob1 = (torch.ones((K.shape[0], K.shape[1], 1), device=feature1.device, dtype=feature1.dtype) / K.shape[1])
prob2 = (torch.ones((K.shape[0], K.shape[2], 1), device=feature2.device, dtype=feature2.dtype) / K.shape[2])
# Sinkhorn algorithm
for _ in range(max_iter):
# Update b
KTa = torch.bmm(K.transpose(1, 2), a)
b = torch.pow(prob2 / (KTa + 1e-8), power)
# Update a
Kb = torch.bmm(K, b)
a = torch.pow(prob1 / (Kb + 1e-8), power)
# Transportation map
T = torch.mul(torch.mul(a, K), b.transpose(1, 2))
return T
class UOTHead(nn.Module):
def __init__(self, nb_iter=5,name='original'):
super().__init__()
self.epsilon = torch.nn.Parameter(torch.zeros(1)) # Entropic regularisation
self.gamma = torch.nn.Parameter(torch.zeros(1)) # Mass regularisation
self.nb_iter = nb_iter
self.name=name
def forward(self, batch_dict, src_coords=None, mode='pairs'):
feats = batch_dict['fea_kpt_'+self.name].squeeze(-1)
B, C, NUM = feats.shape
assert B % 2 == 0, "Batch size must be multiple of 2: B anchor + B positive samples"
B = B // 2
feat1 = feats[:B]
feat2 = feats[B:]
coords = batch_dict['key_points']
coords1 = coords[:B, :, 0:3]
coords2 = coords[B:, :, 0:3]
correspondences_feature = sinkhorn_unbalanced(
feat1.permute(0, 2, 1),
feat2.permute(0, 2, 1),
epsilon=torch.exp(self.epsilon) + 0.03,
gamma=torch.exp(self.gamma),
max_iter=self.nb_iter,
matrix='cosine',
)
feature_corr_sum = correspondences_feature.sum(-1, keepdim=True)
project_kpts = (correspondences_feature @ coords2) / (feature_corr_sum + 1e-8)
project_feas = (correspondences_feature @ feat2.permute(0, 2, 1)) / (feature_corr_sum + 1e-8)
batch_dict['project_kpts_'+self.name] = project_kpts
batch_dict['project_feas_'+self.name] = project_feas.permute(0, 2, 1)
# batch_dict['project_coord_kpts'] = project_coord_kpts
batch_dict['correspondences_feature_'+self.name] = correspondences_feature
# batch_dict['correspondences_coord'] = correspondences_coord
transformation = compute_rigid_transform(coords1, project_kpts, feature_corr_sum.squeeze(-1))
batch_dict['transformation_'+self.name] = transformation
return batch_dict