magicciv/tooling/rl_self_play/smoke.py
Natalie 7cdc8178b7 feat(tooling): add smoke test for protocol layer
Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
2026-05-17 03:59:39 -07:00

153 lines
4.8 KiB
Python

"""Stdlib-only smoke test for the harness + encoder layer.
Verifies — without needing `gymnasium`, `stable-baselines3`, or `torch` —
that the protocol shim works end-to-end:
1. `HarnessClient` spawns the Godot subprocess and returns a valid
`view` JSON on first request.
2. `encode_observation` projects every view into a fixed-shape
`np.float32[OBS_DIM]` without raising.
3. `encode_legal_actions` produces a boolean mask whose `True`
positions all map back to a legal `PlayerAction` via
`decode_action_index`.
4. A random-policy loop bounded by `--turns` reaches the turn limit
OR terminates cleanly without raising `HarnessError`.
Run:
python3 -m tooling.rl_self_play.smoke [--turns 30] [--seed 42]
Output is one-line JSON like:
{"steps": 87, "turns_reached": 30, "mask_violations": 0,
"harness_errors": 0, "obs_dim": 32, "action_dim": 322,
"episodes": 1, "passed": true}
Exit 0 on `passed: true`; non-zero otherwise. Suitable as a CI gate
before any real training run.
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
import numpy as np
THIS_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = THIS_DIR.parents[1]
if __package__ is None:
sys.path.insert(0, str(PROJECT_ROOT))
from tooling.rl_self_play.encoders import ( # noqa: E402
ACTION_DIM,
OBS_DIM,
decode_action_index,
encode_legal_actions,
encode_observation,
)
from tooling.rl_self_play.harness_client import ( # noqa: E402
HarnessClient,
HarnessConfig,
HarnessError,
)
def _build_argparser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description="Smoke-test the harness + encoder layer")
p.add_argument("--turns", type=int, default=30, help="Max turns per episode")
p.add_argument("--episodes", type=int, default=1, help="Episodes to run")
p.add_argument("--seed", type=int, default=42, help="Base RNG seed")
p.add_argument("--players", type=int, default=2)
p.add_argument("--map-size", default="duel")
return p
def _run_episode(
client: HarnessClient, rng: np.random.Generator, max_turns: int
) -> dict[str, int]:
steps = 0
mask_violations = 0
eliminations = 0
last_turn = 0
view = client.view()
while last_turn < max_turns:
obs = encode_observation(view)
if obs.shape != (OBS_DIM,):
mask_violations += 1
break
mask, idx_to_action = encode_legal_actions(view)
legal_indices = np.where(mask)[0]
if legal_indices.size == 0:
mask_violations += 1
break
idx = int(rng.choice(legal_indices))
action = decode_action_index(idx, idx_to_action)
if action.get("type") == "end_turn":
client.end_turn()
else:
client.act(action)
view = client.view()
last_turn = int(view.get("turn", 0))
steps += 1
score = view.get("score", {})
if int(score.get("city_count", 0)) == 0:
units = view.get("units", [])
me = int(view.get("player", 0))
has_founder = any(
int(u.get("owner", -1)) == me
and "founder" in str(u.get("type", ""))
and float(u.get("hp", 0)) > 0
for u in units
)
if not has_founder:
eliminations += 1
break
return {
"steps": steps,
"turns_reached": last_turn,
"mask_violations": mask_violations,
"eliminations": eliminations,
}
def main() -> int:
args = _build_argparser().parse_args()
rng = np.random.default_rng(args.seed)
totals = {
"steps": 0,
"turns_reached": 0,
"mask_violations": 0,
"eliminations": 0,
"harness_errors": 0,
}
for episode in range(args.episodes):
cfg = HarnessConfig(
seed=args.seed + episode,
players=args.players,
player_slot=0,
map_size=args.map_size,
)
with HarnessClient(cfg) as client:
try:
result = _run_episode(client, rng, args.turns)
except HarnessError:
totals["harness_errors"] += 1
continue
totals["steps"] += result["steps"]
totals["turns_reached"] = max(totals["turns_reached"], result["turns_reached"])
totals["mask_violations"] += result["mask_violations"]
totals["eliminations"] += result["eliminations"]
verdict = {
**totals,
"obs_dim": OBS_DIM,
"action_dim": ACTION_DIM,
"episodes": args.episodes,
"passed": totals["mask_violations"] == 0 and totals["harness_errors"] == 0,
}
print(json.dumps(verdict))
return 0 if verdict["passed"] else 1
if __name__ == "__main__":
sys.exit(main())