magicciv/tooling/rl_self_play/smoke_model_opponent.py
2026-05-27 20:26:00 -07:00

128 lines
4.7 KiB
Python

"""Smoke test for the model-backed (self-play) opponent path.
Unlike `smoke.py` (stdlib-only, MCTS opponent) this necessarily imports
`sb3_contrib` + `torch` because it loads a frozen MaskablePPO snapshot
into the opponent slot. It exercises the full integration: a real
`MagicCivEnv` configured with a `ModelOpponent`, driven by a *random*
masked policy on the learner slot for a bounded number of steps.
Verifies, against a live harness:
* the multi-slot wire is honoured (both slots externally driven, no
harness errors),
* the frozen opponent takes real turns (the turn counter advances),
* opponent-driven events bubble into the env (no silent captures),
* the episode can run without wedging and reports a terminal/truncation
reason.
Usage:
python3 -m tooling.rl_self_play.smoke_model_opponent \
--opponent-model tooling/rl_self_play/models/duel-v4-encfix-s7/best_model.zip \
--steps 400
Prints a one-line JSON verdict; exit 0 on `passed: true`.
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Any
import numpy as np
THIS_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = THIS_DIR.parents[1]
if __package__ is None:
sys.path.insert(0, str(PROJECT_ROOT))
def main() -> int:
p = argparse.ArgumentParser(description="Model-opponent self-play smoke")
p.add_argument("--opponent-model", required=True, type=Path)
p.add_argument("--steps", type=int, default=400)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--max-turns", type=int, default=200)
p.add_argument("--opponent-device", default="cpu")
args = p.parse_args()
from tooling.rl_self_play.harness_client import HarnessConfig
from tooling.rl_self_play.magic_civ_env import MagicCivEnv
from tooling.rl_self_play.opponent import ModelOpponent
reasons: list[str] = []
details: dict[str, Any] = {
"steps": 0,
"max_turn_seen": 0,
"mask_violations": 0,
"opp_turns_implied": 0,
"opp_events_total": 0,
"terminal_reason": None,
}
if not args.opponent_model.is_file():
print(json.dumps({"passed": False,
"reasons": [f"opponent model not found: {args.opponent_model}"],
"details": details}))
return 1
rng = np.random.default_rng(args.seed)
opponent = ModelOpponent(
model_path=str(args.opponent_model),
slots=(1,),
device=args.opponent_device,
deterministic=False,
)
cfg = HarnessConfig(seed=args.seed, players=2, player_slot=0, map_size="duel")
env = MagicCivEnv(harness_config=cfg, max_turns=args.max_turns, opponent=opponent)
try:
obs, info = env.reset()
mask = info["action_mask"]
prev_turn = 0
for step_idx in range(args.steps):
legal = np.flatnonzero(mask)
if legal.size == 0:
reasons.append(f"empty action mask at step {step_idx}")
break
action = int(rng.choice(legal))
if not mask[action]:
details["mask_violations"] += 1
obs, reward, terminated, truncated, info = env.step(action)
mask = info.get("action_mask", np.zeros_like(mask))
details["opp_events_total"] += int(info.get("opp_events", 0))
turn = int(info.get("turn", 0))
if turn > details["max_turn_seen"]:
details["max_turn_seen"] = turn
# Each turn advance past prev implies the opponent took a turn
# (in all-external 2-slot play the processor steps per end_turn).
if turn > prev_turn:
details["opp_turns_implied"] += 1
prev_turn = turn
details["steps"] = step_idx + 1
if terminated or truncated:
details["terminal_reason"] = info.get("reason")
break
except Exception as e: # noqa: BLE001 — smoke wants the failure surfaced
reasons.append(f"exception: {type(e).__name__}: {e}")
finally:
env.close()
if details["max_turn_seen"] < 1:
reasons.append("turn counter never advanced — opponent/turn loop stuck")
if details["mask_violations"] > 0:
reasons.append(f"{details['mask_violations']} mask violations")
if details["opp_events_total"] < 1:
reasons.append(
"opponent produced zero wire events across the run — frozen "
"opponent never acted (likely a stale binary not skipping the "
"external slot, so the simulator AI drove it instead)"
)
passed = not reasons
print(json.dumps({"passed": passed, "reasons": reasons, "details": details}))
return 0 if passed else 1
if __name__ == "__main__":
sys.exit(main())