132 lines
4.7 KiB
Python
132 lines
4.7 KiB
Python
"""Run a held-out evaluation of a trained MaskablePPO model.
|
|
|
|
Usage:
|
|
python -m tooling.rl_self_play.evaluate \
|
|
--model-path tooling/rl_self_play/models/duel-v1/final.zip \
|
|
--episodes 50
|
|
|
|
Prints a one-line JSON verdict:
|
|
{"episodes": 50, "wins": 28, "losses": 18, "draws": 4,
|
|
"win_rate": 0.56, "mean_turns": 142.3}
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
THIS_DIR = Path(__file__).resolve().parent
|
|
PROJECT_ROOT = THIS_DIR.parents[1]
|
|
if __package__ is None:
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
|
|
def _build_argparser() -> argparse.ArgumentParser:
|
|
p = argparse.ArgumentParser(description="Evaluate a trained policy against built-in AI")
|
|
p.add_argument("--model-path", required=True, type=Path)
|
|
p.add_argument("--episodes", type=int, default=50)
|
|
p.add_argument("--max-turns", type=int, default=200)
|
|
p.add_argument("--seed-offset", type=int, default=10_000,
|
|
help="Eval episode seeds = offset + episode_idx; avoids overlap with train seeds")
|
|
p.add_argument("--players", type=int, default=2)
|
|
p.add_argument("--map-size", default="duel")
|
|
return p
|
|
|
|
|
|
def _classify_episode(info_history: list[dict[str, object]], total_reward: float) -> str:
|
|
"""Decide win/loss/draw from the last step's info + accumulated reward.
|
|
|
|
Win: terminated with positive terminal reward (i.e. opponent eliminated
|
|
or score-fallback in our favour — currently only "opponent
|
|
eliminated" because the env doesn't yet read win events).
|
|
Loss: terminated with `reason=eliminated` (we ran out of cities).
|
|
Draw: truncated at max_turns OR rolled positive score-shaping but no
|
|
terminal signal — neither side decisively won.
|
|
"""
|
|
if not info_history:
|
|
return "draw"
|
|
last = info_history[-1]
|
|
reason = last.get("reason")
|
|
if reason == "eliminated":
|
|
return "loss"
|
|
if reason == "harness_error":
|
|
return "loss"
|
|
if reason == "step_cap":
|
|
# Policy stuck in a no-progress loop and the env truncated the
|
|
# whole episode — degenerate non-result, surfaced as its own
|
|
# category so it's visible in the eval JSON.
|
|
return "step_cap"
|
|
# No explicit win yet from the env; use score sign as tiebreaker.
|
|
if total_reward > 0.5:
|
|
return "win"
|
|
return "draw"
|
|
|
|
|
|
def main() -> int:
|
|
args = _build_argparser().parse_args()
|
|
from sb3_contrib import MaskablePPO # type: ignore[import-not-found]
|
|
from tooling.rl_self_play.harness_client import HarnessConfig # type: ignore[import-not-found]
|
|
from tooling.rl_self_play.magic_civ_env import MagicCivEnv # type: ignore[import-not-found]
|
|
|
|
model = MaskablePPO.load(str(args.model_path))
|
|
|
|
wins = losses = draws = step_caps = 0
|
|
turns_per_episode: list[int] = []
|
|
for episode in range(args.episodes):
|
|
cfg = HarnessConfig(
|
|
seed=args.seed_offset + episode,
|
|
players=args.players,
|
|
player_slot=0,
|
|
map_size=args.map_size,
|
|
)
|
|
env = MagicCivEnv(harness_config=cfg, max_turns=args.max_turns)
|
|
try:
|
|
obs, info = env.reset()
|
|
done = False
|
|
total_reward = 0.0
|
|
info_history: list[dict[str, object]] = []
|
|
while not done:
|
|
mask = env.action_masks()
|
|
action, _ = model.predict(obs, action_masks=mask, deterministic=True)
|
|
obs, reward, terminated, truncated, info = env.step(int(action))
|
|
total_reward += reward
|
|
info_history.append(info)
|
|
done = terminated or truncated
|
|
verdict = _classify_episode(info_history, total_reward)
|
|
if verdict == "win":
|
|
wins += 1
|
|
elif verdict == "loss":
|
|
losses += 1
|
|
elif verdict == "step_cap":
|
|
step_caps += 1
|
|
else:
|
|
draws += 1
|
|
turns_per_episode.append(int(info.get("turn", 0)))
|
|
finally:
|
|
env.close()
|
|
|
|
total = max(args.episodes, 1)
|
|
mean_turns = sum(turns_per_episode) / max(len(turns_per_episode), 1)
|
|
verdict = {
|
|
"episodes": args.episodes,
|
|
"wins": wins,
|
|
"losses": losses,
|
|
"draws": draws,
|
|
"step_caps": step_caps,
|
|
"win_rate": wins / total,
|
|
"mean_turns": round(mean_turns, 1),
|
|
}
|
|
print(json.dumps(verdict))
|
|
if step_caps:
|
|
print(
|
|
f"WARNING: {step_caps}/{args.episodes} eval episodes hit the "
|
|
f"per-episode step cap — policy got stuck in a no-progress "
|
|
f"loop. Check encoder/reward shaping.",
|
|
file=sys.stderr,
|
|
)
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|