Files
fusion_LCD/auto_eval_checkpoints.py
2026-04-11 14:12:20 +08:00

377 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
自动评估 FUSIONLCD 多个 checkpoint 的脚本
支持:
1. 单卡串行测试:
--gpu 0
2. 多卡并行测试:
--gpu 0,1,2,3
python auto_eval_checkpoints.py \
--project_dir /home/adlab36/chenyouyuan/FUSIONLCD \
--config /home/adlab36/chenyouyuan/FUSIONLCD/config.yaml \
--train_script /home/adlab36/chenyouyuan/FUSIONLCD/train.py \
--models_dir /home/adlab36/chenyouyuan/FUSIONLCD/result/log/models \
--result_name auto_eval \
--gpu 2,3 \
--epochs_filter 119, 139
099, 119, 139, 159, 179, 199
说明:
- 多卡模式下,每个 checkpoint 会分配到一个 GPU
- 每个子进程使用独立的临时工作目录和独立 config.yaml避免冲突
- 会实时输出子进程日志
"""
from __future__ import annotations
import argparse
import csv
import os
import re
import shutil
import subprocess
import sys
import time
import tempfile
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
import yaml
CKPT_RE = re.compile(r"checkpoint_(\d+)\.pth\.tar$")
RESULT_LINE_RE = re.compile(r"^\d{14}\s+(\d+)\s+(\d+)\s+(.*)$")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--project_dir", type=str, required=True, help="项目根目录")
parser.add_argument("--config", type=str, required=True, help="config.yaml 路径")
parser.add_argument("--train_script", type=str, required=True, help="train.py 路径")
parser.add_argument("--models_dir", type=str, required=True, help="checkpoint 目录")
parser.add_argument("--result_name", type=str, default="auto_eval", help="train.py 的 result_name")
parser.add_argument("--gpu", type=str, default="0", help="GPU id例如 0 或 0,1,2,3")
parser.add_argument("--epochs_filter", type=str, default="", help="只测试指定 epoch逗号分隔如 99,109,119")
parser.add_argument("--min_epoch", type=int, default=None, help="最小 epoch 过滤")
parser.add_argument("--max_epoch", type=int, default=None, help="最大 epoch 过滤")
parser.add_argument("--sleep_sec", type=float, default=1.0, help="每次测试后等待秒数")
return parser.parse_args()
def load_yaml(path: Path) -> dict:
with path.open("r", encoding="utf-8") as f:
return yaml.safe_load(f)
def save_yaml(path: Path, data: dict) -> None:
with path.open("w", encoding="utf-8") as f:
yaml.safe_dump(data, f, allow_unicode=True, sort_keys=False)
def list_checkpoints(models_dir: Path) -> List[Tuple[int, Path]]:
ckpts: List[Tuple[int, Path]] = []
for p in models_dir.glob("checkpoint_*.pth.tar"):
m = CKPT_RE.match(p.name)
if m:
ckpts.append((int(m.group(1)), p))
ckpts.sort(key=lambda x: x[0])
return ckpts
def filter_checkpoints(
ckpts: List[Tuple[int, Path]],
epochs_filter: str,
min_epoch: Optional[int],
max_epoch: Optional[int],
) -> List[Tuple[int, Path]]:
selected = ckpts
if epochs_filter.strip():
wanted = {int(x.strip()) for x in epochs_filter.split(",") if x.strip()}
selected = [(e, p) for e, p in selected if e in wanted]
if min_epoch is not None:
selected = [(e, p) for e, p in selected if e >= min_epoch]
if max_epoch is not None:
selected = [(e, p) for e, p in selected if e <= max_epoch]
return selected
def parse_result_file(path: Path) -> List[Dict]:
rows: List[Dict] = []
if not path.exists():
return rows
with path.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line or line.startswith("Time"):
continue
m = RESULT_LINE_RE.match(line)
if not m:
continue
seq = int(m.group(1))
epoch = int(m.group(2))
rest = m.group(3).split()
if len(rest) < 16:
continue
vals = list(map(float, rest[:16]))
rows.append(
{
"seq": seq,
"epoch": epoch,
"AP": vals[0],
"R100": vals[1],
"F1": vals[2],
"R@1": vals[3],
"R@2": vals[4],
"R@3": vals[5],
"R@4": vals[6],
"R@5": vals[7],
"R@6": vals[8],
"R@7": vals[9],
"R@8": vals[10],
"R@9": vals[11],
"R@10": vals[12],
"R@15": vals[13],
"R@20": vals[14],
"R@25": vals[15],
"raw": line,
}
)
return rows
def make_eval_config(base_config_path: Path, ckpt_path: Path, result_name: str, temp_config_path: Path) -> None:
cfg = load_yaml(base_config_path)
exp = cfg["experiment"]
exp["train_flag"] = 0
exp["validate_flag"] = 0
exp["test_flag"] = 1
exp["load_model"] = 1
exp["last_model"] = str(ckpt_path)
# 保持原 path_result不改数据库等输出位置
save_yaml(temp_config_path, cfg)
def run_one_eval(
work_dir: Path,
train_script: Path,
result_name: str,
gpu: str,
tag: str,
) -> int:
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = gpu
cmd = [
sys.executable,
str(train_script),
"--result_name",
result_name,
"--gpu",
gpu,
"--info",
f"auto_eval_{tag}",
]
print(f"[INFO][{tag}] Running command: {' '.join(cmd)}")
print(f"[INFO][{tag}] CUDA_VISIBLE_DEVICES={gpu}")
print(f"[INFO][{tag}] cwd={work_dir}")
proc = subprocess.Popen(
cmd,
cwd=str(work_dir),
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True,
)
try:
assert proc.stdout is not None
for line in proc.stdout:
print(f"[{tag}] {line}", end="")
proc.wait()
return proc.returncode
except KeyboardInterrupt:
print(f"\n[WARN][{tag}] 收到 Ctrl+C正在终止当前测试子进程...")
proc.terminate()
try:
proc.wait(timeout=5)
except Exception:
print(f"[WARN][{tag}] 子进程未及时退出,强制 kill")
proc.kill()
proc.wait()
raise
def collect_epoch_rows(all_rows: List[Dict], epoch: int) -> List[Dict]:
return [r for r in all_rows if r["epoch"] == epoch]
def aggregate_rows(rows: List[Dict]) -> Dict[str, float]:
if not rows:
return {}
keys = ["AP", "R100", "F1", "R@1", "R@5", "R@10", "R@25"]
out = {}
for k in keys:
out[f"mean_{k}"] = sum(r[k] for r in rows) / len(rows)
return out
def save_summary_csv(path: Path, summary: List[Dict]) -> None:
if not summary:
return
fieldnames = list(summary[0].keys())
with path.open("w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(summary)
def run_single_checkpoint(
epoch: int,
ckpt: Path,
gpu: str,
args: argparse.Namespace,
project_dir: Path,
train_script: Path,
) -> Optional[Dict]:
tag = f"gpu{gpu}_ep{epoch}"
temp_root = Path(tempfile.mkdtemp(prefix=f"auto_eval_{tag}_"))
try:
# train.py 会优先从 cwd/config.yaml 读取配置
temp_config = temp_root / "config.yaml"
make_eval_config(Path(args.config), ckpt, args.result_name, temp_config)
# train.py 会把 txt 写到 cwd/result/result_name.txt
(temp_root / "result").mkdir(parents=True, exist_ok=True)
ret = run_one_eval(
work_dir=temp_root,
train_script=train_script,
result_name=args.result_name,
gpu=gpu,
tag=tag,
)
if ret != 0:
print(f"[WARN][{tag}] checkpoint {epoch} 测试失败,返回码 {ret}")
return None
time.sleep(args.sleep_sec)
result_txt = temp_root / "result" / f"{args.result_name}.txt"
parsed = parse_result_file(result_txt)
epoch_rows = collect_epoch_rows(parsed, epoch)
if not epoch_rows:
print(f"[WARN][{tag}] 没有在结果文件中找到 epoch={epoch} 的记录")
return None
agg = aggregate_rows(epoch_rows)
row = {
"epoch": epoch,
"checkpoint": str(ckpt),
"gpu": gpu,
**agg,
}
print(f"[INFO][{tag}] 汇总: {row}")
return row
finally:
shutil.rmtree(temp_root, ignore_errors=True)
def main() -> None:
args = parse_args()
project_dir = Path(args.project_dir).resolve()
train_script = Path(args.train_script).resolve()
models_dir = Path(args.models_dir).resolve()
if not project_dir.exists():
raise FileNotFoundError(f"project_dir 不存在: {project_dir}")
if not Path(args.config).exists():
raise FileNotFoundError(f"config 不存在: {args.config}")
if not train_script.exists():
raise FileNotFoundError(f"train_script 不存在: {train_script}")
if not models_dir.exists():
raise FileNotFoundError(f"models_dir 不存在: {models_dir}")
ckpts = list_checkpoints(models_dir)
ckpts = filter_checkpoints(ckpts, args.epochs_filter, args.min_epoch, args.max_epoch)
if not ckpts:
raise RuntimeError("没有找到符合条件的 checkpoint")
gpu_list = [x.strip() for x in args.gpu.split(",") if x.strip()]
if not gpu_list:
raise RuntimeError("没有可用 GPU 参数")
print(f"[INFO] 使用 GPU 列表: {gpu_list}")
print(f"[INFO] 待测试 checkpoint 数量: {len(ckpts)}")
summary_rows: List[Dict] = []
# 单卡时保持串行行为
if len(gpu_list) == 1:
gpu = gpu_list[0]
for epoch, ckpt in ckpts:
print("=" * 100)
print(f"[INFO] 开始测试 checkpoint: epoch={epoch}, path={ckpt}, gpu={gpu}")
print("=" * 100)
row = run_single_checkpoint(epoch, ckpt, gpu, args, project_dir, train_script)
if row is not None:
summary_rows.append(row)
else:
# 多卡并行round-robin 分配 checkpoint 到不同 GPU
futures = []
with ThreadPoolExecutor(max_workers=len(gpu_list)) as ex:
for idx, (epoch, ckpt) in enumerate(ckpts):
gpu = gpu_list[idx % len(gpu_list)]
futures.append(
ex.submit(
run_single_checkpoint,
epoch,
ckpt,
gpu,
args,
project_dir,
train_script,
)
)
for fut in as_completed(futures):
row = fut.result()
if row is not None:
summary_rows.append(row)
summary_rows.sort(key=lambda x: x["epoch"])
summary_csv = project_dir / "result" / f"{args.result_name}_summary.csv"
save_summary_csv(summary_csv, summary_rows)
print(f"[INFO] 汇总结果已保存到: {summary_csv}")
if summary_rows:
best_by_ap = max(summary_rows, key=lambda x: x.get("mean_AP", float("-inf")))
print("\n[INFO] 最佳 checkpoint按 mean_AP:")
print(best_by_ap)
if __name__ == "__main__":
main()