Grows the obs contract 32->96 as a SCHEMA change, not a dual hand-rewrite. obs_schema.json v2 adds the channels the trained AI needs: economy aggregates (yields/territory/golden-age), tech/culture/civics, military (army-health, experience, posture, equipment), per-city blocks, terrain summary + biome histogram, richer diplomacy, and the 6-wide CLAN ONE-HOT (the clan-conditioning input; -1 generalist = all-zero). Both interpreters gain the same new ops — onehot, frac (nested operands), histogram, per_entity, plus reduce/sum_len, count_nonnull, truthy, where_any (OR), bool->1.0 coercion in scalar reads. Multi-slot ops emit a run of consecutive slots. OBS_DIM 32->96, OBS_SCHEMA_VERSION 1->2. The contract gate earned its keep: it caught a real Python<->Rust divergence (a stale isinstance(dict) guard zeroed counts over string lists like city.buildings) then confirmed the fix byte-exact. Verified on the DO fleet: learned_encoder_parity green (Rust v2 == Python v2, 56 fixtures incl clan one-hot variety -1..5, zero drift); mc-player-api 188/188. Next: learned:clan-v1 controller wiring, then training on the richer obs. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
109 lines
3.6 KiB
Python
109 lines
3.6 KiB
Python
"""Verify the Python half of the shared observation contract.
|
|
|
|
Checks, in order:
|
|
1. The schema (`obs_schema.json`) is well-formed: version + obs_dim present,
|
|
every field index < obs_dim, every op + agg in the known vocabulary.
|
|
2. The Python interpreter (`obs_contract`) reproduces the recorded encoder
|
|
parity fixtures byte-for-byte (≤1e-6) — i.e. the schema actually encodes
|
|
the contract the fixtures pin.
|
|
|
|
The Rust half is verified by `cargo test -p mc-player-api learned_parity`
|
|
(orchestrated together by `scripts/verify-obs-contract.sh`).
|
|
|
|
Exit 0 only if the schema + Python interpreter agree with the fixtures.
|
|
Run: `python3 -m tooling.rl_self_play.verify_obs_contract` (or directly).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import sys
|
|
|
|
import numpy as np
|
|
|
|
_HERE = os.path.dirname(os.path.abspath(__file__))
|
|
_REPO_ROOT = os.path.abspath(os.path.join(_HERE, "..", ".."))
|
|
sys.path.insert(0, _HERE)
|
|
|
|
import obs_contract # noqa: E402
|
|
|
|
FIXTURES_REL = "src/simulator/crates/mc-player-api/tests/fixtures/learned_mp_v1_encoder_parity.json"
|
|
KNOWN_OPS = {
|
|
"scalar", "reduce", "clamp_div", "onehot", "frac", "histogram",
|
|
"per_entity", "count_nonnull", "truthy",
|
|
}
|
|
KNOWN_AGGS = {"sum", "avg", "count", "sum_len"}
|
|
|
|
|
|
def _fail(msg: str) -> None:
|
|
print(f" FAIL: {msg}")
|
|
|
|
|
|
def validate_schema(schema: dict) -> list[str]:
|
|
errs: list[str] = []
|
|
if "version" not in schema:
|
|
errs.append("missing 'version'")
|
|
obs_dim = schema.get("obs_dim")
|
|
if not isinstance(obs_dim, int) or obs_dim <= 0:
|
|
errs.append(f"bad obs_dim {obs_dim!r}")
|
|
return errs
|
|
seen: set[int] = set()
|
|
for fld in schema.get("fields", []):
|
|
i = fld.get("i")
|
|
if not isinstance(i, int) or not (0 <= i < obs_dim):
|
|
errs.append(f"field index {i!r} out of [0,{obs_dim})")
|
|
elif i in seen:
|
|
errs.append(f"duplicate field index {i}")
|
|
else:
|
|
seen.add(i)
|
|
if fld.get("op") not in KNOWN_OPS:
|
|
errs.append(f"field {i}: unknown op {fld.get('op')!r}")
|
|
if fld.get("op") == "reduce" and "contains" not in fld and fld.get("agg") not in KNOWN_AGGS:
|
|
errs.append(f"field {i}: unknown agg {fld.get('agg')!r}")
|
|
return errs
|
|
|
|
|
|
def main() -> int:
|
|
print("== obs contract verification (Python half) ==")
|
|
schema = obs_contract.load_schema()
|
|
print(f"schema version={schema.get('version')} obs_dim={schema.get('obs_dim')}")
|
|
|
|
errs = validate_schema(schema)
|
|
for e in errs:
|
|
_fail(e)
|
|
if errs:
|
|
print(f"SCHEMA INVALID ({len(errs)} error(s))")
|
|
return 1
|
|
print("schema: well-formed")
|
|
|
|
fx_path = os.path.join(_REPO_ROOT, FIXTURES_REL)
|
|
fx = json.load(open(fx_path, encoding="utf-8"))
|
|
fixtures = fx["fixtures"] if isinstance(fx, dict) else fx
|
|
|
|
nfail = 0
|
|
worst = 0.0
|
|
for i, f in enumerate(fixtures):
|
|
got = obs_contract.encode_observation(f["view"])
|
|
want = np.array(f["obs"], dtype=np.float32)
|
|
if got.shape != want.shape:
|
|
_fail(f"fixture {i}: shape {got.shape} != {want.shape}")
|
|
nfail += 1
|
|
continue
|
|
d = float(np.max(np.abs(got - want)))
|
|
worst = max(worst, d)
|
|
if d > 1e-6:
|
|
nfail += 1
|
|
if nfail <= 5:
|
|
bad = int(np.argmax(np.abs(got - want)))
|
|
_fail(f"fixture {i}: drift {d:.3g} at obs[{bad}]")
|
|
|
|
print(f"parity: {len(fixtures) - nfail}/{len(fixtures)} fixtures match (worst drift {worst:.3g})")
|
|
if nfail:
|
|
print("PYTHON CONTRACT CHECK FAILED")
|
|
return 1
|
|
print("OK: schema + Python interpreter agree with fixtures")
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|