magicciv/tooling/rl_self_play/evaluate.py
Natalie b7891991a4 feat(@projects/@magic-civilization): add rl_self_play tooling for self-play training
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
2026-05-17 03:54:40 -07:00

117 lines
4.1 KiB
Python

"""Run a held-out evaluation of a trained MaskablePPO model.
Usage:
python -m tooling.rl_self_play.evaluate \
--model-path tooling/rl_self_play/models/duel-v1/final.zip \
--episodes 50
Prints a one-line JSON verdict:
{"episodes": 50, "wins": 28, "losses": 18, "draws": 4,
"win_rate": 0.56, "mean_turns": 142.3}
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
THIS_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = THIS_DIR.parents[1]
if __package__ is None:
sys.path.insert(0, str(PROJECT_ROOT))
def _build_argparser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Evaluate a trained policy against built-in AI")
p.add_argument("--model-path", required=True, type=Path)
p.add_argument("--episodes", type=int, default=50)
p.add_argument("--max-turns", type=int, default=200)
p.add_argument("--seed-offset", type=int, default=10_000,
help="Eval episode seeds = offset + episode_idx; avoids overlap with train seeds")
p.add_argument("--players", type=int, default=2)
p.add_argument("--map-size", default="duel")
return p
def _classify_episode(info_history: list[dict[str, object]], total_reward: float) -> str:
"""Decide win/loss/draw from the last step's info + accumulated reward.
Win: terminated with positive terminal reward (i.e. opponent eliminated
or score-fallback in our favour — currently only "opponent
eliminated" because the env doesn't yet read win events).
Loss: terminated with `reason=eliminated` (we ran out of cities).
Draw: truncated at max_turns OR rolled positive score-shaping but no
terminal signal — neither side decisively won.
"""
if not info_history:
return "draw"
last = info_history[-1]
reason = last.get("reason")
if reason == "eliminated":
return "loss"
if reason == "harness_error":
return "loss"
# No explicit win yet from the env; use score sign as tiebreaker.
if total_reward > 0.5:
return "win"
return "draw"
def main() -> int:
args = _build_argparser().parse_args()
from sb3_contrib import MaskablePPO # type: ignore[import-not-found]
from tooling.rl_self_play.harness_client import HarnessConfig # type: ignore[import-not-found]
from tooling.rl_self_play.magic_civ_env import MagicCivEnv # type: ignore[import-not-found]
model = MaskablePPO.load(str(args.model_path))
wins = losses = draws = 0
turns_per_episode: list[int] = []
for episode in range(args.episodes):
cfg = HarnessConfig(
seed=args.seed_offset + episode,
players=args.players,
player_slot=0,
map_size=args.map_size,
)
env = MagicCivEnv(harness_config=cfg, max_turns=args.max_turns)
try:
obs, info = env.reset()
done = False
total_reward = 0.0
info_history: list[dict[str, object]] = []
while not done:
mask = env.action_masks()
action, _ = model.predict(obs, action_masks=mask, deterministic=True)
obs, reward, terminated, truncated, info = env.step(int(action))
total_reward += reward
info_history.append(info)
done = terminated or truncated
verdict = _classify_episode(info_history, total_reward)
if verdict == "win":
wins += 1
elif verdict == "loss":
losses += 1
else:
draws += 1
turns_per_episode.append(int(info.get("turn", 0)))
finally:
env.close()
total = max(args.episodes, 1)
mean_turns = sum(turns_per_episode) / max(len(turns_per_episode), 1)
verdict = {
"episodes": args.episodes,
"wins": wins,
"losses": losses,
"draws": draws,
"win_rate": wins / total,
"mean_turns": round(mean_turns, 1),
}
print(json.dumps(verdict))
return 0
if __name__ == "__main__":
sys.exit(main())