init fusion lcd orin config
This commit is contained in:
393
BEVNet.py
Normal file
393
BEVNet.py
Normal file
@@ -0,0 +1,393 @@
|
||||
import math
|
||||
import torch
|
||||
import torch._utils
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Callable
|
||||
from torchvision.models import resnet
|
||||
|
||||
|
||||
class RIConv2d(nn.Module):
|
||||
def __init__(self, in_channel, out_channel, kernel_size=1, stride=1, padding=0, bias=True):
|
||||
super().__init__()
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
self.use_bias = bias
|
||||
idx = torch.arange(kernel_size ** 2).view(-1, 1)
|
||||
row = torch.div(idx, kernel_size, rounding_mode='floor')
|
||||
col = torch.fmod(idx, kernel_size)
|
||||
idx = torch.cat([row, col], dim=1)
|
||||
dis = (idx - 0.5 * (kernel_size - 1)).norm(dim=1) + 0.5 * (kernel_size % 2 - 1)
|
||||
dis = dis.view(kernel_size, kernel_size)
|
||||
dis = torch.round(dis).long()
|
||||
dis[dis > 0.5 * (kernel_size - 1)] = -1
|
||||
self.mask = dis
|
||||
self.number = int(torch.max(dis).item() + 1)
|
||||
self.weight = torch.zeros([kernel_size, kernel_size, out_channel, in_channel])
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.rand([out_channel, ]))
|
||||
else:
|
||||
self.bias = None
|
||||
self.weight1 = torch.nn.Parameter(torch.rand([self.number, out_channel, in_channel]))
|
||||
|
||||
def forward(self, x):
|
||||
weight = self.weight.to(self.weight1.device)
|
||||
for i in range(self.number):
|
||||
mask = self.mask == i
|
||||
weight[mask] = self.weight1[i]
|
||||
weight = weight.permute(2, 3, 0, 1)
|
||||
y = torch.nn.functional.conv2d(x, weight, self.bias, self.stride, self.padding)
|
||||
return y
|
||||
|
||||
def __repr__(self):
|
||||
return f"RIConv2d(in_channel={self.weight.shape[3]}, out_channel={self.weight.shape[2]}," \
|
||||
f" kernel_size={self.weight.shape[0]}, stride={self.stride}, padding={self.padding}, bias={self.bias is not None})"
|
||||
|
||||
|
||||
class RIMaxpool2d(nn.Module):
|
||||
def __init__(self, kernel_size=1, stride=1, padding=0):
|
||||
super().__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
idx = torch.arange(kernel_size ** 2).view(-1, 1)
|
||||
row = torch.div(idx, kernel_size, rounding_mode='floor')
|
||||
col = torch.fmod(idx, kernel_size)
|
||||
idx = torch.cat([row, col], dim=1)
|
||||
dis = (idx - 0.5 * (kernel_size - 1)).norm(dim=1) + 0.5 * (kernel_size % 2 - 1)
|
||||
dis = dis.view(kernel_size, kernel_size)
|
||||
dis = torch.round(dis)
|
||||
dis[dis > 0.5 * (kernel_size - 1)] = -1
|
||||
self.mask = dis.view(-1, ) > -1
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
h_out = math.floor((H + 2 * self.padding - (self.kernel_size - 1) - 1) / self.stride + 1)
|
||||
w_out = math.floor((W + 2 * self.padding - (self.kernel_size - 1) - 1) / self.stride + 1)
|
||||
unfold_x = torch.nn.functional.unfold(x, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
|
||||
y = unfold_x.view(B, C, self.kernel_size * self.kernel_size, h_out, w_out)
|
||||
y = y.permute(2, 0, 1, 3, 4)
|
||||
y1 = y[self.mask]
|
||||
y_max = torch.max(y1, dim=0, keepdim=False)[0]
|
||||
return y_max
|
||||
|
||||
def __repr__(self):
|
||||
return f"RIMaxpool2d(kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding})"
|
||||
|
||||
|
||||
class RIAvgpool2d(nn.Module):
|
||||
def __init__(self, kernel_size=1, stride=1, padding=0):
|
||||
super().__init__()
|
||||
self.padding = padding
|
||||
self.stride = stride
|
||||
idx = torch.arange(kernel_size ** 2).view(-1, 1)
|
||||
row = torch.div(idx, kernel_size, rounding_mode='floor')
|
||||
col = torch.fmod(idx, kernel_size)
|
||||
idx = torch.cat([row, col], dim=1)
|
||||
dis = (idx - 0.5 * (kernel_size - 1)).norm(dim=1) + 0.5 * (kernel_size % 2 - 1)
|
||||
dis = dis.view(kernel_size, kernel_size)
|
||||
dis = torch.round(dis)
|
||||
dis[dis > 0.5 * (kernel_size - 1)] = -1
|
||||
mask = dis > -1
|
||||
self.number = torch.sum(mask)
|
||||
self.weight = torch.zeros([kernel_size, kernel_size, 1, 1])
|
||||
self.weight[mask] = 1
|
||||
|
||||
def forward(self, x):
|
||||
weight = self.weight.to(x.device)
|
||||
weight = weight.permute(2, 3, 0, 1)
|
||||
weight = weight.repeat(x.shape[1], 1, 1, 1)
|
||||
sum = torch.nn.functional.conv2d(x, weight, None, self.stride, self.padding, groups=x.shape[1])
|
||||
avg = sum / self.number
|
||||
return avg
|
||||
|
||||
def __repr__(self):
|
||||
return f"RIAvgpool2d(kernel_size={self.weight.shape[0]}, stride={self.stride}, padding={self.padding})"
|
||||
|
||||
|
||||
class RIConvBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels,
|
||||
gate: Optional[Callable[..., nn.Module]] = None,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None):
|
||||
super().__init__()
|
||||
if gate is None:
|
||||
self.gate = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.gate = gate
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self.conv1 = RIConv2d(in_channel=in_channels, out_channel=out_channels, kernel_size=5, padding=2, bias=False)
|
||||
self.bn1 = norm_layer(out_channels)
|
||||
self.conv2 = RIConv2d(in_channel=out_channels, out_channel=out_channels, kernel_size=5, padding=2, bias=False)
|
||||
self.bn2 = norm_layer(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W
|
||||
x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W
|
||||
return x
|
||||
|
||||
class RIResBlock(nn.Module):
|
||||
expansion: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
gate: Optional[Callable[..., nn.Module]] = None,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(RIResBlock, self).__init__()
|
||||
if gate is None:
|
||||
self.gate = nn.ReLU(inplace=True)
|
||||
else:
|
||||
self.gate = gate
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('ResBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in ResBlock")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = RIConv2d(in_channel=inplanes, out_channel=planes, kernel_size=5, stride=1, padding=2, bias=False)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.conv2 = RIConv2d(in_channel=planes, out_channel=planes, kernel_size=5, stride=1, padding=2, bias=False)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.gate(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.gate(out)
|
||||
|
||||
return out
|
||||
|
||||
class RICNN(nn.Module):
|
||||
def __init__(self, c1: int = 8, c2: int = 16, c3: int = 32, c4: int = 64, dim: int = 64
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.gate = nn.ReLU(inplace=True)
|
||||
self.pool2 = RIMaxpool2d(kernel_size=2, stride=2)
|
||||
self.pool4 = RIMaxpool2d(kernel_size=5, stride=4, padding=1)
|
||||
self.block1 = RIConvBlock(3, c1, self.gate, nn.BatchNorm2d)
|
||||
self.block2 = RIResBlock(inplanes=c1, planes=c2, stride=1,
|
||||
downsample=nn.Conv2d(c1, c2, 1),
|
||||
gate=self.gate,
|
||||
norm_layer=nn.BatchNorm2d)
|
||||
self.block3 = RIResBlock(inplanes=c2, planes=c3, stride=1,
|
||||
downsample=nn.Conv2d(c2, c3, 1),
|
||||
gate=self.gate,
|
||||
norm_layer=nn.BatchNorm2d)
|
||||
self.block4 = RIResBlock(inplanes=c3, planes=c4, stride=1,
|
||||
downsample=nn.Conv2d(c3, c4, 1),
|
||||
gate=self.gate,
|
||||
norm_layer=nn.BatchNorm2d)
|
||||
|
||||
self.conv1 = resnet.conv1x1(c1, dim // 4)
|
||||
self.conv2 = resnet.conv1x1(c2, dim // 4)
|
||||
self.conv3 = resnet.conv1x1(c3, dim // 4)
|
||||
self.conv4 = resnet.conv1x1(dim, dim // 4)
|
||||
self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
||||
self.upsample3 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
|
||||
self.upsample4 = nn.Upsample(scale_factor=32, mode='bilinear', align_corners=True)
|
||||
self.out = nn.Conv2d(dim, dim + 1, 1)
|
||||
|
||||
def forward(self, image):
|
||||
x1 = self.block1(image)
|
||||
x2 = self.pool2(x1)
|
||||
x2 = self.block2(x2)
|
||||
x3 = self.pool4(x2)
|
||||
x3 = self.block3(x3)
|
||||
x4 = self.pool4(x3)
|
||||
x4 = self.block4(x4)
|
||||
x1 = self.gate(self.conv1(x1))
|
||||
x2 = self.gate(self.conv2(x2))
|
||||
x3 = self.gate(self.conv3(x3))
|
||||
x4 = self.gate(self.conv4(x4))
|
||||
x2_up = self.upsample2(x2)
|
||||
x3_up = self.upsample3(x3)
|
||||
x4_up = self.upsample4(x4)
|
||||
x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1)
|
||||
y = self.out(x1234)
|
||||
descriptor_map = y[:, :-1, :, :]
|
||||
scores_map = torch.sigmoid(y[:, -1, :, :]).unsqueeze(1)
|
||||
return scores_map, descriptor_map
|
||||
|
||||
def ri2maxpool(self, pool):
|
||||
stride = pool.stride
|
||||
pool_new = nn.MaxPool2d(stride)
|
||||
return pool_new
|
||||
|
||||
def maxpool2ri(self, pool):
|
||||
kernel_size = stride = pool.stride
|
||||
ds = round((math.sqrt(2) - 1) / 2 * stride - 0.25 * (stride % 2 - 1))
|
||||
kernel_size = kernel_size + ds
|
||||
pool_new = RIMaxpool2d(kernel_size, stride, ds)
|
||||
return pool_new
|
||||
|
||||
def ri2avgpool(self, pool):
|
||||
stride = pool.stride
|
||||
pool_new = nn.AvgPool2d(stride)
|
||||
return pool_new
|
||||
|
||||
def avgpool2ri(self, pool):
|
||||
kernel_size = stride = pool.stride
|
||||
if stride > 3:
|
||||
kernel_size = kernel_size + 1
|
||||
pool_new = RIAvgpool2d(kernel_size, stride)
|
||||
return pool_new
|
||||
|
||||
def ri2conv(self, conv):
|
||||
ri = conv
|
||||
weight = ri.weight
|
||||
device = ri.weight1.device
|
||||
bias = ri.bias
|
||||
use_bias = bias is not None
|
||||
weight_copy = weight.clone().to(device)
|
||||
for i in range(ri.number):
|
||||
mask = ri.mask == i
|
||||
weight_copy[mask] = ri.weight1[i]
|
||||
weight_copy = weight_copy.permute(2, 3, 0, 1)
|
||||
in_c = weight.shape[3]
|
||||
out_c = weight.shape[2]
|
||||
kz = weight.shape[0]
|
||||
sd = ri.stride
|
||||
pd = ri.padding
|
||||
conv_new = nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=kz, stride=sd, padding=pd, bias=use_bias)
|
||||
if use_bias:
|
||||
state_dict = {'weight': weight_copy, 'bias': bias}
|
||||
else:
|
||||
state_dict = {'weight': weight_copy}
|
||||
conv_new.load_state_dict(state_dict)
|
||||
return conv_new.to(device)
|
||||
|
||||
def conv2ri(self, conv):
|
||||
weight = conv.weight
|
||||
bias = conv.bias
|
||||
device = weight.device
|
||||
in_c = weight.shape[1]
|
||||
out_c = weight.shape[0]
|
||||
kz = weight.shape[2]
|
||||
if kz < 3:
|
||||
return conv
|
||||
sd = conv.stride
|
||||
pd = conv.padding
|
||||
idx = torch.arange(kz ** 2).view(-1, 1)
|
||||
row = torch.div(idx, kz, rounding_mode='floor')
|
||||
col = torch.fmod(idx, kz)
|
||||
idx = torch.cat([row, col], dim=1)
|
||||
dis = (idx - 0.5 * (kz - 1)).norm(dim=1) + 0.5 * (kz % 2 - 1)
|
||||
dis = dis.view(kz, kz)
|
||||
dis = torch.round(dis).long()
|
||||
dis[dis > 0.5 * (kz - 1)] = -1
|
||||
mask = dis
|
||||
number = int(torch.max(dis).item() + 1)
|
||||
weight1 = torch.rand([number, out_c, in_c]).to(device)
|
||||
weight2 = weight.clone()
|
||||
weight2 = weight2.permute(2, 3, 0, 1)
|
||||
used_bias = bias is not None
|
||||
for i in range(number):
|
||||
mask1 = mask == i
|
||||
w = weight2[mask1]
|
||||
weight1[i] = torch.mean(w, dim=0)
|
||||
if used_bias:
|
||||
state_dict = {'weight1': weight1, 'bias': bias}
|
||||
else:
|
||||
state_dict = {'weight1': weight1}
|
||||
conv_new = RIConv2d(in_channel=in_c, out_channel=out_c, kernel_size=kz, stride=sd, padding=pd, bias=used_bias)
|
||||
conv_new.load_state_dict(state_dict)
|
||||
return conv_new.to(device)
|
||||
|
||||
def disable_ri(self):
|
||||
modules = self.__dict__['_modules']
|
||||
for key, value in modules.items():
|
||||
if isinstance(value, RIMaxpool2d):
|
||||
setattr(self, key, self.ri2maxpool(value))
|
||||
if isinstance(value, RIAvgpool2d):
|
||||
setattr(self, key, self.ri2avgpool(value))
|
||||
if isinstance(value, RIConv2d):
|
||||
setattr(self, key, self.ri2conv(value))
|
||||
if 'block' in key:
|
||||
block = value
|
||||
block_modules = block.__dict__['_modules']
|
||||
for bkey, bvalue in block_modules.items():
|
||||
if isinstance(bvalue, RIMaxpool2d):
|
||||
setattr(block, bkey, self.ri2maxpool(bvalue))
|
||||
if isinstance(bvalue, RIAvgpool2d):
|
||||
setattr(block, bkey, self.ri2avgpool(bvalue))
|
||||
if isinstance(bvalue, RIConv2d):
|
||||
setattr(block, bkey, self.ri2conv(bvalue))
|
||||
modules[key] = block_modules
|
||||
setattr(self, key, block)
|
||||
|
||||
def enable_ri(self):
|
||||
modules = self.__dict__['_modules']
|
||||
for key, value in modules.items():
|
||||
if isinstance(value, nn.MaxPool2d):
|
||||
setattr(self, key, self.maxpool2ri(value))
|
||||
if isinstance(value, nn.AvgPool2d):
|
||||
setattr(self, key, self.avgpool2ri(value))
|
||||
if isinstance(value, nn.Conv2d):
|
||||
setattr(self, key, self.conv2ri(value))
|
||||
if 'block' in key:
|
||||
block = value
|
||||
block_modules = block.__dict__['_modules']
|
||||
for bkey, bvalue in block_modules.items():
|
||||
if isinstance(bvalue, nn.MaxPool2d):
|
||||
setattr(block, bkey, self.maxpool2ri(bvalue))
|
||||
if isinstance(bvalue, nn.AvgPool2d):
|
||||
setattr(block, bkey, self.avgpool2ri(bvalue))
|
||||
if isinstance(bvalue, nn.Conv2d):
|
||||
setattr(block, bkey, self.conv2ri(bvalue))
|
||||
modules[key] = block_modules
|
||||
setattr(self, key, block)
|
||||
|
||||
|
||||
class EncodePosition(nn.Module):
|
||||
def __init__(self, feature_size=128):
|
||||
super().__init__()
|
||||
self.bins = 16
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv1d(in_channels=self.bins, out_channels=feature_size//2, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm1d(feature_size//2), nn.ReLU(),
|
||||
nn.Conv1d(in_channels=feature_size//2, out_channels=feature_size//2, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm1d(feature_size//2), nn.ReLU(),
|
||||
nn.Conv1d(in_channels=feature_size//2, out_channels=feature_size, kernel_size=1, stride=1, padding=0, bias=True)
|
||||
)
|
||||
# self.conv2=(nn.Conv1d(in_channels=256,out_channels=128,kernel_size=1))
|
||||
|
||||
def forward(self, x, fea):
|
||||
b, n, c = x.shape
|
||||
x1 = x.unsqueeze(1)
|
||||
x2 = x.unsqueeze(2)
|
||||
dx = x1 - x2
|
||||
distance = dx.norm(p=2, dim=3)
|
||||
hists = torch.zeros([b, n, self.bins]).to(x.device)
|
||||
for i in range(b):
|
||||
for j in range(n):
|
||||
dis = distance[i, j]
|
||||
hist = torch.histc(dis, bins=self.bins, min=1, max=80)
|
||||
hists[i, j] = hist
|
||||
hists = hists / torch.sum(hists, dim=2, keepdim=True)
|
||||
x3 = hists.permute(0, 2, 1)
|
||||
x4 = self.conv1(x3)
|
||||
if hasattr(self, 'conv2'):
|
||||
x5 = torch.cat([fea, x4], dim=1)
|
||||
y = self.conv2(x5)
|
||||
else:
|
||||
y = fea + x4
|
||||
return y
|
||||
Reference in New Issue
Block a user