feat(rl-self-play): Add mine divergence metric for evaluating strategy differences in RL self-play

Co-Authored-By: Lilith Autocommit <noreply@atlilith.com>
This commit is contained in:
autocommit 2026-05-27 20:04:30 -07:00
parent 2d7357550e
commit bb15503079

View file

@ -0,0 +1,242 @@
"""Mine a trained MaskablePPO policy for heuristic-patch candidates.
Drives a scripted-vs-scripted duel (both slots controlled by the
built-in `suggest()` chain the SAME trajectory p1-29c/29d measure
against) and, at the start of every turn, probes the trained policy on
the identical PlayerView WITHOUT applying its choice. Logs where the
policy's preferred action diverges from what the scripted AI actually
does, bucketed by board state (sole-city? trailing?).
The divergence we can extract is constrained by the policy's action
space (encoders.py): there is NO research action only end_turn,
per-unit micro (skip/fortify/sentry/found_city/move/attack) and per-city
`queue_production` over a 16-item building roster. So the high-resolution
signal is build-queue composition + unit posture (settle / fortify /
attack), nothing else.
Usage:
python -m tooling.rl_self_play.mine_divergence \
--model-path tooling/rl_self_play/models/duel-v4-encfix-s7/best_model.zip \
--seeds 1,2,3,4,5,6,7,8,9,10 --turns 60 \
--out tooling/rl_self_play/runs/duel-v4-encfix-s7/divergence.json
"""
from __future__ import annotations
import argparse
import json
import sys
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any
THIS_DIR = Path(__file__).resolve().parent
PROJECT_ROOT = THIS_DIR.parents[1]
if __package__ is None:
sys.path.insert(0, str(PROJECT_ROOT))
from tooling.rl_self_play.encoders import ( # noqa: E402
decode_action_index,
encode_legal_actions,
encode_observation,
)
from tooling.rl_self_play.harness_client import ( # noqa: E402
HarnessClient,
HarnessConfig,
)
from tooling.rl_self_play.record_expert import ( # noqa: E402
_DropStats,
_resolve,
)
def _my_units(view: dict[str, Any]) -> list[dict[str, Any]]:
me = int(view.get("player", 0))
return [u for u in view.get("units", []) if int(u.get("owner", -1)) == me]
def _my_cities(view: dict[str, Any]) -> list[dict[str, Any]]:
me = int(view.get("player", 0))
return [c for c in view.get("cities", []) if int(c.get("owner", -1)) == me]
def _action_label(a: dict[str, Any]) -> str:
"""Compact label: type, plus the queued item for queue_production."""
t = a.get("type", "?")
if t == "queue_production":
return f"build:{a.get('item','?')}"
return t
def _policy_top_action(model, view: dict[str, Any]) -> dict[str, Any]:
mask, idx_to_action = encode_legal_actions(view)
if not mask.any():
return {"type": "end_turn"}
obs = encode_observation(view)
action, _ = model.predict(obs, action_masks=mask, deterministic=True)
return decode_action_index(int(action), idx_to_action)
def _scripted_first_meaningful(client: HarnessClient, slot: int,
view: dict[str, Any]) -> dict[str, Any]:
"""First scripted action that resolves to a legal index and is not a
bare end_turn (so we compare 'what the AI does this turn', not the
turn boundary)."""
chain = client.suggest(slot=slot)
_, idx_to_action = encode_legal_actions(view)
for a in chain:
if a.get("type") == "end_turn":
continue
idx = _resolve(a, view, idx_to_action, _DropStats())
if idx is not None:
return idx_to_action[idx]
return {"type": "end_turn"}
def _advance_slot(client: HarnessClient, slot: int) -> None:
"""Apply the scripted chain for one slot, then end its turn — the
diag_suggest advance logic."""
for a in client.suggest(slot=slot):
v = client.view(slot=slot)
_, i2a = encode_legal_actions(v)
idx = _resolve(a, v, i2a, _DropStats())
if idx is None:
continue
r = i2a[idx]
try:
if r.get("type") == "end_turn":
client.end_turn(slot=slot)
else:
client.act(r, slot=slot)
except Exception: # noqa: BLE001 best-effort advance
break
try:
client.end_turn(slot=slot)
except Exception: # noqa: BLE001
pass
client.drain_notifications()
def _bucket(view: dict[str, Any]) -> str:
"""Coarse board-state bucket for the probed slot."""
ncities = len(_my_cities(view))
units = _my_units(view)
nmil = sum(1 for u in units if "founder" not in str(u.get("type", "")))
turn = int(view.get("turn", 0))
phase = "early" if turn <= 20 else ("mid" if turn <= 50 else "late")
city_state = "solo" if ncities <= 1 else "multi"
mil_state = "weak" if nmil <= 2 else "armed"
return f"{phase}/{city_state}/{mil_state}"
def mine_seed(model, seed: int, turns: int, probe_slot: int) -> dict[str, Any]:
cfg = HarnessConfig(
seed=seed, players=2, player_slots=(0, 1),
map_size="duel", map_type="continents", victory_mode="domination",
)
client = HarnessClient(cfg)
rows: list[dict[str, Any]] = []
try:
for _ in range(turns):
view = client.view(slot=probe_slot)
if not view.get("legal_actions") and not view.get("units"):
break
policy_a = _policy_top_action(model, view)
scripted_a = _scripted_first_meaningful(client, probe_slot, view)
rows.append({
"turn": int(view.get("turn", 0)),
"bucket": _bucket(view),
"ncities": len(_my_cities(view)),
"nunits": len(_my_units(view)),
"score": float(view.get("score", {}).get("score_estimate", 0.0)),
"policy": _action_label(policy_a),
"scripted": _action_label(scripted_a),
"agree": _action_label(policy_a) == _action_label(scripted_a),
})
# Advance BOTH slots on the scripted path.
for s in (0, 1):
_advance_slot(client, s)
# Stop if probed slot lost all cities and has no founder.
v2 = client.view(slot=probe_slot)
cs = _my_cities(v2)
founders = [u for u in _my_units(v2)
if "founder" in str(u.get("type", ""))]
if not cs and not founders:
rows.append({"turn": int(v2.get("turn", 0)),
"bucket": "ELIMINATED", "ncities": 0,
"nunits": len(_my_units(v2)), "score": 0.0,
"policy": "-", "scripted": "-", "agree": True})
break
finally:
client.shutdown()
return {"seed": seed, "probe_slot": probe_slot, "rows": rows}
def aggregate(results: list[dict[str, Any]]) -> dict[str, Any]:
# Divergences bucketed: when policy != scripted, what did each pick?
by_bucket: dict[str, dict[str, Counter]] = defaultdict(
lambda: {"policy": Counter(), "scripted": Counter()})
pair_counts: Counter = Counter()
total = agree = 0
elim_turns: list[int] = []
for res in results:
for r in res["rows"]:
if r["bucket"] == "ELIMINATED":
elim_turns.append(r["turn"])
continue
total += 1
if r["agree"]:
agree += 1
continue
by_bucket[r["bucket"]]["policy"][r["policy"]] += 1
by_bucket[r["bucket"]]["scripted"][r["scripted"]] += 1
pair_counts[(r["scripted"], r["policy"])] += 1
return {
"decisions": total,
"agreements": agree,
"agree_rate": round(agree / max(total, 1), 3),
"eliminations": len(elim_turns),
"elim_turns": sorted(elim_turns),
"top_divergences (scripted -> policy)": [
{"scripted": k[0], "policy": k[1], "count": v}
for k, v in pair_counts.most_common(20)
],
"by_bucket": {
b: {"policy_prefers": dict(d["policy"].most_common(6)),
"scripted_prefers": dict(d["scripted"].most_common(6))}
for b, d in sorted(by_bucket.items())
},
}
def main() -> int:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--model-path", required=True, type=Path)
p.add_argument("--seeds", default="1,2,3,4,5,6,7,8,9,10")
p.add_argument("--turns", type=int, default=60)
p.add_argument("--probe-slot", type=int, default=1,
help="Slot to probe (1 = the trailing seat in p1-29d).")
p.add_argument("--out", type=Path, default=None)
args = p.parse_args()
from sb3_contrib import MaskablePPO # type: ignore[import-not-found]
model = MaskablePPO.load(str(args.model_path))
seeds = [int(s) for s in args.seeds.split(",") if s.strip()]
results = []
for seed in seeds:
print(f"[mine] seed {seed} probe_slot {args.probe_slot} ...",
file=sys.stderr, flush=True)
results.append(mine_seed(model, seed, args.turns, args.probe_slot))
agg = aggregate(results)
out = {"model": str(args.model_path), "seeds": seeds,
"turns": args.turns, "probe_slot": args.probe_slot,
"aggregate": agg, "per_seed": results}
print(json.dumps(agg, indent=2, default=str))
if args.out:
args.out.write_text(json.dumps(out, indent=2, default=str))
print(f"\n[mine] wrote {args.out}", file=sys.stderr)
return 0
if __name__ == "__main__":
sys.exit(main())