125 lines
5 KiB
Python
125 lines
5 KiB
Python
"""Model-backed frozen opponent for the self-play curriculum.
|
|
|
|
Drives one or more *externally-controlled* player slots with a saved
|
|
MaskablePPO snapshot, so the learning policy on slot 0 trains against a
|
|
fixed **learned** policy rather than the harness's built-in MCTS. This
|
|
is the AlphaZero-style curriculum rung described in `README.md`:
|
|
graduate a policy, freeze it, train the next generation against it.
|
|
|
|
The opponent shares the exact obs/action encoders the learner uses
|
|
(`encoders.py`). Both slots are the same Dwarf race with the same
|
|
`PlayerView` shape, so the encoding is symmetric — `encode_observation`
|
|
keys "me" off `view["player"]`, which the harness fills with whichever
|
|
slot was queried via the multi-slot `slot` kwarg.
|
|
|
|
The frozen model loads on CPU per env worker (the learner owns the GPU;
|
|
the MlpPolicy is well under 1 GB). Loading is lazy so the heavy
|
|
`sb3_contrib` import happens inside the SubprocVecEnv worker process,
|
|
never in the parent that forks them.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from .encoders import (
|
|
decode_action_index,
|
|
encode_legal_actions,
|
|
encode_observation,
|
|
)
|
|
from .harness_client import HarnessClient
|
|
|
|
# Hard ceiling on opponent micro-actions per turn. The frozen policy is
|
|
# sampled stochastically; without a budget a degenerate snapshot could
|
|
# loop on skip/noop and never emit end_turn, wedging the episode.
|
|
MAX_OPPONENT_ACTIONS_PER_TURN = 5000
|
|
|
|
# Wire-event tags that mean the game ended mid-opponent-turn — once one
|
|
# fires there is nothing left to drive, so stop immediately.
|
|
_TERMINAL_EVENT_TYPES = frozenset({"game_over", "player_eliminated"})
|
|
|
|
|
|
def _has_terminal(events: list[dict[str, Any]]) -> bool:
|
|
return any(e.get("type") in _TERMINAL_EVENT_TYPES for e in events)
|
|
|
|
|
|
class ModelOpponent:
|
|
"""A frozen MaskablePPO snapshot driving the opponent slot(s).
|
|
|
|
One instance per env. `play_turn` is called by `MagicCivEnv.step`
|
|
after the learner ends its turn, and runs each opponent slot through
|
|
a full turn (many micro-actions then `end_turn`), returning every
|
|
wire event so the env can fold them into its termination + reward
|
|
bookkeeping.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_path: str,
|
|
slots: tuple[int, ...] = (1,),
|
|
device: str = "cpu",
|
|
deterministic: bool = False,
|
|
) -> None:
|
|
self._model_path = model_path
|
|
self._slots = slots
|
|
self._device = device
|
|
self._deterministic = deterministic
|
|
self._model: Any = None # lazy-loaded inside the worker
|
|
|
|
@property
|
|
def slots(self) -> tuple[int, ...]:
|
|
return self._slots
|
|
|
|
@property
|
|
def model_path(self) -> str:
|
|
return self._model_path
|
|
|
|
def _ensure_model(self) -> Any:
|
|
if self._model is None:
|
|
from sb3_contrib import MaskablePPO # type: ignore[import-not-found]
|
|
|
|
self._model = MaskablePPO.load(self._model_path, device=self._device)
|
|
return self._model
|
|
|
|
def play_turn(self, client: HarnessClient) -> list[dict[str, Any]]:
|
|
"""Drive every opponent slot through one full turn.
|
|
|
|
Returns all wire events collected across the opponent's actions
|
|
(each `act`/`end_turn` response carries an `events` array, plus a
|
|
final notification drain). The caller merges these into the
|
|
learner's `recent_events` so opponent-driven captures, kills, and
|
|
`game_over` are scored and detected.
|
|
|
|
May raise `HarnessError` — the caller wraps the whole opponent
|
|
turn in the same try/except that guards the learner's action, so
|
|
a dead harness terminates the episode as a loss.
|
|
"""
|
|
model = self._ensure_model()
|
|
events: list[dict[str, Any]] = []
|
|
for slot in self._slots:
|
|
for _ in range(MAX_OPPONENT_ACTIONS_PER_TURN):
|
|
view = client.view(slot=slot)
|
|
mask, idx_to_action = encode_legal_actions(view)
|
|
if not mask.any():
|
|
resp = client.end_turn(slot=slot)
|
|
events.extend(resp.get("events", []))
|
|
break
|
|
obs = encode_observation(view)
|
|
action_idx, _ = model.predict(
|
|
obs, action_masks=mask, deterministic=self._deterministic
|
|
)
|
|
player_action = decode_action_index(int(action_idx), idx_to_action)
|
|
if player_action.get("type") == "end_turn":
|
|
resp = client.end_turn(slot=slot)
|
|
events.extend(resp.get("events", []))
|
|
break
|
|
resp = client.act(player_action, slot=slot)
|
|
resp_events = resp.get("events", [])
|
|
events.extend(resp_events)
|
|
if _has_terminal(resp_events):
|
|
return events
|
|
else:
|
|
# Budget exhausted without an end_turn — force the boundary.
|
|
resp = client.end_turn(slot=slot)
|
|
events.extend(resp.get("events", []))
|
|
events.extend(client.drain_notifications())
|
|
return events
|