377 lines
11 KiB
Python
377 lines
11 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
自动评估 FUSIONLCD 多个 checkpoint 的脚本
|
||
|
||
支持:
|
||
1. 单卡串行测试:
|
||
--gpu 0
|
||
2. 多卡并行测试:
|
||
--gpu 0,1,2,3
|
||
|
||
python auto_eval_checkpoints.py \
|
||
--project_dir /home/adlab36/chenyouyuan/FUSIONLCD \
|
||
--config /home/adlab36/chenyouyuan/FUSIONLCD/config.yaml \
|
||
--train_script /home/adlab36/chenyouyuan/FUSIONLCD/train.py \
|
||
--models_dir /home/adlab36/chenyouyuan/FUSIONLCD/result/log/models \
|
||
--result_name auto_eval \
|
||
--gpu 2,3 \
|
||
--epochs_filter 119, 139
|
||
|
||
099, 119, 139, 159, 179, 199
|
||
说明:
|
||
- 多卡模式下,每个 checkpoint 会分配到一个 GPU
|
||
- 每个子进程使用独立的临时工作目录和独立 config.yaml,避免冲突
|
||
- 会实时输出子进程日志
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import csv
|
||
import os
|
||
import re
|
||
import shutil
|
||
import subprocess
|
||
import sys
|
||
import time
|
||
import tempfile
|
||
from pathlib import Path
|
||
from typing import Dict, List, Tuple, Optional
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
|
||
import yaml
|
||
|
||
|
||
CKPT_RE = re.compile(r"checkpoint_(\d+)\.pth\.tar$")
|
||
RESULT_LINE_RE = re.compile(r"^\d{14}\s+(\d+)\s+(\d+)\s+(.*)$")
|
||
|
||
|
||
def parse_args() -> argparse.Namespace:
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--project_dir", type=str, required=True, help="项目根目录")
|
||
parser.add_argument("--config", type=str, required=True, help="config.yaml 路径")
|
||
parser.add_argument("--train_script", type=str, required=True, help="train.py 路径")
|
||
parser.add_argument("--models_dir", type=str, required=True, help="checkpoint 目录")
|
||
parser.add_argument("--result_name", type=str, default="auto_eval", help="train.py 的 result_name")
|
||
parser.add_argument("--gpu", type=str, default="0", help="GPU id,例如 0 或 0,1,2,3")
|
||
parser.add_argument("--epochs_filter", type=str, default="", help="只测试指定 epoch,逗号分隔,如 99,109,119")
|
||
parser.add_argument("--min_epoch", type=int, default=None, help="最小 epoch 过滤")
|
||
parser.add_argument("--max_epoch", type=int, default=None, help="最大 epoch 过滤")
|
||
parser.add_argument("--sleep_sec", type=float, default=1.0, help="每次测试后等待秒数")
|
||
return parser.parse_args()
|
||
|
||
|
||
def load_yaml(path: Path) -> dict:
|
||
with path.open("r", encoding="utf-8") as f:
|
||
return yaml.safe_load(f)
|
||
|
||
|
||
def save_yaml(path: Path, data: dict) -> None:
|
||
with path.open("w", encoding="utf-8") as f:
|
||
yaml.safe_dump(data, f, allow_unicode=True, sort_keys=False)
|
||
|
||
|
||
def list_checkpoints(models_dir: Path) -> List[Tuple[int, Path]]:
|
||
ckpts: List[Tuple[int, Path]] = []
|
||
for p in models_dir.glob("checkpoint_*.pth.tar"):
|
||
m = CKPT_RE.match(p.name)
|
||
if m:
|
||
ckpts.append((int(m.group(1)), p))
|
||
ckpts.sort(key=lambda x: x[0])
|
||
return ckpts
|
||
|
||
|
||
def filter_checkpoints(
|
||
ckpts: List[Tuple[int, Path]],
|
||
epochs_filter: str,
|
||
min_epoch: Optional[int],
|
||
max_epoch: Optional[int],
|
||
) -> List[Tuple[int, Path]]:
|
||
selected = ckpts
|
||
|
||
if epochs_filter.strip():
|
||
wanted = {int(x.strip()) for x in epochs_filter.split(",") if x.strip()}
|
||
selected = [(e, p) for e, p in selected if e in wanted]
|
||
|
||
if min_epoch is not None:
|
||
selected = [(e, p) for e, p in selected if e >= min_epoch]
|
||
|
||
if max_epoch is not None:
|
||
selected = [(e, p) for e, p in selected if e <= max_epoch]
|
||
|
||
return selected
|
||
|
||
|
||
def parse_result_file(path: Path) -> List[Dict]:
|
||
rows: List[Dict] = []
|
||
if not path.exists():
|
||
return rows
|
||
|
||
with path.open("r", encoding="utf-8") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if not line or line.startswith("Time"):
|
||
continue
|
||
|
||
m = RESULT_LINE_RE.match(line)
|
||
if not m:
|
||
continue
|
||
|
||
seq = int(m.group(1))
|
||
epoch = int(m.group(2))
|
||
rest = m.group(3).split()
|
||
|
||
if len(rest) < 16:
|
||
continue
|
||
|
||
vals = list(map(float, rest[:16]))
|
||
rows.append(
|
||
{
|
||
"seq": seq,
|
||
"epoch": epoch,
|
||
"AP": vals[0],
|
||
"R100": vals[1],
|
||
"F1": vals[2],
|
||
"R@1": vals[3],
|
||
"R@2": vals[4],
|
||
"R@3": vals[5],
|
||
"R@4": vals[6],
|
||
"R@5": vals[7],
|
||
"R@6": vals[8],
|
||
"R@7": vals[9],
|
||
"R@8": vals[10],
|
||
"R@9": vals[11],
|
||
"R@10": vals[12],
|
||
"R@15": vals[13],
|
||
"R@20": vals[14],
|
||
"R@25": vals[15],
|
||
"raw": line,
|
||
}
|
||
)
|
||
return rows
|
||
|
||
|
||
def make_eval_config(base_config_path: Path, ckpt_path: Path, result_name: str, temp_config_path: Path) -> None:
|
||
cfg = load_yaml(base_config_path)
|
||
exp = cfg["experiment"]
|
||
|
||
exp["train_flag"] = 0
|
||
exp["validate_flag"] = 0
|
||
exp["test_flag"] = 1
|
||
exp["load_model"] = 1
|
||
exp["last_model"] = str(ckpt_path)
|
||
|
||
# 保持原 path_result,不改数据库等输出位置
|
||
save_yaml(temp_config_path, cfg)
|
||
|
||
|
||
def run_one_eval(
|
||
work_dir: Path,
|
||
train_script: Path,
|
||
result_name: str,
|
||
gpu: str,
|
||
tag: str,
|
||
) -> int:
|
||
env = os.environ.copy()
|
||
env["CUDA_VISIBLE_DEVICES"] = gpu
|
||
|
||
cmd = [
|
||
sys.executable,
|
||
str(train_script),
|
||
"--result_name",
|
||
result_name,
|
||
"--gpu",
|
||
gpu,
|
||
"--info",
|
||
f"auto_eval_{tag}",
|
||
]
|
||
|
||
print(f"[INFO][{tag}] Running command: {' '.join(cmd)}")
|
||
print(f"[INFO][{tag}] CUDA_VISIBLE_DEVICES={gpu}")
|
||
print(f"[INFO][{tag}] cwd={work_dir}")
|
||
|
||
proc = subprocess.Popen(
|
||
cmd,
|
||
cwd=str(work_dir),
|
||
env=env,
|
||
stdout=subprocess.PIPE,
|
||
stderr=subprocess.STDOUT,
|
||
text=True,
|
||
bufsize=1,
|
||
universal_newlines=True,
|
||
)
|
||
|
||
try:
|
||
assert proc.stdout is not None
|
||
for line in proc.stdout:
|
||
print(f"[{tag}] {line}", end="")
|
||
proc.wait()
|
||
return proc.returncode
|
||
except KeyboardInterrupt:
|
||
print(f"\n[WARN][{tag}] 收到 Ctrl+C,正在终止当前测试子进程...")
|
||
proc.terminate()
|
||
try:
|
||
proc.wait(timeout=5)
|
||
except Exception:
|
||
print(f"[WARN][{tag}] 子进程未及时退出,强制 kill")
|
||
proc.kill()
|
||
proc.wait()
|
||
raise
|
||
|
||
|
||
def collect_epoch_rows(all_rows: List[Dict], epoch: int) -> List[Dict]:
|
||
return [r for r in all_rows if r["epoch"] == epoch]
|
||
|
||
|
||
def aggregate_rows(rows: List[Dict]) -> Dict[str, float]:
|
||
if not rows:
|
||
return {}
|
||
keys = ["AP", "R100", "F1", "R@1", "R@5", "R@10", "R@25"]
|
||
out = {}
|
||
for k in keys:
|
||
out[f"mean_{k}"] = sum(r[k] for r in rows) / len(rows)
|
||
return out
|
||
|
||
|
||
def save_summary_csv(path: Path, summary: List[Dict]) -> None:
|
||
if not summary:
|
||
return
|
||
fieldnames = list(summary[0].keys())
|
||
with path.open("w", newline="", encoding="utf-8") as f:
|
||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||
writer.writeheader()
|
||
writer.writerows(summary)
|
||
|
||
|
||
def run_single_checkpoint(
|
||
epoch: int,
|
||
ckpt: Path,
|
||
gpu: str,
|
||
args: argparse.Namespace,
|
||
project_dir: Path,
|
||
train_script: Path,
|
||
) -> Optional[Dict]:
|
||
tag = f"gpu{gpu}_ep{epoch}"
|
||
temp_root = Path(tempfile.mkdtemp(prefix=f"auto_eval_{tag}_"))
|
||
try:
|
||
# train.py 会优先从 cwd/config.yaml 读取配置
|
||
temp_config = temp_root / "config.yaml"
|
||
make_eval_config(Path(args.config), ckpt, args.result_name, temp_config)
|
||
|
||
# train.py 会把 txt 写到 cwd/result/result_name.txt
|
||
(temp_root / "result").mkdir(parents=True, exist_ok=True)
|
||
|
||
ret = run_one_eval(
|
||
work_dir=temp_root,
|
||
train_script=train_script,
|
||
result_name=args.result_name,
|
||
gpu=gpu,
|
||
tag=tag,
|
||
)
|
||
if ret != 0:
|
||
print(f"[WARN][{tag}] checkpoint {epoch} 测试失败,返回码 {ret}")
|
||
return None
|
||
|
||
time.sleep(args.sleep_sec)
|
||
|
||
result_txt = temp_root / "result" / f"{args.result_name}.txt"
|
||
parsed = parse_result_file(result_txt)
|
||
epoch_rows = collect_epoch_rows(parsed, epoch)
|
||
|
||
if not epoch_rows:
|
||
print(f"[WARN][{tag}] 没有在结果文件中找到 epoch={epoch} 的记录")
|
||
return None
|
||
|
||
agg = aggregate_rows(epoch_rows)
|
||
row = {
|
||
"epoch": epoch,
|
||
"checkpoint": str(ckpt),
|
||
"gpu": gpu,
|
||
**agg,
|
||
}
|
||
print(f"[INFO][{tag}] 汇总: {row}")
|
||
return row
|
||
finally:
|
||
shutil.rmtree(temp_root, ignore_errors=True)
|
||
|
||
|
||
def main() -> None:
|
||
args = parse_args()
|
||
|
||
project_dir = Path(args.project_dir).resolve()
|
||
train_script = Path(args.train_script).resolve()
|
||
models_dir = Path(args.models_dir).resolve()
|
||
|
||
if not project_dir.exists():
|
||
raise FileNotFoundError(f"project_dir 不存在: {project_dir}")
|
||
if not Path(args.config).exists():
|
||
raise FileNotFoundError(f"config 不存在: {args.config}")
|
||
if not train_script.exists():
|
||
raise FileNotFoundError(f"train_script 不存在: {train_script}")
|
||
if not models_dir.exists():
|
||
raise FileNotFoundError(f"models_dir 不存在: {models_dir}")
|
||
|
||
ckpts = list_checkpoints(models_dir)
|
||
ckpts = filter_checkpoints(ckpts, args.epochs_filter, args.min_epoch, args.max_epoch)
|
||
|
||
if not ckpts:
|
||
raise RuntimeError("没有找到符合条件的 checkpoint")
|
||
|
||
gpu_list = [x.strip() for x in args.gpu.split(",") if x.strip()]
|
||
if not gpu_list:
|
||
raise RuntimeError("没有可用 GPU 参数")
|
||
|
||
print(f"[INFO] 使用 GPU 列表: {gpu_list}")
|
||
print(f"[INFO] 待测试 checkpoint 数量: {len(ckpts)}")
|
||
|
||
summary_rows: List[Dict] = []
|
||
|
||
# 单卡时保持串行行为
|
||
if len(gpu_list) == 1:
|
||
gpu = gpu_list[0]
|
||
for epoch, ckpt in ckpts:
|
||
print("=" * 100)
|
||
print(f"[INFO] 开始测试 checkpoint: epoch={epoch}, path={ckpt}, gpu={gpu}")
|
||
print("=" * 100)
|
||
row = run_single_checkpoint(epoch, ckpt, gpu, args, project_dir, train_script)
|
||
if row is not None:
|
||
summary_rows.append(row)
|
||
else:
|
||
# 多卡并行:round-robin 分配 checkpoint 到不同 GPU
|
||
futures = []
|
||
with ThreadPoolExecutor(max_workers=len(gpu_list)) as ex:
|
||
for idx, (epoch, ckpt) in enumerate(ckpts):
|
||
gpu = gpu_list[idx % len(gpu_list)]
|
||
futures.append(
|
||
ex.submit(
|
||
run_single_checkpoint,
|
||
epoch,
|
||
ckpt,
|
||
gpu,
|
||
args,
|
||
project_dir,
|
||
train_script,
|
||
)
|
||
)
|
||
|
||
for fut in as_completed(futures):
|
||
row = fut.result()
|
||
if row is not None:
|
||
summary_rows.append(row)
|
||
|
||
summary_rows.sort(key=lambda x: x["epoch"])
|
||
|
||
summary_csv = project_dir / "result" / f"{args.result_name}_summary.csv"
|
||
save_summary_csv(summary_csv, summary_rows)
|
||
print(f"[INFO] 汇总结果已保存到: {summary_csv}")
|
||
|
||
if summary_rows:
|
||
best_by_ap = max(summary_rows, key=lambda x: x.get("mean_AP", float("-inf")))
|
||
print("\n[INFO] 最佳 checkpoint(按 mean_AP):")
|
||
print(best_by_ap)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |