test_model

This commit is contained in:
MobKBK
2026-04-11 09:06:04 +08:00
parent bc0498e453
commit 71199090c8
2 changed files with 308 additions and 0 deletions

308
auto_eval_checkpoints.py Normal file
View File

@@ -0,0 +1,308 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
自动评估 FUSIONLCD 多个 checkpoint 的脚本
用法示例:
python auto_eval_checkpoints.py \
--project_dir /home/adlab36/chenyouyuan/FUSIONLCD \
--config /home/adlab36/chenyouyuan/FUSIONLCD/config.yaml \
--train_script /home/adlab36/chenyouyuan/FUSIONLCD/train.py \
--models_dir /home/adlab36/chenyouyuan/FUSIONLCD/result/log/models \
--result_name auto_eval \
--gpu 1
说明:
1. 会备份原 config.yaml 为 config.yaml.bak_auto_eval
2. 每个 checkpoint 测试前会把 config 改成:
- train_flag = 0
- validate_flag = 0
- test_flag = 1
- load_model = 1
- last_model = 当前 checkpoint
3. 每测完一个 checkpoint会读取 result/<result_name>.txt 追加的新结果
4. 最终输出 summary.csv
"""
from __future__ import annotations
import argparse
import csv
import os
import re
import shutil
import subprocess
import sys
import time
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import yaml
CKPT_RE = re.compile(r"checkpoint_(\d+)\.pth\.tar$")
RESULT_LINE_RE = re.compile(r"^\d{14}\s+(\d+)\s+(\d+)\s+(.*)$")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--project_dir", type=str, required=True, help="项目根目录")
parser.add_argument("--config", type=str, required=True, help="config.yaml 路径")
parser.add_argument("--train_script", type=str, required=True, help="train.py 路径")
parser.add_argument("--models_dir", type=str, required=True, help="checkpoint 目录")
parser.add_argument("--result_name", type=str, default="auto_eval", help="train.py 的 result_name")
parser.add_argument("--gpu", type=str, default="0", help="GPU id例如 0 或 1")
parser.add_argument("--epochs_filter", type=str, default="", help="只测试指定 epoch逗号分隔如 99,109,119")
parser.add_argument("--min_epoch", type=int, default=None, help="最小 epoch 过滤")
parser.add_argument("--max_epoch", type=int, default=None, help="最大 epoch 过滤")
parser.add_argument("--sleep_sec", type=float, default=1.0, help="每次测试后等待秒数")
return parser.parse_args()
def load_yaml(path: Path) -> dict:
with path.open("r", encoding="utf-8") as f:
return yaml.safe_load(f)
def save_yaml(path: Path, data: dict) -> None:
with path.open("w", encoding="utf-8") as f:
yaml.safe_dump(data, f, allow_unicode=True, sort_keys=False)
def list_checkpoints(models_dir: Path) -> List[Tuple[int, Path]]:
ckpts: List[Tuple[int, Path]] = []
for p in models_dir.glob("checkpoint_*.pth.tar"):
m = CKPT_RE.match(p.name)
if m:
ckpts.append((int(m.group(1)), p))
ckpts.sort(key=lambda x: x[0])
return ckpts
def filter_checkpoints(
ckpts: List[Tuple[int, Path]],
epochs_filter: str,
min_epoch: Optional[int],
max_epoch: Optional[int],
) -> List[Tuple[int, Path]]:
selected = ckpts
if epochs_filter.strip():
wanted = {int(x.strip()) for x in epochs_filter.split(",") if x.strip()}
selected = [(e, p) for e, p in selected if e in wanted]
if min_epoch is not None:
selected = [(e, p) for e, p in selected if e >= min_epoch]
if max_epoch is not None:
selected = [(e, p) for e, p in selected if e <= max_epoch]
return selected
def result_txt_path(project_dir: Path, result_name: str) -> Path:
return project_dir / "result" / f"{result_name}.txt"
def read_result_lines(path: Path) -> List[str]:
if not path.exists():
return []
with path.open("r", encoding="utf-8") as f:
return [line.rstrip("\n") for line in f.readlines()]
def parse_result_file(path: Path) -> List[Dict]:
rows: List[Dict] = []
if not path.exists():
return rows
with path.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or line.startswith("Time"):
continue
m = RESULT_LINE_RE.match(line)
if not m:
continue
seq = int(m.group(1))
epoch = int(m.group(2))
rest = m.group(3).split()
# 表头来自你的 log_result:
# AP R100 F1 R@1 R@2 R@3 R@4 R@5 R@6 R@7 R@8 R@9 R@10 R@15 R@20 R@25
if len(rest) < 16:
continue
vals = list(map(float, rest[:16]))
rows.append(
{
"seq": seq,
"epoch": epoch,
"AP": vals[0],
"R100": vals[1],
"F1": vals[2],
"R@1": vals[3],
"R@2": vals[4],
"R@3": vals[5],
"R@4": vals[6],
"R@5": vals[7],
"R@6": vals[8],
"R@7": vals[9],
"R@8": vals[10],
"R@9": vals[11],
"R@10": vals[12],
"R@15": vals[13],
"R@20": vals[14],
"R@25": vals[15],
"raw": line,
}
)
return rows
def overwrite_test_config(config_path: Path, ckpt_path: Path) -> None:
cfg = load_yaml(config_path)
exp = cfg["experiment"]
exp["train_flag"] = 0
exp["validate_flag"] = 0
exp["test_flag"] = 1
exp["load_model"] = 1
exp["last_model"] = str(ckpt_path)
save_yaml(config_path, cfg)
def run_one_eval(project_dir: Path, train_script: Path, result_name: str, gpu: str) -> int:
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = gpu
cmd = [
sys.executable,
str(train_script),
"--result_name",
result_name,
"--gpu",
gpu,
"--info",
"auto_eval",
]
proc = subprocess.run(
cmd,
cwd=str(project_dir),
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
)
print(proc.stdout)
return proc.returncode
def collect_epoch_rows(all_rows: List[Dict], epoch: int) -> List[Dict]:
return [r for r in all_rows if r["epoch"] == epoch]
def aggregate_rows(rows: List[Dict]) -> Dict[str, float]:
if not rows:
return {}
keys = ["AP", "R100", "F1", "R@1", "R@5", "R@10", "R@25"]
out = {}
for k in keys:
out[f"mean_{k}"] = sum(r[k] for r in rows) / len(rows)
return out
def save_summary_csv(path: Path, summary: List[Dict]) -> None:
if not summary:
return
fieldnames = list(summary[0].keys())
with path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(summary)
def main() -> None:
args = parse_args()
project_dir = Path(args.project_dir).resolve()
config_path = Path(args.config).resolve()
train_script = Path(args.train_script).resolve()
models_dir = Path(args.models_dir).resolve()
if not project_dir.exists():
raise FileNotFoundError(f"project_dir 不存在: {project_dir}")
if not config_path.exists():
raise FileNotFoundError(f"config 不存在: {config_path}")
if not train_script.exists():
raise FileNotFoundError(f"train_script 不存在: {train_script}")
if not models_dir.exists():
raise FileNotFoundError(f"models_dir 不存在: {models_dir}")
ckpts = list_checkpoints(models_dir)
ckpts = filter_checkpoints(ckpts, args.epochs_filter, args.min_epoch, args.max_epoch)
if not ckpts:
raise RuntimeError("没有找到符合条件的 checkpoint")
backup_path = config_path.with_suffix(config_path.suffix + ".bak_auto_eval")
shutil.copy2(config_path, backup_path)
print(f"[INFO] 已备份配置到: {backup_path}")
result_txt = result_txt_path(project_dir, args.result_name)
summary_rows: List[Dict] = []
try:
for epoch, ckpt in ckpts:
print("=" * 100)
print(f"[INFO] 开始测试 checkpoint: epoch={epoch}, path={ckpt}")
print("=" * 100)
overwrite_test_config(config_path, ckpt)
ret = run_one_eval(project_dir, train_script, args.result_name, args.gpu)
if ret != 0:
print(f"[WARN] checkpoint {epoch} 测试失败,返回码 {ret}")
continue
time.sleep(args.sleep_sec)
parsed = parse_result_file(result_txt)
epoch_rows = collect_epoch_rows(parsed, epoch)
if not epoch_rows:
print(f"[WARN] 没有在结果文件中找到 epoch={epoch} 的记录")
continue
agg = aggregate_rows(epoch_rows)
row = {
"epoch": epoch,
"checkpoint": str(ckpt),
**agg,
}
summary_rows.append(row)
print(f"[INFO] epoch={epoch} 汇总: {row}")
summary_csv = project_dir / "result" / f"{args.result_name}_summary.csv"
save_summary_csv(summary_csv, summary_rows)
print(f"[INFO] 汇总结果已保存到: {summary_csv}")
if summary_rows:
best_by_ap = max(summary_rows, key=lambda x: x.get("mean_AP", float("-inf")))
print("\n[INFO] 最佳 checkpoint按 mean_AP:")
print(best_by_ap)
finally:
shutil.copy2(backup_path, config_path)
print(f"[INFO] 已恢复原始配置: {config_path}")
if __name__ == "__main__":
main()