import threading import torch import os import time def farthest_point_sample(xyz, npoint): """Iterative farthest point sampling Args: xyz: pointcloud data_loader, [B, N, C] npoint: number of samples Returns: centroids: sampled pointcloud index, [B, npoint] """ device = xyz.device B, N, C = xyz.shape centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) distance = torch.ones(B, N).to(device) * 1e10 farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) batch_indices = torch.arange(B, dtype=torch.long).to(device) for i in range(npoint): centroids[:, i] = farthest centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) dist = torch.sum((xyz - centroid) ** 2, -1) mask = dist < distance distance[mask] = dist[mask] farthest = torch.max(distance, -1)[1] return centroids def batch_distance(feature1,feature2,mode='cosine'): if mode == 'cosine': # Transport cost matrix feature1 = feature1 / torch.sqrt(torch.sum(feature1 ** 2, -1, keepdim=True) + 1e-8) feature2 = feature2 / torch.sqrt(torch.sum(feature2 ** 2, -1, keepdim=True) + 1e-8) dis = 1.0 - torch.bmm(feature1, feature2.transpose(1, 2)) elif mode == 'euclidean': feature=torch.cat([feature1,feature2],dim=1) feature_mean=torch.mean(feature,dim=1,keepdim=True) feature1=feature1-feature_mean feature2=feature2-feature_mean distance_matrix = torch.sum(feature1 ** 2, -1, keepdim=True) distance_matrix = distance_matrix + torch.sum(feature2 ** 2, -1, keepdim=True).transpose(1, 2) distance_matrix = distance_matrix - 2 * torch.bmm(feature1, feature2.transpose(1, 2)) # c^2=a^2+b^2-2abcos distance_matrix = distance_matrix ** 0.5 dis = distance_matrix return dis def nn_match(fea1, fea2, matrix='cosine'): assert len(fea1.shape) == 2 and len(fea2.shape) == 2, 'nnmatch error' if not isinstance(fea1, torch.Tensor): fea1 = torch.tensor(fea1) if not isinstance(fea2, torch.Tensor): fea2 = torch.tensor(fea2) if matrix == 'cosine': # Transport cost matrix fea1 = fea1 / torch.sqrt(torch.sum(fea1 ** 2, -1, keepdim=True) + 1e-8) fea2 = fea2 / torch.sqrt(torch.sum(fea2 ** 2, -1, keepdim=True) + 1e-8) dis = 1.0 - torch.mm(fea1, fea2.transpose(0, 1)) elif matrix == 'euclidean': distance_matrix = torch.sum(fea1 ** 2, -1, keepdim=True) distance_matrix = distance_matrix + torch.sum(fea2 ** 2, -1, keepdim=True).transpose(0, 1) distance_matrix = distance_matrix - 2 * torch.mm(fea1, fea2.transpose(0, 1)) # c^2=a^2+b^2-2abcos dis = distance_matrix ** 0.5 else: dis = 0 print('Invalid matrix') idx0_min = torch.argmin(dis, dim=0) idx1_min = torch.argmin(dis, dim=1) ids1 = torch.arange(0, dis.shape[1]).to(fea1.device) idx = idx1_min[idx0_min] idx_match = ids1 == idx idx1 = ids1[idx_match] idx2 = idx0_min[idx_match] dis_min = dis[idx2, idx1] return idx2, idx1, dis_min def path_join(*args): names = list(args) path = names[0] for i in range(len(names) - 1): path = os.path.join(path, names[i + 1]) path = list(path) while "\\" in path: idx = path.index("\\") path[idx] = "/" path = ''.join(path) return path def make_save_path(*args): path = path_join(*args) if not os.path.exists(path): os.makedirs(path) return path def read_cfg(data): if type(data) is int: result = [data] else: result = data.split(',') return result class Timer: """A module to record the program running time""" def __init__(self, name="Now"): self.strat = time.time() self.cnt = 0 self.end = time.time() self.avg = 0 self.all = 0 self.now = 0 self.name = name time_now = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) print('Init timer: ',time_now) def update(self, name=None): if name is not None: self.name = name self.cnt = self.cnt + 1 self.end = time.time() self.avg = (self.end - self.strat) / self.cnt self.now = self.end - self.all - self.strat self.all = self.end - self.strat time_now = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time())) if self.avg<1: print("%s | %s | using %d | each %.3f" % (time_now, self.name, self.all, self.now)) elif self.avg<10: print("%s | %s | using %d | each %.2f" % (time_now, self.name, self.all, self.now)) elif self.avg<100: print("%s | %s | using %d | each %.1f" % (time_now, self.name, self.all, self.now)) else: print("%s | %s | using %d | each %d" % (time_now, self.name, self.all, self.now)) if __name__ == '__main__': # draw_trace() pass