183 lines
6.9 KiB
Python
183 lines
6.9 KiB
Python
"""Run a held-out evaluation of a trained MaskablePPO model.
|
|
|
|
Usage:
|
|
python -m tooling.rl_self_play.evaluate \
|
|
--model-path tooling/rl_self_play/models/duel-v1/final.zip \
|
|
--episodes 50
|
|
|
|
Prints a one-line JSON verdict:
|
|
{"episodes": 50, "wins": 28, "losses": 18, "draws": 4,
|
|
"turn_caps": 0, "step_caps": 0,
|
|
"win_rate": 0.56, "mean_turns": 142.3}
|
|
|
|
Verdicts:
|
|
win Terminated with `reason=won` (opponent eliminated).
|
|
loss Terminated with `reason=eliminated` or `reason=harness_error`.
|
|
draw Terminated with no clear winner.
|
|
turn_cap Episode hit --max-turns with no terminal verdict. Not a win:
|
|
PlayerView doesn't expose opponent score, so standings at the
|
|
cap are not comparable.
|
|
step_cap Episode hit the per-episode step cap. Degenerate non-result.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
THIS_DIR = Path(__file__).resolve().parent
|
|
PROJECT_ROOT = THIS_DIR.parents[1]
|
|
if __package__ is None:
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
|
|
def _build_argparser() -> argparse.ArgumentParser:
|
|
p = argparse.ArgumentParser(description="Evaluate a trained policy against built-in AI")
|
|
p.add_argument("--model-path", required=True, type=Path)
|
|
p.add_argument("--episodes", type=int, default=50)
|
|
p.add_argument("--max-turns", type=int, default=1000)
|
|
p.add_argument("--seed-offset", type=int, default=10_000,
|
|
help="Eval episode seeds = offset + episode_idx; avoids overlap with train seeds")
|
|
p.add_argument("--players", type=int, default=2)
|
|
p.add_argument("--map-size", default="duel")
|
|
p.add_argument("--opponent-model", default=None, type=Path,
|
|
help=("Frozen MaskablePPO snapshot (.zip) to use as the "
|
|
"opponent instead of the harness MCTS. Set to the "
|
|
"graduated snapshot to measure win-rate against a "
|
|
"learned policy (self-play curriculum gate)."))
|
|
p.add_argument("--opponent-slots", default="1",
|
|
help="Comma-separated opponent slot indices (default '1').")
|
|
p.add_argument("--opponent-device", default="cpu")
|
|
p.add_argument("--opponent-deterministic", action="store_true",
|
|
help="Argmax opponent actions (default: stochastic sampling).")
|
|
p.add_argument("--learner-deterministic", action=argparse.BooleanOptionalAction,
|
|
default=True,
|
|
help=("Argmax the evaluated (slot-0) policy. Default True. "
|
|
"For a symmetric self-play sanity check (e.g. v4 vs "
|
|
"v4, expect ~50%) pass --no-learner-deterministic so "
|
|
"both sides sample from the masked softmax — matching "
|
|
"the stochastic training-eval regime."))
|
|
return p
|
|
|
|
|
|
def _classify_episode(info_history: list[dict[str, object]]) -> str:
|
|
"""Decide verdict from the last step's info.
|
|
|
|
Only terminal events count as win/loss. Truncations (turn_cap, step_cap)
|
|
are reported as their own categories; we don't promote positive
|
|
score-shaping to a "win" because the simulator's PlayerView doesn't
|
|
expose opponent score, so we have no honest way to compare standings
|
|
when the clock runs out.
|
|
"""
|
|
if not info_history:
|
|
return "draw"
|
|
last = info_history[-1]
|
|
reason = last.get("reason")
|
|
if reason == "won":
|
|
return "win"
|
|
if reason == "eliminated":
|
|
return "loss"
|
|
if reason == "harness_error":
|
|
return "loss"
|
|
if reason == "step_cap":
|
|
return "step_cap"
|
|
if reason == "turn_cap":
|
|
return "turn_cap"
|
|
return "draw"
|
|
|
|
|
|
def main() -> int:
|
|
args = _build_argparser().parse_args()
|
|
from sb3_contrib import MaskablePPO # type: ignore[import-not-found]
|
|
from tooling.rl_self_play.harness_client import HarnessConfig # type: ignore[import-not-found]
|
|
from tooling.rl_self_play.magic_civ_env import MagicCivEnv # type: ignore[import-not-found]
|
|
from tooling.rl_self_play.opponent import ModelOpponent # type: ignore[import-not-found]
|
|
|
|
model = MaskablePPO.load(str(args.model_path))
|
|
|
|
opp_slots: tuple[int, ...] = tuple(
|
|
int(s) for s in str(args.opponent_slots).split(",") if s.strip()
|
|
)
|
|
|
|
def _make_opponent() -> ModelOpponent | None:
|
|
if not args.opponent_model:
|
|
return None
|
|
return ModelOpponent(
|
|
model_path=str(args.opponent_model),
|
|
slots=opp_slots,
|
|
device=args.opponent_device,
|
|
deterministic=args.opponent_deterministic,
|
|
)
|
|
|
|
wins = losses = draws = step_caps = turn_caps = 0
|
|
turns_per_episode: list[int] = []
|
|
for episode in range(args.episodes):
|
|
cfg = HarnessConfig(
|
|
seed=args.seed_offset + episode,
|
|
players=args.players,
|
|
player_slot=0,
|
|
map_size=args.map_size,
|
|
)
|
|
env = MagicCivEnv(
|
|
harness_config=cfg, max_turns=args.max_turns, opponent=_make_opponent()
|
|
)
|
|
try:
|
|
obs, info = env.reset()
|
|
done = False
|
|
info_history: list[dict[str, object]] = []
|
|
while not done:
|
|
mask = env.action_masks()
|
|
action, _ = model.predict(
|
|
obs, action_masks=mask, deterministic=args.learner_deterministic
|
|
)
|
|
obs, reward, terminated, truncated, info = env.step(int(action))
|
|
info_history.append(info)
|
|
done = terminated or truncated
|
|
verdict = _classify_episode(info_history)
|
|
if verdict == "win":
|
|
wins += 1
|
|
elif verdict == "loss":
|
|
losses += 1
|
|
elif verdict == "step_cap":
|
|
step_caps += 1
|
|
elif verdict == "turn_cap":
|
|
turn_caps += 1
|
|
else:
|
|
draws += 1
|
|
turns_per_episode.append(int(info.get("turn", 0)))
|
|
finally:
|
|
env.close()
|
|
|
|
total = max(args.episodes, 1)
|
|
mean_turns = sum(turns_per_episode) / max(len(turns_per_episode), 1)
|
|
verdict = {
|
|
"episodes": args.episodes,
|
|
"wins": wins,
|
|
"losses": losses,
|
|
"draws": draws,
|
|
"turn_caps": turn_caps,
|
|
"step_caps": step_caps,
|
|
"win_rate": wins / total,
|
|
"mean_turns": round(mean_turns, 1),
|
|
}
|
|
print(json.dumps(verdict))
|
|
if step_caps:
|
|
print(
|
|
f"WARNING: {step_caps}/{args.episodes} eval episodes hit the "
|
|
f"per-episode step cap — policy got stuck in a no-progress "
|
|
f"loop. Check encoder/reward shaping.",
|
|
file=sys.stderr,
|
|
)
|
|
if turn_caps:
|
|
print(
|
|
f"NOTE: {turn_caps}/{args.episodes} eval episodes hit the "
|
|
f"--max-turns={args.max_turns} cap with no terminal verdict. "
|
|
f"Raise --max-turns or expose opponent score for tie-breaking.",
|
|
file=sys.stderr,
|
|
)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|