magicciv/tooling/rl_self_play/verify_obs_contract.py
Natalie a6fb75a480
Some checks are pending
ci / regression gate (push) Waiting to run
deploy-next / deploy dev guide to mc.next.black.lan (push) Waiting to run
feat(ai): v2 richer 96-dim clan-conditioned observation (schema data + 4 new ops)
Grows the obs contract 32->96 as a SCHEMA change, not a dual hand-rewrite. obs_schema.json
v2 adds the channels the trained AI needs: economy aggregates (yields/territory/golden-age),
tech/culture/civics, military (army-health, experience, posture, equipment), per-city blocks,
terrain summary + biome histogram, richer diplomacy, and the 6-wide CLAN ONE-HOT (the
clan-conditioning input; -1 generalist = all-zero).

Both interpreters gain the same new ops — onehot, frac (nested operands), histogram,
per_entity, plus reduce/sum_len, count_nonnull, truthy, where_any (OR), bool->1.0 coercion
in scalar reads. Multi-slot ops emit a run of consecutive slots. OBS_DIM 32->96,
OBS_SCHEMA_VERSION 1->2.

The contract gate earned its keep: it caught a real Python<->Rust divergence (a stale
isinstance(dict) guard zeroed counts over string lists like city.buildings) then confirmed
the fix byte-exact. Verified on the DO fleet: learned_encoder_parity green (Rust v2 ==
Python v2, 56 fixtures incl clan one-hot variety -1..5, zero drift); mc-player-api 188/188.

Next: learned:clan-v1 controller wiring, then training on the richer obs.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-30 12:38:50 -04:00

109 lines
3.6 KiB
Python

"""Verify the Python half of the shared observation contract.
Checks, in order:
1. The schema (`obs_schema.json`) is well-formed: version + obs_dim present,
every field index < obs_dim, every op + agg in the known vocabulary.
2. The Python interpreter (`obs_contract`) reproduces the recorded encoder
parity fixtures byte-for-byte (≤1e-6) — i.e. the schema actually encodes
the contract the fixtures pin.
The Rust half is verified by `cargo test -p mc-player-api learned_parity`
(orchestrated together by `scripts/verify-obs-contract.sh`).
Exit 0 only if the schema + Python interpreter agree with the fixtures.
Run: `python3 -m tooling.rl_self_play.verify_obs_contract` (or directly).
"""
from __future__ import annotations
import json
import os
import sys
import numpy as np
_HERE = os.path.dirname(os.path.abspath(__file__))
_REPO_ROOT = os.path.abspath(os.path.join(_HERE, "..", ".."))
sys.path.insert(0, _HERE)
import obs_contract # noqa: E402
FIXTURES_REL = "src/simulator/crates/mc-player-api/tests/fixtures/learned_mp_v1_encoder_parity.json"
KNOWN_OPS = {
"scalar", "reduce", "clamp_div", "onehot", "frac", "histogram",
"per_entity", "count_nonnull", "truthy",
}
KNOWN_AGGS = {"sum", "avg", "count", "sum_len"}
def _fail(msg: str) -> None:
print(f" FAIL: {msg}")
def validate_schema(schema: dict) -> list[str]:
errs: list[str] = []
if "version" not in schema:
errs.append("missing 'version'")
obs_dim = schema.get("obs_dim")
if not isinstance(obs_dim, int) or obs_dim <= 0:
errs.append(f"bad obs_dim {obs_dim!r}")
return errs
seen: set[int] = set()
for fld in schema.get("fields", []):
i = fld.get("i")
if not isinstance(i, int) or not (0 <= i < obs_dim):
errs.append(f"field index {i!r} out of [0,{obs_dim})")
elif i in seen:
errs.append(f"duplicate field index {i}")
else:
seen.add(i)
if fld.get("op") not in KNOWN_OPS:
errs.append(f"field {i}: unknown op {fld.get('op')!r}")
if fld.get("op") == "reduce" and "contains" not in fld and fld.get("agg") not in KNOWN_AGGS:
errs.append(f"field {i}: unknown agg {fld.get('agg')!r}")
return errs
def main() -> int:
print("== obs contract verification (Python half) ==")
schema = obs_contract.load_schema()
print(f"schema version={schema.get('version')} obs_dim={schema.get('obs_dim')}")
errs = validate_schema(schema)
for e in errs:
_fail(e)
if errs:
print(f"SCHEMA INVALID ({len(errs)} error(s))")
return 1
print("schema: well-formed")
fx_path = os.path.join(_REPO_ROOT, FIXTURES_REL)
fx = json.load(open(fx_path, encoding="utf-8"))
fixtures = fx["fixtures"] if isinstance(fx, dict) else fx
nfail = 0
worst = 0.0
for i, f in enumerate(fixtures):
got = obs_contract.encode_observation(f["view"])
want = np.array(f["obs"], dtype=np.float32)
if got.shape != want.shape:
_fail(f"fixture {i}: shape {got.shape} != {want.shape}")
nfail += 1
continue
d = float(np.max(np.abs(got - want)))
worst = max(worst, d)
if d > 1e-6:
nfail += 1
if nfail <= 5:
bad = int(np.argmax(np.abs(got - want)))
_fail(f"fixture {i}: drift {d:.3g} at obs[{bad}]")
print(f"parity: {len(fixtures) - nfail}/{len(fixtures)} fixtures match (worst drift {worst:.3g})")
if nfail:
print("PYTHON CONTRACT CHECK FAILED")
return 1
print("OK: schema + Python interpreter agree with fixtures")
return 0
if __name__ == "__main__":
raise SystemExit(main())