feat(rl-self-play): ✨ Implement opponent model loading, execution, and behavior management for reinforcement learning self-play
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
parent
4564074d86
commit
236160134c
1 changed files with 125 additions and 0 deletions
125
tooling/rl_self_play/opponent.py
Normal file
125
tooling/rl_self_play/opponent.py
Normal file
|
|
@ -0,0 +1,125 @@
|
|||
"""Model-backed frozen opponent for the self-play curriculum.
|
||||
|
||||
Drives one or more *externally-controlled* player slots with a saved
|
||||
MaskablePPO snapshot, so the learning policy on slot 0 trains against a
|
||||
fixed **learned** policy rather than the harness's built-in MCTS. This
|
||||
is the AlphaZero-style curriculum rung described in `README.md`:
|
||||
graduate a policy, freeze it, train the next generation against it.
|
||||
|
||||
The opponent shares the exact obs/action encoders the learner uses
|
||||
(`encoders.py`). Both slots are the same Dwarf race with the same
|
||||
`PlayerView` shape, so the encoding is symmetric — `encode_observation`
|
||||
keys "me" off `view["player"]`, which the harness fills with whichever
|
||||
slot was queried via the multi-slot `slot` kwarg.
|
||||
|
||||
The frozen model loads on CPU per env worker (the learner owns the GPU;
|
||||
the MlpPolicy is well under 1 GB). Loading is lazy so the heavy
|
||||
`sb3_contrib` import happens inside the SubprocVecEnv worker process,
|
||||
never in the parent that forks them.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .encoders import (
|
||||
decode_action_index,
|
||||
encode_legal_actions,
|
||||
encode_observation,
|
||||
)
|
||||
from .harness_client import HarnessClient
|
||||
|
||||
# Hard ceiling on opponent micro-actions per turn. The frozen policy is
|
||||
# sampled stochastically; without a budget a degenerate snapshot could
|
||||
# loop on skip/noop and never emit end_turn, wedging the episode.
|
||||
MAX_OPPONENT_ACTIONS_PER_TURN = 5000
|
||||
|
||||
# Wire-event tags that mean the game ended mid-opponent-turn — once one
|
||||
# fires there is nothing left to drive, so stop immediately.
|
||||
_TERMINAL_EVENT_TYPES = frozenset({"game_over", "player_eliminated"})
|
||||
|
||||
|
||||
def _has_terminal(events: list[dict[str, Any]]) -> bool:
|
||||
return any(e.get("type") in _TERMINAL_EVENT_TYPES for e in events)
|
||||
|
||||
|
||||
class ModelOpponent:
|
||||
"""A frozen MaskablePPO snapshot driving the opponent slot(s).
|
||||
|
||||
One instance per env. `play_turn` is called by `MagicCivEnv.step`
|
||||
after the learner ends its turn, and runs each opponent slot through
|
||||
a full turn (many micro-actions then `end_turn`), returning every
|
||||
wire event so the env can fold them into its termination + reward
|
||||
bookkeeping.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
slots: tuple[int, ...] = (1,),
|
||||
device: str = "cpu",
|
||||
deterministic: bool = False,
|
||||
) -> None:
|
||||
self._model_path = model_path
|
||||
self._slots = slots
|
||||
self._device = device
|
||||
self._deterministic = deterministic
|
||||
self._model: Any = None # lazy-loaded inside the worker
|
||||
|
||||
@property
|
||||
def slots(self) -> tuple[int, ...]:
|
||||
return self._slots
|
||||
|
||||
@property
|
||||
def model_path(self) -> str:
|
||||
return self._model_path
|
||||
|
||||
def _ensure_model(self) -> Any:
|
||||
if self._model is None:
|
||||
from sb3_contrib import MaskablePPO # type: ignore[import-not-found]
|
||||
|
||||
self._model = MaskablePPO.load(self._model_path, device=self._device)
|
||||
return self._model
|
||||
|
||||
def play_turn(self, client: HarnessClient) -> list[dict[str, Any]]:
|
||||
"""Drive every opponent slot through one full turn.
|
||||
|
||||
Returns all wire events collected across the opponent's actions
|
||||
(each `act`/`end_turn` response carries an `events` array, plus a
|
||||
final notification drain). The caller merges these into the
|
||||
learner's `recent_events` so opponent-driven captures, kills, and
|
||||
`game_over` are scored and detected.
|
||||
|
||||
May raise `HarnessError` — the caller wraps the whole opponent
|
||||
turn in the same try/except that guards the learner's action, so
|
||||
a dead harness terminates the episode as a loss.
|
||||
"""
|
||||
model = self._ensure_model()
|
||||
events: list[dict[str, Any]] = []
|
||||
for slot in self._slots:
|
||||
for _ in range(MAX_OPPONENT_ACTIONS_PER_TURN):
|
||||
view = client.view(slot=slot)
|
||||
mask, idx_to_action = encode_legal_actions(view)
|
||||
if not mask.any():
|
||||
resp = client.end_turn(slot=slot)
|
||||
events.extend(resp.get("events", []))
|
||||
break
|
||||
obs = encode_observation(view)
|
||||
action_idx, _ = model.predict(
|
||||
obs, action_masks=mask, deterministic=self._deterministic
|
||||
)
|
||||
player_action = decode_action_index(int(action_idx), idx_to_action)
|
||||
if player_action.get("type") == "end_turn":
|
||||
resp = client.end_turn(slot=slot)
|
||||
events.extend(resp.get("events", []))
|
||||
break
|
||||
resp = client.act(player_action, slot=slot)
|
||||
resp_events = resp.get("events", [])
|
||||
events.extend(resp_events)
|
||||
if _has_terminal(resp_events):
|
||||
return events
|
||||
else:
|
||||
# Budget exhausted without an end_turn — force the boundary.
|
||||
resp = client.end_turn(slot=slot)
|
||||
events.extend(resp.get("events", []))
|
||||
events.extend(client.drain_notifications())
|
||||
return events
|
||||
Loading…
Add table
Reference in a new issue