init fusion lcd orin config

This commit is contained in:
MobKBK
2026-03-04 20:07:57 +08:00
commit bc0498e453
42 changed files with 4750 additions and 0 deletions

432
train.py Normal file
View File

@@ -0,0 +1,432 @@
import argparse
import os
import time
import numpy as np
import torch
import torch.optim as optim
import yaml
import net
import tools
from dataset import KittiTotalLoader
from evaluate_lcd import lcd
from loss import TotalLoss
test_step = 10 # 保存测试点的步长
def save_checkpoint(model, optimizer, loss_total_fun, epoch, iter_train, path_result):
if (epoch + 1) % test_step == 0 and epoch+1>=test_step:
time_now = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
checkpoint = {'time': time_now,
'epoch': epoch,
'model': model.state_dict(),
'optimizer': optimizer.state_dict()}
savepath = tools.make_save_path(path_result, 'models')
torch.save(checkpoint, savepath + '/checkpoint_%03d.pth.tar' % epoch)
print(savepath + '/checkpoint_%03d.pth.tar is saved' % epoch)
class log_result():
def __init__(self,path_result):
self.path=path_result
if not os.path.exists(path_result):
with open(path_result, 'w') as file:
file.write('Time Sequence Epoch AP R100 F1 R@1 R@2 R@3 R@4 R@5')
file.write(' R@6 R@7 R@8 R@9 R@10 R@15 R@20 R@25\n')
for i in range(300):
file.write('\n')
def write(self,seq,epoch,row,x):
with open(self.path, 'r') as file:
lines = file.readlines()
time_now = time.strftime('%Y%m%d%H%M%S', time.localtime(time.time()))
new_content='%s %08d %06d'%(time_now,seq,epoch)
for x1 in x:
new_content=new_content + ' %.3f'%x1
lines[row] = new_content+'\n'
with open(self.path, 'w') as file:
file.writelines(lines)
def train(model, optimizer, loss_total_fun, data, device):
model.train()
sequences = data['sequence']
id_query = data['id_query']
id_positive = data['id_positive']
batchsize = len(id_query)
bev_query = data['bev_query'].to(device)
bev_positive = data['bev_positive'].to(device)
pose_query = data['pose_query'].to(device)
pose_positive = data['pose_positive'].to(device)
pose_to_frame = data['pose_to_frame'].to(device)
label_score = data['label_score'].to(device)
img_query = data['img_query'].to(device)
img_positive = data['img_positive'].to(device)
try:
bev = torch.cat([bev_query, bev_positive], dim=0)
bev = bev.permute(0, 3, 1, 2)
except:
bev = 0
try:
img = torch.cat([img_query, img_positive], dim=0)
img = img.permute(0, 3, 1, 2)
except:
img = 0
try:
relation = data['relation'].to(device)
except:
relation = 0
batch_dict = {'bev': bev,
'label_score': label_score,
'img': img,
'relation': relation,
'id_query': id_query,
'sequence': sequences,
'id_positive': id_positive,
'pose_to_frame': pose_to_frame,
'pose_query': pose_query,
'pose_positive': pose_positive,
'batch_size': int(batchsize * 2)}
model(batch_dict)
loss_total_fun(batch_dict)
l_total = batch_dict['loss'][0]
optimizer.zero_grad()
l_total.backward()
optimizer.step()
for p in model.parameters():
if torch.isnan(p).any():
print('Model NAN, ', p.shape)
exit()
return batch_dict
def validate(model, loss_total_fun, data, device):
model.eval()
with torch.no_grad():
sequences = data['sequence']
id_query = data['id_query']
id_positive = data['id_positive']
batchsize = len(id_query)
bev_query = data['bev_query'].to(device)
bev_positive = data['bev_positive'].to(device)
pose_query = data['pose_query'].to(device)
pose_positive = data['pose_positive'].to(device)
pose_to_frame = data['pose_to_frame'].to(device)
label_score = data['label_score'].to(device)
img_query = data['img_query'].to(device)
img_positive = data['img_positive'].to(device)
try:
bev = torch.cat([bev_query, bev_positive], dim=0)
bev = bev.permute(0, 3, 1, 2)
except:
bev = 0
try:
img = torch.cat([img_query, img_positive], dim=0)
img = img.permute(0, 3, 1, 2)
except:
img = 0
try:
relation = data['relation'].to(device)
except:
relation = 0
batch_dict = {'bev': bev,
'label_score': label_score,
'img': img,
'relation': relation,
'id_query': id_query,
'sequence': sequences,
'id_positive': id_positive,
'pose_to_frame': pose_to_frame,
'pose_query': pose_query,
'pose_positive': pose_positive,
'batch_size': int(batchsize * 2)}
model(batch_dict)
loss_total_fun(batch_dict)
return batch_dict
def test(model, data, device):
model.eval()
with torch.no_grad():
sequences = data['sequence']
id_query = data['id_query']
batchsize = len(id_query)
bev_query = data['bev_query'].to(device)
pose_query = data['pose_query'].to(device)
img_query = data['img_query'].to(device)
try:
bev = bev_query
bev = bev.permute(0, 3, 1, 2)
except:
bev = 0
try:
img = img_query
img = img.permute(0, 3, 1, 2)
except:
img = 0
try:
relation = data['relation'].to(device)
except:
relation = 0
batch_dict = {'bev': bev,
'img': img,
'relation': relation,
'id_query': id_query,
'sequence': sequences,
'pose_query': pose_query,
'batch_size': int(batchsize * 2)}
model(batch_dict)
return batch_dict
def main(args):
try:
with open(os.path.join(os.getcwd(), "config.yaml"), "r") as ymlfile:
cfg = yaml.load(ymlfile, Loader=yaml.SafeLoader)
print('Loading config file from %s' % os.path.join(os.getcwd(), "config.yaml"))
except:
with open(os.path.join(os.getcwd(), "project/FUSIONLCD/config.yaml"), "r") as ymlfile:
cfg = yaml.load(ymlfile, Loader=yaml.SafeLoader)
print('Loading config file from %s' % os.path.join(os.getcwd(), "project/BevNvLcd/config.yaml"))
cfg = cfg['experiment']
for k, v in cfg.items():
print(k, ':', v)
path_result = os.path.join(cfg['path_result'],args.result_name)
lres=log_result(os.path.join(os.getcwd(),'result',args.result_name+'.txt'))
device = torch.device("cuda" if torch.cuda.is_available() and cfg['cuda'] else "cpu")
start_epoch = 0
iter_train = 0
epochs = cfg['epochs']
model = net.Fusion(cfg)
print(model)
model = model.to(device)
loss_total_fun = TotalLoss(cfg).to(device)
print("Model params: %.6fM" % (sum(p.numel() for p in model.parameters()) / 1e6))
optimizer = optim.Adam(model.parameters(), lr=cfg['learning_rate'], betas=(cfg['beta1'], cfg['beta2']), eps=cfg['eps'], weight_decay=cfg['weight_decay'])
# optimizer = optim.Adam([{'params': model.bev.parameters(), 'lr': 0.0002},
# {'params': model.img.parameters(), 'lr': 0.0001},
# {'params': model.vlad_fusion_layer.parameters(), 'lr': 0.0001}],
# betas=(cfg['beta1'], cfg['beta2']), eps=cfg['eps'], weight_decay=cfg['weight_decay'])
# print(optimizer)
loader_train, loader_val, loader_test = KittiTotalLoader(cfg)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[1, 3, 5, 10, 50, 100], gamma=0.5, last_epoch=start_epoch - 1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, 0.99)
# scheduler = warmup(optimizer, 5, 1e-6, cfg['learning_rate'])
# writer = SummaryWriter(tools.make_save_path(path_result, 'tensorboard_log'))
t = tools.Timer()
test_best = np.zeros([len(loader_test.dataset.datasets), 3])
if cfg['load_model']:
checkpoint = torch.load((cfg['last_model']))
start_epoch = checkpoint['epoch'] + 1 * cfg['train_flag']
state_dict_saved = checkpoint['model']
model.load_state_dict(state_dict_saved)
optimizer.load_state_dict(checkpoint['optimizer'])
print('loaded %s' % cfg['last_model'])
if not cfg['train_flag']:
print_frequency = 1e9
else:
print_frequency = 1
for epoch in range(start_epoch, epochs):
torch.cuda.empty_cache()
'''
============================== train ===============================
'''
if cfg['train_flag']:
if epoch - start_epoch == 0:
pf = print_frequency
print_frequency = min(len(loader_train), print_frequency * 10)
else:
print_frequency = pf
l0, l1, l2, l3, l4, l5, l6, l7, l8, l9, l10,l11 = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
step_print = max(1, int(len(loader_train) / print_frequency))
step_now = 0
optimizer.zero_grad()
for id_sample, data in enumerate(loader_train):
batch_dict = train(model, optimizer, loss_total_fun, data, device)
# if (id_sample+1)%4==0 or (id_sample+1)==len(loader_train):
# optimizer.step()
# optimizer.zero_grad()
if step_now < step_print:
step_now = step_now + 1
l0 = l0 + batch_dict['loss'][0]
l1 = l1 + batch_dict['loss'][1]
l2 = l2 + batch_dict['loss'][2]
l3 = l3 + batch_dict['loss'][3]
l4 = l4 + batch_dict['loss'][4]
l5 = l5 + batch_dict['loss'][5]
l6 = l6 + batch_dict['loss'][6]
l7 = l7 + batch_dict['loss'][7]
l8 = l8 + batch_dict['loss'][8]
l9 = l9 + batch_dict['loss'][9]
l10 = l10 + batch_dict['loss'][10]
l11 = l11 + batch_dict['loss'][11]
if step_now == step_print:
step_now = 0
info = 'loss a%.3f p%.3f s%.3f m%.3f t%.3f tr%.3f_%.1f genb%.3f geni%.3f genpa%.3f genpo%.3f genkpl%.3f' % (
l0 / step_print, l1 / step_print, l2 / step_print, l3 / step_print,
l4 / step_print, l5 / step_print, l6 / step_print, l7 / step_print,
l8 / step_print, l9 / step_print, l10 / step_print, l11 / step_print)
t.update("Epoch %03d | train %04d/%04d | %s" %
(epoch, id_sample, len(loader_train) - 1, info))
l0, l1, l2, l3, l4, l5, l6, l7, l8, l9, l10, l11 = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
save_checkpoint(model, optimizer, loss_total_fun, epoch, iter_train, path_result)
scheduler.step()
'''
============================= validate =============================
'''
if cfg['validate_flag'] and (epoch + 1) % test_step == 0:
l0, l1, l2, l3, l4, l5, l6, l7, l8, l9, l10,l11 = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
step_print = max(1, int(len(loader_val) / print_frequency))
step_now = 0
for id_sample, data in enumerate(loader_val):
batch_dict = validate(model, loss_total_fun, data, device)
if step_now < step_print:
step_now = step_now + 1
l0 = l0 + batch_dict['loss'][0]
l1 = l1 + batch_dict['loss'][1]
l2 = l2 + batch_dict['loss'][2]
l3 = l3 + batch_dict['loss'][3]
l4 = l4 + batch_dict['loss'][4]
l5 = l5 + batch_dict['loss'][5]
l6 = l6 + batch_dict['loss'][6]
l7 = l7 + batch_dict['loss'][7]
l8 = l8 + batch_dict['loss'][8]
l9 = l9 + batch_dict['loss'][9]
l10 = l10 + batch_dict['loss'][10]
l11 = l11 + batch_dict['loss'][11]
if step_now == step_print:
step_now = 0
info = 'loss a%.3f p%.3f s%.3f m%.3f t%.3f tr%.3f_%.1f genb%.3f geni%.3f genpa%.3f genpo%.3f genkpl%.3f' % (
l0 / step_print, l1 / step_print, l2 / step_print, l3 / step_print,
l4 / step_print, l5 / step_print, l6 / step_print, l7 / step_print,
l8 / step_print, l9 / step_print, l10 / step_print, l11 / step_print)
t.update("Epoch %03d | validate %04d/%04d | %s" %
(epoch, id_sample, len(loader_val) - 1, info))
l0, l1, l2, l3, l4, l5, l6, l7, l8, l9, l10,l11 = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
'''
============================== test ================================
'''
if cfg['test_flag'] and (epoch + 1) % test_step == 0:
step_print = max(1, int(len(loader_test) / print_frequency))
step_now = 0
vlads = []
kpts = []
feas_original = []
feas_fusion = []
sequences = []
poses = []
for id_sample, data in enumerate(loader_test):
batch_dict = test(model, data, device)
# save_figure(batch_dict, epoch, path_result, cfg)
sequences.append((batch_dict['sequence']).detach().cpu())
vlads.append(batch_dict['vlads'].detach().cpu())
poses.append(batch_dict['pose_query'].detach().cpu())
kpts.append(batch_dict['key_points'].detach().cpu())
if 'fea_kpt_fusion' in batch_dict.keys():
feas_fusion.append(batch_dict['fea_kpt_fusion'].detach().cpu().permute(0, 2, 1))
feas_original.append(batch_dict['fea_kpt_original'].detach().cpu().permute(0, 2, 1))
if step_now < step_print:
step_now = step_now + 1
if step_now == step_print:
step_now = 0
t.update("Epoch %03d | test %05d/%05d" % (epoch, id_sample, len(loader_test)))
vlads = torch.cat(vlads)
kpts = torch.cat(kpts)
feas_original = torch.cat(feas_original)
if 'fea_kpt_fusion' in batch_dict.keys():
feas_fusion = torch.cat(feas_fusion)
else:
feas_fusion=feas_original
poses = torch.cat(poses)
sequences = torch.cat(sequences)
database = {'vlads': vlads,
'key_points': kpts,
'fea_kpt_original': feas_original,
'fea_kpt_fusion': feas_fusion,
'fea_kpt': feas_fusion,
'sequences': sequences,
'pose_query': poses}
savepath = tools.make_save_path(path_result, 'database')
torch.save(database, savepath + '/database_bevp.pth.tar')
# print('save ' + savepath + '/database_%03d.pth.tar' % epoch)
# exit()
# database = torch.load('/data4/caodanyang/results/FUSIONLCD/07250/database/database_159.pth.tar')
print()
print('***************************************************************************************************************************************')
print('Epoch %03d' % epoch)
# feature_match(loader_val,database)
result,recall_at_k = lcd(database)
seq = torch.unique(sequences)
for i in range((test_best.shape[0])):
recall_at_k1=recall_at_k[i]
for j in range((test_best.shape[1])):
test_best[i, j] = max([test_best[i, j], result[i][j]])
print('Best, sequence %02d, AP=%.3f, R100=%.3f, F1=%.3f' % (seq[i], test_best[i, 0], test_best[i, 1],test_best[i, 2]))
lres.write(seq[i],epoch,(epoch+1)//test_step+i*(epochs)//test_step,
[result[i][0],result[i][1],result[i][2],recall_at_k1[0],recall_at_k1[1],recall_at_k1[2],recall_at_k1[3],recall_at_k1[4],
recall_at_k1[5],recall_at_k1[6],recall_at_k1[7],recall_at_k1[8],recall_at_k1[9],recall_at_k1[14],recall_at_k1[19],recall_at_k1[24]
])
# print('Sequence %02d, AP=%.3f[%.3f], R100=%.3f[%.3f], F1=%.3f[%.3f], Recall@1[%.3f] 2[%.3f] 5[%.3f] 10[%.3f] 15[%.3f] 25[%.3f]' %
# (sequences[i], result[i][0], test_best[i, 0], result[i][1], test_best[i, 1], result[i][2], test_best[i, 2],
# recall_at_k1[0],recall_at_k1[1],recall_at_k1[4],recall_at_k1[9],recall_at_k1[14],recall_at_k1[24]))
print('***************************************************************************************************************************************')
print()
if cfg['train_flag']:
pass
else:
exit()
# exit()
if __name__ == '__main__':
# CUDA_VISIBLE_DEVICES=2 nohup python -u train.py --result_name=08280 --info=cosim >log/08280.log 2>&1 &
# fuser /dev/nvidia*
parser = argparse.ArgumentParser()
parser.add_argument('--result_name', type=str, default='log', help='log name of result')
parser.add_argument('--pro_name', type=str, default='python', help='name of process')
parser.add_argument('--info', type=str, default='python', help='name of process')
parser.add_argument('--gpu', type=str, default=None, help="GPU id(s), e.g. '0' or '0,1'. Use 'cpu' to force CPU.")
args = parser.parse_args()
# set visible GPUs before any CUDA call / seed
if args.gpu:
if args.gpu.lower() == 'cpu':
# force CPU by hiding GPUs
os.environ['CUDA_VISIBLE_DEVICES'] = ''
else:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
np.random.seed(123)
torch.manual_seed(123)
# only call cuda seed if CUDA visible
if torch.cuda.is_available():
torch.cuda.manual_seed(123)
print(args.info)
try:
print("Using GPU device:", os.environ.get("CUDA_VISIBLE_DEVICES", ""))
except:
pass
main(args)