diff --git a/Cross_Fusion_of_Point_Cloud_and_Learned_Image_for_Loop_Closure_Detection.pdf b/Cross_Fusion_of_Point_Cloud_and_Learned_Image_for_Loop_Closure_Detection.pdf new file mode 100644 index 0000000..02f20a3 Binary files /dev/null and b/Cross_Fusion_of_Point_Cloud_and_Learned_Image_for_Loop_Closure_Detection.pdf differ diff --git a/auto_eval_checkpoints.py b/auto_eval_checkpoints.py new file mode 100644 index 0000000..869bb83 --- /dev/null +++ b/auto_eval_checkpoints.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +自动评估 FUSIONLCD 多个 checkpoint 的脚本 + +用法示例: +python auto_eval_checkpoints.py \ + --project_dir /home/adlab36/chenyouyuan/FUSIONLCD \ + --config /home/adlab36/chenyouyuan/FUSIONLCD/config.yaml \ + --train_script /home/adlab36/chenyouyuan/FUSIONLCD/train.py \ + --models_dir /home/adlab36/chenyouyuan/FUSIONLCD/result/log/models \ + --result_name auto_eval \ + --gpu 1 + +说明: +1. 会备份原 config.yaml 为 config.yaml.bak_auto_eval +2. 每个 checkpoint 测试前会把 config 改成: + - train_flag = 0 + - validate_flag = 0 + - test_flag = 1 + - load_model = 1 + - last_model = 当前 checkpoint +3. 每测完一个 checkpoint,会读取 result/.txt 追加的新结果 +4. 最终输出 summary.csv +""" + +from __future__ import annotations + +import argparse +import csv +import os +import re +import shutil +import subprocess +import sys +import time +from pathlib import Path +from typing import Dict, List, Tuple, Optional + +import yaml + + +CKPT_RE = re.compile(r"checkpoint_(\d+)\.pth\.tar$") +RESULT_LINE_RE = re.compile(r"^\d{14}\s+(\d+)\s+(\d+)\s+(.*)$") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--project_dir", type=str, required=True, help="项目根目录") + parser.add_argument("--config", type=str, required=True, help="config.yaml 路径") + parser.add_argument("--train_script", type=str, required=True, help="train.py 路径") + parser.add_argument("--models_dir", type=str, required=True, help="checkpoint 目录") + parser.add_argument("--result_name", type=str, default="auto_eval", help="train.py 的 result_name") + parser.add_argument("--gpu", type=str, default="0", help="GPU id,例如 0 或 1") + parser.add_argument("--epochs_filter", type=str, default="", help="只测试指定 epoch,逗号分隔,如 99,109,119") + parser.add_argument("--min_epoch", type=int, default=None, help="最小 epoch 过滤") + parser.add_argument("--max_epoch", type=int, default=None, help="最大 epoch 过滤") + parser.add_argument("--sleep_sec", type=float, default=1.0, help="每次测试后等待秒数") + return parser.parse_args() + + +def load_yaml(path: Path) -> dict: + with path.open("r", encoding="utf-8") as f: + return yaml.safe_load(f) + + +def save_yaml(path: Path, data: dict) -> None: + with path.open("w", encoding="utf-8") as f: + yaml.safe_dump(data, f, allow_unicode=True, sort_keys=False) + + +def list_checkpoints(models_dir: Path) -> List[Tuple[int, Path]]: + ckpts: List[Tuple[int, Path]] = [] + for p in models_dir.glob("checkpoint_*.pth.tar"): + m = CKPT_RE.match(p.name) + if m: + ckpts.append((int(m.group(1)), p)) + ckpts.sort(key=lambda x: x[0]) + return ckpts + + +def filter_checkpoints( + ckpts: List[Tuple[int, Path]], + epochs_filter: str, + min_epoch: Optional[int], + max_epoch: Optional[int], +) -> List[Tuple[int, Path]]: + selected = ckpts + + if epochs_filter.strip(): + wanted = {int(x.strip()) for x in epochs_filter.split(",") if x.strip()} + selected = [(e, p) for e, p in selected if e in wanted] + + if min_epoch is not None: + selected = [(e, p) for e, p in selected if e >= min_epoch] + + if max_epoch is not None: + selected = [(e, p) for e, p in selected if e <= max_epoch] + + return selected + + +def result_txt_path(project_dir: Path, result_name: str) -> Path: + return project_dir / "result" / f"{result_name}.txt" + + +def read_result_lines(path: Path) -> List[str]: + if not path.exists(): + return [] + with path.open("r", encoding="utf-8") as f: + return [line.rstrip("\n") for line in f.readlines()] + + +def parse_result_file(path: Path) -> List[Dict]: + rows: List[Dict] = [] + if not path.exists(): + return rows + + with path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line or line.startswith("Time"): + continue + + m = RESULT_LINE_RE.match(line) + if not m: + continue + + seq = int(m.group(1)) + epoch = int(m.group(2)) + rest = m.group(3).split() + + # 表头来自你的 log_result: + # AP R100 F1 R@1 R@2 R@3 R@4 R@5 R@6 R@7 R@8 R@9 R@10 R@15 R@20 R@25 + if len(rest) < 16: + continue + + vals = list(map(float, rest[:16])) + rows.append( + { + "seq": seq, + "epoch": epoch, + "AP": vals[0], + "R100": vals[1], + "F1": vals[2], + "R@1": vals[3], + "R@2": vals[4], + "R@3": vals[5], + "R@4": vals[6], + "R@5": vals[7], + "R@6": vals[8], + "R@7": vals[9], + "R@8": vals[10], + "R@9": vals[11], + "R@10": vals[12], + "R@15": vals[13], + "R@20": vals[14], + "R@25": vals[15], + "raw": line, + } + ) + return rows + + +def overwrite_test_config(config_path: Path, ckpt_path: Path) -> None: + cfg = load_yaml(config_path) + exp = cfg["experiment"] + + exp["train_flag"] = 0 + exp["validate_flag"] = 0 + exp["test_flag"] = 1 + exp["load_model"] = 1 + exp["last_model"] = str(ckpt_path) + + save_yaml(config_path, cfg) + + +def run_one_eval(project_dir: Path, train_script: Path, result_name: str, gpu: str) -> int: + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = gpu + + cmd = [ + sys.executable, + str(train_script), + "--result_name", + result_name, + "--gpu", + gpu, + "--info", + "auto_eval", + ] + + proc = subprocess.run( + cmd, + cwd=str(project_dir), + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + + print(proc.stdout) + return proc.returncode + + +def collect_epoch_rows(all_rows: List[Dict], epoch: int) -> List[Dict]: + return [r for r in all_rows if r["epoch"] == epoch] + + +def aggregate_rows(rows: List[Dict]) -> Dict[str, float]: + if not rows: + return {} + keys = ["AP", "R100", "F1", "R@1", "R@5", "R@10", "R@25"] + out = {} + for k in keys: + out[f"mean_{k}"] = sum(r[k] for r in rows) / len(rows) + return out + + +def save_summary_csv(path: Path, summary: List[Dict]) -> None: + if not summary: + return + fieldnames = list(summary[0].keys()) + with path.open("w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(summary) + + +def main() -> None: + args = parse_args() + + project_dir = Path(args.project_dir).resolve() + config_path = Path(args.config).resolve() + train_script = Path(args.train_script).resolve() + models_dir = Path(args.models_dir).resolve() + + if not project_dir.exists(): + raise FileNotFoundError(f"project_dir 不存在: {project_dir}") + if not config_path.exists(): + raise FileNotFoundError(f"config 不存在: {config_path}") + if not train_script.exists(): + raise FileNotFoundError(f"train_script 不存在: {train_script}") + if not models_dir.exists(): + raise FileNotFoundError(f"models_dir 不存在: {models_dir}") + + ckpts = list_checkpoints(models_dir) + ckpts = filter_checkpoints(ckpts, args.epochs_filter, args.min_epoch, args.max_epoch) + + if not ckpts: + raise RuntimeError("没有找到符合条件的 checkpoint") + + backup_path = config_path.with_suffix(config_path.suffix + ".bak_auto_eval") + shutil.copy2(config_path, backup_path) + print(f"[INFO] 已备份配置到: {backup_path}") + + result_txt = result_txt_path(project_dir, args.result_name) + summary_rows: List[Dict] = [] + + try: + for epoch, ckpt in ckpts: + print("=" * 100) + print(f"[INFO] 开始测试 checkpoint: epoch={epoch}, path={ckpt}") + print("=" * 100) + + overwrite_test_config(config_path, ckpt) + + ret = run_one_eval(project_dir, train_script, args.result_name, args.gpu) + if ret != 0: + print(f"[WARN] checkpoint {epoch} 测试失败,返回码 {ret}") + continue + + time.sleep(args.sleep_sec) + + parsed = parse_result_file(result_txt) + epoch_rows = collect_epoch_rows(parsed, epoch) + + if not epoch_rows: + print(f"[WARN] 没有在结果文件中找到 epoch={epoch} 的记录") + continue + + agg = aggregate_rows(epoch_rows) + row = { + "epoch": epoch, + "checkpoint": str(ckpt), + **agg, + } + summary_rows.append(row) + + print(f"[INFO] epoch={epoch} 汇总: {row}") + + summary_csv = project_dir / "result" / f"{args.result_name}_summary.csv" + save_summary_csv(summary_csv, summary_rows) + print(f"[INFO] 汇总结果已保存到: {summary_csv}") + + if summary_rows: + best_by_ap = max(summary_rows, key=lambda x: x.get("mean_AP", float("-inf"))) + print("\n[INFO] 最佳 checkpoint(按 mean_AP):") + print(best_by_ap) + + finally: + shutil.copy2(backup_path, config_path) + print(f"[INFO] 已恢复原始配置: {config_path}") + + +if __name__ == "__main__": + main() \ No newline at end of file