153 lines
4.8 KiB
Python
153 lines
4.8 KiB
Python
"""Stdlib-only smoke test for the harness + encoder layer.
|
|
|
|
Verifies — without needing `gymnasium`, `stable-baselines3`, or `torch` —
|
|
that the protocol shim works end-to-end:
|
|
|
|
1. `HarnessClient` spawns the Godot subprocess and returns a valid
|
|
`view` JSON on first request.
|
|
2. `encode_observation` projects every view into a fixed-shape
|
|
`np.float32[OBS_DIM]` without raising.
|
|
3. `encode_legal_actions` produces a boolean mask whose `True`
|
|
positions all map back to a legal `PlayerAction` via
|
|
`decode_action_index`.
|
|
4. A random-policy loop bounded by `--turns` reaches the turn limit
|
|
OR terminates cleanly without raising `HarnessError`.
|
|
|
|
Run:
|
|
python3 -m tooling.rl_self_play.smoke [--turns 30] [--seed 42]
|
|
|
|
Output is one-line JSON like:
|
|
|
|
{"steps": 87, "turns_reached": 30, "mask_violations": 0,
|
|
"harness_errors": 0, "obs_dim": 32, "action_dim": 322,
|
|
"episodes": 1, "passed": true}
|
|
|
|
Exit 0 on `passed: true`; non-zero otherwise. Suitable as a CI gate
|
|
before any real training run.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
THIS_DIR = Path(__file__).resolve().parent
|
|
PROJECT_ROOT = THIS_DIR.parents[1]
|
|
if __package__ is None:
|
|
sys.path.insert(0, str(PROJECT_ROOT))
|
|
|
|
from tooling.rl_self_play.encoders import ( # noqa: E402
|
|
ACTION_DIM,
|
|
OBS_DIM,
|
|
decode_action_index,
|
|
encode_legal_actions,
|
|
encode_observation,
|
|
)
|
|
from tooling.rl_self_play.harness_client import ( # noqa: E402
|
|
HarnessClient,
|
|
HarnessConfig,
|
|
HarnessError,
|
|
)
|
|
|
|
|
|
def _build_argparser() -> argparse.ArgumentParser:
|
|
p = argparse.ArgumentParser(description="Smoke-test the harness + encoder layer")
|
|
p.add_argument("--turns", type=int, default=30, help="Max turns per episode")
|
|
p.add_argument("--episodes", type=int, default=1, help="Episodes to run")
|
|
p.add_argument("--seed", type=int, default=42, help="Base RNG seed")
|
|
p.add_argument("--players", type=int, default=2)
|
|
p.add_argument("--map-size", default="duel")
|
|
return p
|
|
|
|
|
|
def _run_episode(
|
|
client: HarnessClient, rng: np.random.Generator, max_turns: int
|
|
) -> dict[str, int]:
|
|
steps = 0
|
|
mask_violations = 0
|
|
eliminations = 0
|
|
last_turn = 0
|
|
view = client.view()
|
|
while last_turn < max_turns:
|
|
obs = encode_observation(view)
|
|
if obs.shape != (OBS_DIM,):
|
|
mask_violations += 1
|
|
break
|
|
mask, idx_to_action = encode_legal_actions(view)
|
|
legal_indices = np.where(mask)[0]
|
|
if legal_indices.size == 0:
|
|
mask_violations += 1
|
|
break
|
|
idx = int(rng.choice(legal_indices))
|
|
action = decode_action_index(idx, idx_to_action)
|
|
if action.get("type") == "end_turn":
|
|
client.end_turn()
|
|
else:
|
|
client.act(action)
|
|
view = client.view()
|
|
last_turn = int(view.get("turn", 0))
|
|
steps += 1
|
|
score = view.get("score", {})
|
|
if int(score.get("city_count", 0)) == 0:
|
|
units = view.get("units", [])
|
|
me = int(view.get("player", 0))
|
|
has_founder = any(
|
|
int(u.get("owner", -1)) == me
|
|
and "founder" in str(u.get("type", ""))
|
|
and float(u.get("hp", 0)) > 0
|
|
for u in units
|
|
)
|
|
if not has_founder:
|
|
eliminations += 1
|
|
break
|
|
return {
|
|
"steps": steps,
|
|
"turns_reached": last_turn,
|
|
"mask_violations": mask_violations,
|
|
"eliminations": eliminations,
|
|
}
|
|
|
|
|
|
def main() -> int:
|
|
args = _build_argparser().parse_args()
|
|
rng = np.random.default_rng(args.seed)
|
|
totals = {
|
|
"steps": 0,
|
|
"turns_reached": 0,
|
|
"mask_violations": 0,
|
|
"eliminations": 0,
|
|
"harness_errors": 0,
|
|
}
|
|
for episode in range(args.episodes):
|
|
cfg = HarnessConfig(
|
|
seed=args.seed + episode,
|
|
players=args.players,
|
|
player_slot=0,
|
|
map_size=args.map_size,
|
|
)
|
|
with HarnessClient(cfg) as client:
|
|
try:
|
|
result = _run_episode(client, rng, args.turns)
|
|
except HarnessError:
|
|
totals["harness_errors"] += 1
|
|
continue
|
|
totals["steps"] += result["steps"]
|
|
totals["turns_reached"] = max(totals["turns_reached"], result["turns_reached"])
|
|
totals["mask_violations"] += result["mask_violations"]
|
|
totals["eliminations"] += result["eliminations"]
|
|
verdict = {
|
|
**totals,
|
|
"obs_dim": OBS_DIM,
|
|
"action_dim": ACTION_DIM,
|
|
"episodes": args.episodes,
|
|
"passed": totals["mask_violations"] == 0 and totals["harness_errors"] == 0,
|
|
}
|
|
print(json.dumps(verdict))
|
|
return 0 if verdict["passed"] else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|