128 lines
4.7 KiB
Python
128 lines
4.7 KiB
Python
"""Smoke test for the model-backed (self-play) opponent path.
|
|
|
|
Unlike `smoke.py` (stdlib-only, MCTS opponent) this necessarily imports
|
|
`sb3_contrib` + `torch` because it loads a frozen MaskablePPO snapshot
|
|
into the opponent slot. It exercises the full integration: a real
|
|
`MagicCivEnv` configured with a `ModelOpponent`, driven by a *random*
|
|
masked policy on the learner slot for a bounded number of steps.
|
|
|
|
Verifies, against a live harness:
|
|
* the multi-slot wire is honoured (both slots externally driven, no
|
|
harness errors),
|
|
* the frozen opponent takes real turns (the turn counter advances),
|
|
* opponent-driven events bubble into the env (no silent captures),
|
|
* the episode can run without wedging and reports a terminal/truncation
|
|
reason.
|
|
|
|
Usage:
|
|
python3 -m tooling.rl_self_play.smoke_model_opponent \
|
|
--opponent-model tooling/rl_self_play/models/duel-v4-encfix-s7/best_model.zip \
|
|
--steps 400
|
|
|
|
Prints a one-line JSON verdict; exit 0 on `passed: true`.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
THIS_DIR = Path(__file__).resolve().parent
|
|
PROJECT_ROOT = THIS_DIR.parents[1]
|
|
if __package__ is None:
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
|
|
def main() -> int:
|
|
p = argparse.ArgumentParser(description="Model-opponent self-play smoke")
|
|
p.add_argument("--opponent-model", required=True, type=Path)
|
|
p.add_argument("--steps", type=int, default=400)
|
|
p.add_argument("--seed", type=int, default=42)
|
|
p.add_argument("--max-turns", type=int, default=200)
|
|
p.add_argument("--opponent-device", default="cpu")
|
|
args = p.parse_args()
|
|
|
|
from tooling.rl_self_play.harness_client import HarnessConfig
|
|
from tooling.rl_self_play.magic_civ_env import MagicCivEnv
|
|
from tooling.rl_self_play.opponent import ModelOpponent
|
|
|
|
reasons: list[str] = []
|
|
details: dict[str, Any] = {
|
|
"steps": 0,
|
|
"max_turn_seen": 0,
|
|
"mask_violations": 0,
|
|
"opp_turns_implied": 0,
|
|
"opp_events_total": 0,
|
|
"terminal_reason": None,
|
|
}
|
|
|
|
if not args.opponent_model.is_file():
|
|
print(json.dumps({"passed": False,
|
|
"reasons": [f"opponent model not found: {args.opponent_model}"],
|
|
"details": details}))
|
|
return 1
|
|
|
|
rng = np.random.default_rng(args.seed)
|
|
opponent = ModelOpponent(
|
|
model_path=str(args.opponent_model),
|
|
slots=(1,),
|
|
device=args.opponent_device,
|
|
deterministic=False,
|
|
)
|
|
cfg = HarnessConfig(seed=args.seed, players=2, player_slot=0, map_size="duel")
|
|
env = MagicCivEnv(harness_config=cfg, max_turns=args.max_turns, opponent=opponent)
|
|
|
|
try:
|
|
obs, info = env.reset()
|
|
mask = info["action_mask"]
|
|
prev_turn = 0
|
|
for step_idx in range(args.steps):
|
|
legal = np.flatnonzero(mask)
|
|
if legal.size == 0:
|
|
reasons.append(f"empty action mask at step {step_idx}")
|
|
break
|
|
action = int(rng.choice(legal))
|
|
if not mask[action]:
|
|
details["mask_violations"] += 1
|
|
obs, reward, terminated, truncated, info = env.step(action)
|
|
mask = info.get("action_mask", np.zeros_like(mask))
|
|
details["opp_events_total"] += int(info.get("opp_events", 0))
|
|
turn = int(info.get("turn", 0))
|
|
if turn > details["max_turn_seen"]:
|
|
details["max_turn_seen"] = turn
|
|
# Each turn advance past prev implies the opponent took a turn
|
|
# (in all-external 2-slot play the processor steps per end_turn).
|
|
if turn > prev_turn:
|
|
details["opp_turns_implied"] += 1
|
|
prev_turn = turn
|
|
details["steps"] = step_idx + 1
|
|
if terminated or truncated:
|
|
details["terminal_reason"] = info.get("reason")
|
|
break
|
|
except Exception as e: # noqa: BLE001 — smoke wants the failure surfaced
|
|
reasons.append(f"exception: {type(e).__name__}: {e}")
|
|
finally:
|
|
env.close()
|
|
|
|
if details["max_turn_seen"] < 1:
|
|
reasons.append("turn counter never advanced — opponent/turn loop stuck")
|
|
if details["mask_violations"] > 0:
|
|
reasons.append(f"{details['mask_violations']} mask violations")
|
|
if details["opp_events_total"] < 1:
|
|
reasons.append(
|
|
"opponent produced zero wire events across the run — frozen "
|
|
"opponent never acted (likely a stale binary not skipping the "
|
|
"external slot, so the simulator AI drove it instead)"
|
|
)
|
|
|
|
passed = not reasons
|
|
print(json.dumps({"passed": passed, "reasons": reasons, "details": details}))
|
|
return 0 if passed else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|