feat(rl-self-play): Add learned opponent policy evaluation options to RL self-play evaluation script

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
autocommit 2026-05-27 20:15:33 -07:00
parent bb15503079
commit e2e578cdab

View file

@ -41,6 +41,16 @@ def _build_argparser() -> argparse.ArgumentParser:
help="Eval episode seeds = offset + episode_idx; avoids overlap with train seeds")
p.add_argument("--players", type=int, default=2)
p.add_argument("--map-size", default="duel")
p.add_argument("--opponent-model", default=None, type=Path,
help=("Frozen MaskablePPO snapshot (.zip) to use as the "
"opponent instead of the harness MCTS. Set to the "
"graduated snapshot to measure win-rate against a "
"learned policy (self-play curriculum gate)."))
p.add_argument("--opponent-slots", default="1",
help="Comma-separated opponent slot indices (default '1').")
p.add_argument("--opponent-device", default="cpu")
p.add_argument("--opponent-deterministic", action="store_true",
help="Argmax opponent actions (default: stochastic sampling).")
return p
@ -75,9 +85,24 @@ def main() -> int:
from sb3_contrib import MaskablePPO # type: ignore[import-not-found]
from tooling.rl_self_play.harness_client import HarnessConfig # type: ignore[import-not-found]
from tooling.rl_self_play.magic_civ_env import MagicCivEnv # type: ignore[import-not-found]
from tooling.rl_self_play.opponent import ModelOpponent # type: ignore[import-not-found]
model = MaskablePPO.load(str(args.model_path))
opp_slots: tuple[int, ...] = tuple(
int(s) for s in str(args.opponent_slots).split(",") if s.strip()
)
def _make_opponent() -> ModelOpponent | None:
if not args.opponent_model:
return None
return ModelOpponent(
model_path=str(args.opponent_model),
slots=opp_slots,
device=args.opponent_device,
deterministic=args.opponent_deterministic,
)
wins = losses = draws = step_caps = turn_caps = 0
turns_per_episode: list[int] = []
for episode in range(args.episodes):
@ -87,7 +112,9 @@ def main() -> int:
player_slot=0,
map_size=args.map_size,
)
env = MagicCivEnv(harness_config=cfg, max_turns=args.max_turns)
env = MagicCivEnv(
harness_config=cfg, max_turns=args.max_turns, opponent=_make_opponent()
)
try:
obs, info = env.reset()
done = False