magicciv/tooling/rl_self_play/opponent.py
2026-05-27 20:15:34 -07:00

125 lines
5 KiB
Python

"""Model-backed frozen opponent for the self-play curriculum.
Drives one or more *externally-controlled* player slots with a saved
MaskablePPO snapshot, so the learning policy on slot 0 trains against a
fixed **learned** policy rather than the harness's built-in MCTS. This
is the AlphaZero-style curriculum rung described in `README.md`:
graduate a policy, freeze it, train the next generation against it.
The opponent shares the exact obs/action encoders the learner uses
(`encoders.py`). Both slots are the same Dwarf race with the same
`PlayerView` shape, so the encoding is symmetric — `encode_observation`
keys "me" off `view["player"]`, which the harness fills with whichever
slot was queried via the multi-slot `slot` kwarg.
The frozen model loads on CPU per env worker (the learner owns the GPU;
the MlpPolicy is well under 1 GB). Loading is lazy so the heavy
`sb3_contrib` import happens inside the SubprocVecEnv worker process,
never in the parent that forks them.
"""
from __future__ import annotations
from typing import Any
from .encoders import (
decode_action_index,
encode_legal_actions,
encode_observation,
)
from .harness_client import HarnessClient
# Hard ceiling on opponent micro-actions per turn. The frozen policy is
# sampled stochastically; without a budget a degenerate snapshot could
# loop on skip/noop and never emit end_turn, wedging the episode.
MAX_OPPONENT_ACTIONS_PER_TURN = 5000
# Wire-event tags that mean the game ended mid-opponent-turn — once one
# fires there is nothing left to drive, so stop immediately.
_TERMINAL_EVENT_TYPES = frozenset({"game_over", "player_eliminated"})
def _has_terminal(events: list[dict[str, Any]]) -> bool:
return any(e.get("type") in _TERMINAL_EVENT_TYPES for e in events)
class ModelOpponent:
"""A frozen MaskablePPO snapshot driving the opponent slot(s).
One instance per env. `play_turn` is called by `MagicCivEnv.step`
after the learner ends its turn, and runs each opponent slot through
a full turn (many micro-actions then `end_turn`), returning every
wire event so the env can fold them into its termination + reward
bookkeeping.
"""
def __init__(
self,
model_path: str,
slots: tuple[int, ...] = (1,),
device: str = "cpu",
deterministic: bool = False,
) -> None:
self._model_path = model_path
self._slots = slots
self._device = device
self._deterministic = deterministic
self._model: Any = None # lazy-loaded inside the worker
@property
def slots(self) -> tuple[int, ...]:
return self._slots
@property
def model_path(self) -> str:
return self._model_path
def _ensure_model(self) -> Any:
if self._model is None:
from sb3_contrib import MaskablePPO # type: ignore[import-not-found]
self._model = MaskablePPO.load(self._model_path, device=self._device)
return self._model
def play_turn(self, client: HarnessClient) -> list[dict[str, Any]]:
"""Drive every opponent slot through one full turn.
Returns all wire events collected across the opponent's actions
(each `act`/`end_turn` response carries an `events` array, plus a
final notification drain). The caller merges these into the
learner's `recent_events` so opponent-driven captures, kills, and
`game_over` are scored and detected.
May raise `HarnessError` — the caller wraps the whole opponent
turn in the same try/except that guards the learner's action, so
a dead harness terminates the episode as a loss.
"""
model = self._ensure_model()
events: list[dict[str, Any]] = []
for slot in self._slots:
for _ in range(MAX_OPPONENT_ACTIONS_PER_TURN):
view = client.view(slot=slot)
mask, idx_to_action = encode_legal_actions(view)
if not mask.any():
resp = client.end_turn(slot=slot)
events.extend(resp.get("events", []))
break
obs = encode_observation(view)
action_idx, _ = model.predict(
obs, action_masks=mask, deterministic=self._deterministic
)
player_action = decode_action_index(int(action_idx), idx_to_action)
if player_action.get("type") == "end_turn":
resp = client.end_turn(slot=slot)
events.extend(resp.get("events", []))
break
resp = client.act(player_action, slot=slot)
resp_events = resp.get("events", [])
events.extend(resp_events)
if _has_terminal(resp_events):
return events
else:
# Budget exhausted without an end_turn — force the boundary.
resp = client.end_turn(slot=slot)
events.extend(resp.get("events", []))
events.extend(client.drain_notifications())
return events