From 236160134c066ec7ce83a03931bec7b07ee39e58 Mon Sep 17 00:00:00 2001 From: autocommit Date: Wed, 27 May 2026 20:15:34 -0700 Subject: [PATCH] =?UTF-8?q?feat(rl-self-play):=20=E2=9C=A8=20Implement=20o?= =?UTF-8?q?pponent=20model=20loading,=20execution,=20and=20behavior=20mana?= =?UTF-8?q?gement=20for=20reinforcement=20learning=20self-play?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Lilith Autocommit --- tooling/rl_self_play/opponent.py | 125 +++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 tooling/rl_self_play/opponent.py diff --git a/tooling/rl_self_play/opponent.py b/tooling/rl_self_play/opponent.py new file mode 100644 index 00000000..075ffe21 --- /dev/null +++ b/tooling/rl_self_play/opponent.py @@ -0,0 +1,125 @@ +"""Model-backed frozen opponent for the self-play curriculum. + +Drives one or more *externally-controlled* player slots with a saved +MaskablePPO snapshot, so the learning policy on slot 0 trains against a +fixed **learned** policy rather than the harness's built-in MCTS. This +is the AlphaZero-style curriculum rung described in `README.md`: +graduate a policy, freeze it, train the next generation against it. + +The opponent shares the exact obs/action encoders the learner uses +(`encoders.py`). Both slots are the same Dwarf race with the same +`PlayerView` shape, so the encoding is symmetric — `encode_observation` +keys "me" off `view["player"]`, which the harness fills with whichever +slot was queried via the multi-slot `slot` kwarg. + +The frozen model loads on CPU per env worker (the learner owns the GPU; +the MlpPolicy is well under 1 GB). Loading is lazy so the heavy +`sb3_contrib` import happens inside the SubprocVecEnv worker process, +never in the parent that forks them. +""" +from __future__ import annotations + +from typing import Any + +from .encoders import ( + decode_action_index, + encode_legal_actions, + encode_observation, +) +from .harness_client import HarnessClient + +# Hard ceiling on opponent micro-actions per turn. The frozen policy is +# sampled stochastically; without a budget a degenerate snapshot could +# loop on skip/noop and never emit end_turn, wedging the episode. +MAX_OPPONENT_ACTIONS_PER_TURN = 5000 + +# Wire-event tags that mean the game ended mid-opponent-turn — once one +# fires there is nothing left to drive, so stop immediately. +_TERMINAL_EVENT_TYPES = frozenset({"game_over", "player_eliminated"}) + + +def _has_terminal(events: list[dict[str, Any]]) -> bool: + return any(e.get("type") in _TERMINAL_EVENT_TYPES for e in events) + + +class ModelOpponent: + """A frozen MaskablePPO snapshot driving the opponent slot(s). + + One instance per env. `play_turn` is called by `MagicCivEnv.step` + after the learner ends its turn, and runs each opponent slot through + a full turn (many micro-actions then `end_turn`), returning every + wire event so the env can fold them into its termination + reward + bookkeeping. + """ + + def __init__( + self, + model_path: str, + slots: tuple[int, ...] = (1,), + device: str = "cpu", + deterministic: bool = False, + ) -> None: + self._model_path = model_path + self._slots = slots + self._device = device + self._deterministic = deterministic + self._model: Any = None # lazy-loaded inside the worker + + @property + def slots(self) -> tuple[int, ...]: + return self._slots + + @property + def model_path(self) -> str: + return self._model_path + + def _ensure_model(self) -> Any: + if self._model is None: + from sb3_contrib import MaskablePPO # type: ignore[import-not-found] + + self._model = MaskablePPO.load(self._model_path, device=self._device) + return self._model + + def play_turn(self, client: HarnessClient) -> list[dict[str, Any]]: + """Drive every opponent slot through one full turn. + + Returns all wire events collected across the opponent's actions + (each `act`/`end_turn` response carries an `events` array, plus a + final notification drain). The caller merges these into the + learner's `recent_events` so opponent-driven captures, kills, and + `game_over` are scored and detected. + + May raise `HarnessError` — the caller wraps the whole opponent + turn in the same try/except that guards the learner's action, so + a dead harness terminates the episode as a loss. + """ + model = self._ensure_model() + events: list[dict[str, Any]] = [] + for slot in self._slots: + for _ in range(MAX_OPPONENT_ACTIONS_PER_TURN): + view = client.view(slot=slot) + mask, idx_to_action = encode_legal_actions(view) + if not mask.any(): + resp = client.end_turn(slot=slot) + events.extend(resp.get("events", [])) + break + obs = encode_observation(view) + action_idx, _ = model.predict( + obs, action_masks=mask, deterministic=self._deterministic + ) + player_action = decode_action_index(int(action_idx), idx_to_action) + if player_action.get("type") == "end_turn": + resp = client.end_turn(slot=slot) + events.extend(resp.get("events", [])) + break + resp = client.act(player_action, slot=slot) + resp_events = resp.get("events", []) + events.extend(resp_events) + if _has_terminal(resp_events): + return events + else: + # Budget exhausted without an end_turn — force the boundary. + resp = client.end_turn(slot=slot) + events.extend(resp.get("events", [])) + events.extend(client.drain_notifications()) + return events