#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 自动评估 FUSIONLCD 多个 checkpoint 的脚本 用法示例: python auto_eval_checkpoints.py \ --project_dir /home/adlab36/chenyouyuan/FUSIONLCD \ --config /home/adlab36/chenyouyuan/FUSIONLCD/config.yaml \ --train_script /home/adlab36/chenyouyuan/FUSIONLCD/train.py \ --models_dir /home/adlab36/chenyouyuan/FUSIONLCD/result/log/models \ --result_name auto_eval \ --gpu 1 说明: 1. 会备份原 config.yaml 为 config.yaml.bak_auto_eval 2. 每个 checkpoint 测试前会把 config 改成: - train_flag = 0 - validate_flag = 0 - test_flag = 1 - load_model = 1 - last_model = 当前 checkpoint 3. 每测完一个 checkpoint,会读取 result/.txt 追加的新结果 4. 最终输出 summary.csv """ from __future__ import annotations import argparse import csv import os import re import shutil import subprocess import sys import time from pathlib import Path from typing import Dict, List, Tuple, Optional import yaml CKPT_RE = re.compile(r"checkpoint_(\d+)\.pth\.tar$") RESULT_LINE_RE = re.compile(r"^\d{14}\s+(\d+)\s+(\d+)\s+(.*)$") def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--project_dir", type=str, required=True, help="项目根目录") parser.add_argument("--config", type=str, required=True, help="config.yaml 路径") parser.add_argument("--train_script", type=str, required=True, help="train.py 路径") parser.add_argument("--models_dir", type=str, required=True, help="checkpoint 目录") parser.add_argument("--result_name", type=str, default="auto_eval", help="train.py 的 result_name") parser.add_argument("--gpu", type=str, default="0", help="GPU id,例如 0 或 1") parser.add_argument("--epochs_filter", type=str, default="", help="只测试指定 epoch,逗号分隔,如 99,109,119") parser.add_argument("--min_epoch", type=int, default=None, help="最小 epoch 过滤") parser.add_argument("--max_epoch", type=int, default=None, help="最大 epoch 过滤") parser.add_argument("--sleep_sec", type=float, default=1.0, help="每次测试后等待秒数") return parser.parse_args() def load_yaml(path: Path) -> dict: with path.open("r", encoding="utf-8") as f: return yaml.safe_load(f) def save_yaml(path: Path, data: dict) -> None: with path.open("w", encoding="utf-8") as f: yaml.safe_dump(data, f, allow_unicode=True, sort_keys=False) def list_checkpoints(models_dir: Path) -> List[Tuple[int, Path]]: ckpts: List[Tuple[int, Path]] = [] for p in models_dir.glob("checkpoint_*.pth.tar"): m = CKPT_RE.match(p.name) if m: ckpts.append((int(m.group(1)), p)) ckpts.sort(key=lambda x: x[0]) return ckpts def filter_checkpoints( ckpts: List[Tuple[int, Path]], epochs_filter: str, min_epoch: Optional[int], max_epoch: Optional[int], ) -> List[Tuple[int, Path]]: selected = ckpts if epochs_filter.strip(): wanted = {int(x.strip()) for x in epochs_filter.split(",") if x.strip()} selected = [(e, p) for e, p in selected if e in wanted] if min_epoch is not None: selected = [(e, p) for e, p in selected if e >= min_epoch] if max_epoch is not None: selected = [(e, p) for e, p in selected if e <= max_epoch] return selected def result_txt_path(project_dir: Path, result_name: str) -> Path: return project_dir / "result" / f"{result_name}.txt" def read_result_lines(path: Path) -> List[str]: if not path.exists(): return [] with path.open("r", encoding="utf-8") as f: return [line.rstrip("\n") for line in f.readlines()] def parse_result_file(path: Path) -> List[Dict]: rows: List[Dict] = [] if not path.exists(): return rows with path.open("r", encoding="utf-8") as f: for line in f: line = line.strip() if not line or line.startswith("Time"): continue m = RESULT_LINE_RE.match(line) if not m: continue seq = int(m.group(1)) epoch = int(m.group(2)) rest = m.group(3).split() # 表头来自你的 log_result: # AP R100 F1 R@1 R@2 R@3 R@4 R@5 R@6 R@7 R@8 R@9 R@10 R@15 R@20 R@25 if len(rest) < 16: continue vals = list(map(float, rest[:16])) rows.append( { "seq": seq, "epoch": epoch, "AP": vals[0], "R100": vals[1], "F1": vals[2], "R@1": vals[3], "R@2": vals[4], "R@3": vals[5], "R@4": vals[6], "R@5": vals[7], "R@6": vals[8], "R@7": vals[9], "R@8": vals[10], "R@9": vals[11], "R@10": vals[12], "R@15": vals[13], "R@20": vals[14], "R@25": vals[15], "raw": line, } ) return rows def overwrite_test_config(config_path: Path, ckpt_path: Path) -> None: cfg = load_yaml(config_path) exp = cfg["experiment"] exp["train_flag"] = 0 exp["validate_flag"] = 0 exp["test_flag"] = 1 exp["load_model"] = 1 exp["last_model"] = str(ckpt_path) save_yaml(config_path, cfg) def run_one_eval(project_dir: Path, train_script: Path, result_name: str, gpu: str) -> int: env = os.environ.copy() env["CUDA_VISIBLE_DEVICES"] = gpu cmd = [ sys.executable, str(train_script), "--result_name", result_name, "--gpu", gpu, "--info", "auto_eval", ] proc = subprocess.run( cmd, cwd=str(project_dir), env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, ) print(proc.stdout) return proc.returncode def collect_epoch_rows(all_rows: List[Dict], epoch: int) -> List[Dict]: return [r for r in all_rows if r["epoch"] == epoch] def aggregate_rows(rows: List[Dict]) -> Dict[str, float]: if not rows: return {} keys = ["AP", "R100", "F1", "R@1", "R@5", "R@10", "R@25"] out = {} for k in keys: out[f"mean_{k}"] = sum(r[k] for r in rows) / len(rows) return out def save_summary_csv(path: Path, summary: List[Dict]) -> None: if not summary: return fieldnames = list(summary[0].keys()) with path.open("w", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(summary) def main() -> None: args = parse_args() project_dir = Path(args.project_dir).resolve() config_path = Path(args.config).resolve() train_script = Path(args.train_script).resolve() models_dir = Path(args.models_dir).resolve() if not project_dir.exists(): raise FileNotFoundError(f"project_dir 不存在: {project_dir}") if not config_path.exists(): raise FileNotFoundError(f"config 不存在: {config_path}") if not train_script.exists(): raise FileNotFoundError(f"train_script 不存在: {train_script}") if not models_dir.exists(): raise FileNotFoundError(f"models_dir 不存在: {models_dir}") ckpts = list_checkpoints(models_dir) ckpts = filter_checkpoints(ckpts, args.epochs_filter, args.min_epoch, args.max_epoch) if not ckpts: raise RuntimeError("没有找到符合条件的 checkpoint") backup_path = config_path.with_suffix(config_path.suffix + ".bak_auto_eval") shutil.copy2(config_path, backup_path) print(f"[INFO] 已备份配置到: {backup_path}") result_txt = result_txt_path(project_dir, args.result_name) summary_rows: List[Dict] = [] try: for epoch, ckpt in ckpts: print("=" * 100) print(f"[INFO] 开始测试 checkpoint: epoch={epoch}, path={ckpt}") print("=" * 100) overwrite_test_config(config_path, ckpt) ret = run_one_eval(project_dir, train_script, args.result_name, args.gpu) if ret != 0: print(f"[WARN] checkpoint {epoch} 测试失败,返回码 {ret}") continue time.sleep(args.sleep_sec) parsed = parse_result_file(result_txt) epoch_rows = collect_epoch_rows(parsed, epoch) if not epoch_rows: print(f"[WARN] 没有在结果文件中找到 epoch={epoch} 的记录") continue agg = aggregate_rows(epoch_rows) row = { "epoch": epoch, "checkpoint": str(ckpt), **agg, } summary_rows.append(row) print(f"[INFO] epoch={epoch} 汇总: {row}") summary_csv = project_dir / "result" / f"{args.result_name}_summary.csv" save_summary_csv(summary_csv, summary_rows) print(f"[INFO] 汇总结果已保存到: {summary_csv}") if summary_rows: best_by_ap = max(summary_rows, key=lambda x: x.get("mean_AP", float("-inf"))) print("\n[INFO] 最佳 checkpoint(按 mean_AP):") print(best_by_ap) finally: shutil.copy2(backup_path, config_path) print(f"[INFO] 已恢复原始配置: {config_path}") if __name__ == "__main__": main()