Files
fusion_LCD/netvlad.py
2026-03-04 20:07:57 +08:00

156 lines
5.6 KiB
Python

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class NetVLAD(nn.Module):
def __init__(self, fea_size=128, num_clusters=16):
super(NetVLAD, self).__init__()
self.num_clusters = num_clusters
self.conv = nn.Conv2d(fea_size, num_clusters, kernel_size=(1, 1), bias=True)
self.centroids = nn.Parameter(torch.randn(num_clusters, fea_size))
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
"""
x: B, C, H, W,W=1
"""
soft_assign = self.conv(x) # (B, num_clusters, H, W)
soft_assign = self.relu(soft_assign)
soft_assign = torch.nn.functional.softmax(soft_assign, dim=1) # (B, num_clusters, H, W)
# reshape for broadcasting
B, C, H, W = x.shape
soft_assign = soft_assign.view(B, self.num_clusters, -1) # (B, num_clusters, H * W)
x_flatten = x.view(B, C, -1) # (B, C, H * W)
# compute residuals
x_flatten1 = x_flatten.unsqueeze(1).permute(0, 1, 3, 2) # (B, 1, H*W, C)
centroids = self.centroids.unsqueeze(0).unsqueeze(2) # (1, num_cluster, 1, C)
residual = x_flatten1 - centroids # (B, num_clusters, H * W, C)
residual *= soft_assign.unsqueeze(-1) # (B, num_clusters, H * W, C)
# sum residuals and assign
vlad = residual.sum(dim=-2) # (B, num_clusters, C)
vlad = nn.functional.normalize(vlad, p=2, dim=2) # (B, num_clusters, C)
vlad = vlad.view(B, -1)
vlad = nn.functional.normalize(vlad, p=2, dim=1) # (B, num_clusters * C)
return vlad
class NetVLADLoupe(nn.Module):
"""
Original Tensorflow implementation: https://github.com/antoine77340/LOUPE
"""
def __init__(self, feature_size, cluster_size, output_dim,
gating=True, add_norm=True, is_training=True, normalization='batch'):
super(NetVLADLoupe, self).__init__()
self.feature_size = feature_size
# output_dim=cluster_size * feature_size
self.output_dim = output_dim
self.is_training = is_training
self.gating = gating
self.add_batch_norm = add_norm
self.cluster_size = cluster_size
if normalization == 'instance':
norm = lambda x: nn.LayerNorm(x)
elif normalization == 'group':
norm = lambda x: nn.GroupNorm(8, x)
else:
norm = lambda x: nn.BatchNorm1d(x)
self.softmax = nn.Softmax(dim=-1)
self.cluster_weights = nn.Parameter(torch.randn(feature_size, cluster_size) * 1 / math.sqrt(feature_size))
self.cluster_weights2 = nn.Parameter(torch.randn(1, feature_size, cluster_size) * 1 / math.sqrt(feature_size))
self.hidden1_weights = nn.Parameter(torch.randn(cluster_size * feature_size, output_dim) * 1 / math.sqrt(feature_size))
if add_norm:
self.cluster_biases = None
self.bn1 = norm(cluster_size)
else:
self.cluster_biases = nn.Parameter(torch.randn(cluster_size) * 1 / math.sqrt(feature_size))
self.bn1 = None
self.bn2 = norm(output_dim)
if gating:
self.context_gating = GatingContext(output_dim, add_batch_norm=add_norm, normalization=normalization)
def forward(self, x):
"""
x: B N C
"""
max_samples = x.shape[1]
activation = torch.matmul(x, self.cluster_weights)
if self.add_batch_norm:
activation = activation.view(-1, self.cluster_size)
activation = self.bn1(activation)
activation = activation.view(-1, max_samples, self.cluster_size)
else:
activation = activation + self.cluster_biases
activation = self.softmax(activation)
a_sum = activation.sum(-2, keepdim=True)
a = a_sum * self.cluster_weights2
activation = torch.transpose(activation, 2, 1)
x = x.view((-1, max_samples, self.feature_size))
vlad = torch.matmul(activation, x)
vlad = torch.transpose(vlad, 2, 1).contiguous()
vlad0 = vlad - a
vlad1 = F.normalize(vlad0, dim=1, p=2, eps=1e-6)
vlad2 = vlad1.view((-1, self.cluster_size * self.feature_size))
vlad = F.normalize(vlad2, dim=1, p=2, eps=1e-6)
vlad = torch.matmul(vlad, self.hidden1_weights)
vlad = self.bn2(vlad)
if self.gating:
vlad = self.context_gating(vlad)
# vlad = vlad / vlad.norm(dim=1, keepdim=True)
return vlad
class GatingContext(nn.Module):
"""
Original Tensorflow implementation: https://github.com/antoine77340/LOUPE
"""
def __init__(self, dim, add_batch_norm=True, normalization='batch'):
super(GatingContext, self).__init__()
self.dim = dim
self.add_batch_norm = add_batch_norm
if normalization == 'instance':
norm = lambda x: nn.LayerNorm(x)
elif normalization == 'group':
norm = lambda x: nn.GroupNorm(8, x)
else:
norm = lambda x: nn.BatchNorm1d(x)
self.gating_weights = nn.Parameter(torch.randn(dim, dim) * 1 / math.sqrt(dim))
self.sigmoid = nn.Sigmoid()
if add_batch_norm:
self.gating_biases = None
self.bn1 = norm(dim)
else:
self.gating_biases = nn.Parameter(torch.randn(dim) * 1 / math.sqrt(dim))
self.bn1 = None
def forward(self, x):
gates = torch.matmul(x, self.gating_weights)
if self.add_batch_norm:
gates = self.bn1(gates)
else:
gates = gates + self.gating_biases
gates = self.sigmoid(gates)
activation = x * gates
return activation