From e2e578cdabbc7b81a91db2f14310cb30f27cd633 Mon Sep 17 00:00:00 2001 From: autocommit Date: Wed, 27 May 2026 20:15:33 -0700 Subject: [PATCH] =?UTF-8?q?feat(rl-self-play):=20=E2=9C=A8=20Add=20learned?= =?UTF-8?q?=20opponent=20policy=20evaluation=20options=20to=20RL=20self-pl?= =?UTF-8?q?ay=20evaluation=20script?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Lilith Autocommit --- tooling/rl_self_play/evaluate.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tooling/rl_self_play/evaluate.py b/tooling/rl_self_play/evaluate.py index ce736fb2..238745c7 100644 --- a/tooling/rl_self_play/evaluate.py +++ b/tooling/rl_self_play/evaluate.py @@ -41,6 +41,16 @@ def _build_argparser() -> argparse.ArgumentParser: help="Eval episode seeds = offset + episode_idx; avoids overlap with train seeds") p.add_argument("--players", type=int, default=2) p.add_argument("--map-size", default="duel") + p.add_argument("--opponent-model", default=None, type=Path, + help=("Frozen MaskablePPO snapshot (.zip) to use as the " + "opponent instead of the harness MCTS. Set to the " + "graduated snapshot to measure win-rate against a " + "learned policy (self-play curriculum gate).")) + p.add_argument("--opponent-slots", default="1", + help="Comma-separated opponent slot indices (default '1').") + p.add_argument("--opponent-device", default="cpu") + p.add_argument("--opponent-deterministic", action="store_true", + help="Argmax opponent actions (default: stochastic sampling).") return p @@ -75,9 +85,24 @@ def main() -> int: from sb3_contrib import MaskablePPO # type: ignore[import-not-found] from tooling.rl_self_play.harness_client import HarnessConfig # type: ignore[import-not-found] from tooling.rl_self_play.magic_civ_env import MagicCivEnv # type: ignore[import-not-found] + from tooling.rl_self_play.opponent import ModelOpponent # type: ignore[import-not-found] model = MaskablePPO.load(str(args.model_path)) + opp_slots: tuple[int, ...] = tuple( + int(s) for s in str(args.opponent_slots).split(",") if s.strip() + ) + + def _make_opponent() -> ModelOpponent | None: + if not args.opponent_model: + return None + return ModelOpponent( + model_path=str(args.opponent_model), + slots=opp_slots, + device=args.opponent_device, + deterministic=args.opponent_deterministic, + ) + wins = losses = draws = step_caps = turn_caps = 0 turns_per_episode: list[int] = [] for episode in range(args.episodes): @@ -87,7 +112,9 @@ def main() -> int: player_slot=0, map_size=args.map_size, ) - env = MagicCivEnv(harness_config=cfg, max_turns=args.max_turns) + env = MagicCivEnv( + harness_config=cfg, max_turns=args.max_turns, opponent=_make_opponent() + ) try: obs, info = env.reset() done = False